[RLlib] Redo: "fix self play example scripts" PR (17566) (#17895)

* wip.

* wip.

* wip.

* wip.

* wip.

* wip.

* wip.

* wip.

* wip.
This commit is contained in:
Sven Mika 2021-08-17 18:13:35 +02:00 committed by GitHub
parent 2b7d907762
commit f18213712f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 144 additions and 103 deletions

View file

@ -1614,13 +1614,6 @@ py_test(
srcs = ["tests/test_timesteps.py"]
)
py_test(
name = "tests/test_trainer",
tags = ["tests_dir", "tests_dir_T"],
size = "small",
srcs = ["tests/test_trainer.py"]
)
# --------------------------------------------------------------------
# examples/ directory
#

View file

@ -1,4 +1,4 @@
import gym
import copy
from random import choice
import unittest
@ -6,6 +6,7 @@ import ray
import ray.rllib.agents.a3c as a3c
import ray.rllib.agents.dqn as dqn
import ray.rllib.agents.pg as pg
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import framework_iterator
@ -19,9 +20,24 @@ class TestTrainer(unittest.TestCase):
def tearDownClass(cls):
ray.shutdown()
def test_add_delete_policy(self):
env = gym.make("CartPole-v0")
def test_validate_config_idempotent(self):
"""
Asserts that validate_config run multiple
times on COMMON_CONFIG will be idempotent
"""
# Given:
standard_config = copy.deepcopy(COMMON_CONFIG)
# When (we validate config 2 times), ...
Trainer._validate_config(standard_config)
config_v1 = copy.deepcopy(standard_config)
Trainer._validate_config(standard_config)
config_v2 = copy.deepcopy(standard_config)
# ... then ...
self.assertEqual(config_v1, config_v2)
def test_add_delete_policy(self):
config = pg.DEFAULT_CONFIG.copy()
config.update({
"env": MultiAgentCartPole,
@ -30,34 +46,38 @@ class TestTrainer(unittest.TestCase):
"num_agents": 4,
},
},
"num_workers": 2, # Test on remote workers as well.
"model": {
"fcnet_hiddens": [5],
"fcnet_activation": "linear",
},
"train_batch_size": 100,
"rollout_fragment_length": 50,
"multiagent": {
# Start with a single policy.
"policies": {
"p0": (None, env.observation_space, env.action_space, {}),
},
"policies": {"p0"},
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
# And only two policies that can be stored in memory at a
# time.
"policy_map_capacity": 2,
},
})
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
# refactor PR merged.
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
for _ in framework_iterator(config):
trainer = pg.PGTrainer(config=config)
r = trainer.train()
self.assertTrue("p0" in r["policy_reward_min"])
for i in range(1, 4):
self.assertTrue("p0" in r["info"]["learner"])
checkpoints = []
for i in range(1, 3):
def new_mapping_fn(agent_id, episode, **kwargs):
return f"p{choice([i, i - 1])}"
# Add a new policy.
pid = f"p{i}"
new_pol = trainer.add_policy(
f"p{i}",
pid,
trainer._policy_class,
observation_space=env.observation_space,
action_space=env.action_space,
config={},
# Test changing the mapping fn.
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
@ -65,14 +85,27 @@ class TestTrainer(unittest.TestCase):
)
pol_map = trainer.workers.local_worker().policy_map
self.assertTrue(new_pol is not trainer.get_policy("p0"))
for j in range(i):
for j in range(i + 1):
self.assertTrue(f"p{j}" in pol_map)
self.assertTrue(len(pol_map) == i + 1)
r = trainer.train()
self.assertTrue("p1" in r["policy_reward_min"])
self.assertTrue("p1" in r["info"]["learner"])
checkpoints.append(trainer.save())
# Test restoring from the checkpoint (which has more policies
# than what's defined in the config dict).
test = pg.PGTrainer(config=config)
test.restore(checkpoints[-1])
test.train()
# Test creating an action with the added (and restored) policy.
a = test.compute_single_action(
test.get_policy("p0").observation_space.sample(),
policy_id=pid)
self.assertTrue(test.get_policy("p0").action_space.contains(a))
test.stop()
# Delete all added policies again from trainer.
for i in range(3, 0, -1):
for i in range(2, 0, -1):
trainer.remove_policy(
f"p{i}",
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
@ -130,7 +163,7 @@ class TestTrainer(unittest.TestCase):
# Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer
# can `evaluate` even though, it doesn't have an evaluation-worker
# can `evaluate` even though it doesn't have an evaluation-worker
# set.
config["create_env_on_driver"] = True
trainer_w_env_on_driver = a3c.A3CTrainer(config=config)

View file

@ -951,7 +951,7 @@ class Trainer(Trainable):
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
) -> TensorStructType:
"""Computes an action for the specified policy on the local Worker.
"""Computes an action for the specified policy on the local worker.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_single_action() on it
@ -982,17 +982,31 @@ class Trainer(Trainable):
any: The computed action if full_fetch=False, or
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
Raises:
KeyError: If the `policy_id` cannot be found in this Trainer's
local worker.
"""
policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Trainer's local worker!")
local_worker = self.workers.local_worker()
if state is None:
state = []
# Check the preprocessor and preprocess, if necessary.
pp = self.workers.local_worker().preprocessors[policy_id]
pp = local_worker.preprocessors[policy_id]
if type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
filtered_observation = self.workers.local_worker().filters[policy_id](
filtered_observation = local_worker.filters[policy_id](
observation, update=False)
result = self.get_policy(policy_id).compute_single_action(
# Compute the action.
result = policy.compute_single_action(
filtered_observation,
state,
prev_action,
@ -1002,10 +1016,12 @@ class Trainer(Trainable):
clip_actions=clip_actions,
explore=explore)
# Return 3-Tuple: Action, states, and extra-action fetches.
if state or full_fetch:
return result
# Ensure backward compatibility.
else:
return result[0] # backwards compatibility
return result[0]
@Deprecated(new="compute_single_action", error=False)
def compute_action(self, *args, **kwargs):
@ -1193,7 +1209,6 @@ class Trainer(Trainable):
observation_space=observation_space,
action_space=action_space,
config=config,
policy_config=self.config,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)

View file

@ -2,13 +2,13 @@ import random
import numpy as np
import gym
import logging
import pickle
import platform
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
TYPE_CHECKING, Union
import ray
from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
@ -29,7 +29,6 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import force_list, merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
@ -1055,7 +1054,6 @@ class RolloutWorker(ParallelIteratorWorker):
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
config: Optional[PartialTrainerConfigDict] = None,
policy_config: Optional[TrainerConfigDict] = None,
policy_mapping_fn: Optional[Callable[
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
@ -1093,14 +1091,16 @@ class RolloutWorker(ParallelIteratorWorker):
policy_dict = _determine_spaces_for_multi_agent_dict(
{
policy_id: PolicySpec(policy_cls, observation_space,
action_space, config)
action_space, config or {})
},
self.env,
spaces=self.spaces,
policy_config=policy_config)
policy_config=self.policy_config,
)
self._build_policy_map(
policy_dict, policy_config, seed=policy_config.get("seed"))
policy_dict,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]
self.filters[policy_id] = get_filter(
@ -1242,11 +1242,16 @@ class RolloutWorker(ParallelIteratorWorker):
@DeveloperAPI
def save(self) -> bytes:
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
state = {}
policy_specs = {}
for pid in self.policy_map:
state[pid] = self.policy_map[pid].get_state()
policy_specs[pid] = self.policy_map.policy_specs[pid]
return pickle.dumps({
"filters": filters,
"state": state,
"policy_specs": policy_specs,
})
@DeveloperAPI
def restore(self, objs: bytes) -> None:
@ -1254,12 +1259,23 @@ class RolloutWorker(ParallelIteratorWorker):
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
if pid not in self.policy_map:
logger.warning(
f"pid={pid} not found in policy_map! It was probably added"
" on-the-fly and is not part of the static `config."
"multiagent.policies` dict. Ignoring it for now.")
continue
self.policy_map[pid].set_state(state)
pol_spec = objs.get("policy_specs", {}).get(pid)
if not pol_spec:
logger.warning(
f"PolicyID '{pid}' was probably added on-the-fly (not"
" part of the static `multagent.policies` config) and"
" no PolicySpec objects found in the pickled policy "
"state. Will not add `{pid}`, but ignore it for now.")
else:
self.add_policy(
policy_id=pid,
policy_cls=pol_spec.policy_class,
observation_space=pol_spec.observation_space,
action_space=pol_spec.action_space,
config=pol_spec.config,
)
else:
self.policy_map[pid].set_state(state)
@DeveloperAPI
def set_global_vars(self, global_vars: dict) -> None:
@ -1478,10 +1494,3 @@ def _validate_env(env: Any) -> EnvType:
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env
def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool:
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicy):
return True
return False

View file

@ -555,8 +555,12 @@ def _env_runner(
extra_batch_callback,
env_id=env_id)
# Call each policy's Exploration.on_episode_start method.
# types: Policy
for p in worker.policy_map.values():
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
@ -904,8 +908,14 @@ def _process_observations(
if ma_sample_batch:
outputs.append(ma_sample_batch)
# Call each policy's Exploration.on_episode_end method.
for p in worker.policy_map.values():
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,

View file

@ -41,6 +41,7 @@ parser.add_argument(
default="tf",
help="The DL framework specifier.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument(
"--from-checkpoint",
type=str,
@ -197,6 +198,7 @@ if __name__ == "__main__":
# Always just train the "main" policy.
"policies_to_train": ["main"],
},
"num_workers": args.num_workers,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": args.framework,

View file

@ -2,7 +2,7 @@ from collections import deque
import gym
import os
import pickle
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import override
@ -39,10 +39,19 @@ class PolicyMap(dict):
"""Initializes a PolicyMap instance.
Args:
maxlen (int): The maximum number of policies to hold in memory.
worker_index (int): The worker index of the RolloutWorker this map
resides in.
num_workers (int): The total number of remote workers in the
WorkerSet to which this map's RolloutWorker belongs to.
capacity (int): The maximum number of policies to hold in memory.
The least used ones are written to disk/S3 and retrieved
when needed.
path (str):
path (str): The path to store the policy pickle files to. Files
will have the name: [policy_id].[worker idx].policy.pkl.
policy_config (TrainerConfigDict): The Trainer's base config dict.
session_creator (Optional[Callable[[], tf1.Session]): An optional
tf1.Session creation callable.
seed (int): An optional seed (used to seed tf policies).
"""
super().__init__()
@ -53,22 +62,22 @@ class PolicyMap(dict):
# The file extension for stashed policies (that are no longer available
# in-memory but can be reinstated any time from storage).
self.extension = ".policy.pkl"
self.extension = f".{self.worker_index}.policy.pkl"
# Dictionary of keys that may be looked up (cached or not).
self.valid_keys = set()
self.valid_keys: Set[str] = set()
# The actual cache with the in-memory policy objects.
self.cache = {}
self.cache: Dict[str, Policy] = {}
# The doubly-linked list holding the currently in-memory objects.
self.deque = deque(maxlen=capacity or 10)
# The file path where to store overflowing policies.
self.path = path or "."
# The core config to use. Each single policy's config override is
# added on top of this.
self.policy_config = policy_config or {}
self.policy_config: TrainerConfigDict = policy_config or {}
# The orig classes/obs+act spaces, and config overrides of the
# Policies.
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
self.policy_specs: Dict[PolicyID, PolicySpec] = {}
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
observation_space: gym.Space, action_space: gym.Space,
@ -140,7 +149,7 @@ class PolicyMap(dict):
def __getitem__(self, item):
# Never seen this key -> Error.
if item not in self.valid_keys:
raise KeyError(f"'{item}' not a valid key!")
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
# Item already in cache -> Rearrange deque (least recently used) and
# return.
@ -250,9 +259,7 @@ class PolicyMap(dict):
policy_state = policy.get_state()
# Closes policy's tf session, if any.
self._close_session(policy)
# Remove from memory.
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
# Remove from memory. This will clear the tf Graph as well.
del self.cache[delkey]
# Write state to disk.
with open(self.path + "/" + delkey + self.extension, "wb") as f:

View file

@ -457,8 +457,9 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertTrue(
pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in
[0, 1])
self.assertRaises(
self.assertRaisesRegex(
KeyError,
"not found in PolicyMap",
lambda: pg.compute_single_action(
[0, 0, 0, 0], policy_id="policy_3"))

View file

@ -1,29 +0,0 @@
"""Testing for trainer class"""
import copy
import unittest
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
class TestTrainer(unittest.TestCase):
def test_validate_config_idempotent(self):
"""
Asserts that validate_config run multiple
times on COMMON_CONFIG will be idempotent
"""
# Given
standard_config = copy.deepcopy(COMMON_CONFIG)
# When (we validate config 2 times)
Trainer._validate_config(standard_config)
config_v1 = copy.deepcopy(standard_config)
Trainer._validate_config(standard_config)
config_v2 = copy.deepcopy(standard_config)
# Then
self.assertEqual(config_v1, config_v2)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))