[RLlib] Re-do: Trainer: Support add and delete Policies. (#16569)

This commit is contained in:
Sven Mika 2021-06-21 13:46:01 +02:00 committed by GitHub
parent 4da69174c8
commit be6db06485
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 1076 additions and 313 deletions

View file

@ -495,6 +495,14 @@ py_test(
# Tag: agents_dir
# --------------------------------------------------------------------
# Generic (all Trainers)
py_test(
name = "test_trainer",
tags = ["agents_dir"],
size = "medium",
srcs = ["agents/tests/test_trainer.py"]
)
# A2/3CTrainer
py_test(
name = "test_a2c",
@ -1228,13 +1236,6 @@ py_test(
# Tag: evaluation
# --------------------------------------------------------------------
py_test(
name = "evaluation/tests/test_evaluation",
tags = ["evaluation"],
size = "medium",
srcs = ["evaluation/tests/test_evaluation.py"]
)
py_test(
name = "evaluation/tests/test_rollout_worker",
tags = ["evaluation"],
@ -2366,6 +2367,26 @@ py_test(
args = ["--as-test", "--framework=torch"],
)
# Deactivated for now due to open-spiel's dependency on an outdated
# tensorflow-probability version.
# py_test(
# name = "examples/self_play_with_open_spiel_connect_4",
# main = "examples/self_play_with_open_spiel_connect_4.py",
# tags = ["examples", "examples_S"],
# size = "medium",
# srcs = ["examples/self_play_with_open_spiel_connect_4.py"],
# args = ["--framework=tf", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
# )
# py_test(
# name = "examples/self_play_with_open_spiel_connect_4_torch",
# main = "examples/self_play_with_open_spiel_connect_4.py",
# tags = ["examples", "examples_S"],
# size = "medium",
# srcs = ["examples/self_play_with_open_spiel_connect_4.py"],
# args = ["--framework=torch", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
# )
py_test(
name = "examples/trajectory_view_api_tf",
main = "examples/trajectory_view_api.py",

View file

@ -23,8 +23,6 @@ torch, _ = try_import_torch()
class TestDDPG(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
np.random.seed(42)
torch.manual_seed(42)
ray.init()
@classmethod
@ -34,6 +32,7 @@ class TestDDPG(unittest.TestCase):
def test_ddpg_compilation(self):
"""Test whether a DDPGTrainer can be built with both frameworks."""
config = ddpg.DEFAULT_CONFIG.copy()
config["seed"] = 42
config["num_workers"] = 1
config["num_envs_per_worker"] = 2
config["learning_starts"] = 0

View file

View file

@ -0,0 +1,143 @@
import gym
import unittest
import ray
import ray.rllib.agents.a3c as a3c
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.pg as pg
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import framework_iterator
class TestTrainer(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_add_delete_policy(self):
env = gym.make("CartPole-v0")
config = pg.DEFAULT_CONFIG.copy()
config.update({
"env": MultiAgentCartPole,
"env_config": {
"config": {
"num_agents": 4,
},
},
"multiagent": {
# Start with a single policy.
"policies": {
"p0": (None, env.observation_space, env.action_space, {}),
},
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
},
})
# TODO: (sven): Fix TrainTFMultiGPU to be flexible wrt adding policies
# on-the-fly.
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
trainer = pg.PGTrainer(config=config)
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics, while r1, r3 should.
r0 = trainer.train()
self.assertTrue("p0" in r0["policy_reward_min"])
for i in range(1, 4):
# Add a new policy.
new_pol = trainer.add_policy(
f"p{i}",
trainer._policy_class,
observation_space=env.observation_space,
action_space=env.action_space,
config={},
# Test changing the mapping fn.
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i}",
# Change the list of policies to train.
policies_to_train=[f"p{i}"],
)
pol_map = trainer.workers.local_worker().policy_map
self.assertTrue(new_pol is not trainer.get_policy("p0"))
self.assertTrue("p0" in pol_map)
self.assertTrue("p1" in pol_map)
self.assertTrue(len(pol_map) == i + 1)
r = trainer.train()
self.assertTrue("p1" in r["policy_reward_min"])
# Delete all added policies again from trainer.
for i in range(3, 0, -1):
trainer.remove_policy(
f"p{i}",
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
policies_to_train=[f"p{i - 1}"])
trainer.stop()
def test_evaluation_option(self):
config = dqn.DEFAULT_CONFIG.copy()
config.update({
"env": "CartPole-v0",
"evaluation_interval": 2,
"evaluation_num_episodes": 2,
"evaluation_config": {
"gamma": 0.98,
}
})
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQNTrainer(config=config)
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics, while r1, r3 should.
r0 = trainer.train()
print(r0)
r1 = trainer.train()
print(r1)
r2 = trainer.train()
print(r2)
r3 = trainer.train()
print(r3)
trainer.stop()
self.assertFalse("evaluation" in r0)
self.assertTrue("evaluation" in r1)
self.assertFalse("evaluation" in r2)
self.assertTrue("evaluation" in r3)
self.assertTrue("episode_reward_mean" in r1["evaluation"])
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
def test_evaluation_wo_evaluation_worker_set(self):
config = a3c.DEFAULT_CONFIG.copy()
config.update({
"env": "CartPole-v0",
# Switch off evaluation (this should already be the default).
"evaluation_interval": None,
})
for _ in framework_iterator(frameworks=("tf", "torch")):
# Setup trainer w/o evaluation worker set and still call
# evaluate() -> Expect error.
trainer_wo_env_on_driver = a3c.A3CTrainer(config=config)
self.assertRaisesRegexp(
ValueError, "Cannot evaluate w/o an evaluation worker set",
trainer_wo_env_on_driver.evaluate)
trainer_wo_env_on_driver.stop()
# Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer
# can `evaluate` even though, it doesn't have an evaluation-worker
# set.
config["create_env_on_driver"] = True
trainer_w_env_on_driver = a3c.A3CTrainer(config=config)
results = trainer_w_env_on_driver.evaluate()
assert "evaluation" in results
assert "episode_reward_mean" in results["evaluation"]
trainer_w_env_on_driver.stop()
config["create_env_on_driver"] = False
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,6 +1,7 @@
import copy
from datetime import datetime
import functools
import gym
import logging
import math
import numpy as np
@ -31,8 +32,8 @@ from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import TrainerConfigDict, \
PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
from ray.tune.logger import Logger, UnifiedLogger
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.resources import Resources
@ -905,7 +906,7 @@ class Trainer(Trainable):
"""Sync "main" weights to given WorkerSet or list of workers."""
assert worker_set is not None
# Broadcast the new policy weights to all evaluation workers.
logger.info("Synchronizing weights to evaluation workers.")
logger.info("Synchronizing weights to workers.")
weights = ray.put(self.workers.local_worker().save())
worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))
@ -1069,7 +1070,7 @@ class Trainer(Trainable):
"""Return policy for the specified id, or None.
Args:
policy_id (str): id of policy to return.
policy_id (PolicyID): ID of the policy to return.
"""
return self.workers.local_worker().get_policy(policy_id)
@ -1092,6 +1093,101 @@ class Trainer(Trainable):
"""
self.workers.local_worker().set_weights(weights)
@PublicAPI
def add_policy(
self,
policy_id: PolicyID,
policy_cls: Type[Policy],
*,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID],
PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
) -> Policy:
"""Adds a new policy to this Trainer.
Args:
policy_id (PolicyID): ID of the policy to add.
policy_cls (Type[Policy]): The Policy class to use for
constructing the new Policy.
observation_space (Optional[gym.spaces.Space]): The observation
space of the policy to add.
action_space (Optional[gym.spaces.Space]): The action space
of the policy to add.
config (Optional[PartialTrainerConfigDict]): The config overrides
for the policy to add.
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
optional (updated) policy mapping function to use from here on.
Note that already ongoing episodes will not change their
mapping but will use the old mapping till the end of the
episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
Returns:
Policy: The newly added policy (the copy that got added to the
local worker).
"""
def fn(worker):
# `foreach_worker` function: Adds the policy the the worker (and
# maybe changes its policy_mapping_fn - if provided here).
worker.add_policy(
policy_id=policy_id,
policy_cls=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
# Run foreach_worker fn on all workers (incl. evaluation workers).
self.workers.foreach_worker(fn)
if self.evaluation_workers is not None:
self.evaluation_workers.foreach_worker(fn)
# Return newly added policy (from the local rollout worker).
return self.get_policy(policy_id)
@PublicAPI
def remove_policy(
self,
policy_id: PolicyID = DEFAULT_POLICY_ID,
*,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
) -> None:
"""Removes a new policy from this Trainer.
Args:
policy_id (Optional[PolicyID]): ID of the policy to be removed.
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
optional (updated) policy mapping function to use from here on.
Note that already ongoing episodes will not change their
mapping but will use the old mapping till the end of the
episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
"""
def fn(worker):
worker.remove_policy(
policy_id=policy_id,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
self.workers.foreach_worker(fn)
if self.evaluation_workers is not None:
self.evaluation_workers.foreach_worker(fn)
@DeveloperAPI
def export_policy_model(self,
export_dir: str,

61
rllib/env/base_env.py vendored
View file

@ -423,10 +423,6 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
assert isinstance(rewards, dict), "Not a multi-agent reward"
assert isinstance(dones, dict), "Not a multi-agent return"
assert isinstance(infos, dict), "Not a multi-agent info"
if set(obs.keys()) != set(rewards.keys()):
raise ValueError(
"Key set for obs and rewards must be the same: "
"{} vs {}".format(obs.keys(), rewards.keys()))
if set(infos).difference(set(obs)):
raise ValueError("Key set for infos must be a subset of obs: "
"{} vs {}".format(infos.keys(), obs.keys()))
@ -470,31 +466,52 @@ class _MultiAgentEnvState:
if not self.initialized:
self.reset()
self.initialized = True
obs, rew, dones, info = (self.last_obs, self.last_rewards,
self.last_dones, self.last_infos)
self.last_obs = {}
self.last_rewards = {}
self.last_dones = {"__all__": False}
observations = self.last_obs
rewards = {}
dones = {"__all__": self.last_dones["__all__"]}
infos = {}
# If episode is done, release everything we have.
if dones["__all__"]:
rewards = self.last_rewards
self.last_rewards = {}
dones = self.last_dones
self.last_dones = {}
self.last_obs = {}
# Only release those agents' rewards/dones/infos, whose
# observations we have.
else:
for ag in observations.keys():
if ag in self.last_rewards:
rewards[ag] = self.last_rewards[ag]
del self.last_rewards[ag]
if ag in self.last_dones:
dones[ag] = self.last_dones[ag]
del self.last_dones[ag]
self.last_dones["__all__"] = False
self.last_infos = {}
return obs, rew, dones, info
return observations, rewards, dones, infos
def observe(self, obs: MultiAgentDict, rewards: MultiAgentDict,
dones: MultiAgentDict, infos: MultiAgentDict):
self.last_obs = obs
self.last_rewards = rewards
self.last_dones = dones
for ag, r in rewards.items():
if ag in self.last_rewards:
self.last_rewards[ag] += r
else:
self.last_rewards[ag] = r
for ag, d in dones.items():
if ag in self.last_dones:
self.last_dones[ag] = self.last_dones[ag] or d
else:
self.last_dones[ag] = d
self.last_infos = infos
def reset(self) -> MultiAgentDict:
self.last_obs = self.env.reset()
self.last_rewards = {
agent_id: None
for agent_id in self.last_obs.keys()
}
self.last_dones = {
agent_id: False
for agent_id in self.last_obs.keys()
}
self.last_infos = {agent_id: {} for agent_id in self.last_obs.keys()}
self.last_dones["__all__"] = False
self.last_rewards = {}
self.last_dones = {"__all__": False}
self.last_infos = {}
return self.last_obs

75
rllib/env/wrappers/open_spiel.py vendored Normal file
View file

@ -0,0 +1,75 @@
from gym.spaces import Box, Discrete
import numpy as np
import pyspiel
from ray.rllib.env.multi_agent_env import MultiAgentEnv
class OpenSpielEnv(MultiAgentEnv):
def __init__(self, env):
self.env = env
# Agent IDs are ints, starting from 0.
self.num_agents = self.env.num_players()
# Store the open-spiel game type.
self.type = self.env.get_type()
# Stores the current open-spiel game state.
self.state = None
# Extract observation- and action spaces from game.
self.observation_space = Box(
float("-inf"), float("inf"),
(self.env.observation_tensor_size(), ))
self.action_space = Discrete(self.env.num_distinct_actions())
def reset(self):
self.state = self.env.new_initial_state()
return self._get_obs()
def step(self, action):
# Sequential game:
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
curr_player = self.state.current_player()
assert curr_player in action
penalty = None
try:
self.state.apply_action(action[curr_player])
# TODO: (sven) resolve this hack by publishing legal actions
# with each step.
except pyspiel.SpielError:
self.state.apply_action(
np.random.choice(self.state.legal_actions()))
penalty = -0.1
# Are we done?
is_done = self.state.is_terminal()
dones = dict({ag: is_done
for ag in range(self.num_agents)},
**{"__all__": is_done})
# Compile rewards dict.
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
if penalty:
rewards[curr_player] += penalty
# Simultaneous game.
else:
raise NotImplementedError
return self._get_obs(), rewards, dones, {}
def _get_obs(self):
curr_player = self.state.current_player()
# Sequential game:
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
if self.state.is_terminal():
return {}
return {
curr_player: np.reshape(self.state.observation_tensor(), [-1])
}
# Simultaneous game.
else:
raise NotImplementedError
def render(self, mode=None) -> None:
if mode == "human":
print(self.state)

View file

@ -313,7 +313,7 @@ class Unity3DEnv(MultiAgentEnv):
action_spaces["Striker"], {}),
}
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
return "Striker" if "Striker" in agent_id else "Goalie"
else:
@ -322,7 +322,7 @@ class Unity3DEnv(MultiAgentEnv):
action_spaces[game_name], {}),
}
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
return game_name
return policies, policy_mapping_fn

View file

@ -627,7 +627,7 @@ class SimpleListCollector(SampleCollector):
if is_done and check_dones and \
not pre_batch[SampleBatch.DONES][-1]:
raise ValueError(
"Episode {} terminated for all agents, but we still"
"Episode {} terminated for all agents, but we still "
"don't have a last observation for agent {} (policy "
"{}). ".format(
episode_id, agent_id, self.agent_key_to_policy_id[(
@ -674,8 +674,16 @@ class SimpleListCollector(SampleCollector):
policies=self.policy_map,
postprocessed_batch=post_batch,
original_batches=pre_batches)
# Add the postprocessed SampleBatch to the policy collectors for
# training.
# PID may be a newly added policy. Just confirm we have it in our
# policy map before proceeding with adding a new _PolicyCollector()
# to the group.
if pid not in policy_collector_group.policy_collectors:
assert pid in self.policy_map
policy_collector_group.policy_collectors[
pid] = _PolicyCollector(policy)
policy_collector_group.policy_collectors[
pid].add_postprocessed_batch_for_training(
post_batch, policy.view_requirements)
@ -780,9 +788,18 @@ class SimpleListCollector(SampleCollector):
vectorized environments).
"""
pid = self.agent_key_to_policy_id[agent_key]
# PID may be a newly added policy. Just confirm we have it in our
# policy map before proceeding with forward_pass_size=0.
if pid not in self.forward_pass_size:
assert pid in self.policy_map
self.forward_pass_size[pid] = 0
self.forward_pass_agent_keys[pid] = []
idx = self.forward_pass_size[pid]
if idx == 0:
self.forward_pass_agent_keys[pid].clear()
self.forward_pass_agent_keys[pid].append(agent_key)
self.forward_pass_size[pid] += 1

View file

@ -6,9 +6,11 @@ from typing import List, Dict, Callable, Any, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
EnvActionType, EnvID, EnvInfoDict, EnvObsType
from ray.util import log_once
if TYPE_CHECKING:
from ray.rllib.evaluation.sample_batch_builder import \
@ -48,7 +50,8 @@ class MultiAgentEpisode:
"""
def __init__(self, policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
PolicyID],
batch_builder_factory: Callable[
[], "MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None],
@ -68,15 +71,17 @@ class MultiAgentEpisode:
self.user_data: Dict[str, Any] = {}
self.hist_data: Dict[str, List[float]] = {}
self.media: Dict[str, Any] = {}
self._policies: Dict[PolicyID, Policy] = policies
self._policy_mapping_fn: Callable[[AgentID], PolicyID] = \
policy_mapping_fn
self.policy_map: Dict[PolicyID, Policy] = policies
self._policies = self.policy_map # backward compatibility
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
PolicyID] = policy_mapping_fn
self._next_agent_index: int = 0
self._agent_to_index: Dict[AgentID, int] = {}
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
self._agent_to_last_done: Dict[AgentID, bool] = {}
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
self._agent_to_last_pi_info: Dict[AgentID, dict] = {}
@ -112,8 +117,29 @@ class MultiAgentEpisode:
"""
if agent_id not in self._agent_to_policy:
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
# Try new API: pass in agent_id and episode as named args.
# New signature should be: (agent_id, episode, **kwargs)
try:
policy_id = self._agent_to_policy[agent_id] = \
self._policy_mapping_fn(agent_id, self)
except TypeError as e:
if "positional argument" in e.args[0] or \
"unexpected keyword argument" in e.args[0]:
if log_once("policy_mapping_new_signature"):
deprecation_warning(
old="policy_mapping_fn(agent_id)",
new="policy_mapping_fn(agent_id, episode, "
"**kwargs)")
policy_id = self._agent_to_policy[agent_id] = \
self._policy_mapping_fn(agent_id)
else:
raise e
else:
policy_id = self._agent_to_policy[agent_id]
if policy_id not in self.policy_map:
raise KeyError("policy_mapping_fn returned invalid policy id "
f"'{policy_id}'!")
return policy_id
@DeveloperAPI
def last_observation_for(
@ -145,7 +171,8 @@ class MultiAgentEpisode:
return flatten_to_single_ndarray(
self._agent_to_last_action[agent_id])
else:
policy = self._policies[self.policy_for(agent_id)]
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
flat = flatten_to_single_ndarray(policy.action_space.sample())
if hasattr(policy.action_space, "dtype"):
return np.zeros_like(flat, dtype=policy.action_space.dtype)
@ -179,16 +206,34 @@ class MultiAgentEpisode:
"""Returns the last RNN state for the specified agent."""
if agent_id not in self._agent_to_rnn_state:
policy = self._policies[self.policy_for(agent_id)]
policy_id = self.policy_for(agent_id)
policy = self.policy_map[policy_id]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
@DeveloperAPI
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
"""Returns the last done flag received for the specified agent."""
if agent_id not in self._agent_to_last_done:
self._agent_to_last_done[agent_id] = False
return self._agent_to_last_done[agent_id]
@DeveloperAPI
def last_pi_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> dict:
"""Returns the last info object for the specified agent."""
return self._agent_to_last_pi_info[agent_id]
@DeveloperAPI
def get_agents(self) -> List[AgentID]:
"""Returns list of agent IDs that have appeared in this episode.
Returns:
List[AgentID]: The list of all agents that have appeared so
far in this episode.
"""
return list(self._agent_to_index.keys())
def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
for agent_id, reward in reward_dict.items():
if reward is not None:
@ -206,6 +251,9 @@ class MultiAgentEpisode:
def _set_last_raw_obs(self, agent_id, obs):
self._agent_to_last_raw_obs[agent_id] = obs
def _set_last_done(self, agent_id, done):
self._agent_to_last_done[agent_id] = done
def _set_last_info(self, agent_id, info):
self._agent_to_last_info[agent_id] = info

View file

@ -47,6 +47,7 @@ from ray.util.debug import log_once, disable_log_once_globally, \
from ray.util.iter import ParallelIteratorWorker
if TYPE_CHECKING:
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.observation_function import ObservationFunction
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
@ -106,7 +107,7 @@ class RolloutWorker(ParallelIteratorWorker):
... "traffic_light_policy":
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id:
... policy_mapping_fn=lambda agent_id, episode, **kwargs:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(worker.sample())
@ -141,7 +142,8 @@ class RolloutWorker(ParallelIteratorWorker):
policy_spec: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
rollout_fragment_length: int = 100,
@ -200,12 +202,12 @@ class RolloutWorker(ParallelIteratorWorker):
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn can also be set (if not, will map all agents
to DEFAULT_POLICY_ID).
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
callable that maps agent ids to policy ids in multi-agent mode.
This function will be called each time a new agent appears in
an episode, to bind that agent to a policy for the duration of
the episode. If not provided, will map all agents to
DEFAULT_POLICY_ID.
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
PolicyID]]): A callable that maps agent ids to policy ids in
multi-agent mode. This function will be called each time a new
agent appears in an episode, to bind that agent to a policy
for the duration of the episode. If not provided, will map all
agents to DEFAULT_POLICY_ID.
policies_to_train (Optional[List[PolicyID]]): Optional list of
policies to train, or None for all policies.
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
@ -361,16 +363,19 @@ class RolloutWorker(ParallelIteratorWorker):
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
if not callable(policy_mapping_fn):
raise ValueError("Policy mapping function not callable?")
# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
# independent on the agent ID and the episode passed in.
self.policy_mapping_fn = lambda aid, ep, **kwargs: DEFAULT_POLICY_ID
self.set_policy_mapping_fn(policy_mapping_fn)
self.env_creator: Callable[[EnvContext], EnvType] = env_creator
self.rollout_fragment_length: int = rollout_fragment_length * num_envs
self.count_steps_by: str = count_steps_by
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = True
self.observation_filter = observation_filter
self.last_batch: SampleBatchType = None
self.global_vars: dict = None
self.fake_sampler: bool = fake_sampler
@ -472,8 +477,11 @@ class RolloutWorker(ParallelIteratorWorker):
self.tf_sess = None
policy_dict = _validate_and_canonicalize(
policy_spec, self.env, spaces=spaces)
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
self.policies_to_train: List[PolicyID] = list(policy_dict.keys())
self.set_policies_to_train(policies_to_train)
self.policy_map: Dict[PolicyID, Policy] = None
self.preprocessors: Dict[PolicyID, Preprocessor] = None
@ -570,7 +578,7 @@ class RolloutWorker(ParallelIteratorWorker):
"ExternalMultiAgentEnv?".format(self.env))
self.filters: Dict[PolicyID, Filter] = {
policy_id: get_filter(observation_filter,
policy_id: get_filter(self.observation_filter,
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
@ -641,10 +649,6 @@ class RolloutWorker(ParallelIteratorWorker):
self.sampler = AsyncSampler(
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
@ -667,10 +671,6 @@ class RolloutWorker(ParallelIteratorWorker):
self.sampler = SyncSampler(
worker=self,
env=self.async_env,
policies=self.policy_map,
policy_mapping_fn=policy_mapping_fn,
preprocessors=self.preprocessors,
obs_filters=self.filters,
clip_rewards=clip_rewards,
rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
@ -1009,6 +1009,124 @@ class RolloutWorker(ParallelIteratorWorker):
return self.policy_map.get(policy_id)
@DeveloperAPI
def add_policy(
self,
*,
policy_id: PolicyID,
policy_cls: Type[Policy],
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
) -> Policy:
"""Adds a new policy to this RolloutWorker.
Args:
policy_id (Optional[PolicyID]): ID of the policy to add.
policy_cls (Type[Policy]): The Policy class to use for
constructing the new Policy.
observation_space (Optional[gym.spaces.Space]): The observation
space of the policy to add.
action_space (Optional[gym.spaces.Space]): The action space
of the policy to add.
config (Optional[PartialTrainerConfigDict]): The config overrides
for the policy to add.
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
PolicyID]]): An optional (updated) policy mapping function to
use from here on. Note that already ongoing episodes will not
change their mapping but will use the old mapping till the
end of the episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
Returns:
Policy: The newly added policy (the copy that got added to the
local worker).
"""
if policy_id in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
policy_dict = {
policy_id: (policy_cls, observation_space, action_space, config)
}
add_map, add_prep = self._build_policy_map(policy_dict,
self.policy_config)
new_policy = add_map[policy_id]
self.policy_map.update(add_map)
self.preprocessors.update(add_prep)
self.filters[policy_id] = get_filter(
self.observation_filter, new_policy.observation_space.shape)
self.set_policy_mapping_fn(policy_mapping_fn)
self.set_policies_to_train(policies_to_train)
return new_policy
@DeveloperAPI
def remove_policy(
self,
*,
policy_id: PolicyID = DEFAULT_POLICY_ID,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
):
"""Removes a policy from this RolloutWorker.
Args:
policy_id (Optional[PolicyID]): ID of the policy to be removed.
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
optional (updated) policy mapping function to use from here on.
Note that already ongoing episodes will not change their
mapping but will use the old mapping till the end of the
episode.
policies_to_train (Optional[List[PolicyID]]): An optional list of
policy IDs to be trained. If None, will keep the existing list
in place. Policies, whose IDs are not in the list will not be
updated.
"""
if policy_id not in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
del self.policy_map[policy_id]
del self.preprocessors[policy_id]
self.set_policy_mapping_fn(policy_mapping_fn)
self.set_policies_to_train(policies_to_train)
@DeveloperAPI
def set_policy_mapping_fn(
self,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
):
"""Sets `self.policy_mapping_fn` to a new callable (if provided).
Args:
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): The
new mapping function to use. If None, will keep the existing
mapping function in place.
"""
if policy_mapping_fn is not None:
self.policy_mapping_fn = policy_mapping_fn
if not callable(self.policy_mapping_fn):
raise ValueError("`policy_mapping_fn` must be a callable!")
@DeveloperAPI
def set_policies_to_train(
self, policies_to_train: Optional[List[PolicyID]] = None):
"""Sets `self.policies_to_train` to a new list of PolicyIDs.
Args:
policies_to_train (Optional[List[PolicyID]]): The new
list of policy IDs to train with. If None, will keep the
existing list in place.
"""
if policies_to_train is not None:
self.policies_to_train = policies_to_train
@DeveloperAPI
def for_policy(self,
func: Callable[[Policy], T],
@ -1091,6 +1209,12 @@ class RolloutWorker(ParallelIteratorWorker):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
if pid not in self.policy_map:
logger.warning(
f"pid={pid} not found in policy_map! It was probably added"
" on-the-fly and is not part of the static `config."
"multiagent.policies` dict. Ignoring it for now.")
continue
self.policy_map[pid].set_state(state)
@DeveloperAPI
@ -1184,7 +1308,9 @@ class RolloutWorker(ParallelIteratorWorker):
else:
raise ValueError("This policy does not support eager "
"execution: {}".format(cls))
with tf1.variable_scope(name):
scope = name + (("_wk" + str(self.worker_index))
if self.worker_index else "")
with tf1.variable_scope(scope):
policy_map[name] = cls(obs_space, act_space, merged_conf)
# non-tf.
else:
@ -1238,7 +1364,7 @@ def _validate_and_canonicalize(
_validate_multiagent_config(policy)
return policy
elif not issubclass(policy, Policy):
raise ValueError("policy must be a rllib.Policy class")
raise ValueError(f"`policy` ({policy}) must be a rllib.Policy class!")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):

View file

@ -25,6 +25,7 @@ from ray.rllib.offline import InputReader
from ray.rllib.policy.policy import clip_action, Policy
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import unbatch
@ -129,10 +130,6 @@ class SyncSampler(SamplerInput):
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
@ -146,6 +143,11 @@ class SyncSampler(SamplerInput):
observation_fn: "ObservationFunction" = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
):
"""Initializes a SyncSampler object.
@ -153,13 +155,6 @@ class SyncSampler(SamplerInput):
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
policy_mapping_fn (callable): Callable that takes an agent ID and
returns a Policy object.
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
Preprocessor object for the observations prior to filtering.
obs_filters (Dict[str,Filter]): Mapping from policy ID to
env Filter object.
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
float value for +/- value clipping. False for no clipping.
rollout_fragment_length (int): The length of a fragment to collect
@ -189,20 +184,27 @@ class SyncSampler(SamplerInput):
render (bool): Whether to try to render the environment after each
step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
if log_once("deprecated_sync_sampler_args"):
if policies is not None:
deprecation_warning(old="policies")
if policy_mapping_fn is not None:
deprecation_warning(old="policy_mapping_fn")
if preprocessors is not None:
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
self.base_env = BaseEnv.to_base_env(env)
self.rollout_fragment_length = rollout_fragment_length
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self.preprocessors = preprocessors
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats()
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
policies,
worker.policy_map,
clip_rewards,
callbacks,
multiple_episodes_in_batch,
@ -212,11 +214,10 @@ class SyncSampler(SamplerInput):
# Create the rollout generator to use for calls to `get_data()`.
self.rollout_provider = _env_runner(
worker, self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats,
soft_horizon, no_done_at_end, observation_fn,
worker, self.base_env, self.extra_batches.put,
self.rollout_fragment_length, self.horizon, clip_rewards,
clip_actions, multiple_episodes_in_batch, callbacks, tf_sess,
self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
self.sample_collector, self.render)
self.metrics_queue = queue.Queue()
@ -264,10 +265,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
*,
worker: "RolloutWorker",
env: BaseEnv,
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
@ -282,6 +279,11 @@ class AsyncSampler(threading.Thread, SamplerInput):
observation_fn: "ObservationFunction" = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
):
"""Initializes a AsyncSampler object.
@ -289,13 +291,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
worker (RolloutWorker): The RolloutWorker that will use this
Sampler for sampling.
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
policies (Dict[str, Policy]): Mapping from policy ID to Policy obj.
policy_mapping_fn (callable): Callable that takes an agent ID and
returns a Policy object.
preprocessors (Dict[str, Preprocessor]): Mapping from policy ID to
Preprocessor object for the observations prior to filtering.
obs_filters (Dict[str, Filter]): Mapping from policy ID to
env Filter object.
clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual
float value for +/- value clipping. False for no clipping.
rollout_fragment_length (int): The length of a fragment to collect
@ -329,10 +324,24 @@ class AsyncSampler(threading.Thread, SamplerInput):
render (bool): Whether to try to render the environment after each
step.
"""
for _, f in obs_filters.items():
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
if log_once("deprecated_async_sampler_args"):
if policies is not None:
deprecation_warning(old="policies")
if policy_mapping_fn is not None:
deprecation_warning(old="policy_mapping_fn")
if preprocessors is not None:
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
self.worker = worker
for _, f in worker.filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.worker = worker
self.base_env = BaseEnv.to_base_env(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
@ -340,10 +349,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.metrics_queue = queue.Queue()
self.rollout_fragment_length = rollout_fragment_length
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self.preprocessors = preprocessors
self.obs_filters = obs_filters
self.clip_rewards = clip_rewards
self.daemon = True
self.multiple_episodes_in_batch = multiple_episodes_in_batch
@ -360,7 +365,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
policies,
worker.policy_map,
clip_rewards,
callbacks,
multiple_episodes_in_batch,
@ -384,9 +389,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
extra_batches_putter = (
lambda x: self.extra_batches.put(x, timeout=600.0))
rollout_provider = _env_runner(
self.worker, self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.worker, self.base_env, extra_batches_putter,
self.rollout_fragment_length, self.horizon, self.clip_rewards,
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
self.tf_sess, self.perf_stats, self.soft_horizon,
self.no_done_at_end, self.observation_fn, self.sample_collector,
@ -439,12 +443,8 @@ def _env_runner(
worker: "RolloutWorker",
base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None],
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID], PolicyID],
rollout_fragment_length: int,
horizon: int,
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
clip_rewards: bool,
clip_actions: bool,
multiple_episodes_in_batch: bool,
@ -463,19 +463,10 @@ def _env_runner(
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (Dict[PolicyID, Policy]): Map of policy ids to Policy
instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode.
rollout_fragment_length (int): Number of episode steps before
`SampleBatch` is yielded. Set to infinity to yield complete
episodes.
horizon (int): Horizon of the episode.
preprocessors (dict): Map of policy id to preprocessor for the
observations prior to filtering.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
clip_rewards (bool): Whether to clip rewards before postprocessing.
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
@ -551,14 +542,14 @@ def _env_runner(
def new_episode(env_id):
episode = MultiAgentEpisode(
policies,
policy_mapping_fn,
worker.policy_map,
worker.policy_mapping_fn,
get_batch_builder,
extra_batch_callback,
env_id=env_id)
# Call each policy's Exploration.on_episode_start method.
# type: Policy
for p in policies.values():
for p in worker.policy_map.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
@ -568,7 +559,7 @@ def _env_runner(
callbacks.on_episode_start(
worker=worker,
base_env=base_env,
policies=policies,
policies=worker.policy_map,
episode=episode,
env_index=env_id,
)
@ -599,15 +590,12 @@ def _env_runner(
_process_observations(
worker=worker,
base_env=base_env,
policies=policies,
active_episodes=active_episodes,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
dones=dones,
infos=infos,
horizon=horizon,
preprocessors=preprocessors,
obs_filters=obs_filters,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
soft_horizon=soft_horizon,
@ -624,7 +612,8 @@ def _env_runner(
# type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = _do_policy_eval(
to_eval=to_eval,
policies=policies,
policies=worker.policy_map,
policy_mapping_fn=worker.policy_mapping_fn,
sample_collector=sample_collector,
active_episodes=active_episodes,
tf_sess=tf_sess,
@ -640,7 +629,7 @@ def _env_runner(
active_episodes=active_episodes,
active_envs=active_envs,
off_policy_actions=off_policy_actions,
policies=policies,
policies=worker.policy_map,
clip_actions=clip_actions,
)
perf_stats.action_processing_time += time.time() - t3
@ -685,15 +674,12 @@ def _process_observations(
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
active_episodes: Dict[str, MultiAgentEpisode],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]],
dones: Dict[EnvID, Dict[AgentID, bool]],
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
horizon: int,
preprocessors: Dict[PolicyID, Preprocessor],
obs_filters: Dict[PolicyID, Filter],
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
soft_horizon: bool,
@ -707,7 +693,6 @@ def _process_observations(
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): Env implementing BaseEnv.
policies (dict): Map of policy ids to Policy instances.
batch_builder_pool (List[SampleBatchBuilder]): List of pooled
SampleBatchBuilder object for recycling.
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
@ -722,10 +707,6 @@ def _process_observations(
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
info dicts, returned by a `BaseEnv.poll()` call.
horizon (int): Horizon of the episode.
preprocessors (dict): Map of policy id to preprocessor for the
observations prior to filtering.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
rollout_fragment_length (int): Number of episode steps before
`SampleBatch` is yielded. Set to infinity to yield complete
episodes.
@ -781,6 +762,17 @@ def _process_observations(
dict(episode.agent_rewards),
episode.custom_metrics, {},
episode.hist_data, episode.media))
# Check whether we have to create a fake-last observation
# for some agents (the environment is not required to do so if
# dones[__all__]=True).
for ag_id in episode.get_agents():
if not episode.last_done_for(
ag_id) and ag_id not in all_agents_obs:
# Create a fake (all-0s) observation.
obs_sp = worker.policy_map[episode.policy_for(
ag_id)].observation_space
obs_sp = getattr(obs_sp, "original_space", obs_sp)
all_agents_obs[ag_id] = np.zeros_like(obs_sp.sample())
else:
hit_horizon = False
all_agents_done = False
@ -792,7 +784,7 @@ def _process_observations(
agent_obs=all_agents_obs,
worker=worker,
base_env=base_env,
policies=policies,
policies=worker.policy_map,
episode=episode)
if not isinstance(all_agents_obs, dict):
raise ValueError(
@ -813,17 +805,18 @@ def _process_observations(
policy_id: PolicyID = episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(preprocessors,
prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
policy_id).transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
filtered_obs: EnvObsType = _get_or_raise(obs_filters,
filtered_obs: EnvObsType = _get_or_raise(worker.filters,
policy_id)(prep_obs)
if log_once("filtered_obs"):
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_done(agent_id, agent_done)
# Infos from the environment.
agent_infos = infos[env_id].get(agent_id, {})
episode._set_last_info(agent_id, agent_infos)
@ -842,7 +835,7 @@ def _process_observations(
# Action (slot 0) taken at timestep t.
"actions": episode.last_action_for(agent_id),
# Reward received after taking a at timestep t.
"rewards": rewards[env_id][agent_id],
"rewards": rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we reach terminal?
"dones": (False if (no_done_at_end
or (hit_horizon and soft_horizon)) else
@ -851,7 +844,7 @@ def _process_observations(
"new_obs": filtered_obs,
}
# Add extra-action-fetches to collectors.
pol = policies[policy_id]
pol = worker.policy_map[policy_id]
for key, value in episode.last_pi_info_for(agent_id).items():
if key in pol.view_requirements:
values_dict[key] = value
@ -868,8 +861,8 @@ def _process_observations(
if last_observation is None else
episode.rnn_state_for(agent_id), None
if last_observation is None else
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0)
episode.last_action_for(agent_id), rewards[env_id].get(
agent_id, 0.0))
to_eval[policy_id].append(item)
# Invoke the `on_episode_step` callback after the step is logged
@ -904,7 +897,7 @@ def _process_observations(
outputs.append(ma_sample_batch)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
for p in worker.policy_map.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
@ -915,7 +908,7 @@ def _process_observations(
callbacks.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
policies=worker.policy_map,
episode=episode,
env_index=env_id,
)
@ -942,15 +935,15 @@ def _process_observations(
agent_obs=resetted_obs,
worker=worker,
base_env=base_env,
policies=policies,
policies=worker.policy_map,
episode=new_episode)
# type: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(
preprocessors, policy_id).transform(raw_obs)
worker.preprocessors, policy_id).transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
obs_filters, policy_id)(prep_obs)
worker.filters, policy_id)(prep_obs)
new_episode._set_last_observation(agent_id, filtered_obs)
# Add initial obs to buffer.
@ -978,6 +971,7 @@ def _do_policy_eval(
*,
to_eval: Dict[PolicyID, List[PolicyEvalData]],
policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
sample_collector,
active_episodes: Dict[str, MultiAgentEpisode],
tf_sess: Optional["tf.Session"] = None,
@ -1011,7 +1005,15 @@ def _do_policy_eval(
summarize(to_eval)))
for policy_id, eval_data in to_eval.items():
policy: Policy = _get_or_raise(policies, policy_id)
# In case the policyID has been removed from this worker, we need to
# re-assign policy_id and re-lookup the Policy object to use.
try:
policy: Policy = _get_or_raise(policies, policy_id)
except ValueError:
policy_id = policy_mapping_fn(eval_data[0].agent_id,
active_episodes[eval_data[0].env_id])
policy: Policy = _get_or_raise(policies, policy_id)
input_dict = sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = \
policy.compute_actions_from_input_dict(
@ -1163,6 +1165,7 @@ def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
"""
if policy_id not in mapping:
raise ValueError(
"Could not find policy for agent: agent policy id `{}` not "
"in policy map keys {}.".format(policy_id, mapping.keys()))
"Could not find policy for agent: PolicyID `{}` not found "
"in policy map, whose keys are `{}`.".format(
policy_id, mapping.keys()))
return mapping[policy_id]

View file

@ -1,82 +0,0 @@
import unittest
import ray
import ray.rllib.agents.a3c as a3c
import ray.rllib.agents.dqn as dqn
from ray.rllib.utils.test_utils import framework_iterator
class TestTrainerEvaluation(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_evaluation_option(self):
config = dqn.DEFAULT_CONFIG.copy()
config.update({
"env": "CartPole-v0",
"evaluation_interval": 2,
"evaluation_num_episodes": 2,
"evaluation_config": {
"gamma": 0.98,
}
})
for _ in framework_iterator(config, frameworks=("tf", "torch")):
agent = dqn.DQNTrainer(config=config)
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics, while r1, r3 should.
r0 = agent.train()
print(r0)
r1 = agent.train()
print(r1)
r2 = agent.train()
print(r2)
r3 = agent.train()
print(r3)
agent.stop()
self.assertFalse("evaluation" in r0)
self.assertTrue("evaluation" in r1)
self.assertFalse("evaluation" in r2)
self.assertTrue("evaluation" in r3)
self.assertTrue("episode_reward_mean" in r1["evaluation"])
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
def test_evaluation_wo_evaluation_worker_set(self):
config = a3c.DEFAULT_CONFIG.copy()
config.update({
"env": "CartPole-v0",
# Switch off evaluation (this should already be the default).
"evaluation_interval": None,
})
for _ in framework_iterator(frameworks=("tf", "torch")):
# Setup trainer w/o evaluation worker set and still call
# evaluate() -> Expect error.
agent_wo_env_on_driver = a3c.A3CTrainer(config=config)
self.assertRaisesRegexp(
ValueError, "Cannot evaluate w/o an evaluation worker set",
agent_wo_env_on_driver.evaluate)
agent_wo_env_on_driver.stop()
# Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer
# can `evaluate` even though, it doesn't have an evaluation-worker
# set.
config["create_env_on_driver"] = True
agent_w_env_on_driver = a3c.A3CTrainer(config=config)
results = agent_w_env_on_driver.evaluate()
assert "evaluation" in results
assert "episode_reward_mean" in results["evaluation"]
agent_w_env_on_driver.stop()
config["create_env_on_driver"] = False
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -514,7 +514,8 @@ class TestRolloutWorker(unittest.TestCase):
"pol0": (MockPolicy, obs_space, action_space, {}),
"pol1": (MockPolicy, obs_space, action_space, {}),
},
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
policy_mapping_fn=lambda agent_id, episode, **kwargs:
"pol0" if agent_id == 0 else "pol1",
rollout_fragment_length=301,
count_steps_by="env_steps",
batch_mode="truncate_episodes",
@ -531,7 +532,8 @@ class TestRolloutWorker(unittest.TestCase):
"pol0": (MockPolicy, obs_space, action_space, {}),
"pol1": (MockPolicy, obs_space, action_space, {}),
},
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
policy_mapping_fn=lambda agent_id, episode, **kwargs:
"pol0" if agent_id == 0 else "pol1",
rollout_fragment_length=301,
count_steps_by="agent_steps",
batch_mode="truncate_episodes")

View file

@ -218,7 +218,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
}
def policy_fn(agent_id):
def policy_fn(agent_id, episode, **kwargs):
return "pol0"
config = {
@ -269,7 +269,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
{}),
}
def policy_fn(agent_id):
def policy_fn(agent_id, episode, **kwargs):
return "pol0"
config = {
@ -309,7 +309,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
"p0": (None, obs_space, action_space, {}),
"p1": (None, obs_space, action_space, {}),
},
"policy_mapping_fn": lambda aid: "p{}".format(aid),
"policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid),
"count_steps_by": "agent_steps",
}
tune.register_env(

View file

@ -109,7 +109,7 @@ class WorkerSet:
return self._remote_workers
def sync_weights(self) -> None:
"""Syncs weights of remote workers with the local worker."""
"""Syncs weights from the local worker to all remote workers."""
if self.remote_workers():
weights = ray.put(self.local_worker().get_weights())
for e in self.remote_workers():

View file

@ -239,7 +239,8 @@ if __name__ == "__main__":
"framework": args.framework,
}),
},
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
"policy_mapping_fn": (
lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"),
},
"model": {
"custom_model": "cc_model",

View file

@ -116,7 +116,8 @@ if __name__ == "__main__":
"pol1": (None, observer_space, action_space, {}),
"pol2": (None, observer_space, action_space, {}),
},
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
"policy_mapping_fn": (
lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"),
"observation_fn": central_critic_observer,
},
"model": {

View file

@ -45,7 +45,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
None, AsymCoinGame(env_config).OBSERVATION_SPACE,
AsymCoinGame.ACTION_SPACE, {}),
},
"policy_mapping_fn": lambda agent_id: agent_id,
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
# Size of batches collected from each worker.
"rollout_fragment_length": 20,

View file

@ -87,7 +87,7 @@ if __name__ == "__main__":
else:
maze = WindyMazeEnv(None)
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
if agent_id.startswith("low_level_"):
return "low_level_policy"
else:

View file

@ -61,7 +61,7 @@ def get_rllib_config(seeds, debug=False, stop_iters=200, tf=False):
None, IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE, {}),
},
"policy_mapping_fn": lambda agent_id: agent_id,
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),

View file

@ -95,7 +95,7 @@ if __name__ == "__main__":
}
policy_ids = list(policies.keys())
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
pol_id = random.choice(policy_ids)
return pol_id

View file

@ -78,7 +78,7 @@ if __name__ == "__main__":
"random": (RandomPolicy, obs_space, act_space, {}),
},
"policy_mapping_fn": (
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
lambda aid, **kwargs: ["pg_policy", "random"][aid % 2]),
},
"framework": args.framework,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.

View file

@ -32,7 +32,8 @@ if __name__ == "__main__":
# Method specific
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: agent_id),
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: agent_id),
},
},
)

View file

@ -51,7 +51,8 @@ if __name__ == "__main__":
# Method specific
"multiagent": {
"policies": policies,
"policy_mapping_fn": (lambda agent_id: "shared_policy"),
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: "shared_policy"),
},
},
)

View file

@ -68,7 +68,7 @@ if __name__ == "__main__":
DQNTFPolicy, obs_space, act_space, {}),
}
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
if agent_id % 2 == 0:
return "ppo_policy"
else:

View file

@ -54,7 +54,7 @@ if __name__ == "__main__":
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id: "av"
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
}
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.

View file

@ -72,7 +72,7 @@ if __name__ == "__main__":
}
policy_ids = list(policies.keys())
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
pol_id = random.choice(policy_ids)
return pol_id

View file

@ -75,7 +75,7 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
beat_last heuristics.
"""
def select_policy(agent_id):
def select_policy(agent_id, episode, **kwargs):
if agent_id == "player1":
return "learned"
else:

View file

@ -0,0 +1,271 @@
"""Example showing how one can implement a simple self-play training workflow.
Uses the open spiel adapter of RLlib with the "connect_four" game and
a multi-agent setup with a "main" policy and n "main_v[x]" policies
(x=version number), which are all at-some-point-frozen copies of
"main". At the very beginning, "main" plays against RandomPolicy.
Checks for the training progress after each training update via a custom
callback. We simply measure the win rate of "main" vs the opponent
("main_v[x]" or RandomPolicy at the beginning) by looking through the
achieved rewards in the episodes in the train batch. If this win rate
reaches some configurable threshold, we add a new policy to
the policy map (a frozen copy of the current "main" one) and change the
policy_mapping_fn to make new matches of "main" vs the just added one.
After training for n iterations, a configurable number of episodes can
be played by the user against the "main" agent on the command line.
"""
import argparse
import numpy as np
import os
import pyspiel
from open_spiel.python.rl_environment import Environment
import sys
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
from ray.tune import register_env
OBS_SPACE = ACTION_SPACE = None
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--from-checkpoint",
type=str,
default=None,
help="Full path to a checkpoint file for restoring a previously saved "
"Trainer state.")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=1000000,
help="Number of timesteps to train.")
parser.add_argument(
"--win-rate-threshold",
type=float,
default=0.95,
help="Win-rate at which we setup another opponent by freezing the "
"current main policy and playing against a uniform distribution "
"of previously frozen 'main's from here on.")
parser.add_argument(
"--num-episodes-human-play",
type=int,
default=2,
help="How many episodes to play against the user on the command "
"line after training has finished.")
args = parser.parse_args()
def ask_user_for_action(time_step):
"""Asks the user for a valid action on the command line and returns it.
Re-queries the user until she picks a valid one.
Args:
time_step: The open spiel Environment time-step object.
"""
pid = time_step.observations["current_player"]
legal_moves = time_step.observations["legal_actions"][pid]
choice = -1
while choice not in legal_moves:
print("Choose an action from {}:".format(legal_moves))
sys.stdout.flush()
choice_str = input()
try:
choice = int(choice_str)
except ValueError:
continue
return choice
class SelfPlayCallback(DefaultCallbacks):
def __init__(self):
super().__init__()
# 0=RandomPolicy, 1=1st main policy snapshot,
# 2=2nd main policy snapshot, etc..
self.current_opponent = 0
def on_train_result(self, *, trainer, result, **kwargs):
# Get the win rate for the train batch.
# Note that normally, one should set up a proper evaluation config,
# such that evaluation always happens on the already updated policy,
# instead of on the already used train_batch.
main_rew = result["hist_stats"].pop("policy_main_reward")
opponent_rew = list(result["hist_stats"].values())[0]
assert len(main_rew) == len(opponent_rew)
won = 0
for r_main, r_opponent in zip(main_rew, opponent_rew):
if r_main > r_opponent:
won += 1
win_rate = won / len(main_rew)
print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="")
# If win rate is good -> Snapshot current policy and play against
# it next, keeping the snapshot fixed and only improving the "main"
# policy.
if win_rate > args.win_rate_threshold:
self.current_opponent += 1
new_pol_id = f"main_v{self.current_opponent}"
print(f"adding new opponent to the mix ({new_pol_id}).")
# Re-define the mapping function, such that "main" is forced
# to play against any of the previously played policies
# (excluding "random").
def policy_mapping_fn(agent_id, episode, **kwargs):
# agent_id = [0|1] -> policy depends on episode ID
# This way, we make sure that both policies sometimes play
# (start player) and sometimes agent1 (player to move 2nd).
return "main" if episode.episode_id % 2 == agent_id \
else "main_v{}".format(np.random.choice(
list(range(1, self.current_opponent + 1))))
new_policy = trainer.add_policy(
policy_id=new_pol_id,
policy_cls=type(trainer.get_policy("main")),
observation_space=OBS_SPACE,
action_space=ACTION_SPACE,
config={},
policy_mapping_fn=policy_mapping_fn,
)
# Set the weights of the new policy to the main policy.
# We'll keep training the main policy, whereas `new_pol_id` will
# remain fixed.
main_state = trainer.get_policy("main").get_state()
new_policy.set_state(main_state)
# We need to sync the just copied local weights (from main policy)
# to all the remote workers as well.
trainer.workers.sync_weights()
else:
print("not good enough; will keep learning")
if __name__ == "__main__":
ray.init(num_cpus=args.num_cpus or None, include_dashboard=False)
dummy_env = OpenSpielEnv(pyspiel.load_game("connect_four"))
OBS_SPACE = dummy_env.observation_space
ACTION_SPACE = dummy_env.action_space
register_env("connect_four",
lambda _: OpenSpielEnv(pyspiel.load_game("connect_four")))
def policy_mapping_fn(agent_id, episode, **kwargs):
# agent_id = [0|1] -> policy depends on episode ID
# This way, we make sure that both policies sometimes play agent0
# (start player) and sometimes agent1 (player to move 2nd).
return "main" if episode.episode_id % 2 == agent_id else "random"
config = {
"env": "connect_four",
"callbacks": SelfPlayCallback,
"model": {
"fcnet_hiddens": [512, 512],
},
"num_sgd_iter": 20,
"num_envs_per_worker": 5,
"multiagent": {
# Initial policy map: Random and PPO. This will be expanded
# to more policy snapshots taken from "main" against which "main"
# will then play (instead of "random"). This is done in the
# custom callback defined above (`SelfPlayCallback`).
"policies": {
# Our main policy, we'd like to optimize.
"main": (None, OBS_SPACE, ACTION_SPACE, {}),
# An initial random opponent to play against.
"random": (RandomPolicy, OBS_SPACE, ACTION_SPACE, {}),
},
# Assign agent 0 and 1 randomly to the "main" policy or
# to the opponent ("random" at first). Make sure (via episode_id)
# that "main" always plays against "random" (and not against
# another "main").
"policy_mapping_fn": policy_mapping_fn,
# Always just train the "main" policy.
"policies_to_train": ["main"],
},
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,
}
stop = {
"timesteps_total": args.stop_timesteps,
"training_iteration": args.stop_iters,
}
# Train the "main" policy to play really well using self-play.
results = None
if not args.from_checkpoint:
results = tune.run(
"PPO",
config=config,
stop=stop,
checkpoint_at_end=True,
checkpoint_freq=10,
verbose=1)
# Restore trained trainer (set to non-explore behavior) and play against
# human on command line.
if args.num_episodes_human_play > 0:
num_episodes = 0
trainer = PPOTrainer(config=dict(config, **{"explore": False}))
if args.from_checkpoint:
trainer.restore(args.from_checkpoint)
else:
trainer.restore(results.get_last_checkpoint())
# Play from the command line against the trained agent
# in an actual (non-RLlib-wrapped) open-spiel env.
human_player = 1
env = Environment("connect_four")
while num_episodes < args.num_episodes_human_play:
print("You play as {}".format("o" if human_player else "x"))
time_step = env.reset()
while not time_step.last():
player_id = time_step.observations["current_player"]
if player_id == human_player:
action = ask_user_for_action(time_step)
else:
obs = np.array(
time_step.observations["info_state"][player_id])
action = trainer.compute_action(obs, policy_id="main")
# In case computer chooses an invalid action, pick a
# random one.
legal = time_step.observations["legal_actions"][player_id]
if action not in legal:
action = np.random.choice(legal)
time_step = env.step([action])
print(f"\n{env.get_state}")
print(f"\n{env.get_state}")
print("End of game!")
if time_step.rewards[human_player] > 0:
print("You win")
elif time_step.rewards[human_player] < 0:
print("You lose")
else:
print("Draw")
# Switch order of players
human_player = 1 - human_player
num_episodes += 1
ray.shutdown()

View file

@ -140,7 +140,8 @@ if __name__ == "__main__":
marl_env.get_action_space(agent),
agent_policy_params)
config["multiagent"]["policies"] = policies
config["multiagent"]["policy_mapping_fn"] = lambda agent_id: agent_id
config["multiagent"][
"policy_mapping_fn"] = lambda agent_id, episode, **kwargs: agent_id
config["multiagent"]["policies_to_train"] = ["ppo_policy"]
config["env"] = "sumo_test_env"

View file

@ -102,7 +102,8 @@ if __name__ == "__main__":
"agent_id": 1,
}),
},
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
"policy_mapping_fn": (
lambda aid, **kwargs: "pol2" if aid else "pol1"),
},
"framework": args.framework,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.

View file

@ -137,7 +137,7 @@ if __name__ == "__main__":
obs_space, act_space, DQN_CONFIG),
}
def policy_mapping_fn(agent_id):
def policy_mapping_fn(agent_id, episode, **kwargs):
if agent_id % 2 == 0:
return "ppo_policy"
else:

View file

@ -48,7 +48,8 @@ class TrainOneStep:
num_sgd_iter: int = 1,
sgd_minibatch_size: int = 0):
self.workers = workers
self.policies = policies or workers.local_worker().policies_to_train
self.local_worker = workers.local_worker()
self.policies = policies
self.num_sgd_iter = num_sgd_iter
self.sgd_minibatch_size = sgd_minibatch_size
@ -61,9 +62,11 @@ class TrainOneStep:
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
lw = self.workers.local_worker()
info = do_minibatch_sgd(
batch, {pid: lw.get_policy(pid)
for pid in self.policies}, lw, self.num_sgd_iter,
self.sgd_minibatch_size, [])
batch, {
pid: lw.get_policy(pid)
for pid in self.policies
or self.local_worker.policies_to_train
}, lw, self.num_sgd_iter, self.sgd_minibatch_size, [])
# TODO(ekl) shouldn't be returning learner stats directly here
# TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
# callback (shouldn't).
@ -84,7 +87,7 @@ class TrainOneStep:
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights(
self.policies))
self.policies or self.local_worker.policies_to_train))
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
@ -119,7 +122,8 @@ class TrainTFMultiGPU:
_fake_gpus: bool = False,
framework: str = "tf"):
self.workers = workers
self.policies = policies or workers.local_worker().policies_to_train
self.local_worker = workers.local_worker()
self.policies = policies
self.num_sgd_iter = num_sgd_iter
self.sgd_minibatch_size = sgd_minibatch_size
self.shuffle_sequences = shuffle_sequences
@ -150,7 +154,8 @@ class TrainTFMultiGPU:
self.optimizers = {}
with self.workers.local_worker().tf_sess.graph.as_default():
with self.workers.local_worker().tf_sess.as_default():
for policy_id in self.policies:
for policy_id in (self.policies
or self.local_worker.policies_to_train):
policy = self.workers.local_worker().get_policy(policy_id)
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
if policy._state_inputs:
@ -173,7 +178,7 @@ class TrainTFMultiGPU:
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
_check_sample_batch_type(samples)
# Handle everything as if multiagent
# Handle everything as if multi agent.
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
@ -187,7 +192,8 @@ class TrainTFMultiGPU:
num_loaded_tuples = {}
for policy_id, batch in samples.policy_batches.items():
# Not a policy-to-train.
if policy_id not in self.policies:
if policy_id not in (self.policies
or self.local_worker.policies_to_train):
continue
# Decompress SampleBatch, in case some columns are compressed.
@ -245,7 +251,7 @@ class TrainTFMultiGPU:
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights(
self.policies))
self.policies or self.local_worker.policies_to_train))
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
@ -315,7 +321,8 @@ class ApplyGradients:
currently processing (i.e., A3C style).
"""
self.workers = workers
self.policies = policies or workers.local_worker().policies_to_train
self.local_worker = workers.local_worker()
self.policies = policies
self.update_all = update_all
def __call__(self, item: Tuple[ModelGradients, int]) -> None:
@ -339,7 +346,7 @@ class ApplyGradients:
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights(
self.policies))
self.policies or self.local_worker.policies_to_train))
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
else:
@ -350,7 +357,7 @@ class ApplyGradients:
"in the iterator context.")
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = self.workers.local_worker().get_weights(
self.policies)
self.policies or self.local_worker.policies_to_train)
metrics.current_actor.set_weights.remote(
weights, _get_global_vars())
@ -407,8 +414,9 @@ class UpdateTargetNetwork:
by_steps_trained: bool = False,
policies: List[PolicyID] = frozenset([])):
self.workers = workers
self.local_worker = workers.local_worker()
self.target_update_freq = target_update_freq
self.policies = (policies or workers.local_worker().policies_to_train)
self.policies = policies
if by_steps_trained:
self.metric = STEPS_TRAINED_COUNTER
else:
@ -419,7 +427,7 @@ class UpdateTargetNetwork:
cur_ts = metrics.counters[self.metric]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
to_update = self.policies
to_update = self.policies or self.local_worker.policies_to_train
self.workers.local_worker().foreach_trainable_policy(
lambda p, p_id: p_id in to_update and p.update_target())
metrics.counters[NUM_TARGET_UPDATES] += 1

View file

@ -46,7 +46,7 @@ class DynamicTFPolicy(TFPolicy):
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
config (dict): config of the policy
model (TorchModel): TF model instance
model (ModelV2): TF model instance
dist_class (type): TF action distribution class
"""

View file

@ -55,7 +55,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2),
rollout_fragment_length=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)

View file

@ -196,7 +196,7 @@ class AgentIOTest(unittest.TestCase):
"policy_2": gen_policy(),
},
"policy_mapping_fn": (
lambda agent_id: random.choice(
lambda aid, **kwargs: random.choice(
["policy_1", "policy_2"])),
},
"framework": fw,
@ -218,7 +218,7 @@ class AgentIOTest(unittest.TestCase):
"policy_2": gen_policy(),
},
"policy_mapping_fn": (
lambda agent_id: random.choice(
lambda aid, **kwargs: random.choice(
["policy_1", "policy_2"])),
},
"framework": fw,

View file

@ -77,20 +77,15 @@ class TestMultiAgentEnv(unittest.TestCase):
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
self.assertEqual(
dones, {
0: {
0: False,
1: False,
"__all__": False
},
1: {
0: False,
1: False,
"__all__": False
}
})
self.assertEqual(rew, {0: {}, 1: {}})
self.assertEqual(dones, {
0: {
"__all__": False
},
1: {
"__all__": False
},
})
for _ in range(24):
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
obs, rew, dones, _, _ = env.poll()
@ -161,7 +156,7 @@ class TestMultiAgentEnv(unittest.TestCase):
env = _MultiAgentEnvToBaseEnv(lambda v: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
self.assertEqual(rew, {0: {}, 1: {}})
env.send_actions({0: {0: 0}, 1: {0: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
@ -172,13 +167,17 @@ class TestMultiAgentEnv(unittest.TestCase):
def test_multi_agent_sample(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
def policy_mapping_fn(agent_id, episode, **kwargs):
return "p{}".format(agent_id % 2)
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
policy_mapping_fn=policy_mapping_fn,
rollout_fragment_length=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
@ -198,7 +197,10 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
# This signature will raise a soft-deprecation warning due
# to the new signature we are using (agent_id, episode, **kwargs),
# but should not break this test.
policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
rollout_fragment_length=50,
num_envs=4,
remote_worker_envs=True,
@ -217,7 +219,7 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
rollout_fragment_length=50,
num_envs=4,
remote_worker_envs=True)
@ -233,7 +235,7 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)),
episode_horizon=10, # test with episode horizon set
rollout_fragment_length=50)
batch = ev.sample()
@ -248,12 +250,23 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
batch_mode="complete_episodes",
rollout_fragment_length=1)
self.assertRaisesRegexp(ValueError,
".*don't have a last observation.*",
lambda: ev.sample())
# This used to raise an Error due to the EarlyDoneMultiAgent
# terminating at e.g. agent0 w/o publishing the observation for
# agent1 anymore. This limitation is fixed and an env may
# terminate at any time (as well as return rewards for any agent
# at any time, even when that agent doesn't have an obs returned
# in the same call to `step()`).
ma_batch = ev.sample()
# Make sure that agents took the correct (alternating timesteps)
# path. Except for the last timestep, where both agents got
# terminated.
ag0_ts = ma_batch.policy_batches["p0"]["t"]
ag1_ts = ma_batch.policy_batches["p1"]["t"]
self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0))
self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
def test_multi_agent_with_flex_agents(self):
register_env("flex_agents_multi_agent_cartpole",
@ -277,7 +290,7 @@ class TestMultiAgentEnv(unittest.TestCase):
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
rollout_fragment_length=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
@ -340,7 +353,7 @@ class TestMultiAgentEnv(unittest.TestCase):
# the extra trajectory.
env_id = episodes[0].env_id
fake_eps = MultiAgentEpisode(
episodes[0]._policies, episodes[0]._policy_mapping_fn,
episodes[0].policy_map, episodes[0]._policy_mapping_fn,
lambda: None, lambda x: None, env_id)
builder = get_global_worker().sampler.sample_collector
agent_id = "extra_0"
@ -377,7 +390,7 @@ class TestMultiAgentEnv(unittest.TestCase):
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
rollout_fragment_length=5)
batch = ev.sample()
# 5 environment steps (rollout_fragment_length).
@ -430,7 +443,7 @@ class TestMultiAgentEnv(unittest.TestCase):
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
"policy_mapping_fn": lambda agent_id: "policy_1",
"policy_mapping_fn": lambda aid, **kwargs: "policy_1",
},
"framework": "tf",
})

View file

@ -457,9 +457,9 @@ class NestedSpacesTest(unittest.TestCase):
PGTFPolicy, DICT_SPACE, act_space,
{"model": {"custom_model": "dict_spy"}}),
},
"policy_mapping_fn": lambda a: {
"policy_mapping_fn": lambda aid, **kwargs: {
"tuple_agent": "tuple_policy",
"dict_agent": "dict_policy"}[a],
"dict_agent": "dict_policy"}[aid],
},
"framework": "tf",
})

View file

@ -34,7 +34,7 @@ class TestPettingZooEnv(unittest.TestCase):
# the first tuple value is None -> uses default policy
"av": (None, obs_space, act_space, {}),
},
"policy_mapping_fn": lambda agent_id: "av"
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
}
config["log_level"] = "DEBUG"

View file

@ -147,8 +147,8 @@ def learn_test_multi_agent_plus_rollout(algo):
print("RLlib dir = {}\nexists={}".format(rllib_dir,
os.path.exists(rllib_dir)))
def policy_fn(agent):
return "pol{}".format(agent)
def policy_fn(agent_id, episode, **kwargs):
return "pol{}".format(agent_id)
observation_space = Box(float("-inf"), float("inf"), (4, ))
action_space = Discrete(2)