mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
This reverts commit e78ec370a9
.
This commit is contained in:
parent
54d66ac637
commit
bd3cbfc56a
36 changed files with 218 additions and 906 deletions
18
rllib/BUILD
18
rllib/BUILD
|
@ -2366,24 +2366,6 @@ py_test(
|
|||
args = ["--as-test", "--framework=torch"],
|
||||
)
|
||||
|
||||
# py_test(
|
||||
# name = "examples/self_play_with_open_spiel_connect_4",
|
||||
# main = "examples/self_play_with_open_spiel_connect_4.py",
|
||||
# tags = ["examples", "examples_S"],
|
||||
# size = "medium",
|
||||
# srcs = ["examples/self_play_with_open_spiel_connect_4.py"],
|
||||
# args = ["--framework=tf", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
|
||||
# )
|
||||
|
||||
# py_test(
|
||||
# name = "examples/self_play_with_open_spiel_connect_4_torch",
|
||||
# main = "examples/self_play_with_open_spiel_connect_4.py",
|
||||
# tags = ["examples", "examples_S"],
|
||||
# size = "medium",
|
||||
# srcs = ["examples/self_play_with_open_spiel_connect_4.py"],
|
||||
# args = ["--framework=torch", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
|
||||
# )
|
||||
|
||||
py_test(
|
||||
name = "examples/trajectory_view_api_tf",
|
||||
main = "examples/trajectory_view_api.py",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
from datetime import datetime
|
||||
import functools
|
||||
import gym
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
|
@ -32,8 +31,8 @@ from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
|||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces import space_utils
|
||||
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
|
||||
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID
|
||||
from ray.tune.logger import Logger, UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.resources import Resources
|
||||
|
@ -907,7 +906,7 @@ class Trainer(Trainable):
|
|||
"""Sync "main" weights to given WorkerSet or list of workers."""
|
||||
assert worker_set is not None
|
||||
# Broadcast the new policy weights to all evaluation workers.
|
||||
logger.info("Synchronizing weights to workers.")
|
||||
logger.info("Synchronizing weights to evaluation workers.")
|
||||
weights = ray.put(self.workers.local_worker().save())
|
||||
worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))
|
||||
|
||||
|
@ -1071,7 +1070,7 @@ class Trainer(Trainable):
|
|||
"""Return policy for the specified id, or None.
|
||||
|
||||
Args:
|
||||
policy_id (PolicyID): ID of the policy to return.
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
return self.workers.local_worker().get_policy(policy_id)
|
||||
|
||||
|
@ -1094,101 +1093,6 @@ class Trainer(Trainable):
|
|||
"""
|
||||
self.workers.local_worker().set_weights(weights)
|
||||
|
||||
@PublicAPI
|
||||
def add_policy(
|
||||
self,
|
||||
policy_id: PolicyID,
|
||||
policy_cls: Type[Policy],
|
||||
*,
|
||||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID],
|
||||
PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
) -> Policy:
|
||||
"""Adds a new policy to this Trainer.
|
||||
|
||||
Args:
|
||||
policy_id (PolicyID): ID of the policy to add.
|
||||
policy_cls (Type[Policy]): The Policy class to use for
|
||||
constructing the new Policy.
|
||||
observation_space (Optional[gym.spaces.Space]): The observation
|
||||
space of the policy to add.
|
||||
action_space (Optional[gym.spaces.Space]): The action space
|
||||
of the policy to add.
|
||||
config (Optional[PartialTrainerConfigDict]): The config overrides
|
||||
for the policy to add.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
|
||||
optional (updated) policy mapping function to use from here on.
|
||||
Note that already ongoing episodes will not change their
|
||||
mapping but will use the old mapping till the end of the
|
||||
episode.
|
||||
policies_to_train (Optional[List[PolicyID]]): An optional list of
|
||||
policy IDs to be trained. If None, will keep the existing list
|
||||
in place. Policies, whose IDs are not in the list will not be
|
||||
updated.
|
||||
|
||||
Returns:
|
||||
Policy: The newly added policy (the copy that got added to the
|
||||
local worker).
|
||||
"""
|
||||
|
||||
def fn(worker):
|
||||
# `foreach_worker` function: Adds the policy the the worker (and
|
||||
# maybe changes its policy_mapping_fn - if provided here).
|
||||
worker.add_policy(
|
||||
policy_id=policy_id,
|
||||
policy_cls=policy_cls,
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
config=config,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
policies_to_train=policies_to_train,
|
||||
)
|
||||
|
||||
# Run foreach_worker fn on all workers (incl. evaluation workers).
|
||||
self.workers.foreach_worker(fn)
|
||||
if self.evaluation_workers is not None:
|
||||
self.evaluation_workers.foreach_worker(fn)
|
||||
|
||||
# Return newly added policy (from the local rollout worker).
|
||||
return self.get_policy(policy_id)
|
||||
|
||||
@PublicAPI
|
||||
def remove_policy(
|
||||
self,
|
||||
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||
*,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
) -> None:
|
||||
"""Removes a new policy from this Trainer.
|
||||
|
||||
Args:
|
||||
policy_id (Optional[PolicyID]): ID of the policy to be removed.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
|
||||
optional (updated) policy mapping function to use from here on.
|
||||
Note that already ongoing episodes will not change their
|
||||
mapping but will use the old mapping till the end of the
|
||||
episode.
|
||||
policies_to_train (Optional[List[PolicyID]]): An optional list of
|
||||
policy IDs to be trained. If None, will keep the existing list
|
||||
in place. Policies, whose IDs are not in the list will not be
|
||||
updated.
|
||||
"""
|
||||
|
||||
def fn(worker):
|
||||
worker.remove_policy(
|
||||
policy_id=policy_id,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
policies_to_train=policies_to_train,
|
||||
)
|
||||
|
||||
self.workers.foreach_worker(fn)
|
||||
if self.evaluation_workers is not None:
|
||||
self.evaluation_workers.foreach_worker(fn)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self,
|
||||
export_dir: str,
|
||||
|
|
61
rllib/env/base_env.py
vendored
61
rllib/env/base_env.py
vendored
|
@ -423,6 +423,10 @@ class _MultiAgentEnvToBaseEnv(BaseEnv):
|
|||
assert isinstance(rewards, dict), "Not a multi-agent reward"
|
||||
assert isinstance(dones, dict), "Not a multi-agent return"
|
||||
assert isinstance(infos, dict), "Not a multi-agent info"
|
||||
if set(obs.keys()) != set(rewards.keys()):
|
||||
raise ValueError(
|
||||
"Key set for obs and rewards must be the same: "
|
||||
"{} vs {}".format(obs.keys(), rewards.keys()))
|
||||
if set(infos).difference(set(obs)):
|
||||
raise ValueError("Key set for infos must be a subset of obs: "
|
||||
"{} vs {}".format(infos.keys(), obs.keys()))
|
||||
|
@ -468,52 +472,31 @@ class _MultiAgentEnvState:
|
|||
if not self.initialized:
|
||||
self.reset()
|
||||
self.initialized = True
|
||||
|
||||
observations = self.last_obs
|
||||
rewards = {}
|
||||
dones = {"__all__": self.last_dones["__all__"]}
|
||||
infos = {}
|
||||
|
||||
# If episode is done, release everything we have.
|
||||
if dones["__all__"]:
|
||||
rewards = self.last_rewards
|
||||
self.last_rewards = {}
|
||||
dones = self.last_dones
|
||||
self.last_dones = {}
|
||||
self.last_obs = {}
|
||||
# Only release those agents' rewards/dones/infos, whose
|
||||
# observations we have.
|
||||
else:
|
||||
for ag in observations.keys():
|
||||
if ag in self.last_rewards:
|
||||
rewards[ag] = self.last_rewards[ag]
|
||||
del self.last_rewards[ag]
|
||||
if ag in self.last_dones:
|
||||
dones[ag] = self.last_dones[ag]
|
||||
del self.last_dones[ag]
|
||||
|
||||
self.last_dones["__all__"] = False
|
||||
obs, rew, dones, info = (self.last_obs, self.last_rewards,
|
||||
self.last_dones, self.last_infos)
|
||||
self.last_obs = {}
|
||||
self.last_rewards = {}
|
||||
self.last_dones = {"__all__": False}
|
||||
self.last_infos = {}
|
||||
return observations, rewards, dones, infos
|
||||
return obs, rew, dones, info
|
||||
|
||||
def observe(self, obs: MultiAgentDict, rewards: MultiAgentDict,
|
||||
dones: MultiAgentDict, infos: MultiAgentDict):
|
||||
self.last_obs = obs
|
||||
for ag, r in rewards.items():
|
||||
if ag in self.last_rewards:
|
||||
self.last_rewards[ag] += r
|
||||
else:
|
||||
self.last_rewards[ag] = r
|
||||
for ag, d in dones.items():
|
||||
if ag in self.last_dones:
|
||||
self.last_dones[ag] = self.last_dones[ag] or d
|
||||
else:
|
||||
self.last_dones[ag] = d
|
||||
self.last_rewards = rewards
|
||||
self.last_dones = dones
|
||||
self.last_infos = infos
|
||||
|
||||
def reset(self) -> MultiAgentDict:
|
||||
self.last_obs = self.env.reset()
|
||||
self.last_rewards = {}
|
||||
self.last_dones = {"__all__": False}
|
||||
self.last_infos = {}
|
||||
self.last_rewards = {
|
||||
agent_id: None
|
||||
for agent_id in self.last_obs.keys()
|
||||
}
|
||||
self.last_dones = {
|
||||
agent_id: False
|
||||
for agent_id in self.last_obs.keys()
|
||||
}
|
||||
self.last_infos = {agent_id: {} for agent_id in self.last_obs.keys()}
|
||||
self.last_dones["__all__"] = False
|
||||
return self.last_obs
|
||||
|
|
75
rllib/env/wrappers/open_spiel.py
vendored
75
rllib/env/wrappers/open_spiel.py
vendored
|
@ -1,75 +0,0 @@
|
|||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
import pyspiel
|
||||
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class OpenSpielEnv(MultiAgentEnv):
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
|
||||
# Agent IDs are ints, starting from 0.
|
||||
self.num_agents = self.env.num_players()
|
||||
# Store the open-spiel game type.
|
||||
self.type = self.env.get_type()
|
||||
# Stores the current open-spiel game state.
|
||||
self.state = None
|
||||
|
||||
# Extract observation- and action spaces from game.
|
||||
self.observation_space = Box(
|
||||
float("-inf"), float("inf"),
|
||||
(self.env.observation_tensor_size(), ))
|
||||
self.action_space = Discrete(self.env.num_distinct_actions())
|
||||
|
||||
def reset(self):
|
||||
self.state = self.env.new_initial_state()
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, action):
|
||||
# Sequential game:
|
||||
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
|
||||
curr_player = self.state.current_player()
|
||||
assert curr_player in action
|
||||
penalty = None
|
||||
try:
|
||||
self.state.apply_action(action[curr_player])
|
||||
# TODO: (sven) resolve this hack by publishing legal actions
|
||||
# with each step.
|
||||
except pyspiel.SpielError:
|
||||
self.state.apply_action(
|
||||
np.random.choice(self.state.legal_actions()))
|
||||
penalty = -0.1
|
||||
|
||||
# Are we done?
|
||||
is_done = self.state.is_terminal()
|
||||
dones = dict({ag: is_done
|
||||
for ag in range(self.num_agents)},
|
||||
**{"__all__": is_done})
|
||||
|
||||
# Compile rewards dict.
|
||||
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
|
||||
if penalty:
|
||||
rewards[curr_player] += penalty
|
||||
# Simultaneous game.
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return self._get_obs(), rewards, dones, {}
|
||||
|
||||
def _get_obs(self):
|
||||
curr_player = self.state.current_player()
|
||||
# Sequential game:
|
||||
if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
|
||||
if self.state.is_terminal():
|
||||
return {}
|
||||
return {
|
||||
curr_player: np.reshape(self.state.observation_tensor(), [-1])
|
||||
}
|
||||
# Simultaneous game.
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self, mode=None) -> None:
|
||||
if mode == "human":
|
||||
print(self.state)
|
4
rllib/env/wrappers/unity3d_env.py
vendored
4
rllib/env/wrappers/unity3d_env.py
vendored
|
@ -313,7 +313,7 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
action_spaces["Striker"], {}),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
return "Striker" if "Striker" in agent_id else "Goalie"
|
||||
|
||||
else:
|
||||
|
@ -322,7 +322,7 @@ class Unity3DEnv(MultiAgentEnv):
|
|||
action_spaces[game_name], {}),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
return game_name
|
||||
|
||||
return policies, policy_mapping_fn
|
||||
|
|
|
@ -627,7 +627,7 @@ class SimpleListCollector(SampleCollector):
|
|||
if is_done and check_dones and \
|
||||
not pre_batch[SampleBatch.DONES][-1]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we still "
|
||||
"Episode {} terminated for all agents, but we still"
|
||||
"don't have a last observation for agent {} (policy "
|
||||
"{}). ".format(
|
||||
episode_id, agent_id, self.agent_key_to_policy_id[(
|
||||
|
@ -674,16 +674,8 @@ class SimpleListCollector(SampleCollector):
|
|||
policies=self.policy_map,
|
||||
postprocessed_batch=post_batch,
|
||||
original_batches=pre_batches)
|
||||
|
||||
# Add the postprocessed SampleBatch to the policy collectors for
|
||||
# training.
|
||||
# PID may be a newly added policy. Just confirm we have it in our
|
||||
# policy map before proceeding with adding a new _PolicyCollector()
|
||||
# to the group.
|
||||
if pid not in policy_collector_group.policy_collectors:
|
||||
assert pid in self.policy_map
|
||||
policy_collector_group.policy_collectors[
|
||||
pid] = _PolicyCollector(policy)
|
||||
policy_collector_group.policy_collectors[
|
||||
pid].add_postprocessed_batch_for_training(
|
||||
post_batch, policy.view_requirements)
|
||||
|
@ -788,18 +780,9 @@ class SimpleListCollector(SampleCollector):
|
|||
vectorized environments).
|
||||
"""
|
||||
pid = self.agent_key_to_policy_id[agent_key]
|
||||
|
||||
# PID may be a newly added policy. Just confirm we have it in our
|
||||
# policy map before proceeding with forward_pass_size=0.
|
||||
if pid not in self.forward_pass_size:
|
||||
assert pid in self.policy_map
|
||||
self.forward_pass_size[pid] = 0
|
||||
self.forward_pass_agent_keys[pid] = []
|
||||
|
||||
idx = self.forward_pass_size[pid]
|
||||
if idx == 0:
|
||||
self.forward_pass_agent_keys[pid].clear()
|
||||
|
||||
self.forward_pass_agent_keys[pid].append(agent_key)
|
||||
self.forward_pass_size[pid] += 1
|
||||
|
||||
|
|
|
@ -6,11 +6,9 @@ from typing import List, Dict, Callable, Any, TYPE_CHECKING
|
|||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||
EnvActionType, EnvID, EnvInfoDict, EnvObsType
|
||||
from ray.util import log_once
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
|
@ -50,8 +48,7 @@ class MultiAgentEpisode:
|
|||
"""
|
||||
|
||||
def __init__(self, policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
||||
PolicyID],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
batch_builder_factory: Callable[
|
||||
[], "MultiAgentSampleBatchBuilder"],
|
||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||
|
@ -71,17 +68,15 @@ class MultiAgentEpisode:
|
|||
self.user_data: Dict[str, Any] = {}
|
||||
self.hist_data: Dict[str, List[float]] = {}
|
||||
self.media: Dict[str, Any] = {}
|
||||
self.policy_map: Dict[PolicyID, Policy] = policies
|
||||
self._policies = self.policy_map # backward compatibility
|
||||
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
||||
PolicyID] = policy_mapping_fn
|
||||
self._policies: Dict[PolicyID, Policy] = policies
|
||||
self._policy_mapping_fn: Callable[[AgentID], PolicyID] = \
|
||||
policy_mapping_fn
|
||||
self._next_agent_index: int = 0
|
||||
self._agent_to_index: Dict[AgentID, int] = {}
|
||||
self._agent_to_policy: Dict[AgentID, PolicyID] = {}
|
||||
self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
|
||||
self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
|
||||
self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
|
||||
self._agent_to_last_done: Dict[AgentID, bool] = {}
|
||||
self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
|
||||
self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
|
||||
self._agent_to_last_pi_info: Dict[AgentID, dict] = {}
|
||||
|
@ -117,29 +112,8 @@ class MultiAgentEpisode:
|
|||
"""
|
||||
|
||||
if agent_id not in self._agent_to_policy:
|
||||
# Try new API: pass in agent_id and episode as named args.
|
||||
# New signature should be: (agent_id, episode, **kwargs)
|
||||
try:
|
||||
policy_id = self._agent_to_policy[agent_id] = \
|
||||
self._policy_mapping_fn(agent_id, self)
|
||||
except TypeError as e:
|
||||
if "positional argument" in e.args[0] or \
|
||||
"unexpected keyword argument" in e.args[0]:
|
||||
if log_once("policy_mapping_new_signature"):
|
||||
deprecation_warning(
|
||||
old="policy_mapping_fn(agent_id)",
|
||||
new="policy_mapping_fn(agent_id, episode, "
|
||||
"**kwargs)")
|
||||
policy_id = self._agent_to_policy[agent_id] = \
|
||||
self._policy_mapping_fn(agent_id)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
policy_id = self._agent_to_policy[agent_id]
|
||||
if policy_id not in self.policy_map:
|
||||
raise KeyError("policy_mapping_fn returned invalid policy id "
|
||||
f"'{policy_id}'!")
|
||||
return policy_id
|
||||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def last_observation_for(
|
||||
|
@ -171,8 +145,7 @@ class MultiAgentEpisode:
|
|||
return flatten_to_single_ndarray(
|
||||
self._agent_to_last_action[agent_id])
|
||||
else:
|
||||
policy_id = self.policy_for(agent_id)
|
||||
policy = self.policy_map[policy_id]
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
flat = flatten_to_single_ndarray(policy.action_space.sample())
|
||||
if hasattr(policy.action_space, "dtype"):
|
||||
return np.zeros_like(flat, dtype=policy.action_space.dtype)
|
||||
|
@ -206,34 +179,16 @@ class MultiAgentEpisode:
|
|||
"""Returns the last RNN state for the specified agent."""
|
||||
|
||||
if agent_id not in self._agent_to_rnn_state:
|
||||
policy_id = self.policy_for(agent_id)
|
||||
policy = self.policy_map[policy_id]
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
|
||||
"""Returns the last done flag received for the specified agent."""
|
||||
if agent_id not in self._agent_to_last_done:
|
||||
self._agent_to_last_done[agent_id] = False
|
||||
return self._agent_to_last_done[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def last_pi_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> dict:
|
||||
"""Returns the last info object for the specified agent."""
|
||||
|
||||
return self._agent_to_last_pi_info[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def get_agents(self) -> List[AgentID]:
|
||||
"""Returns list of agent IDs that have appeared in this episode.
|
||||
|
||||
Returns:
|
||||
List[AgentID]: The list of all agents that have appeared so
|
||||
far in this episode.
|
||||
"""
|
||||
return list(self._agent_to_index.keys())
|
||||
|
||||
def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
|
||||
for agent_id, reward in reward_dict.items():
|
||||
if reward is not None:
|
||||
|
@ -251,9 +206,6 @@ class MultiAgentEpisode:
|
|||
def _set_last_raw_obs(self, agent_id, obs):
|
||||
self._agent_to_last_raw_obs[agent_id] = obs
|
||||
|
||||
def _set_last_done(self, agent_id, done):
|
||||
self._agent_to_last_done[agent_id] = done
|
||||
|
||||
def _set_last_info(self, agent_id, info):
|
||||
self._agent_to_last_info[agent_id] = info
|
||||
|
||||
|
|
|
@ -47,7 +47,6 @@ from ray.util.debug import log_once, disable_log_once_globally, \
|
|||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.observation_function import ObservationFunction
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
|
||||
|
||||
|
@ -107,7 +106,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
... "traffic_light_policy":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {}),
|
||||
... },
|
||||
... policy_mapping_fn=lambda agent_id, episode, **kwargs:
|
||||
... policy_mapping_fn=lambda agent_id:
|
||||
... random.choice(["car_policy1", "car_policy2"])
|
||||
... if agent_id.startswith("car_") else "traffic_light_policy")
|
||||
>>> print(worker.sample())
|
||||
|
@ -142,8 +141,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy_spec: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
rollout_fragment_length: int = 100,
|
||||
|
@ -202,12 +200,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn can also be set (if not, will map all agents
|
||||
to DEFAULT_POLICY_ID).
|
||||
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
|
||||
PolicyID]]): A callable that maps agent ids to policy ids in
|
||||
multi-agent mode. This function will be called each time a new
|
||||
agent appears in an episode, to bind that agent to a policy
|
||||
for the duration of the episode. If not provided, will map all
|
||||
agents to DEFAULT_POLICY_ID.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
|
||||
callable that maps agent ids to policy ids in multi-agent mode.
|
||||
This function will be called each time a new agent appears in
|
||||
an episode, to bind that agent to a policy for the duration of
|
||||
the episode. If not provided, will map all agents to
|
||||
DEFAULT_POLICY_ID.
|
||||
policies_to_train (Optional[List[PolicyID]]): Optional list of
|
||||
policies to train, or None for all policies.
|
||||
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
|
||||
|
@ -363,19 +361,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.worker_index: int = worker_index
|
||||
self.num_workers: int = num_workers
|
||||
model_config: ModelConfigDict = model_config or {}
|
||||
|
||||
# Default policy mapping fn is to always return DEFAULT_POLICY_ID,
|
||||
# independent on the agent ID and the episode passed in.
|
||||
self.policy_mapping_fn = lambda aid, ep, **kwargs: DEFAULT_POLICY_ID
|
||||
self.set_policy_mapping_fn(policy_mapping_fn)
|
||||
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
if not callable(policy_mapping_fn):
|
||||
raise ValueError("Policy mapping function not callable?")
|
||||
self.env_creator: Callable[[EnvContext], EnvType] = env_creator
|
||||
self.rollout_fragment_length: int = rollout_fragment_length * num_envs
|
||||
self.count_steps_by: str = count_steps_by
|
||||
self.batch_mode: str = batch_mode
|
||||
self.compress_observations: bool = compress_observations
|
||||
self.preprocessing_enabled: bool = True
|
||||
self.observation_filter = observation_filter
|
||||
self.last_batch: SampleBatchType = None
|
||||
self.global_vars: dict = None
|
||||
self.fake_sampler: bool = fake_sampler
|
||||
|
@ -482,11 +477,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(
|
||||
policy_spec, self.env, spaces=spaces)
|
||||
# List of IDs of those policies, which should be trained.
|
||||
# By default, these are all policies found in the policy_dict.
|
||||
self.policies_to_train: List[PolicyID] = list(policy_dict.keys())
|
||||
self.set_policies_to_train(policies_to_train)
|
||||
|
||||
self.policies_to_train: List[PolicyID] = policies_to_train or list(
|
||||
policy_dict.keys())
|
||||
self.policy_map: Dict[PolicyID, Policy] = None
|
||||
self.preprocessors: Dict[PolicyID, Preprocessor] = None
|
||||
|
||||
|
@ -583,7 +575,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
"ExternalMultiAgentEnv?".format(self.env))
|
||||
|
||||
self.filters: Dict[PolicyID, Filter] = {
|
||||
policy_id: get_filter(self.observation_filter,
|
||||
policy_id: get_filter(observation_filter,
|
||||
policy.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
|
@ -654,6 +646,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sampler = AsyncSampler(
|
||||
worker=self,
|
||||
env=self.async_env,
|
||||
policies=self.policy_map,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
preprocessors=self.preprocessors,
|
||||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
count_steps_by=count_steps_by,
|
||||
|
@ -676,6 +672,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sampler = SyncSampler(
|
||||
worker=self,
|
||||
env=self.async_env,
|
||||
policies=self.policy_map,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
preprocessors=self.preprocessors,
|
||||
obs_filters=self.filters,
|
||||
clip_rewards=clip_rewards,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
count_steps_by=count_steps_by,
|
||||
|
@ -1014,123 +1014,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def add_policy(
|
||||
self,
|
||||
*,
|
||||
policy_id: PolicyID,
|
||||
policy_cls: Type[Policy],
|
||||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
) -> Policy:
|
||||
"""Adds a new policy to this RolloutWorker.
|
||||
|
||||
Args:
|
||||
policy_id (Optional[PolicyID]): ID of the policy to add.
|
||||
policy_cls (Type[Policy]): The Policy class to use for
|
||||
constructing the new Policy.
|
||||
observation_space (Optional[gym.spaces.Space]): The observation
|
||||
space of the policy to add.
|
||||
action_space (Optional[gym.spaces.Space]): The action space
|
||||
of the policy to add.
|
||||
config (Optional[PartialTrainerConfigDict]): The config overrides
|
||||
for the policy to add.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID, MultiAgentEpisode],
|
||||
PolicyID]]): An optional (updated) policy mapping function to
|
||||
use from here on. Note that already ongoing episodes will not
|
||||
change their mapping but will use the old mapping till the
|
||||
end of the episode.
|
||||
policies_to_train (Optional[List[PolicyID]]): An optional list of
|
||||
policy IDs to be trained. If None, will keep the existing list
|
||||
in place. Policies, whose IDs are not in the list will not be
|
||||
updated.
|
||||
|
||||
Returns:
|
||||
Policy: The newly added policy (the copy that got added to the
|
||||
local worker).
|
||||
"""
|
||||
if policy_id in self.policy_map:
|
||||
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
|
||||
policy_dict = {
|
||||
policy_id: (policy_cls, observation_space, action_space, config)
|
||||
}
|
||||
add_map, add_prep = self._build_policy_map(policy_dict,
|
||||
self.policy_config)
|
||||
new_policy = add_map[policy_id]
|
||||
|
||||
self.policy_map.update(add_map)
|
||||
self.preprocessors.update(add_prep)
|
||||
self.filters[policy_id] = get_filter(
|
||||
self.observation_filter, new_policy.observation_space.shape)
|
||||
|
||||
self.set_policy_mapping_fn(policy_mapping_fn)
|
||||
self.set_policies_to_train(policies_to_train)
|
||||
|
||||
return new_policy
|
||||
|
||||
@DeveloperAPI
|
||||
def remove_policy(
|
||||
self,
|
||||
*,
|
||||
policy_id: PolicyID = DEFAULT_POLICY_ID,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
):
|
||||
"""Removes a policy from this RolloutWorker.
|
||||
|
||||
Args:
|
||||
policy_id (Optional[PolicyID]): ID of the policy to be removed.
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): An
|
||||
optional (updated) policy mapping function to use from here on.
|
||||
Note that already ongoing episodes will not change their
|
||||
mapping but will use the old mapping till the end of the
|
||||
episode.
|
||||
policies_to_train (Optional[List[PolicyID]]): An optional list of
|
||||
policy IDs to be trained. If None, will keep the existing list
|
||||
in place. Policies, whose IDs are not in the list will not be
|
||||
updated.
|
||||
"""
|
||||
if policy_id not in self.policy_map:
|
||||
raise ValueError(f"Policy ID '{policy_id}' not in policy map!")
|
||||
del self.policy_map[policy_id]
|
||||
del self.preprocessors[policy_id]
|
||||
self.set_policy_mapping_fn(policy_mapping_fn)
|
||||
self.set_policies_to_train(policies_to_train)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_policy_mapping_fn(
|
||||
self,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
||||
):
|
||||
"""Sets `self.policy_mapping_fn` to a new callable (if provided).
|
||||
|
||||
Args:
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): The
|
||||
new mapping function to use. If None, will keep the existing
|
||||
mapping function in place.
|
||||
"""
|
||||
if policy_mapping_fn is not None:
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
if not callable(self.policy_mapping_fn):
|
||||
raise ValueError("`policy_mapping_fn` must be a callable!")
|
||||
|
||||
@DeveloperAPI
|
||||
def set_policies_to_train(
|
||||
self, policies_to_train: Optional[List[PolicyID]] = None):
|
||||
"""Sets `self.policies_to_train` to a new list of PolicyIDs.
|
||||
|
||||
Args:
|
||||
policies_to_train (Optional[List[PolicyID]]): The new
|
||||
list of policy IDs to train with. If None, will keep the
|
||||
existing list in place.
|
||||
"""
|
||||
if policies_to_train is not None:
|
||||
self.policies_to_train = policies_to_train
|
||||
|
||||
@DeveloperAPI
|
||||
def for_policy(self,
|
||||
func: Callable[[Policy], T],
|
||||
|
@ -1213,12 +1096,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
objs = pickle.loads(objs)
|
||||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
if pid not in self.policy_map:
|
||||
logger.warning(
|
||||
f"pid={pid} not found in policy_map! It was probably added"
|
||||
" on-the-fly and is not part of the static `config."
|
||||
"multiagent.policies` dict. Ignoring it for now.")
|
||||
continue
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -1312,9 +1189,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
else:
|
||||
raise ValueError("This policy does not support eager "
|
||||
"execution: {}".format(cls))
|
||||
scope = name + (("_wk" + str(self.worker_index))
|
||||
if self.worker_index else "")
|
||||
with tf1.variable_scope(scope):
|
||||
with tf1.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
# non-tf.
|
||||
else:
|
||||
|
@ -1368,7 +1243,7 @@ def _validate_and_canonicalize(
|
|||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
elif not issubclass(policy, Policy):
|
||||
raise ValueError(f"`policy` ({policy}) must be a rllib.Policy class!")
|
||||
raise ValueError("policy must be a rllib.Policy class")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
|
|
|
@ -25,7 +25,6 @@ from ray.rllib.offline import InputReader
|
|||
from ray.rllib.policy.policy import clip_action, Policy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.filter import Filter
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import unbatch
|
||||
|
@ -130,6 +129,10 @@ class SyncSampler(SamplerInput):
|
|||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
|
@ -143,11 +146,6 @@ class SyncSampler(SamplerInput):
|
|||
observation_fn: "ObservationFunction" = None,
|
||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||
render: bool = False,
|
||||
# Obsolete.
|
||||
policies=None,
|
||||
policy_mapping_fn=None,
|
||||
preprocessors=None,
|
||||
obs_filters=None,
|
||||
):
|
||||
"""Initializes a SyncSampler object.
|
||||
|
||||
|
@ -155,6 +153,13 @@ class SyncSampler(SamplerInput):
|
|||
worker (RolloutWorker): The RolloutWorker that will use this
|
||||
Sampler for sampling.
|
||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
||||
policies (Dict[str,Policy]): Mapping from policy ID to Policy obj.
|
||||
policy_mapping_fn (callable): Callable that takes an agent ID and
|
||||
returns a Policy object.
|
||||
preprocessors (Dict[str,Preprocessor]): Mapping from policy ID to
|
||||
Preprocessor object for the observations prior to filtering.
|
||||
obs_filters (Dict[str,Filter]): Mapping from policy ID to
|
||||
env Filter object.
|
||||
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
|
||||
float value for +/- value clipping. False for no clipping.
|
||||
rollout_fragment_length (int): The length of a fragment to collect
|
||||
|
@ -184,27 +189,20 @@ class SyncSampler(SamplerInput):
|
|||
render (bool): Whether to try to render the environment after each
|
||||
step.
|
||||
"""
|
||||
# All of the following arguments are deprecated. They will instead be
|
||||
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
||||
if log_once("deprecated_sync_sampler_args"):
|
||||
if policies is not None:
|
||||
deprecation_warning(old="policies")
|
||||
if policy_mapping_fn is not None:
|
||||
deprecation_warning(old="policy_mapping_fn")
|
||||
if preprocessors is not None:
|
||||
deprecation_warning(old="preprocessors")
|
||||
if obs_filters is not None:
|
||||
deprecation_warning(old="obs_filters")
|
||||
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self.preprocessors = preprocessors
|
||||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = _PerfStats()
|
||||
if not sample_collector_class:
|
||||
sample_collector_class = SimpleListCollector
|
||||
self.sample_collector = sample_collector_class(
|
||||
worker.policy_map,
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
multiple_episodes_in_batch,
|
||||
|
@ -214,10 +212,11 @@ class SyncSampler(SamplerInput):
|
|||
|
||||
# Create the rollout generator to use for calls to `get_data()`.
|
||||
self.rollout_provider = _env_runner(
|
||||
worker, self.base_env, self.extra_batches.put,
|
||||
self.rollout_fragment_length, self.horizon, clip_rewards,
|
||||
clip_actions, multiple_episodes_in_batch, callbacks, tf_sess,
|
||||
self.perf_stats, soft_horizon, no_done_at_end, observation_fn,
|
||||
worker, self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
multiple_episodes_in_batch, callbacks, tf_sess, self.perf_stats,
|
||||
soft_horizon, no_done_at_end, observation_fn,
|
||||
self.sample_collector, self.render)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
|
@ -265,6 +264,10 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
*,
|
||||
worker: "RolloutWorker",
|
||||
env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
rollout_fragment_length: int,
|
||||
count_steps_by: str = "env_steps",
|
||||
|
@ -279,11 +282,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
observation_fn: "ObservationFunction" = None,
|
||||
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
||||
render: bool = False,
|
||||
# Obsolete.
|
||||
policies=None,
|
||||
policy_mapping_fn=None,
|
||||
preprocessors=None,
|
||||
obs_filters=None,
|
||||
):
|
||||
"""Initializes a AsyncSampler object.
|
||||
|
||||
|
@ -291,6 +289,13 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
worker (RolloutWorker): The RolloutWorker that will use this
|
||||
Sampler for sampling.
|
||||
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
||||
policies (Dict[str, Policy]): Mapping from policy ID to Policy obj.
|
||||
policy_mapping_fn (callable): Callable that takes an agent ID and
|
||||
returns a Policy object.
|
||||
preprocessors (Dict[str, Preprocessor]): Mapping from policy ID to
|
||||
Preprocessor object for the observations prior to filtering.
|
||||
obs_filters (Dict[str, Filter]): Mapping from policy ID to
|
||||
env Filter object.
|
||||
clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual
|
||||
float value for +/- value clipping. False for no clipping.
|
||||
rollout_fragment_length (int): The length of a fragment to collect
|
||||
|
@ -324,24 +329,10 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
render (bool): Whether to try to render the environment after each
|
||||
step.
|
||||
"""
|
||||
# All of the following arguments are deprecated. They will instead be
|
||||
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
||||
if log_once("deprecated_async_sampler_args"):
|
||||
if policies is not None:
|
||||
deprecation_warning(old="policies")
|
||||
if policy_mapping_fn is not None:
|
||||
deprecation_warning(old="policy_mapping_fn")
|
||||
if preprocessors is not None:
|
||||
deprecation_warning(old="preprocessors")
|
||||
if obs_filters is not None:
|
||||
deprecation_warning(old="obs_filters")
|
||||
|
||||
self.worker = worker
|
||||
|
||||
for _, f in worker.filters.items():
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
|
||||
self.worker = worker
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
|
@ -349,6 +340,10 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.metrics_queue = queue.Queue()
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self.preprocessors = preprocessors
|
||||
self.obs_filters = obs_filters
|
||||
self.clip_rewards = clip_rewards
|
||||
self.daemon = True
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
|
@ -365,7 +360,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
if not sample_collector_class:
|
||||
sample_collector_class = SimpleListCollector
|
||||
self.sample_collector = sample_collector_class(
|
||||
worker.policy_map,
|
||||
policies,
|
||||
clip_rewards,
|
||||
callbacks,
|
||||
multiple_episodes_in_batch,
|
||||
|
@ -389,8 +384,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
extra_batches_putter = (
|
||||
lambda x: self.extra_batches.put(x, timeout=600.0))
|
||||
rollout_provider = _env_runner(
|
||||
self.worker, self.base_env, extra_batches_putter,
|
||||
self.rollout_fragment_length, self.horizon, self.clip_rewards,
|
||||
self.worker, self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.multiple_episodes_in_batch, self.callbacks,
|
||||
self.tf_sess, self.perf_stats, self.soft_horizon,
|
||||
self.no_done_at_end, self.observation_fn, self.sample_collector,
|
||||
|
@ -443,8 +439,12 @@ def _env_runner(
|
|||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||
policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID],
|
||||
rollout_fragment_length: int,
|
||||
horizon: int,
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
clip_rewards: bool,
|
||||
clip_actions: bool,
|
||||
multiple_episodes_in_batch: bool,
|
||||
|
@ -463,10 +463,19 @@ def _env_runner(
|
|||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): Env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (Dict[PolicyID, Policy]): Map of policy ids to Policy
|
||||
instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
agent is then "bound" to the returned policy for the episode.
|
||||
rollout_fragment_length (int): Number of episode steps before
|
||||
`SampleBatch` is yielded. Set to infinity to yield complete
|
||||
episodes.
|
||||
horizon (int): Horizon of the episode.
|
||||
preprocessors (dict): Map of policy id to preprocessor for the
|
||||
observations prior to filtering.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be exactly
|
||||
|
@ -542,14 +551,14 @@ def _env_runner(
|
|||
|
||||
def new_episode(env_id):
|
||||
episode = MultiAgentEpisode(
|
||||
worker.policy_map,
|
||||
worker.policy_mapping_fn,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
get_batch_builder,
|
||||
extra_batch_callback,
|
||||
env_id=env_id)
|
||||
# Call each policy's Exploration.on_episode_start method.
|
||||
# type: Policy
|
||||
for p in worker.policy_map.values():
|
||||
for p in policies.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_start(
|
||||
policy=p,
|
||||
|
@ -559,7 +568,7 @@ def _env_runner(
|
|||
callbacks.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
|
@ -590,12 +599,15 @@ def _env_runner(
|
|||
_process_observations(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
active_episodes=active_episodes,
|
||||
unfiltered_obs=unfiltered_obs,
|
||||
rewards=rewards,
|
||||
dones=dones,
|
||||
infos=infos,
|
||||
horizon=horizon,
|
||||
preprocessors=preprocessors,
|
||||
obs_filters=obs_filters,
|
||||
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
||||
callbacks=callbacks,
|
||||
soft_horizon=soft_horizon,
|
||||
|
@ -612,7 +624,7 @@ def _env_runner(
|
|||
# type: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
|
||||
eval_results = _do_policy_eval(
|
||||
to_eval=to_eval,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
sample_collector=sample_collector,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
|
@ -628,7 +640,7 @@ def _env_runner(
|
|||
active_episodes=active_episodes,
|
||||
active_envs=active_envs,
|
||||
off_policy_actions=off_policy_actions,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
clip_actions=clip_actions,
|
||||
)
|
||||
perf_stats.action_processing_time += time.time() - t3
|
||||
|
@ -667,12 +679,15 @@ def _process_observations(
|
|||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
||||
rewards: Dict[EnvID, Dict[AgentID, float]],
|
||||
dones: Dict[EnvID, Dict[AgentID, bool]],
|
||||
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
|
||||
horizon: int,
|
||||
preprocessors: Dict[PolicyID, Preprocessor],
|
||||
obs_filters: Dict[PolicyID, Filter],
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
soft_horizon: bool,
|
||||
|
@ -686,6 +701,7 @@ def _process_observations(
|
|||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): Env implementing BaseEnv.
|
||||
policies (dict): Map of policy ids to Policy instances.
|
||||
batch_builder_pool (List[SampleBatchBuilder]): List of pooled
|
||||
SampleBatchBuilder object for recycling.
|
||||
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
||||
|
@ -700,6 +716,10 @@ def _process_observations(
|
|||
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
|
||||
info dicts, returned by a `BaseEnv.poll()` call.
|
||||
horizon (int): Horizon of the episode.
|
||||
preprocessors (dict): Map of policy id to preprocessor for the
|
||||
observations prior to filtering.
|
||||
obs_filters (dict): Map of policy id to filter used to process
|
||||
observations for the policy.
|
||||
rollout_fragment_length (int): Number of episode steps before
|
||||
`SampleBatch` is yielded. Set to infinity to yield complete
|
||||
episodes.
|
||||
|
@ -755,17 +775,6 @@ def _process_observations(
|
|||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics, {},
|
||||
episode.hist_data, episode.media))
|
||||
# Check whether we have to create a fake-last observation
|
||||
# for some agents (the environment is not required to do so if
|
||||
# dones[__all__]=True).
|
||||
for ag_id in episode.get_agents():
|
||||
if not episode.last_done_for(
|
||||
ag_id) and ag_id not in all_agents_obs:
|
||||
# Create a fake (all-0s) observation.
|
||||
obs_sp = worker.policy_map[episode.policy_for(
|
||||
ag_id)].observation_space
|
||||
obs_sp = getattr(obs_sp, "original_space", obs_sp)
|
||||
all_agents_obs[ag_id] = np.zeros_like(obs_sp.sample())
|
||||
else:
|
||||
hit_horizon = False
|
||||
all_agents_done = False
|
||||
|
@ -777,7 +786,7 @@ def _process_observations(
|
|||
agent_obs=all_agents_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
if not isinstance(all_agents_obs, dict):
|
||||
raise ValueError(
|
||||
|
@ -798,18 +807,17 @@ def _process_observations(
|
|||
|
||||
policy_id: PolicyID = episode.policy_for(agent_id)
|
||||
|
||||
prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
|
||||
prep_obs: EnvObsType = _get_or_raise(preprocessors,
|
||||
policy_id).transform(raw_obs)
|
||||
if log_once("prep_obs"):
|
||||
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
|
||||
filtered_obs: EnvObsType = _get_or_raise(worker.filters,
|
||||
filtered_obs: EnvObsType = _get_or_raise(obs_filters,
|
||||
policy_id)(prep_obs)
|
||||
if log_once("filtered_obs"):
|
||||
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
||||
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_raw_obs(agent_id, raw_obs)
|
||||
episode._set_last_done(agent_id, agent_done)
|
||||
# Infos from the environment.
|
||||
agent_infos = infos[env_id].get(agent_id, {})
|
||||
episode._set_last_info(agent_id, agent_infos)
|
||||
|
@ -828,7 +836,7 @@ def _process_observations(
|
|||
# Action (slot 0) taken at timestep t.
|
||||
"actions": episode.last_action_for(agent_id),
|
||||
# Reward received after taking a at timestep t.
|
||||
"rewards": rewards[env_id].get(agent_id, 0.0),
|
||||
"rewards": rewards[env_id][agent_id],
|
||||
# After taking action=a, did we reach terminal?
|
||||
"dones": (False if (no_done_at_end
|
||||
or (hit_horizon and soft_horizon)) else
|
||||
|
@ -837,7 +845,7 @@ def _process_observations(
|
|||
"new_obs": filtered_obs,
|
||||
}
|
||||
# Add extra-action-fetches to collectors.
|
||||
pol = worker.policy_map[policy_id]
|
||||
pol = policies[policy_id]
|
||||
for key, value in episode.last_pi_info_for(agent_id).items():
|
||||
if key in pol.view_requirements:
|
||||
values_dict[key] = value
|
||||
|
@ -854,8 +862,8 @@ def _process_observations(
|
|||
if last_observation is None else
|
||||
episode.rnn_state_for(agent_id), None
|
||||
if last_observation is None else
|
||||
episode.last_action_for(agent_id), rewards[env_id].get(
|
||||
agent_id, 0.0))
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0)
|
||||
to_eval[policy_id].append(item)
|
||||
|
||||
# Invoke the `on_episode_step` callback after the step is logged
|
||||
|
@ -890,7 +898,7 @@ def _process_observations(
|
|||
outputs.append(ma_sample_batch)
|
||||
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in worker.policy_map.values():
|
||||
for p in policies.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
|
@ -901,7 +909,7 @@ def _process_observations(
|
|||
callbacks.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
|
@ -928,15 +936,15 @@ def _process_observations(
|
|||
agent_obs=resetted_obs,
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=worker.policy_map,
|
||||
policies=policies,
|
||||
episode=new_episode)
|
||||
# type: AgentID, EnvObsType
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id: PolicyID = new_episode.policy_for(agent_id)
|
||||
prep_obs: EnvObsType = _get_or_raise(
|
||||
worker.preprocessors, policy_id).transform(raw_obs)
|
||||
preprocessors, policy_id).transform(raw_obs)
|
||||
filtered_obs: EnvObsType = _get_or_raise(
|
||||
worker.filters, policy_id)(prep_obs)
|
||||
obs_filters, policy_id)(prep_obs)
|
||||
new_episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Add initial obs to buffer.
|
||||
|
|
|
@ -514,8 +514,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
"pol0": (MockPolicy, obs_space, action_space, {}),
|
||||
"pol1": (MockPolicy, obs_space, action_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id, episode, **kwargs:
|
||||
"pol0" if agent_id == 0 else "pol1",
|
||||
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
|
||||
rollout_fragment_length=301,
|
||||
count_steps_by="env_steps",
|
||||
batch_mode="truncate_episodes",
|
||||
|
@ -532,8 +531,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
"pol0": (MockPolicy, obs_space, action_space, {}),
|
||||
"pol1": (MockPolicy, obs_space, action_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id, episode, **kwargs:
|
||||
"pol0" if agent_id == 0 else "pol1",
|
||||
policy_mapping_fn=lambda ag: "pol0" if ag == 0 else "pol1",
|
||||
rollout_fragment_length=301,
|
||||
count_steps_by="agent_steps",
|
||||
batch_mode="truncate_episodes")
|
||||
|
|
|
@ -218,7 +218,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id, episode, **kwargs):
|
||||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
config = {
|
||||
|
@ -269,7 +269,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
{}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id, episode, **kwargs):
|
||||
def policy_fn(agent_id):
|
||||
return "pol0"
|
||||
|
||||
config = {
|
||||
|
@ -309,7 +309,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
"p0": (None, obs_space, action_space, {}),
|
||||
"p1": (None, obs_space, action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid),
|
||||
"policy_mapping_fn": lambda aid: "p{}".format(aid),
|
||||
"count_steps_by": "agent_steps",
|
||||
}
|
||||
tune.register_env(
|
||||
|
|
|
@ -109,7 +109,7 @@ class WorkerSet:
|
|||
return self._remote_workers
|
||||
|
||||
def sync_weights(self) -> None:
|
||||
"""Syncs weights from the local worker to all remote workers."""
|
||||
"""Syncs weights of remote workers with the local worker."""
|
||||
if self.remote_workers():
|
||||
weights = ray.put(self.local_worker().get_weights())
|
||||
for e in self.remote_workers():
|
||||
|
|
|
@ -239,8 +239,7 @@ if __name__ == "__main__":
|
|||
"framework": args.framework,
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"),
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
},
|
||||
"model": {
|
||||
"custom_model": "cc_model",
|
||||
|
|
|
@ -116,8 +116,7 @@ if __name__ == "__main__":
|
|||
"pol1": (None, observer_space, action_space, {}),
|
||||
"pol2": (None, observer_space, action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"),
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
"observation_fn": central_critic_observer,
|
||||
},
|
||||
"model": {
|
||||
|
|
|
@ -45,7 +45,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
|
|||
None, AsymCoinGame(env_config).OBSERVATION_SPACE,
|
||||
AsymCoinGame.ACTION_SPACE, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
|
||||
"policy_mapping_fn": lambda agent_id: agent_id,
|
||||
},
|
||||
# Size of batches collected from each worker.
|
||||
"rollout_fragment_length": 20,
|
||||
|
|
|
@ -87,7 +87,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
maze = WindyMazeEnv(None)
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
if agent_id.startswith("low_level_"):
|
||||
return "low_level_policy"
|
||||
else:
|
||||
|
|
|
@ -61,7 +61,7 @@ def get_rllib_config(seeds, debug=False, stop_iters=200, tf=False):
|
|||
None, IteratedPrisonersDilemma.OBSERVATION_SPACE,
|
||||
IteratedPrisonersDilemma.ACTION_SPACE, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
|
||||
"policy_mapping_fn": lambda agent_id: agent_id,
|
||||
},
|
||||
"seed": tune.grid_search(seeds),
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
|
|
|
@ -95,7 +95,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
policy_ids = list(policies.keys())
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
pol_id = random.choice(policy_ids)
|
||||
return pol_id
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ if __name__ == "__main__":
|
|||
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: ["pg_policy", "random"][aid % 2]),
|
||||
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||
},
|
||||
"framework": args.framework,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
|
|
|
@ -32,8 +32,7 @@ if __name__ == "__main__":
|
|||
# Method specific
|
||||
"multiagent": {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id, episode, **kwargs: agent_id),
|
||||
"policy_mapping_fn": (lambda agent_id: agent_id),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -51,8 +51,7 @@ if __name__ == "__main__":
|
|||
# Method specific
|
||||
"multiagent": {
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id, episode, **kwargs: "shared_policy"),
|
||||
"policy_mapping_fn": (lambda agent_id: "shared_policy"),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
|||
DQNTFPolicy, obs_space, act_space, {}),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
if agent_id % 2 == 0:
|
||||
return "ppo_policy"
|
||||
else:
|
||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
|
||||
"policy_mapping_fn": lambda agent_id: "av"
|
||||
}
|
||||
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
|
|
|
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
policy_ids = list(policies.keys())
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
pol_id = random.choice(policy_ids)
|
||||
return pol_id
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
|
|||
beat_last heuristics.
|
||||
"""
|
||||
|
||||
def select_policy(agent_id, episode, **kwargs):
|
||||
def select_policy(agent_id):
|
||||
if agent_id == "player1":
|
||||
return "learned"
|
||||
else:
|
||||
|
|
|
@ -1,271 +0,0 @@
|
|||
"""Example showing how one can implement a simple self-play training workflow.
|
||||
|
||||
Uses the open spiel adapter of RLlib with the "connect_four" game and
|
||||
a multi-agent setup with a "main" policy and n "main_v[x]" policies
|
||||
(x=version number), which are all at-some-point-frozen copies of
|
||||
"main". At the very beginning, "main" plays against RandomPolicy.
|
||||
|
||||
Checks for the training progress after each training update via a custom
|
||||
callback. We simply measure the win rate of "main" vs the opponent
|
||||
("main_v[x]" or RandomPolicy at the beginning) by looking through the
|
||||
achieved rewards in the episodes in the train batch. If this win rate
|
||||
reaches some configurable threshold, we add a new policy to
|
||||
the policy map (a frozen copy of the current "main" one) and change the
|
||||
policy_mapping_fn to make new matches of "main" vs the just added one.
|
||||
|
||||
After training for n iterations, a configurable number of episodes can
|
||||
be played by the user against the "main" agent on the command line.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
import pyspiel
|
||||
from open_spiel.python.rl_environment import Environment
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
|
||||
from ray.tune import register_env
|
||||
|
||||
OBS_SPACE = ACTION_SPACE = None
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--framework",
|
||||
choices=["tf", "tf2", "tfe", "torch"],
|
||||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Full path to a checkpoint file for restoring a previously saved "
|
||||
"Trainer state.")
|
||||
parser.add_argument(
|
||||
"--stop-iters",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of iterations to train.")
|
||||
parser.add_argument(
|
||||
"--stop-timesteps",
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="Number of timesteps to train.")
|
||||
parser.add_argument(
|
||||
"--win-rate-threshold",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Win-rate at which we setup another opponent by freezing the "
|
||||
"current main policy and playing against a uniform distribution "
|
||||
"of previously frozen 'main's from here on.")
|
||||
parser.add_argument(
|
||||
"--num-episodes-human-play",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many episodes to play against the user on the command "
|
||||
"line after training has finished.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def ask_user_for_action(time_step):
|
||||
"""Asks the user for a valid action on the command line and returns it.
|
||||
|
||||
Re-queries the user until she picks a valid one.
|
||||
|
||||
Args:
|
||||
time_step: The open spiel Environment time-step object.
|
||||
"""
|
||||
pid = time_step.observations["current_player"]
|
||||
legal_moves = time_step.observations["legal_actions"][pid]
|
||||
choice = -1
|
||||
while choice not in legal_moves:
|
||||
print("Choose an action from {}:".format(legal_moves))
|
||||
sys.stdout.flush()
|
||||
choice_str = input()
|
||||
try:
|
||||
choice = int(choice_str)
|
||||
except ValueError:
|
||||
continue
|
||||
return choice
|
||||
|
||||
|
||||
class SelfPlayCallback(DefaultCallbacks):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 0=RandomPolicy, 1=1st main policy snapshot,
|
||||
# 2=2nd main policy snapshot, etc..
|
||||
self.current_opponent = 0
|
||||
|
||||
def on_train_result(self, *, trainer, result, **kwargs):
|
||||
# Get the win rate for the train batch.
|
||||
# Note that normally, one should set up a proper evaluation config,
|
||||
# such that evaluation always happens on the already updated policy,
|
||||
# instead of on the already used train_batch.
|
||||
main_rew = result["hist_stats"].pop("policy_main_reward")
|
||||
opponent_rew = list(result["hist_stats"].values())[0]
|
||||
assert len(main_rew) == len(opponent_rew)
|
||||
won = 0
|
||||
for r_main, r_opponent in zip(main_rew, opponent_rew):
|
||||
if r_main > r_opponent:
|
||||
won += 1
|
||||
win_rate = won / len(main_rew)
|
||||
print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="")
|
||||
# If win rate is good -> Snapshot current policy and play against
|
||||
# it next, keeping the snapshot fixed and only improving the "main"
|
||||
# policy.
|
||||
if win_rate > args.win_rate_threshold:
|
||||
self.current_opponent += 1
|
||||
new_pol_id = f"main_v{self.current_opponent}"
|
||||
print(f"adding new opponent to the mix ({new_pol_id}).")
|
||||
|
||||
# Re-define the mapping function, such that "main" is forced
|
||||
# to play against any of the previously played policies
|
||||
# (excluding "random").
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
# agent_id = [0|1] -> policy depends on episode ID
|
||||
# This way, we make sure that both policies sometimes play
|
||||
# (start player) and sometimes agent1 (player to move 2nd).
|
||||
return "main" if episode.episode_id % 2 == agent_id \
|
||||
else "main_v{}".format(np.random.choice(
|
||||
list(range(1, self.current_opponent + 1))))
|
||||
|
||||
new_policy = trainer.add_policy(
|
||||
policy_id=new_pol_id,
|
||||
policy_cls=type(trainer.get_policy("main")),
|
||||
observation_space=OBS_SPACE,
|
||||
action_space=ACTION_SPACE,
|
||||
config={},
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
)
|
||||
|
||||
# Set the weights of the new policy to the main policy.
|
||||
# We'll keep training the main policy, whereas `new_pol_id` will
|
||||
# remain fixed.
|
||||
main_state = trainer.get_policy("main").get_state()
|
||||
new_policy.set_state(main_state)
|
||||
# We need to sync the just copied local weights (from main policy)
|
||||
# to all the remote workers as well.
|
||||
trainer.workers.sync_weights()
|
||||
else:
|
||||
print("not good enough; will keep learning")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_cpus=args.num_cpus or None, include_dashboard=False)
|
||||
|
||||
dummy_env = OpenSpielEnv(pyspiel.load_game("connect_four"))
|
||||
OBS_SPACE = dummy_env.observation_space
|
||||
ACTION_SPACE = dummy_env.action_space
|
||||
|
||||
register_env("connect_four",
|
||||
lambda _: OpenSpielEnv(pyspiel.load_game("connect_four")))
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
# agent_id = [0|1] -> policy depends on episode ID
|
||||
# This way, we make sure that both policies sometimes play agent0
|
||||
# (start player) and sometimes agent1 (player to move 2nd).
|
||||
return "main" if episode.episode_id % 2 == agent_id else "random"
|
||||
|
||||
config = {
|
||||
"env": "connect_four",
|
||||
"callbacks": SelfPlayCallback,
|
||||
"model": {
|
||||
"fcnet_hiddens": [512, 512],
|
||||
},
|
||||
"num_sgd_iter": 20,
|
||||
"num_envs_per_worker": 5,
|
||||
"multiagent": {
|
||||
# Initial policy map: Random and PPO. This will be expanded
|
||||
# to more policy snapshots taken from "main" against which "main"
|
||||
# will then play (instead of "random"). This is done in the
|
||||
# custom callback defined above (`SelfPlayCallback`).
|
||||
"policies": {
|
||||
# Our main policy, we'd like to optimize.
|
||||
"main": (None, OBS_SPACE, ACTION_SPACE, {}),
|
||||
# An initial random opponent to play against.
|
||||
"random": (RandomPolicy, OBS_SPACE, ACTION_SPACE, {}),
|
||||
},
|
||||
# Assign agent 0 and 1 randomly to the "main" policy or
|
||||
# to the opponent ("random" at first). Make sure (via episode_id)
|
||||
# that "main" always plays against "random" (and not against
|
||||
# another "main").
|
||||
"policy_mapping_fn": policy_mapping_fn,
|
||||
# Always just train the "main" policy.
|
||||
"policies_to_train": ["main"],
|
||||
},
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
}
|
||||
|
||||
stop = {
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"training_iteration": args.stop_iters,
|
||||
}
|
||||
|
||||
# Train the "main" policy to play really well using self-play.
|
||||
results = None
|
||||
if not args.from_checkpoint:
|
||||
results = tune.run(
|
||||
"PPO",
|
||||
config=config,
|
||||
stop=stop,
|
||||
checkpoint_at_end=True,
|
||||
checkpoint_freq=10,
|
||||
verbose=1)
|
||||
|
||||
# Restore trained trainer (set to non-explore behavior) and play against
|
||||
# human on command line.
|
||||
if args.num_episodes_human_play > 0:
|
||||
num_episodes = 0
|
||||
trainer = PPOTrainer(config=dict(config, **{"explore": False}))
|
||||
if args.from_checkpoint:
|
||||
trainer.restore(args.from_checkpoint)
|
||||
else:
|
||||
trainer.restore(results.get_last_checkpoint())
|
||||
|
||||
# Play from the command line against the trained agent
|
||||
# in an actual (non-RLlib-wrapped) open-spiel env.
|
||||
human_player = 1
|
||||
env = Environment("connect_four")
|
||||
|
||||
while num_episodes < args.num_episodes_human_play:
|
||||
print("You play as {}".format("o" if human_player else "x"))
|
||||
time_step = env.reset()
|
||||
while not time_step.last():
|
||||
player_id = time_step.observations["current_player"]
|
||||
if player_id == human_player:
|
||||
action = ask_user_for_action(time_step)
|
||||
else:
|
||||
obs = np.array(
|
||||
time_step.observations["info_state"][player_id])
|
||||
action = trainer.compute_action(obs, policy_id="main")
|
||||
# In case computer chooses an invalid action, pick a
|
||||
# random one.
|
||||
legal = time_step.observations["legal_actions"][player_id]
|
||||
if action not in legal:
|
||||
action = np.random.choice(legal)
|
||||
time_step = env.step([action])
|
||||
print(f"\n{env.get_state}")
|
||||
|
||||
print(f"\n{env.get_state}")
|
||||
|
||||
print("End of game!")
|
||||
if time_step.rewards[human_player] > 0:
|
||||
print("You win")
|
||||
elif time_step.rewards[human_player] < 0:
|
||||
print("You lose")
|
||||
else:
|
||||
print("Draw")
|
||||
# Switch order of players
|
||||
human_player = 1 - human_player
|
||||
|
||||
num_episodes += 1
|
||||
|
||||
ray.shutdown()
|
|
@ -140,8 +140,7 @@ if __name__ == "__main__":
|
|||
marl_env.get_action_space(agent),
|
||||
agent_policy_params)
|
||||
config["multiagent"]["policies"] = policies
|
||||
config["multiagent"][
|
||||
"policy_mapping_fn"] = lambda agent_id, episode, **kwargs: agent_id
|
||||
config["multiagent"]["policy_mapping_fn"] = lambda agent_id: agent_id
|
||||
config["multiagent"]["policies_to_train"] = ["ppo_policy"]
|
||||
|
||||
config["env"] = "sumo_test_env"
|
||||
|
|
|
@ -100,8 +100,7 @@ if __name__ == "__main__":
|
|||
"agent_id": 1,
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: "pol2" if aid else "pol1"),
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
},
|
||||
"framework": args.framework,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
|
|
|
@ -137,7 +137,7 @@ if __name__ == "__main__":
|
|||
obs_space, act_space, DQN_CONFIG),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
def policy_mapping_fn(agent_id):
|
||||
if agent_id % 2 == 0:
|
||||
return "ppo_policy"
|
||||
else:
|
||||
|
|
|
@ -48,8 +48,7 @@ class TrainOneStep:
|
|||
num_sgd_iter: int = 1,
|
||||
sgd_minibatch_size: int = 0):
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
self.policies = policies
|
||||
self.policies = policies or workers.local_worker().policies_to_train
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.sgd_minibatch_size = sgd_minibatch_size
|
||||
|
||||
|
@ -62,11 +61,9 @@ class TrainOneStep:
|
|||
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
|
||||
lw = self.workers.local_worker()
|
||||
info = do_minibatch_sgd(
|
||||
batch, {
|
||||
pid: lw.get_policy(pid)
|
||||
for pid in self.policies
|
||||
or self.local_worker.policies_to_train
|
||||
}, lw, self.num_sgd_iter, self.sgd_minibatch_size, [])
|
||||
batch, {pid: lw.get_policy(pid)
|
||||
for pid in self.policies}, lw, self.num_sgd_iter,
|
||||
self.sgd_minibatch_size, [])
|
||||
# TODO(ekl) shouldn't be returning learner stats directly here
|
||||
# TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
|
||||
# callback (shouldn't).
|
||||
|
@ -87,7 +84,7 @@ class TrainOneStep:
|
|||
if self.workers.remote_workers():
|
||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||
weights = ray.put(self.workers.local_worker().get_weights(
|
||||
self.policies or self.local_worker.policies_to_train))
|
||||
self.policies))
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights, _get_global_vars())
|
||||
# Also update global vars of the local worker.
|
||||
|
@ -122,8 +119,7 @@ class TrainTFMultiGPU:
|
|||
_fake_gpus: bool = False,
|
||||
framework: str = "tf"):
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
self.policies = policies
|
||||
self.policies = policies or workers.local_worker().policies_to_train
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.sgd_minibatch_size = sgd_minibatch_size
|
||||
self.shuffle_sequences = shuffle_sequences
|
||||
|
@ -154,8 +150,7 @@ class TrainTFMultiGPU:
|
|||
self.optimizers = {}
|
||||
with self.workers.local_worker().tf_sess.graph.as_default():
|
||||
with self.workers.local_worker().tf_sess.as_default():
|
||||
for policy_id in (self.policies
|
||||
or self.local_worker.policies_to_train):
|
||||
for policy_id in self.policies:
|
||||
policy = self.workers.local_worker().get_policy(policy_id)
|
||||
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
|
||||
if policy._state_inputs:
|
||||
|
@ -178,7 +173,7 @@ class TrainTFMultiGPU:
|
|||
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
|
||||
_check_sample_batch_type(samples)
|
||||
|
||||
# Handle everything as if multi agent
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(samples, SampleBatch):
|
||||
samples = MultiAgentBatch({
|
||||
DEFAULT_POLICY_ID: samples
|
||||
|
@ -192,8 +187,7 @@ class TrainTFMultiGPU:
|
|||
num_loaded_tuples = {}
|
||||
for policy_id, batch in samples.policy_batches.items():
|
||||
# Not a policy-to-train.
|
||||
if policy_id not in (self.policies
|
||||
or self.local_worker.policies_to_train):
|
||||
if policy_id not in self.policies:
|
||||
continue
|
||||
|
||||
# Decompress SampleBatch, in case some columns are compressed.
|
||||
|
@ -251,7 +245,7 @@ class TrainTFMultiGPU:
|
|||
if self.workers.remote_workers():
|
||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||
weights = ray.put(self.workers.local_worker().get_weights(
|
||||
self.policies or self.local_worker.policies_to_train))
|
||||
self.policies))
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights, _get_global_vars())
|
||||
# Also update global vars of the local worker.
|
||||
|
@ -321,8 +315,7 @@ class ApplyGradients:
|
|||
currently processing (i.e., A3C style).
|
||||
"""
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
self.policies = policies
|
||||
self.policies = policies or workers.local_worker().policies_to_train
|
||||
self.update_all = update_all
|
||||
|
||||
def __call__(self, item: Tuple[ModelGradients, int]) -> None:
|
||||
|
@ -346,7 +339,7 @@ class ApplyGradients:
|
|||
if self.workers.remote_workers():
|
||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||
weights = ray.put(self.workers.local_worker().get_weights(
|
||||
self.policies or self.local_worker.policies_to_train))
|
||||
self.policies))
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights, _get_global_vars())
|
||||
else:
|
||||
|
@ -357,7 +350,7 @@ class ApplyGradients:
|
|||
"in the iterator context.")
|
||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||
weights = self.workers.local_worker().get_weights(
|
||||
self.policies or self.local_worker.policies_to_train)
|
||||
self.policies)
|
||||
metrics.current_actor.set_weights.remote(
|
||||
weights, _get_global_vars())
|
||||
|
||||
|
@ -414,9 +407,8 @@ class UpdateTargetNetwork:
|
|||
by_steps_trained: bool = False,
|
||||
policies: List[PolicyID] = frozenset([])):
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
self.target_update_freq = target_update_freq
|
||||
self.policies = policies
|
||||
self.policies = (policies or workers.local_worker().policies_to_train)
|
||||
if by_steps_trained:
|
||||
self.metric = STEPS_TRAINED_COUNTER
|
||||
else:
|
||||
|
@ -427,7 +419,7 @@ class UpdateTargetNetwork:
|
|||
cur_ts = metrics.counters[self.metric]
|
||||
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update > self.target_update_freq:
|
||||
to_update = self.policies or self.local_worker.policies_to_train
|
||||
to_update = self.policies
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, p_id: p_id in to_update and p.update_target())
|
||||
metrics.counters[NUM_TARGET_UPDATES] += 1
|
||||
|
|
|
@ -55,7 +55,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2),
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
|
|
@ -196,7 +196,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: random.choice(
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
"framework": fw,
|
||||
|
@ -218,7 +218,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda aid, **kwargs: random.choice(
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
"framework": fw,
|
||||
|
|
|
@ -77,15 +77,20 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {}, 1: {}})
|
||||
self.assertEqual(dones, {
|
||||
0: {
|
||||
"__all__": False
|
||||
},
|
||||
1: {
|
||||
"__all__": False
|
||||
},
|
||||
})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
self.assertEqual(
|
||||
dones, {
|
||||
0: {
|
||||
0: False,
|
||||
1: False,
|
||||
"__all__": False
|
||||
},
|
||||
1: {
|
||||
0: False,
|
||||
1: False,
|
||||
"__all__": False
|
||||
}
|
||||
})
|
||||
for _ in range(24):
|
||||
env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
|
@ -156,7 +161,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
env = _MultiAgentEnvToBaseEnv(lambda v: RoundRobinMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
self.assertEqual(rew, {0: {}, 1: {}})
|
||||
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
|
||||
env.send_actions({0: {0: 0}, 1: {0: 0}})
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
|
||||
|
@ -167,17 +172,13 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
def test_multi_agent_sample(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
|
||||
def policy_mapping_fn(agent_id, episode, **kwargs):
|
||||
return "p{}".format(agent_id % 2)
|
||||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
@ -197,10 +198,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
# This signature will raise a soft-deprecation warning due
|
||||
# to the new signature we are using (agent_id, episode, **kwargs),
|
||||
# but should not break this test.
|
||||
policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
rollout_fragment_length=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True,
|
||||
|
@ -219,7 +217,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
rollout_fragment_length=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True)
|
||||
|
@ -235,7 +233,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)),
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
episode_horizon=10, # test with episode horizon set
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
|
@ -250,23 +248,12 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_mode="complete_episodes",
|
||||
rollout_fragment_length=1)
|
||||
# This used to raise an Error due to the EarlyDoneMultiAgent
|
||||
# terminating at e.g. agent0 w/o publishing the observation for
|
||||
# agent1 anymore. This limitation is fixed and an env may
|
||||
# terminate at any time (as well as return rewards for any agent
|
||||
# at any time, even when that agent doesn't have an obs returned
|
||||
# in the same call to `step()`).
|
||||
ma_batch = ev.sample()
|
||||
# Make sure that agents took the correct (alternating timesteps)
|
||||
# path. Except for the last timestep, where both agents got
|
||||
# terminated.
|
||||
ag0_ts = ma_batch.policy_batches["p0"]["t"]
|
||||
ag1_ts = ma_batch.policy_batches["p1"]["t"]
|
||||
self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0))
|
||||
self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
".*don't have a last observation.*",
|
||||
lambda: ev.sample())
|
||||
|
||||
def test_multi_agent_with_flex_agents(self):
|
||||
register_env("flex_agents_multi_agent_cartpole",
|
||||
|
@ -290,7 +277,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
@ -353,7 +340,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
# the extra trajectory.
|
||||
env_id = episodes[0].env_id
|
||||
fake_eps = MultiAgentEpisode(
|
||||
episodes[0].policy_map, episodes[0]._policy_mapping_fn,
|
||||
episodes[0]._policies, episodes[0]._policy_mapping_fn,
|
||||
lambda: None, lambda x: None, env_id)
|
||||
builder = get_global_worker().sampler.sample_collector
|
||||
agent_id = "extra_0"
|
||||
|
@ -390,7 +377,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
rollout_fragment_length=5)
|
||||
batch = ev.sample()
|
||||
# 5 environment steps (rollout_fragment_length).
|
||||
|
@ -443,7 +430,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid, **kwargs: "policy_1",
|
||||
"policy_mapping_fn": lambda agent_id: "policy_1",
|
||||
},
|
||||
"framework": "tf",
|
||||
})
|
||||
|
|
|
@ -457,9 +457,9 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
PGTFPolicy, DICT_SPACE, act_space,
|
||||
{"model": {"custom_model": "dict_spy"}}),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid, **kwargs: {
|
||||
"policy_mapping_fn": lambda a: {
|
||||
"tuple_agent": "tuple_policy",
|
||||
"dict_agent": "dict_policy"}[aid],
|
||||
"dict_agent": "dict_policy"}[a],
|
||||
},
|
||||
"framework": "tf",
|
||||
})
|
||||
|
|
|
@ -34,7 +34,7 @@ class TestPettingZooEnv(unittest.TestCase):
|
|||
# the first tuple value is None -> uses default policy
|
||||
"av": (None, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda agent_id, episode, **kwargs: "av"
|
||||
"policy_mapping_fn": lambda agent_id: "av"
|
||||
}
|
||||
|
||||
config["log_level"] = "DEBUG"
|
||||
|
|
|
@ -147,8 +147,8 @@ def learn_test_multi_agent_plus_rollout(algo):
|
|||
print("RLlib dir = {}\nexists={}".format(rllib_dir,
|
||||
os.path.exists(rllib_dir)))
|
||||
|
||||
def policy_fn(agent_id, episode, **kwargs):
|
||||
return "pol{}".format(agent_id)
|
||||
def policy_fn(agent):
|
||||
return "pol{}".format(agent)
|
||||
|
||||
observation_space = Box(float("-inf"), float("inf"), (4, ))
|
||||
action_space = Discrete(2)
|
||||
|
|
Loading…
Add table
Reference in a new issue