mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Implement policy_maps (multi-agent case) in RolloutWorkers as LRU caches. (#17031)
This commit is contained in:
parent
e0640ad0dc
commit
18d173b172
22 changed files with 503 additions and 208 deletions
|
@ -66,7 +66,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
|
||||
try:
|
||||
if fw == "tf":
|
||||
check(policy._sess.run(policy.cur_lr), 0.0005)
|
||||
check(policy.get_session().run(policy.cur_lr), 0.0005)
|
||||
else:
|
||||
check(policy.cur_lr, 0.0005)
|
||||
r1 = trainer.train()
|
||||
|
|
|
@ -36,10 +36,13 @@ class TestTrainer(unittest.TestCase):
|
|||
"p0": (None, env.observation_space, env.action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
|
||||
"policy_map_capacity": 2,
|
||||
},
|
||||
})
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
|
||||
# refactor PR merged.
|
||||
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||
trainer = pg.PGTrainer(config=config)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p0" in r["policy_reward_min"])
|
||||
|
@ -62,8 +65,8 @@ class TestTrainer(unittest.TestCase):
|
|||
)
|
||||
pol_map = trainer.workers.local_worker().policy_map
|
||||
self.assertTrue(new_pol is not trainer.get_policy("p0"))
|
||||
self.assertTrue("p0" in pol_map)
|
||||
self.assertTrue("p1" in pol_map)
|
||||
for j in range(i):
|
||||
self.assertTrue(f"p{j}" in pol_map)
|
||||
self.assertTrue(len(pol_map) == i + 1)
|
||||
r = trainer.train()
|
||||
self.assertTrue("p1" in r["policy_reward_min"])
|
||||
|
|
|
@ -432,6 +432,13 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# of (policy_cls, obs_space, act_space, config). This defines the
|
||||
# observation and action spaces of the policies and any extra config.
|
||||
"policies": {},
|
||||
# Keep this many policies in the "policy_map" (before writing
|
||||
# least-recently used ones to disk/S3).
|
||||
"policy_map_capacity": 100,
|
||||
# Where to store overflowing (least-recently used) policies?
|
||||
# Could be a directory (str) or an S3 location. None for using
|
||||
# the default output dir.
|
||||
"policy_map_cache": None,
|
||||
# Function mapping agent ids to policy ids.
|
||||
"policy_mapping_fn": None,
|
||||
# Optional list of policies to train, or None for all policies.
|
||||
|
@ -1181,7 +1188,7 @@ class Trainer(Trainable):
|
|||
local worker).
|
||||
"""
|
||||
|
||||
def fn(worker):
|
||||
def fn(worker: RolloutWorker):
|
||||
# `foreach_worker` function: Adds the policy the the worker (and
|
||||
# maybe changes its policy_mapping_fn - if provided here).
|
||||
worker.add_policy(
|
||||
|
|
|
@ -135,13 +135,10 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
|||
if "new_obs" in k:
|
||||
new_obs_n.append(v)
|
||||
|
||||
target_act_sampler_n = [p.target_act_sampler for p in policies.values()]
|
||||
feed_dict = dict(zip(new_obs_ph_n, new_obs_n))
|
||||
|
||||
new_act_n = p.sess.run(target_act_sampler_n, feed_dict)
|
||||
samples.update(
|
||||
{"new_actions_%d" % i: new_act
|
||||
for i, new_act in enumerate(new_act_n)})
|
||||
for i, p in enumerate(policies.values()):
|
||||
feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
|
||||
new_act = p.get_session().run(p.target_act_sampler, feed_dict)
|
||||
samples.update({"new_actions_%d" % i: new_act})
|
||||
|
||||
# Share samples among agents.
|
||||
policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
|
||||
|
|
|
@ -230,7 +230,8 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
|
||||
# _____ TensorFlow Initialization
|
||||
|
||||
self.sess = tf1.get_default_session()
|
||||
sess = tf1.get_default_session()
|
||||
assert sess
|
||||
|
||||
def _make_loss_inputs(placeholders):
|
||||
return [(ph.name.split("/")[-1].split(":")[0], ph)
|
||||
|
@ -244,7 +245,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
obs_space,
|
||||
act_space,
|
||||
config=config,
|
||||
sess=self.sess,
|
||||
sess=sess,
|
||||
obs_input=obs_ph_n[agent_id],
|
||||
sampled_action=act_sampler,
|
||||
loss=actor_loss + critic_loss,
|
||||
|
@ -254,7 +255,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
del self.view_requirements["prev_actions"]
|
||||
del self.view_requirements["prev_rewards"]
|
||||
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
self.get_session().run(tf1.global_variables_initializer())
|
||||
|
||||
# Hard initial update
|
||||
self.update_target(1.0)
|
||||
|
@ -297,11 +298,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
var_list = []
|
||||
for var in self.vars.values():
|
||||
var_list += var
|
||||
return {"_state": self.sess.run(var_list)}
|
||||
return {"_state": self.get_session().run(var_list)}
|
||||
|
||||
@override(TFPolicy)
|
||||
def set_weights(self, weights):
|
||||
self.sess.run(
|
||||
self.get_session().run(
|
||||
self.update_vars,
|
||||
feed_dict=dict(zip(self.vars_ph, weights["_state"])))
|
||||
|
||||
|
@ -377,6 +378,6 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
|||
|
||||
def update_target(self, tau=None):
|
||||
if tau is not None:
|
||||
self.sess.run(self.update_target_vars, {self.tau: tau})
|
||||
self.get_session().run(self.update_target_vars, {self.tau: tau})
|
||||
else:
|
||||
self.sess.run(self.update_target_vars)
|
||||
self.get_session().run(self.update_target_vars)
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
|
@ -30,7 +30,7 @@ class SampleCollector(metaclass=ABCMeta):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
policy_map: PolicyMap,
|
||||
clip_rewards: Union[bool, float],
|
||||
callbacks: "DefaultCallbacks",
|
||||
multiple_episodes_in_batch: bool = True,
|
||||
|
@ -39,8 +39,7 @@ class SampleCollector(metaclass=ABCMeta):
|
|||
"""Initializes a SampleCollector instance.
|
||||
|
||||
Args:
|
||||
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
||||
instances.
|
||||
policy_map (PolicyMap): Maps policy ids to policy instances.
|
||||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
|
|
|
@ -9,6 +9,7 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
|||
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
|
@ -292,7 +293,7 @@ class _PolicyCollector:
|
|||
appended to this policy's buffers.
|
||||
"""
|
||||
|
||||
def __init__(self, policy):
|
||||
def __init__(self, policy: Policy):
|
||||
"""Initializes a _PolicyCollector instance.
|
||||
|
||||
Args:
|
||||
|
@ -382,7 +383,7 @@ class SimpleListCollector(SampleCollector):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy_map: Dict[PolicyID, Policy],
|
||||
policy_map: PolicyMap,
|
||||
clip_rewards: Union[bool, float],
|
||||
callbacks: "DefaultCallbacks",
|
||||
multiple_episodes_in_batch: bool = True,
|
||||
|
@ -650,8 +651,7 @@ class SimpleListCollector(SampleCollector):
|
|||
post_batches[agent_id] = pre_batch
|
||||
if getattr(policy, "exploration", None) is not None:
|
||||
policy.exploration.postprocess_trajectory(
|
||||
policy, post_batches[agent_id],
|
||||
getattr(policy, "_sess", None))
|
||||
policy, post_batches[agent_id], policy.get_session())
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
post_batches[agent_id], other_batches, episode)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import random
|
|||
from typing import List, Dict, Callable, Any, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||
|
@ -49,9 +49,8 @@ class MultiAgentEpisode:
|
|||
>>> episode.extra_batches.add(batch.build_and_reset())
|
||||
"""
|
||||
|
||||
def __init__(self, policies: Dict[PolicyID, Policy],
|
||||
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
||||
PolicyID],
|
||||
def __init__(self, policies: PolicyMap, policy_mapping_fn: Callable[
|
||||
[AgentID, "MultiAgentEpisode"], PolicyID],
|
||||
batch_builder_factory: Callable[
|
||||
[], "MultiAgentSampleBatchBuilder"],
|
||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||
|
@ -71,7 +70,7 @@ class MultiAgentEpisode:
|
|||
self.user_data: Dict[str, Any] = {}
|
||||
self.hist_data: Dict[str, List[float]] = {}
|
||||
self.media: Dict[str, Any] = {}
|
||||
self.policy_map: Dict[PolicyID, Policy] = policies
|
||||
self.policy_map: PolicyMap = policies
|
||||
self._policies = self.policy_map # backward compatibility
|
||||
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
||||
PolicyID] = policy_mapping_fn
|
||||
|
|
|
@ -28,6 +28,7 @@ from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
|||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
@ -489,7 +490,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
self.make_env_fn = make_env
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
|
||||
# List of IDs of those policies, which should be trained.
|
||||
|
@ -498,7 +498,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy_dict.keys())
|
||||
self.set_policies_to_train(self.policies_to_train)
|
||||
|
||||
self.policy_map: Dict[PolicyID, Policy] = None
|
||||
self.policy_map: PolicyMap = None
|
||||
self.preprocessors: Dict[PolicyID, Preprocessor] = None
|
||||
|
||||
# Set Python random, numpy, env, and torch/tf seeds.
|
||||
|
@ -541,26 +541,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
elif tf1 and policy_config.get("framework") == "tfe":
|
||||
tf1.set_random_seed(seed)
|
||||
|
||||
if _has_tensorflow_graph(policy_dict) and not (
|
||||
tf1 and tf1.executing_eagerly()):
|
||||
if not tf1:
|
||||
raise ImportError("Could not import tensorflow")
|
||||
with tf1.Graph().as_default():
|
||||
if tf_session_creator:
|
||||
self.tf_sess = tf_session_creator()
|
||||
else:
|
||||
self.tf_sess = tf1.Session(
|
||||
config=tf1.ConfigProto(
|
||||
gpu_options=tf1.GPUOptions(allow_growth=True)))
|
||||
with self.tf_sess.as_default():
|
||||
# set graph-level seed
|
||||
if seed is not None:
|
||||
tf1.set_random_seed(seed)
|
||||
self.policy_map, self.preprocessors = \
|
||||
self._build_policy_map(policy_dict, policy_config)
|
||||
else:
|
||||
self.policy_map, self.preprocessors = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
self._build_policy_map(
|
||||
policy_dict,
|
||||
policy_config,
|
||||
session_creator=tf_session_creator,
|
||||
seed=seed)
|
||||
|
||||
# Update Policy's view requirements from Model, only if Policy directly
|
||||
# inherited from base `Policy` class. At this point here, the Policy
|
||||
|
@ -591,14 +576,13 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.multiagent: bool = set(
|
||||
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent and self.env is not None:
|
||||
if not ((isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
if not isinstance(self.env,
|
||||
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
|
||||
ray.actor.ActorHandle)):
|
||||
raise ValueError(
|
||||
"Have multiple policies {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
|
||||
"ExternalMultiAgentEnv?".format(self.env))
|
||||
f"Have multiple policies {self.policy_map}, but the "
|
||||
f"env {self.env} is not a subclass of BaseEnv, "
|
||||
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")
|
||||
|
||||
self.filters: Dict[PolicyID, Filter] = {
|
||||
policy_id: get_filter(self.observation_filter,
|
||||
|
@ -678,7 +662,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
normalize_actions=normalize_actions,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
|
@ -701,7 +684,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
callbacks=self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
multiple_episodes_in_batch=pack,
|
||||
tf_sess=self.tf_sess,
|
||||
normalize_actions=normalize_actions,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon,
|
||||
|
@ -923,23 +905,24 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
builders = {}
|
||||
to_fetch = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||
else:
|
||||
builder = None
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
# Decompress SampleBatch, in case some columns are compressed.
|
||||
batch.decompress_if_needed()
|
||||
policy = self.policy_map[pid]
|
||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||
tf_session = policy.get_session()
|
||||
if tf_session and hasattr(policy, "_build_learn_on_batch"):
|
||||
builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
|
||||
to_fetch[pid] = policy._build_learn_on_batch(
|
||||
builder, batch)
|
||||
builders[pid], batch)
|
||||
else:
|
||||
info_out[pid] = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
info_out.update(
|
||||
{pid: builders[pid].get(v)
|
||||
for pid, v in to_fetch.items()})
|
||||
else:
|
||||
info_out = {
|
||||
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
|
||||
|
@ -1024,12 +1007,15 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
return ret
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(
|
||||
self, policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID) -> Policy:
|
||||
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
|
||||
"""Return policy for the specified id, or None.
|
||||
|
||||
Args:
|
||||
policy_id (str): id of policy to return.
|
||||
policy_id (PolicyID): ID of the policy to return.
|
||||
|
||||
Returns:
|
||||
Optional[Policy]: The policy under the given ID (or None if not
|
||||
found).
|
||||
"""
|
||||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
@ -1078,18 +1064,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy_dict = {
|
||||
policy_id: (policy_cls, observation_space, action_space, config)
|
||||
}
|
||||
if self.tf_sess is not None:
|
||||
with self.tf_sess.graph.as_default():
|
||||
with self.tf_sess.as_default():
|
||||
add_map, add_prep = self._build_policy_map(
|
||||
policy_dict, self.policy_config)
|
||||
else:
|
||||
add_map, add_prep = self._build_policy_map(policy_dict,
|
||||
self.policy_config)
|
||||
new_policy = add_map[policy_id]
|
||||
self._build_policy_map(
|
||||
policy_dict,
|
||||
self.policy_config,
|
||||
seed=self.policy_config.get("seed"))
|
||||
new_policy = self.policy_map[policy_id]
|
||||
|
||||
self.policy_map.update(add_map)
|
||||
self.preprocessors.update(add_prep)
|
||||
self.filters[policy_id] = get_filter(
|
||||
self.observation_filter, new_policy.observation_space.shape)
|
||||
|
||||
|
@ -1301,12 +1281,27 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
return func(self, *args)
|
||||
|
||||
def _build_policy_map(
|
||||
self, policy_dict: MultiAgentPolicyConfigDict,
|
||||
policy_config: TrainerConfigDict
|
||||
self,
|
||||
policy_dict: MultiAgentPolicyConfigDict,
|
||||
policy_config: TrainerConfigDict,
|
||||
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
|
||||
ma_config = policy_config.get("multiagent", {})
|
||||
|
||||
self.policy_map = self.policy_map or PolicyMap(
|
||||
worker_index=self.worker_index,
|
||||
num_workers=self.num_workers,
|
||||
capacity=ma_config.get("policy_map_capacity"),
|
||||
path=ma_config.get("policy_map_cache"),
|
||||
policy_config=policy_config,
|
||||
session_creator=session_creator,
|
||||
seed=seed,
|
||||
)
|
||||
self.preprocessors = self.preprocessors or {}
|
||||
|
||||
for name, (orig_cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
logger.debug("Creating policy for {}".format(name))
|
||||
merged_conf = merge_dicts(policy_config, conf or {})
|
||||
|
@ -1315,43 +1310,23 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
if self.preprocessing_enabled:
|
||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
obs_space, merged_conf.get("model"))
|
||||
preprocessors[name] = preprocessor
|
||||
self.preprocessors[name] = preprocessor
|
||||
obs_space = preprocessor.observation_space
|
||||
else:
|
||||
preprocessors[name] = NoPreprocessor(obs_space)
|
||||
self.preprocessors[name] = NoPreprocessor(obs_space)
|
||||
|
||||
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
# Tf.
|
||||
framework = policy_config.get("framework", "tf")
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
assert tf1
|
||||
if framework in ["tf2", "tfe"]:
|
||||
assert tf1.executing_eagerly()
|
||||
if hasattr(cls, "as_eager"):
|
||||
cls = cls.as_eager()
|
||||
if policy_config.get("eager_tracing"):
|
||||
cls = cls.with_tracing()
|
||||
elif not issubclass(cls, TFPolicy):
|
||||
pass # could be some other type of policy
|
||||
else:
|
||||
raise ValueError("This policy does not support eager "
|
||||
"execution: {}".format(cls))
|
||||
scope = name + (("_wk" + str(self.worker_index))
|
||||
if self.worker_index else "")
|
||||
with tf1.variable_scope(scope):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
# non-tf.
|
||||
else:
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
|
||||
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
|
||||
conf, merged_conf)
|
||||
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built policy map: {}".format(policy_map))
|
||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
||||
return policy_map, preprocessors
|
||||
logger.info(f"Built policy map: {self.policy_map}")
|
||||
logger.info(f"Built preprocessor map: {self.preprocessors}")
|
||||
|
||||
def setup_torch_data_parallel(self, url: str, world_rank: int,
|
||||
world_size: int, backend: str) -> None:
|
||||
|
|
|
@ -205,8 +205,7 @@ class MultiAgentSampleBatchBuilder:
|
|||
post_batches[agent_id] = pre_batch
|
||||
if getattr(policy, "exploration", None) is not None:
|
||||
policy.exploration.postprocess_trajectory(
|
||||
policy, post_batches[agent_id],
|
||||
getattr(policy, "_sess", None))
|
||||
policy, post_batches[agent_id], policy.get_session())
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
post_batches[agent_id], other_batches, episode)
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ from ray.rllib.utils.filter import Filter
|
|||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
||||
unsquash_action, unbatch
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
|
||||
TensorStructType
|
||||
|
@ -137,7 +136,6 @@ class SyncSampler(SamplerInput):
|
|||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
normalize_actions: bool = True,
|
||||
clip_actions: bool = False,
|
||||
soft_horizon: bool = False,
|
||||
|
@ -150,6 +148,7 @@ class SyncSampler(SamplerInput):
|
|||
policy_mapping_fn=None,
|
||||
preprocessors=None,
|
||||
obs_filters=None,
|
||||
tf_sess=None,
|
||||
):
|
||||
"""Initializes a SyncSampler object.
|
||||
|
||||
|
@ -168,8 +167,6 @@ class SyncSampler(SamplerInput):
|
|||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
framework=tf).
|
||||
normalize_actions (bool): Whether to normalize actions to the
|
||||
action space's bounds.
|
||||
clip_actions (bool): Whether to clip actions according to the
|
||||
|
@ -199,6 +196,8 @@ class SyncSampler(SamplerInput):
|
|||
deprecation_warning(old="preprocessors")
|
||||
if obs_filters is not None:
|
||||
deprecation_warning(old="obs_filters")
|
||||
if tf_sess is not None:
|
||||
deprecation_warning(old="tf_sess")
|
||||
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
|
@ -221,7 +220,7 @@ class SyncSampler(SamplerInput):
|
|||
worker, self.base_env, self.extra_batches.put,
|
||||
self.rollout_fragment_length, self.horizon, clip_rewards,
|
||||
normalize_actions, clip_actions, multiple_episodes_in_batch,
|
||||
callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end,
|
||||
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
|
||||
observation_fn, self.sample_collector, self.render)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
|
@ -275,7 +274,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
callbacks: "DefaultCallbacks",
|
||||
horizon: int = None,
|
||||
multiple_episodes_in_batch: bool = False,
|
||||
tf_sess=None,
|
||||
normalize_actions: bool = True,
|
||||
clip_actions: bool = False,
|
||||
blackhole_outputs: bool = False,
|
||||
|
@ -289,6 +287,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
policy_mapping_fn=None,
|
||||
preprocessors=None,
|
||||
obs_filters=None,
|
||||
tf_sess=None,
|
||||
):
|
||||
"""Initializes a AsyncSampler object.
|
||||
|
||||
|
@ -309,8 +308,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||
episodes into each batch. This guarantees batches will be
|
||||
exactly `rollout_fragment_length` in size.
|
||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
||||
framework=tf).
|
||||
normalize_actions (bool): Whether to normalize actions to the
|
||||
action space's bounds.
|
||||
clip_actions (bool): Whether to clip actions according to the
|
||||
|
@ -342,6 +339,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
deprecation_warning(old="preprocessors")
|
||||
if obs_filters is not None:
|
||||
deprecation_warning(old="obs_filters")
|
||||
if tf_sess is not None:
|
||||
deprecation_warning(old="tf_sess")
|
||||
|
||||
self.worker = worker
|
||||
|
||||
|
@ -359,7 +358,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.clip_rewards = clip_rewards
|
||||
self.daemon = True
|
||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||
self.tf_sess = tf_sess
|
||||
self.callbacks = callbacks
|
||||
self.normalize_actions = normalize_actions
|
||||
self.clip_actions = clip_actions
|
||||
|
@ -400,9 +398,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.worker, self.base_env, extra_batches_putter,
|
||||
self.rollout_fragment_length, self.horizon, self.clip_rewards,
|
||||
self.normalize_actions, self.clip_actions,
|
||||
self.multiple_episodes_in_batch, self.callbacks, self.tf_sess,
|
||||
self.perf_stats, self.soft_horizon, self.no_done_at_end,
|
||||
self.observation_fn, self.sample_collector, self.render)
|
||||
self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
|
||||
self.soft_horizon, self.no_done_at_end, self.observation_fn,
|
||||
self.sample_collector, self.render)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
|
@ -458,7 +456,6 @@ def _env_runner(
|
|||
clip_actions: bool,
|
||||
multiple_episodes_in_batch: bool,
|
||||
callbacks: "DefaultCallbacks",
|
||||
tf_sess: Optional["tf.Session"],
|
||||
perf_stats: _PerfStats,
|
||||
soft_horizon: bool,
|
||||
no_done_at_end: bool,
|
||||
|
@ -484,8 +481,6 @@ def _env_runner(
|
|||
space's bounds.
|
||||
clip_actions (bool): Whether to clip actions to the space range.
|
||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
perf_stats (_PerfStats): Record perf stats into this object.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
|
@ -566,7 +561,7 @@ def _env_runner(
|
|||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
tf_sess=p.get_session())
|
||||
callbacks.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
|
@ -627,7 +622,6 @@ def _env_runner(
|
|||
policy_mapping_fn=worker.policy_mapping_fn,
|
||||
sample_collector=sample_collector,
|
||||
active_episodes=active_episodes,
|
||||
tf_sess=tf_sess,
|
||||
)
|
||||
perf_stats.inference_time += time.time() - t2
|
||||
|
||||
|
@ -915,7 +909,7 @@ def _process_observations(
|
|||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
tf_sess=p.get_session())
|
||||
# Call custom on_episode_end callback.
|
||||
callbacks.on_episode_end(
|
||||
worker=worker,
|
||||
|
@ -986,7 +980,6 @@ def _do_policy_eval(
|
|||
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
|
||||
sample_collector,
|
||||
active_episodes: Dict[str, MultiAgentEpisode],
|
||||
tf_sess: Optional["tf.Session"] = None,
|
||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||
"""Call compute_actions on collected episode/model data to get next action.
|
||||
|
||||
|
@ -997,8 +990,6 @@ def _do_policy_eval(
|
|||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
||||
obj.
|
||||
sample_collector (SampleCollector): The SampleCollector object to use.
|
||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
||||
batching TF policy evaluations.
|
||||
|
||||
Returns:
|
||||
eval_results: dict of policy to compute_action() outputs.
|
||||
|
@ -1006,12 +997,6 @@ def _do_policy_eval(
|
|||
|
||||
eval_results: Dict[PolicyID, TensorStructType] = {}
|
||||
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
pending_fetches: Dict[PolicyID, Any] = {}
|
||||
else:
|
||||
builder = None
|
||||
|
||||
if log_once("compute_actions_input"):
|
||||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||
summarize(to_eval)))
|
||||
|
@ -1033,11 +1018,6 @@ def _do_policy_eval(
|
|||
timestep=policy.global_timestep,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
|
||||
if builder:
|
||||
# types: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
||||
for pid, v in pending_fetches.items():
|
||||
eval_results[pid] = builder.get(v)
|
||||
|
||||
if log_once("compute_actions_result"):
|
||||
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
||||
summarize(eval_results)))
|
||||
|
|
|
@ -182,7 +182,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
|
||||
lr = policy.cur_lr
|
||||
if fw == "tf":
|
||||
lr = policy._sess.run(lr)
|
||||
lr = policy.get_session().run(lr)
|
||||
check(lr, expected_lr, rtol=0.05)
|
||||
agent.stop()
|
||||
|
||||
|
|
|
@ -78,13 +78,14 @@ class TFMultiGPULearner(LearnerThread):
|
|||
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
|
||||
tf_session = self.policy.get_session()
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.par_opt = []
|
||||
with self.local_worker.tf_sess.graph.as_default():
|
||||
with self.local_worker.tf_sess.as_default():
|
||||
with tf_session.graph.as_default():
|
||||
with tf_session.as_default():
|
||||
with tf1.variable_scope(
|
||||
DEFAULT_POLICY_ID, reuse=tf1.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
|
@ -106,7 +107,7 @@ class TFMultiGPULearner(LearnerThread):
|
|||
999999, # it will get rounded down
|
||||
self.policy.copy))
|
||||
|
||||
self.sess = self.local_worker.tf_sess
|
||||
self.sess = tf_session
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
|
||||
self.idle_optimizers = queue.Queue()
|
||||
|
|
|
@ -148,14 +148,9 @@ class TrainTFMultiGPU:
|
|||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.optimizers = {}
|
||||
with self.workers.local_worker().tf_sess.graph.as_default():
|
||||
with self.workers.local_worker().tf_sess.as_default():
|
||||
for policy_id in (self.policies
|
||||
or self.local_worker.policies_to_train):
|
||||
self.add_optimizer(policy_id)
|
||||
|
||||
self.sess = self.workers.local_worker().tf_sess
|
||||
self.sess.run(tf1.global_variables_initializer())
|
||||
for policy_id in (self.policies
|
||||
or self.local_worker.policies_to_train):
|
||||
self.add_optimizer(policy_id)
|
||||
|
||||
def __call__(self,
|
||||
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
|
||||
|
@ -181,10 +176,7 @@ class TrainTFMultiGPU:
|
|||
# Policy seems to be new and doesn't have an optimizer yet.
|
||||
# Add it here and continue.
|
||||
elif policy_id not in self.optimizers:
|
||||
with self.workers.local_worker().tf_sess.graph.as_default(
|
||||
):
|
||||
with self.workers.local_worker().tf_sess.as_default():
|
||||
self.add_optimizer(policy_id)
|
||||
self.add_optimizer(policy_id)
|
||||
|
||||
# Decompress SampleBatch, in case some columns are compressed.
|
||||
batch.decompress_if_needed()
|
||||
|
@ -200,13 +192,14 @@ class TrainTFMultiGPU:
|
|||
state_keys = []
|
||||
num_loaded_tuples[policy_id] = (
|
||||
self.optimizers[policy_id].load_data(
|
||||
self.sess, [tuples[k] for k in data_keys],
|
||||
policy.get_session(), [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys]))
|
||||
|
||||
# Execute minibatch SGD on loaded data.
|
||||
with learn_timer:
|
||||
fetches = {}
|
||||
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
||||
policy = self.workers.local_worker().get_policy(policy_id)
|
||||
optimizer = self.optimizers[policy_id]
|
||||
num_batches = max(
|
||||
1,
|
||||
|
@ -217,7 +210,7 @@ class TrainTFMultiGPU:
|
|||
batch_fetches_all_towers = []
|
||||
for batch_index in range(num_batches):
|
||||
batch_fetches = optimizer.optimize(
|
||||
self.sess, permutation[batch_index] *
|
||||
policy.get_session(), permutation[batch_index] *
|
||||
self.per_device_batch_size)
|
||||
|
||||
batch_fetches_all_towers.append(
|
||||
|
@ -250,15 +243,20 @@ class TrainTFMultiGPU:
|
|||
|
||||
def add_optimizer(self, policy_id):
|
||||
policy = self.workers.local_worker().get_policy(policy_id)
|
||||
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
|
||||
if policy._state_inputs:
|
||||
rnn_inputs = policy._state_inputs + [policy._seq_lens]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
|
||||
policy._optimizer, self.devices,
|
||||
list(policy._loss_input_dict_no_rnn.values()), rnn_inputs,
|
||||
self.per_device_batch_size, policy.copy))
|
||||
tf_session = policy.get_session()
|
||||
with tf_session.graph.as_default():
|
||||
with tf_session.as_default():
|
||||
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
|
||||
if policy._state_inputs:
|
||||
rnn_inputs = policy._state_inputs + [policy._seq_lens]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
|
||||
policy._optimizer, self.devices,
|
||||
list(policy._loss_input_dict_no_rnn.values()),
|
||||
rnn_inputs, self.per_device_batch_size, policy.copy))
|
||||
|
||||
tf_session.run(tf1.global_variables_initializer())
|
||||
|
||||
|
||||
def all_tower_reduce(path, *tower_data):
|
||||
|
|
|
@ -166,8 +166,7 @@ class JsonReader(InputReader):
|
|||
self.ioctx.worker.policy_map[pid].action_space_struct)
|
||||
# Re-normalize actions (from env's bounds to 0.0 centered), if
|
||||
# necessary.
|
||||
if "actions_in_input_normalized" in cfg and \
|
||||
cfg["actions_in_input_normalized"] is False:
|
||||
if cfg.get("actions_in_input_normalized") is False:
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch[SampleBatch.ACTIONS] = normalize_action(
|
||||
batch[SampleBatch.ACTIONS], self.ioctx.worker.policy_map[
|
||||
|
|
|
@ -467,7 +467,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
self._optimizer = self.optimizer()
|
||||
|
||||
# Test calls depend on variable init, so initialize model first.
|
||||
self._sess.run(tf1.global_variables_initializer())
|
||||
self.get_session().run(tf1.global_variables_initializer())
|
||||
|
||||
logger.info("Testing `compute_actions` w/ dummy batch.")
|
||||
actions, state_outs, extra_fetches = \
|
||||
|
@ -486,7 +486,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
dummy_batch = self._dummy_batch
|
||||
|
||||
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
||||
self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
|
||||
self.exploration.postprocess_trajectory(self, dummy_batch,
|
||||
self.get_session())
|
||||
_ = self.postprocess_trajectory(dummy_batch)
|
||||
# Add new columns automatically to (loss) input_dict.
|
||||
for key in dummy_batch.added_keys:
|
||||
|
@ -592,7 +593,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
}
|
||||
|
||||
# Initialize again after loss init.
|
||||
self._sess.run(tf1.global_variables_initializer())
|
||||
self.get_session().run(tf1.global_variables_initializer())
|
||||
|
||||
def _do_loss_init(self, train_batch: SampleBatch):
|
||||
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
||||
|
|
|
@ -255,7 +255,6 @@ def build_eager_tf_policy(
|
|||
|
||||
self._is_training = False
|
||||
self._loss_initialized = False
|
||||
self._sess = None
|
||||
|
||||
self._loss = loss_fn
|
||||
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||
|
|
|
@ -461,6 +461,9 @@ class Policy(metaclass=ABCMeta):
|
|||
def get_weights(self) -> ModelWeights:
|
||||
"""Returns model weights.
|
||||
|
||||
Note: The return value of this method will reside under the "weights"
|
||||
key in the return value of Policy.get_state().
|
||||
|
||||
Returns:
|
||||
ModelWeights: Serializable copy or view of model weights.
|
||||
"""
|
||||
|
@ -468,7 +471,7 @@ class Policy(metaclass=ABCMeta):
|
|||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights: ModelWeights) -> None:
|
||||
"""Sets model weights.
|
||||
"""Sets this Policy's model's weights.
|
||||
|
||||
Args:
|
||||
weights (ModelWeights): Serializable copy or view of model weights.
|
||||
|
@ -523,12 +526,16 @@ class Policy(metaclass=ABCMeta):
|
|||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
"""Returns all local state.
|
||||
|
||||
Note: Not to be confused with an RNN model's internal state.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
|
||||
state.
|
||||
"""
|
||||
state = {
|
||||
# All the policy's weights.
|
||||
"weights": self.get_weights(),
|
||||
# The current global timestep.
|
||||
"global_timestep": self.global_timestep,
|
||||
}
|
||||
return state
|
||||
|
@ -581,6 +588,16 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_session(self) -> Optional["tf1.Session"]:
|
||||
"""Returns tf.Session object to use for computing actions or None.
|
||||
|
||||
Returns:
|
||||
Optional[tf1.Session]: The tf Session to use for computing actions
|
||||
and losses with this policy.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _create_exploration(self) -> Exploration:
|
||||
"""Creates the Policy's Exploration object.
|
||||
|
||||
|
|
292
rllib/policy/policy_map.py
Normal file
292
rllib/policy/policy_map.py
Normal file
|
@ -0,0 +1,292 @@
|
|||
from collections import deque
|
||||
import gym
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict, \
|
||||
PolicyID, TrainerConfigDict
|
||||
from ray.tune.utils.util import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class PolicyMap(dict):
|
||||
"""Maps policy IDs to Policy objects.
|
||||
|
||||
Thereby, keeps n policies in memory and - when capacity is reached -
|
||||
writes the least recently used to disk. This allows adding 100s of
|
||||
policies to a Trainer for league-based setups w/o running out of memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_index: int,
|
||||
num_workers: int,
|
||||
capacity: Optional[int] = None,
|
||||
path: Optional[str] = None,
|
||||
policy_config: Optional[TrainerConfigDict] = None,
|
||||
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
"""Initializes a PolicyMap instance.
|
||||
|
||||
Args:
|
||||
maxlen (int): The maximum number of policies to hold in memory.
|
||||
The least used ones are written to disk/S3 and retrieved
|
||||
when needed.
|
||||
path (str):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.worker_index = worker_index
|
||||
self.num_workers = num_workers
|
||||
self.session_creator = session_creator
|
||||
self.seed = seed
|
||||
|
||||
# The file extension for stashed policies (that are no longer available
|
||||
# in-memory but can be reinstated any time from storage).
|
||||
self.extension = ".policy.pkl"
|
||||
|
||||
# Dictionary of keys that may be looked up (cached or not).
|
||||
self.valid_keys = set()
|
||||
# The actual cache with the in-memory policy objects.
|
||||
self.cache = {}
|
||||
# The doubly-linked list holding the currently in-memory objects.
|
||||
self.deque = deque(maxlen=capacity or 10)
|
||||
# The file path where to store overflowing policies.
|
||||
self.path = path or "."
|
||||
# The core config to use. Each single policy's config override is
|
||||
# added on top of this.
|
||||
self.policy_config = policy_config or {}
|
||||
# The orig classes/obs+act spaces, and config overrides of the
|
||||
# Policies.
|
||||
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
|
||||
|
||||
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
|
||||
observation_space: gym.Space, action_space: gym.Space,
|
||||
config_override: PartialTrainerConfigDict,
|
||||
merged_config: TrainerConfigDict) -> None:
|
||||
"""Creates a new policy and stores it to the cache.
|
||||
|
||||
Args:
|
||||
policy_id (PolicyID): The policy ID. This is the key under which
|
||||
the created policy will be stored in this map.
|
||||
policy_cls (Type[Policy]): The (original) policy class to use.
|
||||
This may still be altered in case tf-eager (and tracing)
|
||||
is used.
|
||||
observation_space (gym.Space): The observation space of the
|
||||
policy.
|
||||
action_space (gym.Space): The action space of the policy.
|
||||
config_override (PartialTrainerConfigDict): The config override
|
||||
dict for this policy. This is the partial dict provided by
|
||||
the user.
|
||||
merged_config (TrainerConfigDict): The entire config (merged
|
||||
default config + `config_override`).
|
||||
"""
|
||||
framework = merged_config.get("framework", "tf")
|
||||
class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
|
||||
|
||||
# Tf.
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
var_scope = policy_id + (("_wk" + str(self.worker_index))
|
||||
if self.worker_index else "")
|
||||
|
||||
# For tf static graph, build every policy in its own graph
|
||||
# and create a new session for it.
|
||||
if framework == "tf":
|
||||
with tf1.Graph().as_default():
|
||||
if self.session_creator:
|
||||
sess = self.session_creator()
|
||||
else:
|
||||
sess = tf1.Session(
|
||||
config=tf1.ConfigProto(
|
||||
gpu_options=tf1.GPUOptions(allow_growth=True)))
|
||||
with sess.as_default():
|
||||
# Set graph-level seed.
|
||||
if self.seed is not None:
|
||||
tf1.set_random_seed(self.seed)
|
||||
with tf1.variable_scope(var_scope):
|
||||
self[policy_id] = class_(
|
||||
observation_space, action_space, merged_config)
|
||||
# For tf-eager: no graph, no session.
|
||||
else:
|
||||
with tf1.variable_scope(var_scope):
|
||||
self[policy_id] = \
|
||||
class_(observation_space, action_space, merged_config)
|
||||
# Non-tf: No graph, no session.
|
||||
else:
|
||||
class_ = policy_cls
|
||||
self[policy_id] = class_(observation_space, action_space,
|
||||
merged_config)
|
||||
|
||||
# Store spec (class, obs-space, act-space, and config overrides) such
|
||||
# that the map will be able to reproduce on-the-fly added policies
|
||||
# from disk.
|
||||
self.policy_specs[policy_id] = PolicySpec(
|
||||
policy_class=policy_cls,
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
config=config_override)
|
||||
|
||||
@override(dict)
|
||||
def __getitem__(self, item):
|
||||
# Never seen this key -> Error.
|
||||
if item not in self.valid_keys:
|
||||
raise KeyError(f"'{item}' not a valid key!")
|
||||
|
||||
# Item already in cache -> Rearrange deque (least recently used) and
|
||||
# return.
|
||||
if item in self.cache:
|
||||
self.deque.remove(item)
|
||||
self.deque.append(item)
|
||||
# Item not currently in cache -> Get from disk and - if at capacity -
|
||||
# remove leftmost one.
|
||||
else:
|
||||
self._read_from_disk(policy_id=item)
|
||||
|
||||
return self.cache[item]
|
||||
|
||||
@override(dict)
|
||||
def __setitem__(self, key, value):
|
||||
# Item already in cache -> Rearrange deque (least recently used).
|
||||
if key in self.cache:
|
||||
self.deque.remove(key)
|
||||
self.deque.append(key)
|
||||
self.cache[key] = value
|
||||
# Item not currently in cache -> store new value and - if at capacity -
|
||||
# remove leftmost one.
|
||||
else:
|
||||
# Cache at capacity -> Drop leftmost item.
|
||||
if len(self.deque) == self.deque.maxlen:
|
||||
self._stash_to_disk()
|
||||
self.deque.append(key)
|
||||
self.cache[key] = value
|
||||
self.valid_keys.add(key)
|
||||
|
||||
@override(dict)
|
||||
def __delitem__(self, key):
|
||||
# Make key invalid.
|
||||
self.valid_keys.remove(key)
|
||||
# Remove policy from memory if currently cached.
|
||||
if key in self.cache:
|
||||
policy = self.cache[key]
|
||||
self._close_session(policy)
|
||||
del self.cache[key]
|
||||
# Remove file associated with the policy, if it exists.
|
||||
filename = self.path + "/" + key + self.extension
|
||||
if os.path.isfile(filename):
|
||||
os.remove(filename)
|
||||
|
||||
@override(dict)
|
||||
def __iter__(self):
|
||||
return self.keys()
|
||||
|
||||
@override(dict)
|
||||
def items(self):
|
||||
"""Iterates over all policies, even the stashed-to-disk ones."""
|
||||
|
||||
def gen():
|
||||
for key in self.valid_keys:
|
||||
yield (key, self[key])
|
||||
|
||||
return gen()
|
||||
|
||||
@override(dict)
|
||||
def keys(self):
|
||||
def gen():
|
||||
for key in self.valid_keys:
|
||||
yield key
|
||||
|
||||
return gen()
|
||||
|
||||
@override(dict)
|
||||
def values(self):
|
||||
def gen():
|
||||
for key in self.valid_keys:
|
||||
yield self[key]
|
||||
|
||||
return gen()
|
||||
|
||||
@override(dict)
|
||||
def update(self, __m, **kwargs):
|
||||
for k, v in __m.items():
|
||||
self[k] = v
|
||||
for k, v in kwargs.items():
|
||||
self[k] = v
|
||||
|
||||
@override(dict)
|
||||
def get(self, key):
|
||||
if key not in self.valid_keys:
|
||||
return None
|
||||
return self[key]
|
||||
|
||||
@override(dict)
|
||||
def __len__(self):
|
||||
"""Returns number of all policies, including the stashed-to-disk ones.
|
||||
"""
|
||||
return len(self.valid_keys)
|
||||
|
||||
@override(dict)
|
||||
def __contains__(self, item):
|
||||
return item in self.valid_keys
|
||||
|
||||
def _stash_to_disk(self):
|
||||
"""Writes the least-recently used policy to disk and rearranges cache.
|
||||
|
||||
Also closes the session - if applicable - of the stashed policy.
|
||||
"""
|
||||
# Get least recently used policy (all the way on the left in deque).
|
||||
delkey = self.deque.popleft()
|
||||
policy = self.cache[delkey]
|
||||
# Get its state for writing to disk.
|
||||
policy_state = policy.get_state()
|
||||
# Closes policy's tf session, if any.
|
||||
self._close_session(policy)
|
||||
# Remove from memory.
|
||||
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
|
||||
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
|
||||
del self.cache[delkey]
|
||||
# Write state to disk.
|
||||
with open(self.path + "/" + delkey + self.extension, "wb") as f:
|
||||
pickle.dump(policy_state, file=f)
|
||||
|
||||
def _read_from_disk(self, policy_id):
|
||||
"""Reads a policy ID from disk and re-adds it to the cache.
|
||||
"""
|
||||
# Make sure this policy ID is not in the cache right now.
|
||||
assert policy_id not in self.cache
|
||||
# Read policy state from disk.
|
||||
with open(self.path + "/" + policy_id + self.extension, "rb") as f:
|
||||
policy_state = pickle.load(f)
|
||||
|
||||
# Get class and config override.
|
||||
merged_conf = merge_dicts(self.policy_config,
|
||||
self.policy_specs[policy_id].config)
|
||||
|
||||
# Create policy object (from its spec: cls, obs-space, act-space,
|
||||
# config).
|
||||
self.create_policy(
|
||||
policy_id,
|
||||
self.policy_specs[policy_id].policy_class,
|
||||
self.policy_specs[policy_id].observation_space,
|
||||
self.policy_specs[policy_id].action_space,
|
||||
self.policy_specs[policy_id].config,
|
||||
merged_conf,
|
||||
)
|
||||
# Restore policy's state.
|
||||
policy = self[policy_id]
|
||||
policy.set_state(policy_state)
|
||||
|
||||
def _close_session(self, policy):
|
||||
sess = policy.get_session()
|
||||
# Closes the tf session, if any.
|
||||
if sess is not None:
|
||||
sess.close()
|
|
@ -263,7 +263,8 @@ class TFPolicy(Policy):
|
|||
"`get_placeholder()` can be called"
|
||||
return self._loss_input_dict[name]
|
||||
|
||||
def get_session(self) -> "tf1.Session":
|
||||
@override(Policy)
|
||||
def get_session(self) -> Optional["tf1.Session"]:
|
||||
"""Returns a reference to the TF session for this policy."""
|
||||
return self._sess
|
||||
|
||||
|
@ -305,7 +306,7 @@ class TFPolicy(Policy):
|
|||
|
||||
if self.model:
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
[], self._sess, self.variables())
|
||||
[], self.get_session(), self.variables())
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if not self._update_ops:
|
||||
|
@ -323,12 +324,12 @@ class TFPolicy(Policy):
|
|||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||
summarize(self._loss_input_dict)))
|
||||
|
||||
self._sess.run(tf1.global_variables_initializer())
|
||||
self.get_session().run(tf1.global_variables_initializer())
|
||||
self._optimizer_variables = None
|
||||
if self._optimizer:
|
||||
self._optimizer_variables = \
|
||||
ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._optimizer.variables(), self._sess)
|
||||
self._optimizer.variables(), self.get_session())
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(
|
||||
|
@ -346,7 +347,7 @@ class TFPolicy(Policy):
|
|||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
builder = TFRunBuilder(self.get_session(), "compute_actions")
|
||||
to_fetch = self._build_compute_actions(
|
||||
builder,
|
||||
obs_batch=obs_batch,
|
||||
|
@ -378,7 +379,8 @@ class TFPolicy(Policy):
|
|||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
|
||||
builder = TFRunBuilder(self.get_session(),
|
||||
"compute_actions_from_input_dict")
|
||||
obs_batch = input_dict[SampleBatch.OBS]
|
||||
to_fetch = self._build_compute_actions(
|
||||
builder, input_dict=input_dict, explore=explore, timestep=timestep)
|
||||
|
@ -413,7 +415,7 @@ class TFPolicy(Policy):
|
|||
self.exploration.before_compute_actions(
|
||||
explore=False, tf_sess=self.get_session())
|
||||
|
||||
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
|
||||
builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods")
|
||||
|
||||
# Normalize actions if necessary.
|
||||
if actions_normalized is False and self.config["normalize_actions"]:
|
||||
|
@ -451,7 +453,7 @@ class TFPolicy(Policy):
|
|||
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
assert self.loss_initialized()
|
||||
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
builder = TFRunBuilder(self.get_session(), "learn_on_batch")
|
||||
|
||||
# Callback handling.
|
||||
learn_stats = {}
|
||||
|
@ -470,7 +472,7 @@ class TFPolicy(Policy):
|
|||
postprocessed_batch: SampleBatch) -> \
|
||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
builder = TFRunBuilder(self.get_session(), "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
|
@ -478,7 +480,7 @@ class TFPolicy(Policy):
|
|||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||
assert self.loss_initialized()
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
builder = TFRunBuilder(self.get_session(), "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
|
@ -510,7 +512,7 @@ class TFPolicy(Policy):
|
|||
if self._optimizer_variables and \
|
||||
len(self._optimizer_variables.variables) > 0:
|
||||
state["_optimizer_variables"] = \
|
||||
self._sess.run(self._optimizer_variables.variables)
|
||||
self.get_session().run(self._optimizer_variables.variables)
|
||||
# Add exploration state.
|
||||
state["_exploration_state"] = \
|
||||
self.exploration.get_state(self.get_session())
|
||||
|
@ -546,7 +548,7 @@ class TFPolicy(Policy):
|
|||
"`tf2onnx` to be installed. Install with "
|
||||
"`pip install tf2onnx`.") from e
|
||||
|
||||
with self._sess.graph.as_default():
|
||||
with self.get_session().graph.as_default():
|
||||
signature_def_map = self._build_signature_def()
|
||||
|
||||
sd = signature_def_map[tf1.saved_model.signature_constants.
|
||||
|
@ -558,7 +560,7 @@ class TFPolicy(Policy):
|
|||
frozen_graph_def = tf_loader.freeze_session(
|
||||
self._sess, input_names=inputs, output_names=outputs)
|
||||
|
||||
with tf.compat.v1.Session(graph=tf.Graph()) as session:
|
||||
with tf1.Session(graph=tf.Graph()) as session:
|
||||
tf.import_graph_def(frozen_graph_def, name="")
|
||||
|
||||
g = tf2onnx.tfonnx.process_tf_graph(
|
||||
|
@ -574,14 +576,15 @@ class TFPolicy(Policy):
|
|||
feed_dict={},
|
||||
model_proto=model_proto)
|
||||
else:
|
||||
with self._sess.graph.as_default():
|
||||
with self.get_session().graph.as_default():
|
||||
signature_def_map = self._build_signature_def()
|
||||
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
|
||||
builder.add_meta_graph_and_variables(
|
||||
self._sess, [tf1.saved_model.tag_constants.SERVING],
|
||||
self.get_session(),
|
||||
[tf1.saved_model.tag_constants.SERVING],
|
||||
signature_def_map=signature_def_map,
|
||||
saver=tf1.summary.FileWriter(export_dir).add_graph(
|
||||
graph=self._sess.graph))
|
||||
graph=self.get_session().graph))
|
||||
builder.save()
|
||||
|
||||
# TODO: (sven) Deprecate this in favor of `save()`.
|
||||
|
@ -599,17 +602,17 @@ class TFPolicy(Policy):
|
|||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
save_path = os.path.join(export_dir, filename_prefix)
|
||||
with self._sess.graph.as_default():
|
||||
with self.get_session().graph.as_default():
|
||||
saver = tf1.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
saver.save(self.get_session(), save_path)
|
||||
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def import_model_from_h5(self, import_file: str) -> None:
|
||||
"""Imports weights into tf model."""
|
||||
# Make sure the session is the right one (see issue #7046).
|
||||
with self._sess.graph.as_default():
|
||||
with self._sess.as_default():
|
||||
with self.get_session().graph.as_default():
|
||||
with self.get_session().as_default():
|
||||
return self.model.import_from_h5(import_file)
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -1026,7 +1029,7 @@ class LearningRateSchedule:
|
|||
if self._lr_schedule is not None:
|
||||
new_val = self._lr_schedule.value(global_vars["timestep"])
|
||||
if self.framework == "tf":
|
||||
self._sess.run(
|
||||
self.get_session().run(
|
||||
self._lr_update, feed_dict={self._lr_placeholder: new_val})
|
||||
else:
|
||||
self.cur_lr.assign(new_val, read_value=False)
|
||||
|
@ -1082,7 +1085,7 @@ class EntropyCoeffSchedule:
|
|||
new_val = self._entropy_coeff_schedule.value(
|
||||
global_vars["timestep"])
|
||||
if self.framework == "tf":
|
||||
self._sess.run(
|
||||
self.get_session().run(
|
||||
self._entropy_coeff_update,
|
||||
feed_dict={self._entropy_coeff_placeholder: new_val})
|
||||
else:
|
||||
|
|
|
@ -45,7 +45,7 @@ class TestParameterNoise(unittest.TestCase):
|
|||
|
||||
trainer = trainer_cls(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
pol_sess = getattr(policy, "_sess", None)
|
||||
pol_sess = policy.get_session()
|
||||
# Remove noise that has been added during policy initialization
|
||||
# (exploration.postprocess_trajectory does add noise to measure
|
||||
# the delta).
|
||||
|
@ -110,7 +110,7 @@ class TestParameterNoise(unittest.TestCase):
|
|||
config["explore"] = False
|
||||
trainer = trainer_cls(config=config, env=env)
|
||||
policy = trainer.get_policy()
|
||||
pol_sess = getattr(policy, "_sess", None)
|
||||
pol_sess = policy.get_session()
|
||||
# Remove noise that has been added during policy initialization
|
||||
# (exploration.postprocess_trajectory does add noise to measure
|
||||
# the delta).
|
||||
|
|
|
@ -60,6 +60,31 @@ def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
|
|||
)
|
||||
|
||||
|
||||
def get_tf_eager_cls_if_necessary(orig_cls, config):
|
||||
cls = orig_cls
|
||||
framework = config.get("framework", "tf")
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
if not tf1:
|
||||
raise ImportError("Could not import tensorflow!")
|
||||
if framework in ["tf2", "tfe"]:
|
||||
assert tf1.executing_eagerly()
|
||||
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
|
||||
# Create eager-class.
|
||||
if hasattr(orig_cls, "as_eager"):
|
||||
cls = orig_cls.as_eager()
|
||||
if config.get("eager_tracing"):
|
||||
cls = cls.with_tracing()
|
||||
# Could be some other type of policy.
|
||||
elif not issubclass(orig_cls, TFPolicy):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("This policy does not support eager "
|
||||
"execution: {}".format(orig_cls))
|
||||
return cls
|
||||
|
||||
|
||||
def huber_loss(x, delta=1.0):
|
||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||
return tf.where(
|
||||
|
|
Loading…
Add table
Reference in a new issue