diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index ab2cdef5d..54babf148 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -1,4 +1,3 @@ -import gym from random import choice import unittest @@ -20,8 +19,6 @@ class TestTrainer(unittest.TestCase): ray.shutdown() def test_add_delete_policy(self): - env = gym.make("CartPole-v0") - config = pg.DEFAULT_CONFIG.copy() config.update({ "env": MultiAgentCartPole, @@ -30,34 +27,30 @@ class TestTrainer(unittest.TestCase): "num_agents": 4, }, }, + "num_workers": 2, # Test on remote workers as well. "multiagent": { # Start with a single policy. - "policies": { - "p0": (None, env.observation_space, env.action_space, {}), - }, + "policies": {"p0"}, "policy_mapping_fn": lambda aid, episode, **kwargs: "p0", "policy_map_capacity": 2, }, }) - # TODO: (sven) this will work for tf, once we have the DynamicTFPolicy - # refactor PR merged. - for _ in framework_iterator(config, frameworks=("tf2", "torch")): + for _ in framework_iterator(config): trainer = pg.PGTrainer(config=config) r = trainer.train() self.assertTrue("p0" in r["policy_reward_min"]) - for i in range(1, 4): + checkpoints = [] + for i in range(1, 3): def new_mapping_fn(agent_id, episode, **kwargs): return f"p{choice([i, i - 1])}" # Add a new policy. + pid = f"p{i}" new_pol = trainer.add_policy( - f"p{i}", + pid, trainer._policy_class, - observation_space=env.observation_space, - action_space=env.action_space, - config={}, # Test changing the mapping fn. policy_mapping_fn=new_mapping_fn, # Change the list of policies to train. @@ -70,9 +63,22 @@ class TestTrainer(unittest.TestCase): self.assertTrue(len(pol_map) == i + 1) r = trainer.train() self.assertTrue("p1" in r["policy_reward_min"]) + checkpoints.append(trainer.save()) + + # Test restoring from the checkpoint (which has more policies + # than what's defined in the config dict). + test = pg.PGTrainer(config=config) + test.restore(checkpoints[-1]) + test.train() + # Test creating an action with the added (and restored) policy. + a = test.compute_single_action( + test.get_policy("p0").observation_space.sample(), + policy_id=pid) + self.assertTrue(test.get_policy("p0").action_space.contains(a)) + test.stop() # Delete all added policies again from trainer. - for i in range(3, 0, -1): + for i in range(2, 0, -1): trainer.remove_policy( f"p{i}", policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}", @@ -130,7 +136,7 @@ class TestTrainer(unittest.TestCase): # Try again using `create_env_on_driver=True`. # This force-adds the env on the local-worker, so this Trainer - # can `evaluate` even though, it doesn't have an evaluation-worker + # can `evaluate` even though it doesn't have an evaluation-worker # set. config["create_env_on_driver"] = True trainer_w_env_on_driver = a3c.A3CTrainer(config=config) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e81d0139a..ef5b4fe06 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -951,7 +951,7 @@ class Trainer(Trainable): unsquash_actions: Optional[bool] = None, clip_actions: Optional[bool] = None, ) -> TensorStructType: - """Computes an action for the specified policy on the local Worker. + """Computes an action for the specified policy on the local worker. Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it @@ -982,17 +982,31 @@ class Trainer(Trainable): any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. + + Raises: + KeyError: If the `policy_id` cannot be found in this Trainer's + local worker. """ + policy = self.get_policy(policy_id) + if policy is None: + raise KeyError( + f"PolicyID '{policy_id}' not found in PolicyMap of the " + f"Trainer's local worker!") + + local_worker = self.workers.local_worker() + if state is None: state = [] + # Check the preprocessor and preprocess, if necessary. - pp = self.workers.local_worker().preprocessors[policy_id] + pp = local_worker.preprocessors[policy_id] if type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) - filtered_observation = self.workers.local_worker().filters[policy_id]( + filtered_observation = local_worker.filters[policy_id]( observation, update=False) - result = self.get_policy(policy_id).compute_single_action( + # Compute the action. + result = policy.compute_single_action( filtered_observation, state, prev_action, @@ -1002,10 +1016,12 @@ class Trainer(Trainable): clip_actions=clip_actions, explore=explore) + # Return 3-Tuple: Action, states, and extra-action fetches. if state or full_fetch: return result + # Ensure backward compatibility. else: - return result[0] # backwards compatibility + return result[0] @Deprecated(new="compute_single_action", error=False) def compute_action(self, *args, **kwargs): @@ -1193,7 +1209,6 @@ class Trainer(Trainable): observation_space=observation_space, action_space=action_space, config=config, - policy_config=self.config, policy_mapping_fn=policy_mapping_fn, policies_to_train=policies_to_train, ) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 38313dac5..2db6b2ffa 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -2,13 +2,13 @@ import random import numpy as np import gym import logging -import pickle import platform import os from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ TYPE_CHECKING, Union import ray +from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext from ray.rllib.env.external_env import ExternalEnv @@ -29,7 +29,6 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import force_list, merge_dicts from ray.rllib.utils.annotations import DeveloperAPI @@ -1057,7 +1056,6 @@ class RolloutWorker(ParallelIteratorWorker): observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, config: Optional[PartialTrainerConfigDict] = None, - policy_config: Optional[TrainerConfigDict] = None, policy_mapping_fn: Optional[Callable[ [AgentID, "MultiAgentEpisode"], PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None, @@ -1095,14 +1093,16 @@ class RolloutWorker(ParallelIteratorWorker): policy_dict = _determine_spaces_for_multi_agent_dict( { policy_id: PolicySpec(policy_cls, observation_space, - action_space, config) + action_space, config or {}) }, self.env, spaces=self.spaces, - policy_config=policy_config) - + policy_config=self.policy_config, + ) self._build_policy_map( - policy_dict, policy_config, seed=policy_config.get("seed")) + policy_dict, + self.policy_config, + seed=self.policy_config.get("seed")) new_policy = self.policy_map[policy_id] self.filters[policy_id] = get_filter( @@ -1244,11 +1244,16 @@ class RolloutWorker(ParallelIteratorWorker): @DeveloperAPI def save(self) -> bytes: filters = self.get_filters(flush_after=True) - state = { - pid: self.policy_map[pid].get_state() - for pid in self.policy_map - } - return pickle.dumps({"filters": filters, "state": state}) + state = {} + policy_specs = {} + for pid in self.policy_map: + state[pid] = self.policy_map[pid].get_state() + policy_specs[pid] = self.policy_map.policy_specs[pid] + return pickle.dumps({ + "filters": filters, + "state": state, + "policy_specs": policy_specs, + }) @DeveloperAPI def restore(self, objs: bytes) -> None: @@ -1256,12 +1261,23 @@ class RolloutWorker(ParallelIteratorWorker): self.sync_filters(objs["filters"]) for pid, state in objs["state"].items(): if pid not in self.policy_map: - logger.warning( - f"pid={pid} not found in policy_map! It was probably added" - " on-the-fly and is not part of the static `config." - "multiagent.policies` dict. Ignoring it for now.") - continue - self.policy_map[pid].set_state(state) + pol_spec = objs.get("policy_specs", {}).get(pid) + if not pol_spec: + logger.warning( + f"PolicyID '{pid}' was probably added on-the-fly (not" + " part of the static `multagent.policies` config) and" + " no PolicySpec objects found in the pickled policy " + "state. Will not add `{pid}`, but ignore it for now.") + else: + self.add_policy( + policy_id=pid, + policy_cls=pol_spec.policy_class, + observation_space=pol_spec.observation_space, + action_space=pol_spec.action_space, + config=pol_spec.config, + ) + else: + self.policy_map[pid].set_state(state) @DeveloperAPI def set_global_vars(self, global_vars: dict) -> None: @@ -1480,10 +1496,3 @@ def _validate_env(env: Any) -> EnvType: "ExternalEnv, VectorEnv, or BaseEnv. The provided env creator " "function returned {} ({}).".format(env, type(env))) return env - - -def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool: - for policy, _, _, _ in policy_dict.values(): - if issubclass(policy, TFPolicy): - return True - return False diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 38073abe1..d5ebb6e43 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -554,8 +554,12 @@ def _env_runner( extra_batch_callback, env_id=env_id) # Call each policy's Exploration.on_episode_start method. - # types: Policy - for p in worker.policy_map.values(): + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start( policy=p, @@ -902,8 +906,14 @@ def _process_observations( if ma_sample_batch: outputs.append(ma_sample_batch) - # Call each policy's Exploration.on_episode_end method. - for p in worker.policy_map.values(): + # Call each (in-memory) policy's Exploration.on_episode_end + # method. + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, diff --git a/rllib/examples/self_play_with_open_spiel.py b/rllib/examples/self_play_with_open_spiel.py index 5e292d040..86eca7f8b 100644 --- a/rllib/examples/self_play_with_open_spiel.py +++ b/rllib/examples/self_play_with_open_spiel.py @@ -41,6 +41,7 @@ parser.add_argument( default="tf", help="The DL framework specifier.") parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument("--num-workers", type=int, default=2) parser.add_argument( "--from-checkpoint", type=str, @@ -197,6 +198,7 @@ if __name__ == "__main__": # Always just train the "main" policy. "policies_to_train": ["main"], }, + "num_workers": args.num_workers, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "framework": args.framework, diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index 5a8a960dc..5a0457126 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -2,7 +2,7 @@ from collections import deque import gym import os import pickle -from typing import Callable, Dict, Optional, Type, TYPE_CHECKING +from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import override @@ -39,10 +39,19 @@ class PolicyMap(dict): """Initializes a PolicyMap instance. Args: - maxlen (int): The maximum number of policies to hold in memory. + worker_index (int): The worker index of the RolloutWorker this map + resides in. + num_workers (int): The total number of remote workers in the + WorkerSet to which this map's RolloutWorker belongs to. + capacity (int): The maximum number of policies to hold in memory. The least used ones are written to disk/S3 and retrieved when needed. - path (str): + path (str): The path to store the policy pickle files to. Files + will have the name: [policy_id].[worker idx].policy.pkl. + policy_config (TrainerConfigDict): The Trainer's base config dict. + session_creator (Optional[Callable[[], tf1.Session]): An optional + tf1.Session creation callable. + seed (int): An optional seed (used to seed tf policies). """ super().__init__() @@ -53,22 +62,22 @@ class PolicyMap(dict): # The file extension for stashed policies (that are no longer available # in-memory but can be reinstated any time from storage). - self.extension = ".policy.pkl" + self.extension = f".{self.worker_index}.policy.pkl" # Dictionary of keys that may be looked up (cached or not). - self.valid_keys = set() + self.valid_keys: Set[str] = set() # The actual cache with the in-memory policy objects. - self.cache = {} + self.cache: Dict[str, Policy] = {} # The doubly-linked list holding the currently in-memory objects. self.deque = deque(maxlen=capacity or 10) # The file path where to store overflowing policies. self.path = path or "." # The core config to use. Each single policy's config override is # added on top of this. - self.policy_config = policy_config or {} + self.policy_config: TrainerConfigDict = policy_config or {} # The orig classes/obs+act spaces, and config overrides of the # Policies. - self.policy_specs = {} # type: Dict[PolicyID, PolicySpec] + self.policy_specs: Dict[PolicyID, PolicySpec] = {} def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"], observation_space: gym.Space, action_space: gym.Space, @@ -140,7 +149,7 @@ class PolicyMap(dict): def __getitem__(self, item): # Never seen this key -> Error. if item not in self.valid_keys: - raise KeyError(f"'{item}' not a valid key!") + raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!") # Item already in cache -> Rearrange deque (least recently used) and # return. @@ -250,9 +259,7 @@ class PolicyMap(dict): policy_state = policy.get_state() # Closes policy's tf session, if any. self._close_session(policy) - # Remove from memory. - # TODO: (sven) This should clear the tf Graph as well, if the Trainer - # would not hold parts of the graph (e.g. in tf multi-GPU setups). + # Remove from memory. This will clear the tf Graph as well. del self.cache[delkey] # Write state to disk. with open(self.path + "/" + delkey + self.extension, "wb") as f: diff --git a/rllib/tests/test_multi_agent_env.py b/rllib/tests/test_multi_agent_env.py index ae2c63b76..c19bc12d4 100644 --- a/rllib/tests/test_multi_agent_env.py +++ b/rllib/tests/test_multi_agent_env.py @@ -457,8 +457,9 @@ class TestMultiAgentEnv(unittest.TestCase): self.assertTrue( pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1]) - self.assertRaises( + self.assertRaisesRegex( KeyError, + "not found in PolicyMap", lambda: pg.compute_single_action( [0, 0, 0, 0], policy_id="policy_3"))