mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
150 lines
5.5 KiB
Python
150 lines
5.5 KiB
Python
from collections import OrderedDict
|
|
import gym
|
|
from typing import Dict, List, Optional
|
|
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.typing import AgentID
|
|
|
|
# info key for the individual rewards of an agent, for example:
|
|
# info: {
|
|
# group_1: {
|
|
# _group_rewards: [5, -1, 1], # 3 agents in this group
|
|
# }
|
|
# }
|
|
GROUP_REWARDS = "_group_rewards"
|
|
|
|
# info key for the individual infos of an agent, for example:
|
|
# info: {
|
|
# group_1: {
|
|
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
|
|
# }
|
|
# }
|
|
GROUP_INFO = "_group_info"
|
|
|
|
|
|
@DeveloperAPI
|
|
class GroupAgentsWrapper(MultiAgentEnv):
|
|
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
|
|
|
|
See multi_agent_env.py for the specification of groups.
|
|
|
|
This API is experimental.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: MultiAgentEnv,
|
|
groups: Dict[str, List[AgentID]],
|
|
obs_space: Optional[gym.Space] = None,
|
|
act_space: Optional[gym.Space] = None,
|
|
):
|
|
"""Wrap an existing MultiAgentEnv to group agent ID together.
|
|
|
|
See `MultiAgentEnv.with_agent_groups()` for more detailed usage info.
|
|
|
|
Args:
|
|
env: The env to wrap and whose agent IDs to group into new agents.
|
|
groups: Mapping from group id to a list of the agent ids
|
|
of group members. If an agent id is not present in any group
|
|
value, it will be left ungrouped. The group id becomes a new agent ID
|
|
in the final environment.
|
|
obs_space: Optional observation space for the grouped
|
|
env. Must be a tuple space. If not provided, will infer this to be a
|
|
Tuple of n individual agents spaces (n=num agents in a group).
|
|
act_space: Optional action space for the grouped env.
|
|
Must be a tuple space. If not provided, will infer this to be a Tuple
|
|
of n individual agents spaces (n=num agents in a group).
|
|
"""
|
|
super().__init__()
|
|
self.env = env
|
|
# Inherit wrapped env's `_skip_env_checking` flag.
|
|
if hasattr(self.env, "_skip_env_checking"):
|
|
self._skip_env_checking = self.env._skip_env_checking
|
|
self.groups = groups
|
|
self.agent_id_to_group = {}
|
|
for group_id, agent_ids in groups.items():
|
|
for agent_id in agent_ids:
|
|
if agent_id in self.agent_id_to_group:
|
|
raise ValueError(
|
|
"Agent id {} is in multiple groups".format(agent_id)
|
|
)
|
|
self.agent_id_to_group[agent_id] = group_id
|
|
if obs_space is not None:
|
|
self.observation_space = obs_space
|
|
if act_space is not None:
|
|
self.action_space = act_space
|
|
for group_id in groups.keys():
|
|
self._agent_ids.add(group_id)
|
|
|
|
def seed(self, seed=None):
|
|
if not hasattr(self.env, "seed"):
|
|
# This is a silent fail. However, OpenAI gyms also silently fail
|
|
# here.
|
|
return
|
|
|
|
self.env.seed(seed)
|
|
|
|
def reset(self):
|
|
obs = self.env.reset()
|
|
return self._group_items(obs)
|
|
|
|
def step(self, action_dict):
|
|
# Ungroup and send actions
|
|
action_dict = self._ungroup_items(action_dict)
|
|
obs, rewards, dones, infos = self.env.step(action_dict)
|
|
|
|
# Apply grouping transforms to the env outputs
|
|
obs = self._group_items(obs)
|
|
rewards = self._group_items(rewards, agg_fn=lambda gvals: list(gvals.values()))
|
|
dones = self._group_items(dones, agg_fn=lambda gvals: all(gvals.values()))
|
|
infos = self._group_items(
|
|
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())}
|
|
)
|
|
|
|
# Aggregate rewards, but preserve the original values in infos
|
|
for agent_id, rew in rewards.items():
|
|
if isinstance(rew, list):
|
|
rewards[agent_id] = sum(rew)
|
|
if agent_id not in infos:
|
|
infos[agent_id] = {}
|
|
infos[agent_id][GROUP_REWARDS] = rew
|
|
|
|
return obs, rewards, dones, infos
|
|
|
|
def _ungroup_items(self, items):
|
|
out = {}
|
|
for agent_id, value in items.items():
|
|
if agent_id in self.groups:
|
|
assert len(value) == len(self.groups[agent_id]), (
|
|
agent_id,
|
|
value,
|
|
self.groups,
|
|
)
|
|
for a, v in zip(self.groups[agent_id], value):
|
|
out[a] = v
|
|
else:
|
|
out[agent_id] = value
|
|
return out
|
|
|
|
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
|
|
grouped_items = {}
|
|
for agent_id, item in items.items():
|
|
if agent_id in self.agent_id_to_group:
|
|
group_id = self.agent_id_to_group[agent_id]
|
|
if group_id in grouped_items:
|
|
continue # already added
|
|
group_out = OrderedDict()
|
|
for a in self.groups[group_id]:
|
|
if a in items:
|
|
group_out[a] = items[a]
|
|
else:
|
|
raise ValueError(
|
|
"Missing member of group {}: {}: {}".format(
|
|
group_id, a, items
|
|
)
|
|
)
|
|
grouped_items[group_id] = agg_fn(group_out)
|
|
else:
|
|
grouped_items[agent_id] = item
|
|
return grouped_items
|