From f18213712f036276a25482186da459cdb906f76b Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 17 Aug 2021 18:13:35 +0200 Subject: [PATCH] [RLlib] Redo: "fix self play example scripts" PR (17566) (#17895) * wip. * wip. * wip. * wip. * wip. * wip. * wip. * wip. * wip. --- rllib/BUILD | 7 -- rllib/agents/tests/test_trainer.py | 71 +++++++++++++++------ rllib/agents/trainer.py | 27 ++++++-- rllib/evaluation/rollout_worker.py | 59 +++++++++-------- rllib/evaluation/sampler.py | 18 ++++-- rllib/examples/self_play_with_open_spiel.py | 2 + rllib/policy/policy_map.py | 31 +++++---- rllib/tests/test_multi_agent_env.py | 3 +- rllib/tests/test_trainer.py | 29 --------- 9 files changed, 144 insertions(+), 103 deletions(-) delete mode 100644 rllib/tests/test_trainer.py diff --git a/rllib/BUILD b/rllib/BUILD index bad5634a9..ec0daed54 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1614,13 +1614,6 @@ py_test( srcs = ["tests/test_timesteps.py"] ) -py_test( - name = "tests/test_trainer", - tags = ["tests_dir", "tests_dir_T"], - size = "small", - srcs = ["tests/test_trainer.py"] -) - # -------------------------------------------------------------------- # examples/ directory # diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index ab2cdef5d..1054bf201 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -1,4 +1,4 @@ -import gym +import copy from random import choice import unittest @@ -6,6 +6,7 @@ import ray import ray.rllib.agents.a3c as a3c import ray.rllib.agents.dqn as dqn import ray.rllib.agents.pg as pg +from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.utils.test_utils import framework_iterator @@ -19,9 +20,24 @@ class TestTrainer(unittest.TestCase): def tearDownClass(cls): ray.shutdown() - def test_add_delete_policy(self): - env = gym.make("CartPole-v0") + def test_validate_config_idempotent(self): + """ + Asserts that validate_config run multiple + times on COMMON_CONFIG will be idempotent + """ + # Given: + standard_config = copy.deepcopy(COMMON_CONFIG) + # When (we validate config 2 times), ... + Trainer._validate_config(standard_config) + config_v1 = copy.deepcopy(standard_config) + Trainer._validate_config(standard_config) + config_v2 = copy.deepcopy(standard_config) + + # ... then ... + self.assertEqual(config_v1, config_v2) + + def test_add_delete_policy(self): config = pg.DEFAULT_CONFIG.copy() config.update({ "env": MultiAgentCartPole, @@ -30,34 +46,38 @@ class TestTrainer(unittest.TestCase): "num_agents": 4, }, }, + "num_workers": 2, # Test on remote workers as well. + "model": { + "fcnet_hiddens": [5], + "fcnet_activation": "linear", + }, + "train_batch_size": 100, + "rollout_fragment_length": 50, "multiagent": { # Start with a single policy. - "policies": { - "p0": (None, env.observation_space, env.action_space, {}), - }, + "policies": {"p0"}, "policy_mapping_fn": lambda aid, episode, **kwargs: "p0", + # And only two policies that can be stored in memory at a + # time. "policy_map_capacity": 2, }, }) - # TODO: (sven) this will work for tf, once we have the DynamicTFPolicy - # refactor PR merged. - for _ in framework_iterator(config, frameworks=("tf2", "torch")): + for _ in framework_iterator(config): trainer = pg.PGTrainer(config=config) r = trainer.train() - self.assertTrue("p0" in r["policy_reward_min"]) - for i in range(1, 4): + self.assertTrue("p0" in r["info"]["learner"]) + checkpoints = [] + for i in range(1, 3): def new_mapping_fn(agent_id, episode, **kwargs): return f"p{choice([i, i - 1])}" # Add a new policy. + pid = f"p{i}" new_pol = trainer.add_policy( - f"p{i}", + pid, trainer._policy_class, - observation_space=env.observation_space, - action_space=env.action_space, - config={}, # Test changing the mapping fn. policy_mapping_fn=new_mapping_fn, # Change the list of policies to train. @@ -65,14 +85,27 @@ class TestTrainer(unittest.TestCase): ) pol_map = trainer.workers.local_worker().policy_map self.assertTrue(new_pol is not trainer.get_policy("p0")) - for j in range(i): + for j in range(i + 1): self.assertTrue(f"p{j}" in pol_map) self.assertTrue(len(pol_map) == i + 1) r = trainer.train() - self.assertTrue("p1" in r["policy_reward_min"]) + self.assertTrue("p1" in r["info"]["learner"]) + checkpoints.append(trainer.save()) + + # Test restoring from the checkpoint (which has more policies + # than what's defined in the config dict). + test = pg.PGTrainer(config=config) + test.restore(checkpoints[-1]) + test.train() + # Test creating an action with the added (and restored) policy. + a = test.compute_single_action( + test.get_policy("p0").observation_space.sample(), + policy_id=pid) + self.assertTrue(test.get_policy("p0").action_space.contains(a)) + test.stop() # Delete all added policies again from trainer. - for i in range(3, 0, -1): + for i in range(2, 0, -1): trainer.remove_policy( f"p{i}", policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}", @@ -130,7 +163,7 @@ class TestTrainer(unittest.TestCase): # Try again using `create_env_on_driver=True`. # This force-adds the env on the local-worker, so this Trainer - # can `evaluate` even though, it doesn't have an evaluation-worker + # can `evaluate` even though it doesn't have an evaluation-worker # set. config["create_env_on_driver"] = True trainer_w_env_on_driver = a3c.A3CTrainer(config=config) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e81d0139a..ef5b4fe06 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -951,7 +951,7 @@ class Trainer(Trainable): unsquash_actions: Optional[bool] = None, clip_actions: Optional[bool] = None, ) -> TensorStructType: - """Computes an action for the specified policy on the local Worker. + """Computes an action for the specified policy on the local worker. Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it @@ -982,17 +982,31 @@ class Trainer(Trainable): any: The computed action if full_fetch=False, or tuple: The full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy. + + Raises: + KeyError: If the `policy_id` cannot be found in this Trainer's + local worker. """ + policy = self.get_policy(policy_id) + if policy is None: + raise KeyError( + f"PolicyID '{policy_id}' not found in PolicyMap of the " + f"Trainer's local worker!") + + local_worker = self.workers.local_worker() + if state is None: state = [] + # Check the preprocessor and preprocess, if necessary. - pp = self.workers.local_worker().preprocessors[policy_id] + pp = local_worker.preprocessors[policy_id] if type(pp).__name__ != "NoPreprocessor": observation = pp.transform(observation) - filtered_observation = self.workers.local_worker().filters[policy_id]( + filtered_observation = local_worker.filters[policy_id]( observation, update=False) - result = self.get_policy(policy_id).compute_single_action( + # Compute the action. + result = policy.compute_single_action( filtered_observation, state, prev_action, @@ -1002,10 +1016,12 @@ class Trainer(Trainable): clip_actions=clip_actions, explore=explore) + # Return 3-Tuple: Action, states, and extra-action fetches. if state or full_fetch: return result + # Ensure backward compatibility. else: - return result[0] # backwards compatibility + return result[0] @Deprecated(new="compute_single_action", error=False) def compute_action(self, *args, **kwargs): @@ -1193,7 +1209,6 @@ class Trainer(Trainable): observation_space=observation_space, action_space=action_space, config=config, - policy_config=self.config, policy_mapping_fn=policy_mapping_fn, policies_to_train=policies_to_train, ) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 3f23174a4..96982de20 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -2,13 +2,13 @@ import random import numpy as np import gym import logging -import pickle import platform import os from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ TYPE_CHECKING, Union import ray +from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext from ray.rllib.env.external_env import ExternalEnv @@ -29,7 +29,6 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import force_list, merge_dicts from ray.rllib.utils.annotations import DeveloperAPI @@ -1055,7 +1054,6 @@ class RolloutWorker(ParallelIteratorWorker): observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, config: Optional[PartialTrainerConfigDict] = None, - policy_config: Optional[TrainerConfigDict] = None, policy_mapping_fn: Optional[Callable[ [AgentID, "MultiAgentEpisode"], PolicyID]] = None, policies_to_train: Optional[List[PolicyID]] = None, @@ -1093,14 +1091,16 @@ class RolloutWorker(ParallelIteratorWorker): policy_dict = _determine_spaces_for_multi_agent_dict( { policy_id: PolicySpec(policy_cls, observation_space, - action_space, config) + action_space, config or {}) }, self.env, spaces=self.spaces, - policy_config=policy_config) - + policy_config=self.policy_config, + ) self._build_policy_map( - policy_dict, policy_config, seed=policy_config.get("seed")) + policy_dict, + self.policy_config, + seed=self.policy_config.get("seed")) new_policy = self.policy_map[policy_id] self.filters[policy_id] = get_filter( @@ -1242,11 +1242,16 @@ class RolloutWorker(ParallelIteratorWorker): @DeveloperAPI def save(self) -> bytes: filters = self.get_filters(flush_after=True) - state = { - pid: self.policy_map[pid].get_state() - for pid in self.policy_map - } - return pickle.dumps({"filters": filters, "state": state}) + state = {} + policy_specs = {} + for pid in self.policy_map: + state[pid] = self.policy_map[pid].get_state() + policy_specs[pid] = self.policy_map.policy_specs[pid] + return pickle.dumps({ + "filters": filters, + "state": state, + "policy_specs": policy_specs, + }) @DeveloperAPI def restore(self, objs: bytes) -> None: @@ -1254,12 +1259,23 @@ class RolloutWorker(ParallelIteratorWorker): self.sync_filters(objs["filters"]) for pid, state in objs["state"].items(): if pid not in self.policy_map: - logger.warning( - f"pid={pid} not found in policy_map! It was probably added" - " on-the-fly and is not part of the static `config." - "multiagent.policies` dict. Ignoring it for now.") - continue - self.policy_map[pid].set_state(state) + pol_spec = objs.get("policy_specs", {}).get(pid) + if not pol_spec: + logger.warning( + f"PolicyID '{pid}' was probably added on-the-fly (not" + " part of the static `multagent.policies` config) and" + " no PolicySpec objects found in the pickled policy " + "state. Will not add `{pid}`, but ignore it for now.") + else: + self.add_policy( + policy_id=pid, + policy_cls=pol_spec.policy_class, + observation_space=pol_spec.observation_space, + action_space=pol_spec.action_space, + config=pol_spec.config, + ) + else: + self.policy_map[pid].set_state(state) @DeveloperAPI def set_global_vars(self, global_vars: dict) -> None: @@ -1478,10 +1494,3 @@ def _validate_env(env: Any) -> EnvType: "ExternalEnv, VectorEnv, or BaseEnv. The provided env creator " "function returned {} ({}).".format(env, type(env))) return env - - -def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool: - for policy, _, _, _ in policy_dict.values(): - if issubclass(policy, TFPolicy): - return True - return False diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index c9f604a0e..5737c7bd2 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -555,8 +555,12 @@ def _env_runner( extra_batch_callback, env_id=env_id) # Call each policy's Exploration.on_episode_start method. - # types: Policy - for p in worker.policy_map.values(): + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start( policy=p, @@ -904,8 +908,14 @@ def _process_observations( if ma_sample_batch: outputs.append(ma_sample_batch) - # Call each policy's Exploration.on_episode_end method. - for p in worker.policy_map.values(): + # Call each (in-memory) policy's Exploration.on_episode_end + # method. + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, diff --git a/rllib/examples/self_play_with_open_spiel.py b/rllib/examples/self_play_with_open_spiel.py index 5e292d040..86eca7f8b 100644 --- a/rllib/examples/self_play_with_open_spiel.py +++ b/rllib/examples/self_play_with_open_spiel.py @@ -41,6 +41,7 @@ parser.add_argument( default="tf", help="The DL framework specifier.") parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument("--num-workers", type=int, default=2) parser.add_argument( "--from-checkpoint", type=str, @@ -197,6 +198,7 @@ if __name__ == "__main__": # Always just train the "main" policy. "policies_to_train": ["main"], }, + "num_workers": args.num_workers, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "framework": args.framework, diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index 5a8a960dc..5a0457126 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -2,7 +2,7 @@ from collections import deque import gym import os import pickle -from typing import Callable, Dict, Optional, Type, TYPE_CHECKING +from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import override @@ -39,10 +39,19 @@ class PolicyMap(dict): """Initializes a PolicyMap instance. Args: - maxlen (int): The maximum number of policies to hold in memory. + worker_index (int): The worker index of the RolloutWorker this map + resides in. + num_workers (int): The total number of remote workers in the + WorkerSet to which this map's RolloutWorker belongs to. + capacity (int): The maximum number of policies to hold in memory. The least used ones are written to disk/S3 and retrieved when needed. - path (str): + path (str): The path to store the policy pickle files to. Files + will have the name: [policy_id].[worker idx].policy.pkl. + policy_config (TrainerConfigDict): The Trainer's base config dict. + session_creator (Optional[Callable[[], tf1.Session]): An optional + tf1.Session creation callable. + seed (int): An optional seed (used to seed tf policies). """ super().__init__() @@ -53,22 +62,22 @@ class PolicyMap(dict): # The file extension for stashed policies (that are no longer available # in-memory but can be reinstated any time from storage). - self.extension = ".policy.pkl" + self.extension = f".{self.worker_index}.policy.pkl" # Dictionary of keys that may be looked up (cached or not). - self.valid_keys = set() + self.valid_keys: Set[str] = set() # The actual cache with the in-memory policy objects. - self.cache = {} + self.cache: Dict[str, Policy] = {} # The doubly-linked list holding the currently in-memory objects. self.deque = deque(maxlen=capacity or 10) # The file path where to store overflowing policies. self.path = path or "." # The core config to use. Each single policy's config override is # added on top of this. - self.policy_config = policy_config or {} + self.policy_config: TrainerConfigDict = policy_config or {} # The orig classes/obs+act spaces, and config overrides of the # Policies. - self.policy_specs = {} # type: Dict[PolicyID, PolicySpec] + self.policy_specs: Dict[PolicyID, PolicySpec] = {} def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"], observation_space: gym.Space, action_space: gym.Space, @@ -140,7 +149,7 @@ class PolicyMap(dict): def __getitem__(self, item): # Never seen this key -> Error. if item not in self.valid_keys: - raise KeyError(f"'{item}' not a valid key!") + raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!") # Item already in cache -> Rearrange deque (least recently used) and # return. @@ -250,9 +259,7 @@ class PolicyMap(dict): policy_state = policy.get_state() # Closes policy's tf session, if any. self._close_session(policy) - # Remove from memory. - # TODO: (sven) This should clear the tf Graph as well, if the Trainer - # would not hold parts of the graph (e.g. in tf multi-GPU setups). + # Remove from memory. This will clear the tf Graph as well. del self.cache[delkey] # Write state to disk. with open(self.path + "/" + delkey + self.extension, "wb") as f: diff --git a/rllib/tests/test_multi_agent_env.py b/rllib/tests/test_multi_agent_env.py index ae2c63b76..c19bc12d4 100644 --- a/rllib/tests/test_multi_agent_env.py +++ b/rllib/tests/test_multi_agent_env.py @@ -457,8 +457,9 @@ class TestMultiAgentEnv(unittest.TestCase): self.assertTrue( pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1]) - self.assertRaises( + self.assertRaisesRegex( KeyError, + "not found in PolicyMap", lambda: pg.compute_single_action( [0, 0, 0, 0], policy_id="policy_3")) diff --git a/rllib/tests/test_trainer.py b/rllib/tests/test_trainer.py deleted file mode 100644 index 605b4cac8..000000000 --- a/rllib/tests/test_trainer.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Testing for trainer class""" -import copy -import unittest -from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG - - -class TestTrainer(unittest.TestCase): - def test_validate_config_idempotent(self): - """ - Asserts that validate_config run multiple - times on COMMON_CONFIG will be idempotent - """ - # Given - standard_config = copy.deepcopy(COMMON_CONFIG) - - # When (we validate config 2 times) - Trainer._validate_config(standard_config) - config_v1 = copy.deepcopy(standard_config) - Trainer._validate_config(standard_config) - config_v2 = copy.deepcopy(standard_config) - - # Then - self.assertEqual(config_v1, config_v2) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__]))