mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
from typing import Any, List
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
ConnectorContext,
|
|
AgentConnector,
|
|
register_connector,
|
|
)
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.typing import AgentConnectorDataType
|
|
|
|
|
|
@DeveloperAPI
|
|
class EnvToAgentDataConnector(AgentConnector):
|
|
"""Converts per environment multi-agent obs into per agent SampleBatches."""
|
|
|
|
def __init__(self, ctx: ConnectorContext):
|
|
super().__init__(ctx)
|
|
self._view_requirements = ctx.view_requirements
|
|
|
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
|
if ac_data.agent_id:
|
|
# data is already for a single agent.
|
|
return [ac_data]
|
|
|
|
assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, (
|
|
"EnvToPerAgentDataConnector expects a tuple of "
|
|
+ "(obs, rewards, dones, infos, episode_infos)."
|
|
)
|
|
# episode_infos contains additional training related data bits
|
|
# for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
|
|
# SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
|
|
# and is usually empty in inference mode.
|
|
obs, rewards, dones, infos, training_episode_infos = ac_data.data
|
|
for var, name in zip(
|
|
(obs, rewards, dones, infos, training_episode_infos),
|
|
("obs", "rewards", "dones", "infos", "training_episode_infos"),
|
|
):
|
|
assert isinstance(var, dict), (
|
|
f"EnvToPerAgentDataConnector expects {name} "
|
|
+ "to be a MultiAgentDict."
|
|
)
|
|
|
|
env_id = ac_data.env_id
|
|
per_agent_data = []
|
|
for agent_id, obs in obs.items():
|
|
input_dict = {
|
|
SampleBatch.ENV_ID: env_id,
|
|
SampleBatch.REWARDS: rewards[agent_id],
|
|
# SampleBatch.DONES may be overridden by data from
|
|
# training_episode_infos next.
|
|
SampleBatch.DONES: dones[agent_id],
|
|
SampleBatch.NEXT_OBS: obs,
|
|
}
|
|
if SampleBatch.INFOS in self._view_requirements:
|
|
input_dict[SampleBatch.INFOS] = infos[agent_id]
|
|
if agent_id in training_episode_infos:
|
|
input_dict.update(training_episode_infos[agent_id])
|
|
|
|
per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict))
|
|
|
|
return per_agent_data
|
|
|
|
def to_config(self):
|
|
return EnvToAgentDataConnector.__name__, None
|
|
|
|
@staticmethod
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
return EnvToAgentDataConnector(ctx)
|
|
|
|
|
|
register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)
|