mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Redo: "fix self play example scripts" PR (17566) (#17895)
* wip. * wip. * wip. * wip. * wip. * wip. * wip. * wip. * wip.
This commit is contained in:
parent
2b7d907762
commit
f18213712f
9 changed files with 144 additions and 103 deletions
|
@ -1614,13 +1614,6 @@ py_test(
|
|||
srcs = ["tests/test_timesteps.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_trainer",
|
||||
tags = ["tests_dir", "tests_dir_T"],
|
||||
size = "small",
|
||||
srcs = ["tests/test_trainer.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# examples/ directory
|
||||
#
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import gym
|
||||
import copy
|
||||
from random import choice
|
||||
import unittest
|
||||
|
||||
|
@ -6,6 +6,7 @@ import ray
|
|||
import ray.rllib.agents.a3c as a3c
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
import ray.rllib.agents.pg as pg
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
@ -19,9 +20,24 @@ class TestTrainer(unittest.TestCase):
|
|||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_add_delete_policy(self):
|
||||
env = gym.make("CartPole-v0")
|
||||
def test_validate_config_idempotent(self):
|
||||
"""
|
||||
Asserts that validate_config run multiple
|
||||
times on COMMON_CONFIG will be idempotent
|
||||
"""
|
||||
# Given:
|
||||
standard_config = copy.deepcopy(COMMON_CONFIG)
|
||||
|
||||
# When (we validate config 2 times), ...
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v1 = copy.deepcopy(standard_config)
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v2 = copy.deepcopy(standard_config)
|
||||
|
||||
# ... then ...
|
||||
self.assertEqual(config_v1, config_v2)
|
||||
|
||||
def test_add_delete_policy(self):
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
config.update({
|
||||
"env": MultiAgentCartPole,
|
||||
|
@ -30,34 +46,38 @@ class TestTrainer(unittest.TestCase):
|
|||
"num_agents": 4,
|
||||
},
|
||||
},
|
||||
"num_workers": 2, # Test on remote workers as well.
|
||||
"model": {
|
||||
"fcnet_hiddens": [5],
|
||||
"fcnet_activation": "linear",
|
||||
},
|
||||
"train_batch_size": 100,
|
||||
"rollout_fragment_length": 50,
|
||||
"multiagent": {
|
||||
# Start with a single policy.
|
||||
"policies": {
|
||||
"p0": (None, env.observation_space, env.action_space, {}),
|
||||
},
|
||||
"policies": {"p0"},
|
||||
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
|
||||
# And only two policies that can be stored in memory at a
|
||||
# time.
|
||||
"policy_map_capacity": 2,
|
||||
},
|
||||
})
|
||||
|
||||
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
|
||||
# refactor PR merged.
|
||||
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||
for _ in framework_iterator(config):
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p0" in r["policy_reward_min"])
|
||||
for i in range(1, 4):
|
||||
self.assertTrue("p0" in r["info"]["learner"])
|
||||
checkpoints = []
|
||||
for i in range(1, 3):
|
||||
|
||||
def new_mapping_fn(agent_id, episode, **kwargs):
|
||||
return f"p{choice([i, i - 1])}"
|
||||
|
||||
# Add a new policy.
|
||||
pid = f"p{i}"
|
||||
new_pol = trainer.add_policy(
|
||||
f"p{i}",
|
||||
pid,
|
||||
trainer._policy_class,
|
||||
observation_space=env.observation_space,
|
||||
action_space=env.action_space,
|
||||
config={},
|
||||
# Test changing the mapping fn.
|
||||
policy_mapping_fn=new_mapping_fn,
|
||||
# Change the list of policies to train.
|
||||
|
@ -65,14 +85,27 @@ class TestTrainer(unittest.TestCase):
|
|||
)
|
||||
pol_map = trainer.workers.local_worker().policy_map
|
||||
self.assertTrue(new_pol is not trainer.get_policy("p0"))
|
||||
for j in range(i):
|
||||
for j in range(i + 1):
|
||||
self.assertTrue(f"p{j}" in pol_map)
|
||||
self.assertTrue(len(pol_map) == i + 1)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p1" in r["policy_reward_min"])
|
||||
self.assertTrue("p1" in r["info"]["learner"])
|
||||
checkpoints.append(trainer.save())
|
||||
|
||||
# Test restoring from the checkpoint (which has more policies
|
||||
# than what's defined in the config dict).
|
||||
test = pg.PGTrainer(config=config)
|
||||
test.restore(checkpoints[-1])
|
||||
test.train()
|
||||
# Test creating an action with the added (and restored) policy.
|
||||
a = test.compute_single_action(
|
||||
test.get_policy("p0").observation_space.sample(),
|
||||
policy_id=pid)
|
||||
self.assertTrue(test.get_policy("p0").action_space.contains(a))
|
||||
test.stop()
|
||||
|
||||
# Delete all added policies again from trainer.
|
||||
for i in range(3, 0, -1):
|
||||
for i in range(2, 0, -1):
|
||||
trainer.remove_policy(
|
||||
f"p{i}",
|
||||
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
|
||||
|
@ -130,7 +163,7 @@ class TestTrainer(unittest.TestCase):
|
|||
|
||||
# Try again using `create_env_on_driver=True`.
|
||||
# This force-adds the env on the local-worker, so this Trainer
|
||||
# can `evaluate` even though, it doesn't have an evaluation-worker
|
||||
# can `evaluate` even though it doesn't have an evaluation-worker
|
||||
# set.
|
||||
config["create_env_on_driver"] = True
|
||||
trainer_w_env_on_driver = a3c.A3CTrainer(config=config)
|
||||
|
|
|
@ -951,7 +951,7 @@ class Trainer(Trainable):
|
|||
unsquash_actions: Optional[bool] = None,
|
||||
clip_actions: Optional[bool] = None,
|
||||
) -> TensorStructType:
|
||||
"""Computes an action for the specified policy on the local Worker.
|
||||
"""Computes an action for the specified policy on the local worker.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_single_action() on it
|
||||
|
@ -982,17 +982,31 @@ class Trainer(Trainable):
|
|||
any: The computed action if full_fetch=False, or
|
||||
tuple: The full output of policy.compute_actions() if
|
||||
full_fetch=True or we have an RNN-based Policy.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `policy_id` cannot be found in this Trainer's
|
||||
local worker.
|
||||
"""
|
||||
policy = self.get_policy(policy_id)
|
||||
if policy is None:
|
||||
raise KeyError(
|
||||
f"PolicyID '{policy_id}' not found in PolicyMap of the "
|
||||
f"Trainer's local worker!")
|
||||
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
|
||||
# Check the preprocessor and preprocess, if necessary.
|
||||
pp = self.workers.local_worker().preprocessors[policy_id]
|
||||
pp = local_worker.preprocessors[policy_id]
|
||||
if type(pp).__name__ != "NoPreprocessor":
|
||||
observation = pp.transform(observation)
|
||||
filtered_observation = self.workers.local_worker().filters[policy_id](
|
||||
filtered_observation = local_worker.filters[policy_id](
|
||||
observation, update=False)
|
||||
|
||||
result = self.get_policy(policy_id).compute_single_action(
|
||||
# Compute the action.
|
||||
result = policy.compute_single_action(
|
||||
filtered_observation,
|
||||
state,
|
||||
prev_action,
|
||||
|
@ -1002,10 +1016,12 @@ class Trainer(Trainable):
|
|||
clip_actions=clip_actions,
|
||||
explore=explore)
|
||||
|
||||
# Return 3-Tuple: Action, states, and extra-action fetches.
|
||||
if state or full_fetch:
|
||||
return result
|
||||
# Ensure backward compatibility.
|
||||
else:
|
||||
return result[0] # backwards compatibility
|
||||
return result[0]
|
||||
|
||||
@Deprecated(new="compute_single_action", error=False)
|
||||
def compute_action(self, *args, **kwargs):
|
||||
|
@ -1193,7 +1209,6 @@ class Trainer(Trainable):
|
|||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
config=config,
|
||||
policy_config=self.config,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
policies_to_train=policies_to_train,
|
||||
)
|
||||
|
|
|
@ -2,13 +2,13 @@ import random
|
|||
import numpy as np
|
||||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
import platform
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
|
||||
TYPE_CHECKING, Union
|
||||
|
||||
import ray
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
|
@ -29,7 +29,6 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
|||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import force_list, merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -1055,7 +1054,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
policy_config: Optional[TrainerConfigDict] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
|
@ -1093,14 +1091,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||
{
|
||||
policy_id: PolicySpec(policy_cls, observation_space,
|
||||
action_space, config)
|
||||
action_space, config or {})
|
||||
},
|
||||
self.env,
|
||||
spaces=self.spaces,
|
||||
policy_config=policy_config)
|
||||
|
||||
policy_config=self.policy_config,
|
||||
)
|
||||
self._build_policy_map(
|
||||
policy_dict, policy_config, seed=policy_config.get("seed"))
|
||||
policy_dict,
|
||||
self.policy_config,
|
||||
seed=self.policy_config.get("seed"))
|
||||
new_policy = self.policy_map[policy_id]
|
||||
|
||||
self.filters[policy_id] = get_filter(
|
||||
|
@ -1242,11 +1242,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
@DeveloperAPI
|
||||
def save(self) -> bytes:
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = {
|
||||
pid: self.policy_map[pid].get_state()
|
||||
for pid in self.policy_map
|
||||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
state = {}
|
||||
policy_specs = {}
|
||||
for pid in self.policy_map:
|
||||
state[pid] = self.policy_map[pid].get_state()
|
||||
policy_specs[pid] = self.policy_map.policy_specs[pid]
|
||||
return pickle.dumps({
|
||||
"filters": filters,
|
||||
"state": state,
|
||||
"policy_specs": policy_specs,
|
||||
})
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, objs: bytes) -> None:
|
||||
|
@ -1254,12 +1259,23 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
if pid not in self.policy_map:
|
||||
logger.warning(
|
||||
f"pid={pid} not found in policy_map! It was probably added"
|
||||
" on-the-fly and is not part of the static `config."
|
||||
"multiagent.policies` dict. Ignoring it for now.")
|
||||
continue
|
||||
self.policy_map[pid].set_state(state)
|
||||
pol_spec = objs.get("policy_specs", {}).get(pid)
|
||||
if not pol_spec:
|
||||
logger.warning(
|
||||
f"PolicyID '{pid}' was probably added on-the-fly (not"
|
||||
" part of the static `multagent.policies` config) and"
|
||||
" no PolicySpec objects found in the pickled policy "
|
||||
"state. Will not add `{pid}`, but ignore it for now.")
|
||||
else:
|
||||
self.add_policy(
|
||||
policy_id=pid,
|
||||
policy_cls=pol_spec.policy_class,
|
||||
observation_space=pol_spec.observation_space,
|
||||
action_space=pol_spec.action_space,
|
||||
config=pol_spec.config,
|
||||
)
|
||||
else:
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_global_vars(self, global_vars: dict) -> None:
|
||||
|
@ -1478,10 +1494,3 @@ def _validate_env(env: Any) -> EnvType:
|
|||
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
|
||||
"function returned {} ({}).".format(env, type(env)))
|
||||
return env
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool:
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicy):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -555,8 +555,12 @@ def _env_runner(
|
|||
extra_batch_callback,
|
||||
env_id=env_id)
|
||||
# Call each policy's Exploration.on_episode_start method.
|
||||
# types: Policy
|
||||
for p in worker.policy_map.values():
|
||||
# Note: This may break the exploration (e.g. ParameterNoise) of
|
||||
# policies in the `policy_map` that have not been recently used
|
||||
# (and are therefore stashed to disk). However, we certainly do not
|
||||
# want to loop through all (even stashed) policies here as that
|
||||
# would counter the purpose of the LRU policy caching.
|
||||
for p in worker.policy_map.cache.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_start(
|
||||
policy=p,
|
||||
|
@ -904,8 +908,14 @@ def _process_observations(
|
|||
if ma_sample_batch:
|
||||
outputs.append(ma_sample_batch)
|
||||
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in worker.policy_map.values():
|
||||
# Call each (in-memory) policy's Exploration.on_episode_end
|
||||
# method.
|
||||
# Note: This may break the exploration (e.g. ParameterNoise) of
|
||||
# policies in the `policy_map` that have not been recently used
|
||||
# (and are therefore stashed to disk). However, we certainly do not
|
||||
# want to loop through all (even stashed) policies here as that
|
||||
# would counter the purpose of the LRU policy caching.
|
||||
for p in worker.policy_map.cache.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
|
|
|
@ -41,6 +41,7 @@ parser.add_argument(
|
|||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument("--num-workers", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=str,
|
||||
|
@ -197,6 +198,7 @@ if __name__ == "__main__":
|
|||
# Always just train the "main" policy.
|
||||
"policies_to_train": ["main"],
|
||||
},
|
||||
"num_workers": args.num_workers,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import deque
|
|||
import gym
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
|
||||
from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -39,10 +39,19 @@ class PolicyMap(dict):
|
|||
"""Initializes a PolicyMap instance.
|
||||
|
||||
Args:
|
||||
maxlen (int): The maximum number of policies to hold in memory.
|
||||
worker_index (int): The worker index of the RolloutWorker this map
|
||||
resides in.
|
||||
num_workers (int): The total number of remote workers in the
|
||||
WorkerSet to which this map's RolloutWorker belongs to.
|
||||
capacity (int): The maximum number of policies to hold in memory.
|
||||
The least used ones are written to disk/S3 and retrieved
|
||||
when needed.
|
||||
path (str):
|
||||
path (str): The path to store the policy pickle files to. Files
|
||||
will have the name: [policy_id].[worker idx].policy.pkl.
|
||||
policy_config (TrainerConfigDict): The Trainer's base config dict.
|
||||
session_creator (Optional[Callable[[], tf1.Session]): An optional
|
||||
tf1.Session creation callable.
|
||||
seed (int): An optional seed (used to seed tf policies).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -53,22 +62,22 @@ class PolicyMap(dict):
|
|||
|
||||
# The file extension for stashed policies (that are no longer available
|
||||
# in-memory but can be reinstated any time from storage).
|
||||
self.extension = ".policy.pkl"
|
||||
self.extension = f".{self.worker_index}.policy.pkl"
|
||||
|
||||
# Dictionary of keys that may be looked up (cached or not).
|
||||
self.valid_keys = set()
|
||||
self.valid_keys: Set[str] = set()
|
||||
# The actual cache with the in-memory policy objects.
|
||||
self.cache = {}
|
||||
self.cache: Dict[str, Policy] = {}
|
||||
# The doubly-linked list holding the currently in-memory objects.
|
||||
self.deque = deque(maxlen=capacity or 10)
|
||||
# The file path where to store overflowing policies.
|
||||
self.path = path or "."
|
||||
# The core config to use. Each single policy's config override is
|
||||
# added on top of this.
|
||||
self.policy_config = policy_config or {}
|
||||
self.policy_config: TrainerConfigDict = policy_config or {}
|
||||
# The orig classes/obs+act spaces, and config overrides of the
|
||||
# Policies.
|
||||
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
|
||||
self.policy_specs: Dict[PolicyID, PolicySpec] = {}
|
||||
|
||||
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
|
||||
observation_space: gym.Space, action_space: gym.Space,
|
||||
|
@ -140,7 +149,7 @@ class PolicyMap(dict):
|
|||
def __getitem__(self, item):
|
||||
# Never seen this key -> Error.
|
||||
if item not in self.valid_keys:
|
||||
raise KeyError(f"'{item}' not a valid key!")
|
||||
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
|
||||
|
||||
# Item already in cache -> Rearrange deque (least recently used) and
|
||||
# return.
|
||||
|
@ -250,9 +259,7 @@ class PolicyMap(dict):
|
|||
policy_state = policy.get_state()
|
||||
# Closes policy's tf session, if any.
|
||||
self._close_session(policy)
|
||||
# Remove from memory.
|
||||
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
|
||||
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
|
||||
# Remove from memory. This will clear the tf Graph as well.
|
||||
del self.cache[delkey]
|
||||
# Write state to disk.
|
||||
with open(self.path + "/" + delkey + self.extension, "wb") as f:
|
||||
|
|
|
@ -457,8 +457,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in
|
||||
[0, 1])
|
||||
self.assertRaises(
|
||||
self.assertRaisesRegex(
|
||||
KeyError,
|
||||
"not found in PolicyMap",
|
||||
lambda: pg.compute_single_action(
|
||||
[0, 0, 0, 0], policy_id="policy_3"))
|
||||
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
"""Testing for trainer class"""
|
||||
import copy
|
||||
import unittest
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
|
||||
|
||||
class TestTrainer(unittest.TestCase):
|
||||
def test_validate_config_idempotent(self):
|
||||
"""
|
||||
Asserts that validate_config run multiple
|
||||
times on COMMON_CONFIG will be idempotent
|
||||
"""
|
||||
# Given
|
||||
standard_config = copy.deepcopy(COMMON_CONFIG)
|
||||
|
||||
# When (we validate config 2 times)
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v1 = copy.deepcopy(standard_config)
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v2 = copy.deepcopy(standard_config)
|
||||
|
||||
# Then
|
||||
self.assertEqual(config_v1, config_v2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue