ray/rllib/env/wrappers/group_agents_wrapper.py
Amog Kamsetty ebc44c3d76
[CI] Upgrade flake8 to 3.9.1 (#15527)
* formatting

* format util

* format release

* format rllib/agents

* format rllib/env

* format rllib/execution

* format rllib/evaluation

* format rllib/examples

* format rllib/policy

* format rllib utils and tests

* format streaming

* more formatting

* update requirements files

* fix rllib type checking

* updates

* update

* fix circular import

* Update python/ray/tests/test_runtime_env.py

* noqa
2021-05-03 14:23:28 -07:00

116 lines
4.1 KiB
Python

from collections import OrderedDict
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# info key for the individual rewards of an agent, for example:
# info: {
# group_1: {
# _group_rewards: [5, -1, 1], # 3 agents in this group
# }
# }
GROUP_REWARDS = "_group_rewards"
# info key for the individual infos of an agent, for example:
# info: {
# group_1: {
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
# }
# }
GROUP_INFO = "_group_info"
class GroupAgentsWrapper(MultiAgentEnv):
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
See multi_agent_env.py for the specification of groups.
This API is experimental.
"""
def __init__(self, env, groups, obs_space=None, act_space=None):
"""Wrap an existing multi-agent env to group agents together.
See MultiAgentEnv.with_agent_groups() for usage info.
Args:
env (MultiAgentEnv): env to wrap
groups (dict): Grouping spec as documented in MultiAgentEnv.
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
"""
self.env = env
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():
for agent_id in agent_ids:
if agent_id in self.agent_id_to_group:
raise ValueError(
"Agent id {} is in multiple groups".format(agent_id))
self.agent_id_to_group[agent_id] = group_id
if obs_space is not None:
self.observation_space = obs_space
if act_space is not None:
self.action_space = act_space
def reset(self):
obs = self.env.reset()
return self._group_items(obs)
def step(self, action_dict):
# Ungroup and send actions
action_dict = self._ungroup_items(action_dict)
obs, rewards, dones, infos = self.env.step(action_dict)
# Apply grouping transforms to the env outputs
obs = self._group_items(obs)
rewards = self._group_items(
rewards, agg_fn=lambda gvals: list(gvals.values()))
dones = self._group_items(
dones, agg_fn=lambda gvals: all(gvals.values()))
infos = self._group_items(
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
# Aggregate rewards, but preserve the original values in infos
for agent_id, rew in rewards.items():
if isinstance(rew, list):
rewards[agent_id] = sum(rew)
if agent_id not in infos:
infos[agent_id] = {}
infos[agent_id][GROUP_REWARDS] = rew
return obs, rewards, dones, infos
def _ungroup_items(self, items):
out = {}
for agent_id, value in items.items():
if agent_id in self.groups:
assert len(value) == len(self.groups[agent_id]), \
(agent_id, value, self.groups)
for a, v in zip(self.groups[agent_id], value):
out[a] = v
else:
out[agent_id] = value
return out
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
grouped_items = {}
for agent_id, item in items.items():
if agent_id in self.agent_id_to_group:
group_id = self.agent_id_to_group[agent_id]
if group_id in grouped_items:
continue # already added
group_out = OrderedDict()
for a in self.groups[group_id]:
if a in items:
group_out[a] = items[a]
else:
raise ValueError(
"Missing member of group {}: {}: {}".format(
group_id, a, items))
grouped_items[group_id] = agg_fn(group_out)
else:
grouped_items[agent_id] = item
return grouped_items