mirror of
https://github.com/vale981/ray
synced 2025-03-10 05:16:49 -04:00
132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
from collections import defaultdict
|
|
from typing import Any, List
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
AgentConnector,
|
|
ConnectorContext,
|
|
register_connector,
|
|
)
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.typing import (
|
|
AgentConnectorDataType,
|
|
AgentConnectorsOutput,
|
|
)
|
|
from ray.util.annotations import PublicAPI
|
|
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
class ViewRequirementAgentConnector(AgentConnector):
|
|
"""This connector does 2 things:
|
|
1. It filters data columns based on view_requirements for training and inference.
|
|
2. It buffers the right amount of history for computing the sample batch for
|
|
action computation.
|
|
The output of this connector is AgentConnectorsOut, which basically is
|
|
a tuple of 2 things:
|
|
{
|
|
"for_training": {"obs": ...}
|
|
"for_action": SampleBatch
|
|
}
|
|
The "for_training" dict, which contains data for the latest time slice,
|
|
can be used to construct a complete episode by Sampler for training purpose.
|
|
The "for_action" SampleBatch can be used to directly call the policy.
|
|
"""
|
|
|
|
def __init__(self, ctx: ConnectorContext):
|
|
super().__init__(ctx)
|
|
|
|
self._view_requirements = ctx.view_requirements
|
|
|
|
# a dict of env_id to a dict of agent_id to a list of agent_collector objects
|
|
env_default = defaultdict(
|
|
lambda: AgentCollector(
|
|
self._view_requirements,
|
|
max_seq_len=ctx.config["model"]["max_seq_len"],
|
|
intial_states=ctx.initial_states,
|
|
disable_action_flattening=ctx.config.get(
|
|
"_disable_action_flattening", False
|
|
),
|
|
is_policy_recurrent=ctx.is_policy_recurrent,
|
|
)
|
|
)
|
|
self.agent_collectors = defaultdict(lambda: env_default)
|
|
|
|
def reset(self, env_id: str):
|
|
if env_id in self.agent_collectors:
|
|
del self.agent_collectors[env_id]
|
|
|
|
def _get_sample_batch_for_action(
|
|
self, view_requirements, agent_batch
|
|
) -> SampleBatch:
|
|
# TODO(jungong) : actually support buildling input sample batch with all the
|
|
# view shift requirements, etc.
|
|
# For now, we only support last elemen (no shift).
|
|
input_dict = {}
|
|
for col, req in view_requirements.items():
|
|
if not req.used_for_compute_actions:
|
|
continue
|
|
if col not in agent_batch:
|
|
continue
|
|
input_dict[col] = agent_batch[col][-1]
|
|
return SampleBatch(input_dict, is_training=False)
|
|
|
|
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
|
d = ac_data.data
|
|
assert (
|
|
type(d) == dict
|
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
|
|
|
env_id = ac_data.env_id
|
|
agent_id = ac_data.agent_id
|
|
# TODO: we don't keep episode_id around so use env_id as episode_id ?
|
|
episode_id = env_id if SampleBatch.EPS_ID not in d else d[SampleBatch.EPS_ID]
|
|
|
|
assert env_id is not None and agent_id is not None, (
|
|
f"ViewRequirementAgentConnector requires env_id({env_id}) "
|
|
"and agent_id({agent_id})"
|
|
)
|
|
|
|
vr = self._view_requirements
|
|
assert vr, "ViewRequirements required by ViewRequirementConnector"
|
|
|
|
training_dict = None
|
|
# Return full training_dict for env runner to construct episodes.
|
|
if self._is_training:
|
|
# Note(jungong) : we need to keep the entire input dict here.
|
|
# A column may be used by postprocessing (GAE) even if its
|
|
# iew_requirement.used_for_training is False.
|
|
training_dict = d
|
|
|
|
agent_collector = self.agent_collectors[env_id][agent_id]
|
|
|
|
if SampleBatch.NEXT_OBS not in d:
|
|
raise ValueError(f"connector data {d} should contain next_obs.")
|
|
|
|
if agent_collector.is_empty():
|
|
agent_collector.add_init_obs(
|
|
episode_id=episode_id,
|
|
agent_index=agent_id,
|
|
env_id=env_id,
|
|
t=-1,
|
|
init_obs=d[SampleBatch.NEXT_OBS],
|
|
)
|
|
else:
|
|
agent_collector.add_action_reward_next_obs(d)
|
|
sample_batch = agent_collector.build_for_inference()
|
|
|
|
return_data = AgentConnectorDataType(
|
|
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
|
|
)
|
|
return return_data
|
|
|
|
def to_config(self):
|
|
return ViewRequirementAgentConnector.__name__, None
|
|
|
|
@staticmethod
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
return ViewRequirementAgentConnector(ctx)
|
|
|
|
|
|
register_connector(
|
|
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
|
|
)
|