mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
103 lines
3.9 KiB
Python
103 lines
3.9 KiB
Python
from collections import OrderedDict
|
|
|
|
from ray.rllib.env.constants import GROUP_REWARDS, GROUP_INFO
|
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
|
|
|
|
|
# TODO(ekl) we should add some unit tests for this
|
|
class _GroupAgentsWrapper(MultiAgentEnv):
|
|
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
|
|
|
|
See multi_agent_env.py for the specification of groups.
|
|
|
|
This API is experimental.
|
|
"""
|
|
|
|
def __init__(self, env, groups, obs_space=None, act_space=None):
|
|
"""Wrap an existing multi-agent env to group agents together.
|
|
|
|
See MultiAgentEnv.with_agent_groups() for usage info.
|
|
|
|
Arguments:
|
|
env (MultiAgentEnv): env to wrap
|
|
groups (dict): Grouping spec as documented in MultiAgentEnv
|
|
obs_space (Space): Optional observation space for the grouped
|
|
env. Must be a tuple space.
|
|
act_space (Space): Optional action space for the grouped env.
|
|
Must be a tuple space.
|
|
"""
|
|
|
|
self.env = env
|
|
self.groups = groups
|
|
self.agent_id_to_group = {}
|
|
for group_id, agent_ids in groups.items():
|
|
for agent_id in agent_ids:
|
|
if agent_id in self.agent_id_to_group:
|
|
raise ValueError(
|
|
"Agent id {} is in multiple groups".format(
|
|
agent_id, groups))
|
|
self.agent_id_to_group[agent_id] = group_id
|
|
if obs_space is not None:
|
|
self.observation_space = obs_space
|
|
if act_space is not None:
|
|
self.action_space = act_space
|
|
|
|
def reset(self):
|
|
obs = self.env.reset()
|
|
return self._group_items(obs)
|
|
|
|
def step(self, action_dict):
|
|
# Ungroup and send actions
|
|
action_dict = self._ungroup_items(action_dict)
|
|
obs, rewards, dones, infos = self.env.step(action_dict)
|
|
|
|
# Apply grouping transforms to the env outputs
|
|
obs = self._group_items(obs)
|
|
rewards = self._group_items(
|
|
rewards, agg_fn=lambda gvals: list(gvals.values()))
|
|
dones = self._group_items(
|
|
dones, agg_fn=lambda gvals: all(gvals.values()))
|
|
infos = self._group_items(
|
|
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
|
|
|
|
# Aggregate rewards, but preserve the original values in infos
|
|
for agent_id, rew in rewards.items():
|
|
if isinstance(rew, list):
|
|
rewards[agent_id] = sum(rew)
|
|
if agent_id not in infos:
|
|
infos[agent_id] = {}
|
|
infos[agent_id][GROUP_REWARDS] = rew
|
|
|
|
return obs, rewards, dones, infos
|
|
|
|
def _ungroup_items(self, items):
|
|
out = {}
|
|
for agent_id, value in items.items():
|
|
if agent_id in self.groups:
|
|
assert len(value) == len(self.groups[agent_id]), \
|
|
(agent_id, value, self.groups)
|
|
for a, v in zip(self.groups[agent_id], value):
|
|
out[a] = v
|
|
else:
|
|
out[agent_id] = value
|
|
return out
|
|
|
|
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
|
|
grouped_items = {}
|
|
for agent_id, item in items.items():
|
|
if agent_id in self.agent_id_to_group:
|
|
group_id = self.agent_id_to_group[agent_id]
|
|
if group_id in grouped_items:
|
|
continue # already added
|
|
group_out = OrderedDict()
|
|
for a in self.groups[group_id]:
|
|
if a in items:
|
|
group_out[a] = items[a]
|
|
else:
|
|
raise ValueError(
|
|
"Missing member of group {}: {}: {}".format(
|
|
group_id, a, items))
|
|
grouped_items[group_id] = agg_fn(group_out)
|
|
else:
|
|
grouped_items[agent_id] = item
|
|
return grouped_items
|