mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix Trainer.add_policy
for num_workers>0 (self play example scripts). (#17566)
This commit is contained in:
parent
0eb0e0ff58
commit
3b447265d8
7 changed files with 114 additions and 64 deletions
|
@ -1,4 +1,3 @@
|
|||
import gym
|
||||
from random import choice
|
||||
import unittest
|
||||
|
||||
|
@ -20,8 +19,6 @@ class TestTrainer(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_add_delete_policy(self):
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
config.update({
|
||||
"env": MultiAgentCartPole,
|
||||
|
@ -30,34 +27,30 @@ class TestTrainer(unittest.TestCase):
|
|||
"num_agents": 4,
|
||||
},
|
||||
},
|
||||
"num_workers": 2, # Test on remote workers as well.
|
||||
"multiagent": {
|
||||
# Start with a single policy.
|
||||
"policies": {
|
||||
"p0": (None, env.observation_space, env.action_space, {}),
|
||||
},
|
||||
"policies": {"p0"},
|
||||
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
|
||||
"policy_map_capacity": 2,
|
||||
},
|
||||
})
|
||||
|
||||
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
|
||||
# refactor PR merged.
|
||||
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||
for _ in framework_iterator(config):
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p0" in r["policy_reward_min"])
|
||||
for i in range(1, 4):
|
||||
checkpoints = []
|
||||
for i in range(1, 3):
|
||||
|
||||
def new_mapping_fn(agent_id, episode, **kwargs):
|
||||
return f"p{choice([i, i - 1])}"
|
||||
|
||||
# Add a new policy.
|
||||
pid = f"p{i}"
|
||||
new_pol = trainer.add_policy(
|
||||
f"p{i}",
|
||||
pid,
|
||||
trainer._policy_class,
|
||||
observation_space=env.observation_space,
|
||||
action_space=env.action_space,
|
||||
config={},
|
||||
# Test changing the mapping fn.
|
||||
policy_mapping_fn=new_mapping_fn,
|
||||
# Change the list of policies to train.
|
||||
|
@ -70,9 +63,22 @@ class TestTrainer(unittest.TestCase):
|
|||
self.assertTrue(len(pol_map) == i + 1)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p1" in r["policy_reward_min"])
|
||||
checkpoints.append(trainer.save())
|
||||
|
||||
# Test restoring from the checkpoint (which has more policies
|
||||
# than what's defined in the config dict).
|
||||
test = pg.PGTrainer(config=config)
|
||||
test.restore(checkpoints[-1])
|
||||
test.train()
|
||||
# Test creating an action with the added (and restored) policy.
|
||||
a = test.compute_single_action(
|
||||
test.get_policy("p0").observation_space.sample(),
|
||||
policy_id=pid)
|
||||
self.assertTrue(test.get_policy("p0").action_space.contains(a))
|
||||
test.stop()
|
||||
|
||||
# Delete all added policies again from trainer.
|
||||
for i in range(3, 0, -1):
|
||||
for i in range(2, 0, -1):
|
||||
trainer.remove_policy(
|
||||
f"p{i}",
|
||||
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i - 1}",
|
||||
|
@ -130,7 +136,7 @@ class TestTrainer(unittest.TestCase):
|
|||
|
||||
# Try again using `create_env_on_driver=True`.
|
||||
# This force-adds the env on the local-worker, so this Trainer
|
||||
# can `evaluate` even though, it doesn't have an evaluation-worker
|
||||
# can `evaluate` even though it doesn't have an evaluation-worker
|
||||
# set.
|
||||
config["create_env_on_driver"] = True
|
||||
trainer_w_env_on_driver = a3c.A3CTrainer(config=config)
|
||||
|
|
|
@ -951,7 +951,7 @@ class Trainer(Trainable):
|
|||
unsquash_actions: Optional[bool] = None,
|
||||
clip_actions: Optional[bool] = None,
|
||||
) -> TensorStructType:
|
||||
"""Computes an action for the specified policy on the local Worker.
|
||||
"""Computes an action for the specified policy on the local worker.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_single_action() on it
|
||||
|
@ -982,17 +982,31 @@ class Trainer(Trainable):
|
|||
any: The computed action if full_fetch=False, or
|
||||
tuple: The full output of policy.compute_actions() if
|
||||
full_fetch=True or we have an RNN-based Policy.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `policy_id` cannot be found in this Trainer's
|
||||
local worker.
|
||||
"""
|
||||
policy = self.get_policy(policy_id)
|
||||
if policy is None:
|
||||
raise KeyError(
|
||||
f"PolicyID '{policy_id}' not found in PolicyMap of the "
|
||||
f"Trainer's local worker!")
|
||||
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
|
||||
# Check the preprocessor and preprocess, if necessary.
|
||||
pp = self.workers.local_worker().preprocessors[policy_id]
|
||||
pp = local_worker.preprocessors[policy_id]
|
||||
if type(pp).__name__ != "NoPreprocessor":
|
||||
observation = pp.transform(observation)
|
||||
filtered_observation = self.workers.local_worker().filters[policy_id](
|
||||
filtered_observation = local_worker.filters[policy_id](
|
||||
observation, update=False)
|
||||
|
||||
result = self.get_policy(policy_id).compute_single_action(
|
||||
# Compute the action.
|
||||
result = policy.compute_single_action(
|
||||
filtered_observation,
|
||||
state,
|
||||
prev_action,
|
||||
|
@ -1002,10 +1016,12 @@ class Trainer(Trainable):
|
|||
clip_actions=clip_actions,
|
||||
explore=explore)
|
||||
|
||||
# Return 3-Tuple: Action, states, and extra-action fetches.
|
||||
if state or full_fetch:
|
||||
return result
|
||||
# Ensure backward compatibility.
|
||||
else:
|
||||
return result[0] # backwards compatibility
|
||||
return result[0]
|
||||
|
||||
@Deprecated(new="compute_single_action", error=False)
|
||||
def compute_action(self, *args, **kwargs):
|
||||
|
@ -1193,7 +1209,6 @@ class Trainer(Trainable):
|
|||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
config=config,
|
||||
policy_config=self.config,
|
||||
policy_mapping_fn=policy_mapping_fn,
|
||||
policies_to_train=policies_to_train,
|
||||
)
|
||||
|
|
|
@ -2,13 +2,13 @@ import random
|
|||
import numpy as np
|
||||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
import platform
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
|
||||
TYPE_CHECKING, Union
|
||||
|
||||
import ray
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
|
@ -29,7 +29,6 @@ from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
|||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import force_list, merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -1057,7 +1056,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
policy_config: Optional[TrainerConfigDict] = None,
|
||||
policy_mapping_fn: Optional[Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
|
@ -1095,14 +1093,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||
{
|
||||
policy_id: PolicySpec(policy_cls, observation_space,
|
||||
action_space, config)
|
||||
action_space, config or {})
|
||||
},
|
||||
self.env,
|
||||
spaces=self.spaces,
|
||||
policy_config=policy_config)
|
||||
|
||||
policy_config=self.policy_config,
|
||||
)
|
||||
self._build_policy_map(
|
||||
policy_dict, policy_config, seed=policy_config.get("seed"))
|
||||
policy_dict,
|
||||
self.policy_config,
|
||||
seed=self.policy_config.get("seed"))
|
||||
new_policy = self.policy_map[policy_id]
|
||||
|
||||
self.filters[policy_id] = get_filter(
|
||||
|
@ -1244,11 +1244,16 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
@DeveloperAPI
|
||||
def save(self) -> bytes:
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = {
|
||||
pid: self.policy_map[pid].get_state()
|
||||
for pid in self.policy_map
|
||||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
state = {}
|
||||
policy_specs = {}
|
||||
for pid in self.policy_map:
|
||||
state[pid] = self.policy_map[pid].get_state()
|
||||
policy_specs[pid] = self.policy_map.policy_specs[pid]
|
||||
return pickle.dumps({
|
||||
"filters": filters,
|
||||
"state": state,
|
||||
"policy_specs": policy_specs,
|
||||
})
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, objs: bytes) -> None:
|
||||
|
@ -1256,12 +1261,23 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
if pid not in self.policy_map:
|
||||
logger.warning(
|
||||
f"pid={pid} not found in policy_map! It was probably added"
|
||||
" on-the-fly and is not part of the static `config."
|
||||
"multiagent.policies` dict. Ignoring it for now.")
|
||||
continue
|
||||
self.policy_map[pid].set_state(state)
|
||||
pol_spec = objs.get("policy_specs", {}).get(pid)
|
||||
if not pol_spec:
|
||||
logger.warning(
|
||||
f"PolicyID '{pid}' was probably added on-the-fly (not"
|
||||
" part of the static `multagent.policies` config) and"
|
||||
" no PolicySpec objects found in the pickled policy "
|
||||
"state. Will not add `{pid}`, but ignore it for now.")
|
||||
else:
|
||||
self.add_policy(
|
||||
policy_id=pid,
|
||||
policy_cls=pol_spec.policy_class,
|
||||
observation_space=pol_spec.observation_space,
|
||||
action_space=pol_spec.action_space,
|
||||
config=pol_spec.config,
|
||||
)
|
||||
else:
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_global_vars(self, global_vars: dict) -> None:
|
||||
|
@ -1480,10 +1496,3 @@ def _validate_env(env: Any) -> EnvType:
|
|||
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
|
||||
"function returned {} ({}).".format(env, type(env)))
|
||||
return env
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict: MultiAgentPolicyConfigDict) -> bool:
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicy):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -554,8 +554,12 @@ def _env_runner(
|
|||
extra_batch_callback,
|
||||
env_id=env_id)
|
||||
# Call each policy's Exploration.on_episode_start method.
|
||||
# types: Policy
|
||||
for p in worker.policy_map.values():
|
||||
# Note: This may break the exploration (e.g. ParameterNoise) of
|
||||
# policies in the `policy_map` that have not been recently used
|
||||
# (and are therefore stashed to disk). However, we certainly do not
|
||||
# want to loop through all (even stashed) policies here as that
|
||||
# would counter the purpose of the LRU policy caching.
|
||||
for p in worker.policy_map.cache.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_start(
|
||||
policy=p,
|
||||
|
@ -902,8 +906,14 @@ def _process_observations(
|
|||
if ma_sample_batch:
|
||||
outputs.append(ma_sample_batch)
|
||||
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in worker.policy_map.values():
|
||||
# Call each (in-memory) policy's Exploration.on_episode_end
|
||||
# method.
|
||||
# Note: This may break the exploration (e.g. ParameterNoise) of
|
||||
# policies in the `policy_map` that have not been recently used
|
||||
# (and are therefore stashed to disk). However, we certainly do not
|
||||
# want to loop through all (even stashed) policies here as that
|
||||
# would counter the purpose of the LRU policy caching.
|
||||
for p in worker.policy_map.cache.values():
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
|
|
|
@ -41,6 +41,7 @@ parser.add_argument(
|
|||
default="tf",
|
||||
help="The DL framework specifier.")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument("--num-workers", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--from-checkpoint",
|
||||
type=str,
|
||||
|
@ -197,6 +198,7 @@ if __name__ == "__main__":
|
|||
# Always just train the "main" policy.
|
||||
"policies_to_train": ["main"],
|
||||
},
|
||||
"num_workers": args.num_workers,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import deque
|
|||
import gym
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
|
||||
from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -39,10 +39,19 @@ class PolicyMap(dict):
|
|||
"""Initializes a PolicyMap instance.
|
||||
|
||||
Args:
|
||||
maxlen (int): The maximum number of policies to hold in memory.
|
||||
worker_index (int): The worker index of the RolloutWorker this map
|
||||
resides in.
|
||||
num_workers (int): The total number of remote workers in the
|
||||
WorkerSet to which this map's RolloutWorker belongs to.
|
||||
capacity (int): The maximum number of policies to hold in memory.
|
||||
The least used ones are written to disk/S3 and retrieved
|
||||
when needed.
|
||||
path (str):
|
||||
path (str): The path to store the policy pickle files to. Files
|
||||
will have the name: [policy_id].[worker idx].policy.pkl.
|
||||
policy_config (TrainerConfigDict): The Trainer's base config dict.
|
||||
session_creator (Optional[Callable[[], tf1.Session]): An optional
|
||||
tf1.Session creation callable.
|
||||
seed (int): An optional seed (used to seed tf policies).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -53,22 +62,22 @@ class PolicyMap(dict):
|
|||
|
||||
# The file extension for stashed policies (that are no longer available
|
||||
# in-memory but can be reinstated any time from storage).
|
||||
self.extension = ".policy.pkl"
|
||||
self.extension = f".{self.worker_index}.policy.pkl"
|
||||
|
||||
# Dictionary of keys that may be looked up (cached or not).
|
||||
self.valid_keys = set()
|
||||
self.valid_keys: Set[str] = set()
|
||||
# The actual cache with the in-memory policy objects.
|
||||
self.cache = {}
|
||||
self.cache: Dict[str, Policy] = {}
|
||||
# The doubly-linked list holding the currently in-memory objects.
|
||||
self.deque = deque(maxlen=capacity or 10)
|
||||
# The file path where to store overflowing policies.
|
||||
self.path = path or "."
|
||||
# The core config to use. Each single policy's config override is
|
||||
# added on top of this.
|
||||
self.policy_config = policy_config or {}
|
||||
self.policy_config: TrainerConfigDict = policy_config or {}
|
||||
# The orig classes/obs+act spaces, and config overrides of the
|
||||
# Policies.
|
||||
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
|
||||
self.policy_specs: Dict[PolicyID, PolicySpec] = {}
|
||||
|
||||
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
|
||||
observation_space: gym.Space, action_space: gym.Space,
|
||||
|
@ -140,7 +149,7 @@ class PolicyMap(dict):
|
|||
def __getitem__(self, item):
|
||||
# Never seen this key -> Error.
|
||||
if item not in self.valid_keys:
|
||||
raise KeyError(f"'{item}' not a valid key!")
|
||||
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
|
||||
|
||||
# Item already in cache -> Rearrange deque (least recently used) and
|
||||
# return.
|
||||
|
@ -250,9 +259,7 @@ class PolicyMap(dict):
|
|||
policy_state = policy.get_state()
|
||||
# Closes policy's tf session, if any.
|
||||
self._close_session(policy)
|
||||
# Remove from memory.
|
||||
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
|
||||
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
|
||||
# Remove from memory. This will clear the tf Graph as well.
|
||||
del self.cache[delkey]
|
||||
# Write state to disk.
|
||||
with open(self.path + "/" + delkey + self.extension, "wb") as f:
|
||||
|
|
|
@ -457,8 +457,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in
|
||||
[0, 1])
|
||||
self.assertRaises(
|
||||
self.assertRaisesRegex(
|
||||
KeyError,
|
||||
"not found in PolicyMap",
|
||||
lambda: pg.compute_single_action(
|
||||
[0, 0, 0, 0], policy_id="policy_3"))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue