ray/rllib/connectors/agent/env_to_agent.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

73 lines
2.7 KiB
Python
Raw Normal View History

from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentConnectorDataType
@DeveloperAPI
class EnvToAgentDataConnector(AgentConnector):
"""Converts per environment multi-agent obs into per agent SampleBatches."""
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._view_requirements = ctx.view_requirements
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
if ac_data.agent_id:
# data is already for a single agent.
return [ac_data]
assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, (
"EnvToPerAgentDataConnector expects a tuple of "
+ "(obs, rewards, dones, infos, episode_infos)."
)
# episode_infos contains additional training related data bits
# for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
# SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
# and is usually empty in inference mode.
obs, rewards, dones, infos, training_episode_infos = ac_data.data
for var, name in zip(
(obs, rewards, dones, infos, training_episode_infos),
("obs", "rewards", "dones", "infos", "training_episode_infos"),
):
assert isinstance(var, dict), (
f"EnvToPerAgentDataConnector expects {name} "
+ "to be a MultiAgentDict."
)
env_id = ac_data.env_id
per_agent_data = []
for agent_id, obs in obs.items():
input_dict = {
SampleBatch.ENV_ID: env_id,
SampleBatch.REWARDS: rewards[agent_id],
# SampleBatch.DONES may be overridden by data from
# training_episode_infos next.
SampleBatch.DONES: dones[agent_id],
SampleBatch.NEXT_OBS: obs,
}
if SampleBatch.INFOS in self._view_requirements:
input_dict[SampleBatch.INFOS] = infos[agent_id]
if agent_id in training_episode_infos:
input_dict.update(training_episode_infos[agent_id])
per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict))
return per_agent_data
def to_config(self):
return EnvToAgentDataConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return EnvToAgentDataConnector(ctx)
register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)