ray/rllib/env/wrappers/group_agents_wrapper.py

150 lines
5.5 KiB
Python

from collections import OrderedDict
import gym
from typing import Dict, List, Optional
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentID
# info key for the individual rewards of an agent, for example:
# info: {
# group_1: {
# _group_rewards: [5, -1, 1], # 3 agents in this group
# }
# }
GROUP_REWARDS = "_group_rewards"
# info key for the individual infos of an agent, for example:
# info: {
# group_1: {
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
# }
# }
GROUP_INFO = "_group_info"
@DeveloperAPI
class GroupAgentsWrapper(MultiAgentEnv):
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
See multi_agent_env.py for the specification of groups.
This API is experimental.
"""
def __init__(
self,
env: MultiAgentEnv,
groups: Dict[str, List[AgentID]],
obs_space: Optional[gym.Space] = None,
act_space: Optional[gym.Space] = None,
):
"""Wrap an existing MultiAgentEnv to group agent ID together.
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
Args:
env: The env to wrap and whose agent IDs to group into new agents.
groups: Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped. The group id becomes a new agent ID
in the final environment.
obs_space: Optional observation space for the grouped
env. Must be a tuple space. If not provided, will infer this to be a
Tuple of n individual agents spaces (n=num agents in a group).
act_space: Optional action space for the grouped env.
Must be a tuple space. If not provided, will infer this to be a Tuple
of n individual agents spaces (n=num agents in a group).
"""
super().__init__()
self.env = env
# Inherit wrapped env's `_skip_env_checking` flag.
if hasattr(self.env, "_skip_env_checking"):
self._skip_env_checking = self.env._skip_env_checking
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():
for agent_id in agent_ids:
if agent_id in self.agent_id_to_group:
raise ValueError(
"Agent id {} is in multiple groups".format(agent_id)
)
self.agent_id_to_group[agent_id] = group_id
if obs_space is not None:
self.observation_space = obs_space
if act_space is not None:
self.action_space = act_space
for group_id in groups.keys():
self._agent_ids.add(group_id)
def seed(self, seed=None):
if not hasattr(self.env, "seed"):
# This is a silent fail. However, OpenAI gyms also silently fail
# here.
return
self.env.seed(seed)
def reset(self):
obs = self.env.reset()
return self._group_items(obs)
def step(self, action_dict):
# Ungroup and send actions
action_dict = self._ungroup_items(action_dict)
obs, rewards, dones, infos = self.env.step(action_dict)
# Apply grouping transforms to the env outputs
obs = self._group_items(obs)
rewards = self._group_items(rewards, agg_fn=lambda gvals: list(gvals.values()))
dones = self._group_items(dones, agg_fn=lambda gvals: all(gvals.values()))
infos = self._group_items(
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}
)
# Aggregate rewards, but preserve the original values in infos
for agent_id, rew in rewards.items():
if isinstance(rew, list):
rewards[agent_id] = sum(rew)
if agent_id not in infos:
infos[agent_id] = {}
infos[agent_id][GROUP_REWARDS] = rew
return obs, rewards, dones, infos
def _ungroup_items(self, items):
out = {}
for agent_id, value in items.items():
if agent_id in self.groups:
assert len(value) == len(self.groups[agent_id]), (
agent_id,
value,
self.groups,
)
for a, v in zip(self.groups[agent_id], value):
out[a] = v
else:
out[agent_id] = value
return out
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
grouped_items = {}
for agent_id, item in items.items():
if agent_id in self.agent_id_to_group:
group_id = self.agent_id_to_group[agent_id]
if group_id in grouped_items:
continue # already added
group_out = OrderedDict()
for a in self.groups[group_id]:
if a in items:
group_out[a] = items[a]
else:
raise ValueError(
"Missing member of group {}: {}: {}".format(
group_id, a, items
)
)
grouped_items[group_id] = agg_fn(group_out)
else:
grouped_items[agent_id] = item
return grouped_items