[rllib] ExternalMultiAgentEnv (#4200)

This commit is contained in:
ctombumila37 2019-04-07 04:58:14 +02:00 committed by Eric Liang
parent 991b911e1d
commit 7746d20d30
7 changed files with 372 additions and 49 deletions

View file

@ -304,6 +304,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_external_env.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_external_multi_agent_env.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/parametric_action_cartpole.py --run=PG --stop=50

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override, PublicAPI
@ -102,6 +103,12 @@ class BaseEnv(object):
make_env=make_env,
existing_envs=[env],
num_envs=num_envs)
elif isinstance(env, ExternalMultiAgentEnv):
if num_envs != 1:
raise ValueError(
"ExternalMultiAgentEnv does not currently support "
"num_envs > 1.")
env = _ExternalEnvToBaseEnv(env, multiagent=True)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
@ -203,9 +210,10 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
class _ExternalEnvToBaseEnv(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""
def __init__(self, external_env, preprocessor=None):
def __init__(self, external_env, preprocessor=None, multiagent=False):
self.external_env = external_env
self.prep = preprocessor
self.multiagent = multiagent
self.action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
@ -230,16 +238,22 @@ class _ExternalEnvToBaseEnv(BaseEnv):
@override(BaseEnv)
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.external_env._episodes[eid].action_queue.put(
action[_DUMMY_AGENT_ID])
if self.multiagent:
for env_id, actions in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(actions)
else:
for env_id, action in action_dict.items():
self.external_env._episodes[env_id].action_queue.put(
action[_DUMMY_AGENT_ID])
def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.external_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
cur_done = episode.cur_done_dict[
"__all__"] if self.multiagent else episode.cur_done
if cur_done:
del self.external_env._episodes[eid]
if data:
if self.prep:
@ -251,11 +265,27 @@ class _ExternalEnvToBaseEnv(BaseEnv):
all_infos[eid] = data["info"]
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
return _with_dummy_agent_id(all_obs), \
_with_dummy_agent_id(all_rewards), \
_with_dummy_agent_id(all_dones, "__all__"), \
_with_dummy_agent_id(all_infos), \
_with_dummy_agent_id(off_policy_actions)
if self.multiagent:
# ensure a consistent set of keys
# rely on all_obs having all possible keys for now
for eid, eid_dict in all_obs.items():
for agent_id in eid_dict.keys():
def fix(d, zero_val):
if agent_id not in d[eid]:
d[eid][agent_id] = zero_val
fix(all_rewards, 0.0)
fix(all_dones, False)
fix(all_infos, {})
return (all_obs, all_rewards, all_dones, all_infos,
off_policy_actions)
else:
return _with_dummy_agent_id(all_obs), \
_with_dummy_agent_id(all_rewards), \
_with_dummy_agent_id(all_dones, "__all__"), \
_with_dummy_agent_id(all_infos), \
_with_dummy_agent_id(off_policy_actions)
class _VectorEnvToBaseEnv(BaseEnv):

View file

@ -184,17 +184,29 @@ class ExternalEnv(threading.Thread):
class _ExternalEnvEpisode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition, training_enabled):
def __init__(self,
episode_id,
results_avail_condition,
training_enabled,
multiagent=False):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.multiagent = multiagent
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
if multiagent:
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
self.cur_done_dict = {"__all__": False}
self.cur_info_dict = {}
else:
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
def get_data(self):
if self.data_queue.empty():
@ -202,35 +214,59 @@ class _ExternalEnvEpisode(object):
return self.data_queue.get_nowait()
def log_action(self, observation, action):
self.new_observation = observation
self.new_action = action
if self.multiagent:
self.new_observation_dict = observation
self.new_action_dict = action
else:
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)
def wait_for_action(self, observation):
self.new_observation = observation
if self.multiagent:
self.new_observation_dict = observation
else:
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=60.0)
def done(self, observation):
self.new_observation = observation
self.cur_done = True
if self.multiagent:
self.new_observation_dict = observation
self.cur_done_dict = {"__all__": True}
else:
self.new_observation = observation
self.cur_done = True
self._send()
def _send(self):
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
if self.multiagent:
item = {
"obs": self.new_observation_dict,
"reward": self.cur_reward_dict,
"done": self.cur_done_dict,
"info": self.cur_info_dict,
}
if self.new_action_dict is not None:
item["off_policy_action"] = self.new_action_dict
self.new_observation_dict = None
self.new_action_dict = None
self.cur_reward_dict = {}
else:
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()

View file

@ -0,0 +1,149 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import uuid
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.env.external_env import ExternalEnv, _ExternalEnvEpisode
@PublicAPI
class ExternalMultiAgentEnv(ExternalEnv):
"""This is the multi-agent version of ExternalEnv."""
@PublicAPI
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize a multi-agent external env.
ExternalMultiAgentEnv subclasses must call this during their __init__.
Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
once. Exceeding this limit raises an error.
"""
ExternalEnv.__init__(self, action_space, observation_space,
max_concurrent)
# we require to know all agents' spaces
if isinstance(self.action_space, dict) or isinstance(
self.observation_space, dict):
if not (self.action_space.keys() == self.observation_space.keys()):
raise ValueError("Agent ids disagree for action space and obs "
"space dict: {} {}".format(
self.action_space.keys(),
self.observation_space.keys()))
@PublicAPI
def run(self):
"""Override this to implement the multi-agent run loop.
Your loop should continuously:
1. Call self.start_episode(episode_id)
2. Call self.get_action(episode_id, obs_dict)
-or-
self.log_action(episode_id, obs_dict, action_dict)
3. Call self.log_returns(episode_id, reward_dict)
4. Call self.end_episode(episode_id, obs_dict)
5. Wait if nothing to do.
Multiple episodes may be started at the same time.
"""
raise NotImplementedError
@PublicAPI
@override(ExternalEnv)
def start_episode(self, episode_id=None, training_enabled=True):
if episode_id is None:
episode_id = uuid.uuid4().hex
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id in self._episodes:
raise ValueError(
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _ExternalEnvEpisode(
episode_id,
self._results_avail_condition,
training_enabled,
multiagent=True)
return episode_id
@PublicAPI
@override(ExternalEnv)
def get_action(self, episode_id, observation_dict):
"""Record an observation and get the on-policy action.
observation_dict is expected to contain the observation
of all agents acting in this episode step.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.
Returns:
action (dict): Action from the env action space.
"""
episode = self._get(episode_id)
return episode.wait_for_action(observation_dict)
@PublicAPI
@override(ExternalEnv)
def log_action(self, episode_id, observation_dict, action_dict):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.
action_dict (dict): Action for the observation.
"""
episode = self._get(episode_id)
episode.log_action(observation_dict, action_dict)
@PublicAPI
@override(ExternalEnv)
def log_returns(self, episode_id, reward_dict, info_dict=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward_dict (dict): Reward from the environment agents.
info (dict): Optional info dict.
"""
episode = self._get(episode_id)
# accumulate reward by agent
# for existing agents, we want to add the reward up
for agent, rew in reward_dict.items():
if agent in episode.cur_reward_dict:
episode.cur_reward_dict[agent] += rew
else:
episode.cur_reward_dict[agent] = rew
if info_dict:
episode.cur_info_dict = info_dict or {}
@PublicAPI
@override(ExternalEnv)
def end_episode(self, episode_id, observation_dict):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation_dict (dict): Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
episode.done(observation_dict)

View file

@ -13,6 +13,7 @@ from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
@ -308,12 +309,14 @@ class PolicyEvaluator(EvaluatorInterface):
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not (isinstance(self.env, MultiAgentEnv)
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policy graphs {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of MultiAgentEnv?".format(self.env))
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
self.filters = {
policy_id: get_filter(observation_filter,

View file

@ -18,22 +18,32 @@ from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph,
from ray.tune.registry import register_env
class SimpleServing(ExternalEnv):
def __init__(self, env):
ExternalEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
def make_simple_serving(multiagent, superclass):
class SimpleServing(superclass):
def __init__(self, env):
superclass.__init__(self, env.action_space, env.observation_space)
self.env = env
def run(self):
eid = self.start_episode()
obs = self.env.reset()
while True:
action = self.get_action(eid, obs)
obs, reward, done, info = self.env.step(action)
self.log_returns(eid, reward, info=info)
if done:
self.end_episode(eid, obs)
obs = self.env.reset()
eid = self.start_episode()
def run(self):
eid = self.start_episode()
obs = self.env.reset()
while True:
action = self.get_action(eid, obs)
obs, reward, done, info = self.env.step(action)
if multiagent:
self.log_returns(eid, reward)
else:
self.log_returns(eid, reward, info=info)
if done:
self.end_episode(eid, obs)
obs = self.env.reset()
eid = self.start_episode()
return SimpleServing
# generate & register SimpleServing class
SimpleServing = make_simple_serving(False, ExternalEnv)
class PartOffPolicyServing(ExternalEnv):

View file

@ -0,0 +1,92 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import numpy as np
import random
import unittest
import ray
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
from ray.rllib.tests.test_external_env import make_simple_serving
from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole
from ray.rllib.evaluation.metrics import collect_metrics
SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
class TestExternalMultiAgentEnv(unittest.TestCase):
def testExternalMultiAgentEnvCompleteEpisodes(self):
agents = 4
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 40)
self.assertEqual(len(np.unique(batch["agent_index"])), agents)
def testExternalMultiAgentEnvTruncateEpisodes(self):
agents = 4
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph=MockPolicyGraph,
batch_steps=40,
batch_mode="truncate_episodes")
for _ in range(3):
batch = ev.sample()
self.assertEqual(batch.count, 160)
self.assertEqual(len(np.unique(batch["agent_index"])), agents)
def testExternalMultiAgentEnvSample(self):
agents = 2
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testTrainExternalMultiCartpoleManyPolicies(self):
n = 20
single_env = gym.make("CartPole-v0")
act_space = single_env.action_space
obs_space = single_env.observation_space
policies = {}
for i in range(20):
policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
{})
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [], {})
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
print("Iteration {}, rew {}".format(i,
result["policy_reward_mean"]))
print("Total reward", result["episode_reward_mean"])
if result["episode_reward_mean"] >= 25 * n:
return
raise Exception("failed to improve reward")
if __name__ == "__main__":
ray.init()
unittest.main(verbosity=2)