ray/rllib/connectors/agent/view_requirement.py
2022-07-26 21:52:14 -07:00

132 lines
4.7 KiB
Python

from collections import defaultdict
from typing import Any, List
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.util.annotations import PublicAPI
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
@PublicAPI(stability="alpha")
class ViewRequirementAgentConnector(AgentConnector):
"""This connector does 2 things:
1. It filters data columns based on view_requirements for training and inference.
2. It buffers the right amount of history for computing the sample batch for
action computation.
The output of this connector is AgentConnectorsOut, which basically is
a tuple of 2 things:
{
"for_training": {"obs": ...}
"for_action": SampleBatch
}
The "for_training" dict, which contains data for the latest time slice,
can be used to construct a complete episode by Sampler for training purpose.
The "for_action" SampleBatch can be used to directly call the policy.
"""
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._view_requirements = ctx.view_requirements
# a dict of env_id to a dict of agent_id to a list of agent_collector objects
env_default = defaultdict(
lambda: AgentCollector(
self._view_requirements,
max_seq_len=ctx.config["model"]["max_seq_len"],
intial_states=ctx.initial_states,
disable_action_flattening=ctx.config.get(
"_disable_action_flattening", False
),
is_policy_recurrent=ctx.is_policy_recurrent,
)
)
self.agent_collectors = defaultdict(lambda: env_default)
def reset(self, env_id: str):
if env_id in self.agent_collectors:
del self.agent_collectors[env_id]
def _get_sample_batch_for_action(
self, view_requirements, agent_batch
) -> SampleBatch:
# TODO(jungong) : actually support buildling input sample batch with all the
# view shift requirements, etc.
# For now, we only support last elemen (no shift).
input_dict = {}
for col, req in view_requirements.items():
if not req.used_for_compute_actions:
continue
if col not in agent_batch:
continue
input_dict[col] = agent_batch[col][-1]
return SampleBatch(input_dict, is_training=False)
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
env_id = ac_data.env_id
agent_id = ac_data.agent_id
# TODO: we don't keep episode_id around so use env_id as episode_id ?
episode_id = env_id if SampleBatch.EPS_ID not in d else d[SampleBatch.EPS_ID]
assert env_id is not None and agent_id is not None, (
f"ViewRequirementAgentConnector requires env_id({env_id}) "
"and agent_id({agent_id})"
)
vr = self._view_requirements
assert vr, "ViewRequirements required by ViewRequirementConnector"
training_dict = None
# Return full training_dict for env runner to construct episodes.
if self._is_training:
# Note(jungong) : we need to keep the entire input dict here.
# A column may be used by postprocessing (GAE) even if its
# iew_requirement.used_for_training is False.
training_dict = d
agent_collector = self.agent_collectors[env_id][agent_id]
if SampleBatch.NEXT_OBS not in d:
raise ValueError(f"connector data {d} should contain next_obs.")
if agent_collector.is_empty():
agent_collector.add_init_obs(
episode_id=episode_id,
agent_index=agent_id,
env_id=env_id,
t=-1,
init_obs=d[SampleBatch.NEXT_OBS],
)
else:
agent_collector.add_action_reward_next_obs(d)
sample_batch = agent_collector.build_for_inference()
return_data = AgentConnectorDataType(
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
)
return return_data
def to_config(self):
return ViewRequirementAgentConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ViewRequirementAgentConnector(ctx)
register_connector(
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
)