ray/rllib/env/group_agents_wrapper.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

103 lines
3.9 KiB
Python

from collections import OrderedDict
from ray.rllib.env.constants import GROUP_REWARDS, GROUP_INFO
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# TODO(ekl) we should add some unit tests for this
class _GroupAgentsWrapper(MultiAgentEnv):
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
See multi_agent_env.py for the specification of groups.
This API is experimental.
"""
def __init__(self, env, groups, obs_space=None, act_space=None):
"""Wrap an existing multi-agent env to group agents together.
See MultiAgentEnv.with_agent_groups() for usage info.
Arguments:
env (MultiAgentEnv): env to wrap
groups (dict): Grouping spec as documented in MultiAgentEnv
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
"""
self.env = env
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():
for agent_id in agent_ids:
if agent_id in self.agent_id_to_group:
raise ValueError(
"Agent id {} is in multiple groups".format(
agent_id, groups))
self.agent_id_to_group[agent_id] = group_id
if obs_space is not None:
self.observation_space = obs_space
if act_space is not None:
self.action_space = act_space
def reset(self):
obs = self.env.reset()
return self._group_items(obs)
def step(self, action_dict):
# Ungroup and send actions
action_dict = self._ungroup_items(action_dict)
obs, rewards, dones, infos = self.env.step(action_dict)
# Apply grouping transforms to the env outputs
obs = self._group_items(obs)
rewards = self._group_items(
rewards, agg_fn=lambda gvals: list(gvals.values()))
dones = self._group_items(
dones, agg_fn=lambda gvals: all(gvals.values()))
infos = self._group_items(
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
# Aggregate rewards, but preserve the original values in infos
for agent_id, rew in rewards.items():
if isinstance(rew, list):
rewards[agent_id] = sum(rew)
if agent_id not in infos:
infos[agent_id] = {}
infos[agent_id][GROUP_REWARDS] = rew
return obs, rewards, dones, infos
def _ungroup_items(self, items):
out = {}
for agent_id, value in items.items():
if agent_id in self.groups:
assert len(value) == len(self.groups[agent_id]), \
(agent_id, value, self.groups)
for a, v in zip(self.groups[agent_id], value):
out[a] = v
else:
out[agent_id] = value
return out
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
grouped_items = {}
for agent_id, item in items.items():
if agent_id in self.agent_id_to_group:
group_id = self.agent_id_to_group[agent_id]
if group_id in grouped_items:
continue # already added
group_out = OrderedDict()
for a in self.groups[group_id]:
if a in items:
group_out[a] = items[a]
else:
raise ValueError(
"Missing member of group {}: {}: {}".format(
group_id, a, items))
grouped_items[group_id] = agg_fn(group_out)
else:
grouped_items[agent_id] = item
return grouped_items