[RLlib] Implement policy_maps (multi-agent case) in RolloutWorkers as LRU caches. (#17031)

This commit is contained in:
Sven Mika 2021-07-19 13:16:03 -04:00 committed by GitHub
parent e0640ad0dc
commit 18d173b172
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 503 additions and 208 deletions

View file

@ -66,7 +66,7 @@ class TestIMPALA(unittest.TestCase):
try:
if fw == "tf":
check(policy._sess.run(policy.cur_lr), 0.0005)
check(policy.get_session().run(policy.cur_lr), 0.0005)
else:
check(policy.cur_lr, 0.0005)
r1 = trainer.train()

View file

@ -36,10 +36,13 @@ class TestTrainer(unittest.TestCase):
"p0": (None, env.observation_space, env.action_space, {}),
},
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
"policy_map_capacity": 2,
},
})
for _ in framework_iterator(config):
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
# refactor PR merged.
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
trainer = pg.PGTrainer(config=config)
r = trainer.train()
self.assertTrue("p0" in r["policy_reward_min"])
@ -62,8 +65,8 @@ class TestTrainer(unittest.TestCase):
)
pol_map = trainer.workers.local_worker().policy_map
self.assertTrue(new_pol is not trainer.get_policy("p0"))
self.assertTrue("p0" in pol_map)
self.assertTrue("p1" in pol_map)
for j in range(i):
self.assertTrue(f"p{j}" in pol_map)
self.assertTrue(len(pol_map) == i + 1)
r = trainer.train()
self.assertTrue("p1" in r["policy_reward_min"])

View file

@ -432,6 +432,13 @@ COMMON_CONFIG: TrainerConfigDict = {
# of (policy_cls, obs_space, act_space, config). This defines the
# observation and action spaces of the policies and any extra config.
"policies": {},
# Keep this many policies in the "policy_map" (before writing
# least-recently used ones to disk/S3).
"policy_map_capacity": 100,
# Where to store overflowing (least-recently used) policies?
# Could be a directory (str) or an S3 location. None for using
# the default output dir.
"policy_map_cache": None,
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional list of policies to train, or None for all policies.
@ -1181,7 +1188,7 @@ class Trainer(Trainable):
local worker).
"""
def fn(worker):
def fn(worker: RolloutWorker):
# `foreach_worker` function: Adds the policy the the worker (and
# maybe changes its policy_mapping_fn - if provided here).
worker.add_policy(

View file

@ -135,13 +135,10 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
if "new_obs" in k:
new_obs_n.append(v)
target_act_sampler_n = [p.target_act_sampler for p in policies.values()]
feed_dict = dict(zip(new_obs_ph_n, new_obs_n))
new_act_n = p.sess.run(target_act_sampler_n, feed_dict)
samples.update(
{"new_actions_%d" % i: new_act
for i, new_act in enumerate(new_act_n)})
for i, p in enumerate(policies.values()):
feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
new_act = p.get_session().run(p.target_act_sampler, feed_dict)
samples.update({"new_actions_%d" % i: new_act})
# Share samples among agents.
policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}

View file

@ -230,7 +230,8 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
# _____ TensorFlow Initialization
self.sess = tf1.get_default_session()
sess = tf1.get_default_session()
assert sess
def _make_loss_inputs(placeholders):
return [(ph.name.split("/")[-1].split(":")[0], ph)
@ -244,7 +245,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
obs_space,
act_space,
config=config,
sess=self.sess,
sess=sess,
obs_input=obs_ph_n[agent_id],
sampled_action=act_sampler,
loss=actor_loss + critic_loss,
@ -254,7 +255,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
del self.view_requirements["prev_actions"]
del self.view_requirements["prev_rewards"]
self.sess.run(tf1.global_variables_initializer())
self.get_session().run(tf1.global_variables_initializer())
# Hard initial update
self.update_target(1.0)
@ -297,11 +298,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
var_list = []
for var in self.vars.values():
var_list += var
return {"_state": self.sess.run(var_list)}
return {"_state": self.get_session().run(var_list)}
@override(TFPolicy)
def set_weights(self, weights):
self.sess.run(
self.get_session().run(
self.update_vars,
feed_dict=dict(zip(self.vars_ph, weights["_state"])))
@ -377,6 +378,6 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
def update_target(self, tau=None):
if tau is not None:
self.sess.run(self.update_target_vars, {self.tau: tau})
self.get_session().run(self.update_target_vars, {self.tau: tau})
else:
self.sess.run(self.update_target_vars)
self.get_session().run(self.update_target_vars)

View file

@ -3,7 +3,7 @@ import logging
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
@ -30,7 +30,7 @@ class SampleCollector(metaclass=ABCMeta):
"""
def __init__(self,
policy_map: Dict[PolicyID, Policy],
policy_map: PolicyMap,
clip_rewards: Union[bool, float],
callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = True,
@ -39,8 +39,7 @@ class SampleCollector(metaclass=ABCMeta):
"""Initializes a SampleCollector instance.
Args:
policy_map (Dict[str, Policy]): Maps policy ids to policy
instances.
policy_map (PolicyMap): Maps policy ids to policy instances.
clip_rewards (Union[bool, float]): Whether to clip rewards before
postprocessing (at +/-1.0) or the actual value to +/- clip.
callbacks (DefaultCallbacks): RLlib callbacks.

View file

@ -9,6 +9,7 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
@ -292,7 +293,7 @@ class _PolicyCollector:
appended to this policy's buffers.
"""
def __init__(self, policy):
def __init__(self, policy: Policy):
"""Initializes a _PolicyCollector instance.
Args:
@ -382,7 +383,7 @@ class SimpleListCollector(SampleCollector):
"""
def __init__(self,
policy_map: Dict[PolicyID, Policy],
policy_map: PolicyMap,
clip_rewards: Union[bool, float],
callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = True,
@ -650,8 +651,7 @@ class SimpleListCollector(SampleCollector):
post_batches[agent_id] = pre_batch
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id],
getattr(policy, "_sess", None))
policy, post_batches[agent_id], policy.get_session())
post_batches[agent_id] = policy.postprocess_trajectory(
post_batches[agent_id], other_batches, episode)

View file

@ -4,7 +4,7 @@ import random
from typing import List, Dict, Callable, Any, TYPE_CHECKING
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
@ -49,9 +49,8 @@ class MultiAgentEpisode:
>>> episode.extra_batches.add(batch.build_and_reset())
"""
def __init__(self, policies: Dict[PolicyID, Policy],
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
PolicyID],
def __init__(self, policies: PolicyMap, policy_mapping_fn: Callable[
[AgentID, "MultiAgentEpisode"], PolicyID],
batch_builder_factory: Callable[
[], "MultiAgentSampleBatchBuilder"],
extra_batch_callback: Callable[[SampleBatchType], None],
@ -71,7 +70,7 @@ class MultiAgentEpisode:
self.user_data: Dict[str, Any] = {}
self.hist_data: Dict[str, List[float]] = {}
self.media: Dict[str, Any] = {}
self.policy_map: Dict[PolicyID, Policy] = policies
self.policy_map: PolicyMap = policies
self._policies = self.policy_map # backward compatibility
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
PolicyID] = policy_mapping_fn

View file

@ -28,6 +28,7 @@ from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import merge_dicts
@ -489,7 +490,6 @@ class RolloutWorker(ParallelIteratorWorker):
self.make_env_fn = make_env
self.tf_sess = None
policy_dict = _determine_spaces_for_multi_agent_dict(
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
# List of IDs of those policies, which should be trained.
@ -498,7 +498,7 @@ class RolloutWorker(ParallelIteratorWorker):
policy_dict.keys())
self.set_policies_to_train(self.policies_to_train)
self.policy_map: Dict[PolicyID, Policy] = None
self.policy_map: PolicyMap = None
self.preprocessors: Dict[PolicyID, Preprocessor] = None
# Set Python random, numpy, env, and torch/tf seeds.
@ -541,26 +541,11 @@ class RolloutWorker(ParallelIteratorWorker):
elif tf1 and policy_config.get("framework") == "tfe":
tf1.set_random_seed(seed)
if _has_tensorflow_graph(policy_dict) and not (
tf1 and tf1.executing_eagerly()):
if not tf1:
raise ImportError("Could not import tensorflow")
with tf1.Graph().as_default():
if tf_session_creator:
self.tf_sess = tf_session_creator()
else:
self.tf_sess = tf1.Session(
config=tf1.ConfigProto(
gpu_options=tf1.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
# set graph-level seed
if seed is not None:
tf1.set_random_seed(seed)
self.policy_map, self.preprocessors = \
self._build_policy_map(policy_dict, policy_config)
else:
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
self._build_policy_map(
policy_dict,
policy_config,
session_creator=tf_session_creator,
seed=seed)
# Update Policy's view requirements from Model, only if Policy directly
# inherited from base `Policy` class. At this point here, the Policy
@ -591,14 +576,13 @@ class RolloutWorker(ParallelIteratorWorker):
self.multiagent: bool = set(
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
if not isinstance(self.env,
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
ray.actor.ActorHandle)):
raise ValueError(
"Have multiple policies {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
f"Have multiple policies {self.policy_map}, but the "
f"env {self.env} is not a subclass of BaseEnv, "
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")
self.filters: Dict[PolicyID, Filter] = {
policy_id: get_filter(self.observation_filter,
@ -678,7 +662,6 @@ class RolloutWorker(ParallelIteratorWorker):
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
@ -701,7 +684,6 @@ class RolloutWorker(ParallelIteratorWorker):
callbacks=self.callbacks,
horizon=episode_horizon,
multiple_episodes_in_batch=pack,
tf_sess=self.tf_sess,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
soft_horizon=soft_horizon,
@ -923,23 +905,24 @@ class RolloutWorker(ParallelIteratorWorker):
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
builders = {}
to_fetch = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
else:
builder = None
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
policy = self.policy_map[pid]
if builder and hasattr(policy, "_build_learn_on_batch"):
tf_session = policy.get_session()
if tf_session and hasattr(policy, "_build_learn_on_batch"):
builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
to_fetch[pid] = policy._build_learn_on_batch(
builder, batch)
builders[pid], batch)
else:
info_out[pid] = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
info_out.update(
{pid: builders[pid].get(v)
for pid, v in to_fetch.items()})
else:
info_out = {
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
@ -1024,12 +1007,15 @@ class RolloutWorker(ParallelIteratorWorker):
return ret
@DeveloperAPI
def get_policy(
self, policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID) -> Policy:
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
"""Return policy for the specified id, or None.
Args:
policy_id (str): id of policy to return.
policy_id (PolicyID): ID of the policy to return.
Returns:
Optional[Policy]: The policy under the given ID (or None if not
found).
"""
return self.policy_map.get(policy_id)
@ -1078,18 +1064,12 @@ class RolloutWorker(ParallelIteratorWorker):
policy_dict = {
policy_id: (policy_cls, observation_space, action_space, config)
}
if self.tf_sess is not None:
with self.tf_sess.graph.as_default():
with self.tf_sess.as_default():
add_map, add_prep = self._build_policy_map(
policy_dict, self.policy_config)
else:
add_map, add_prep = self._build_policy_map(policy_dict,
self.policy_config)
new_policy = add_map[policy_id]
self._build_policy_map(
policy_dict,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]
self.policy_map.update(add_map)
self.preprocessors.update(add_prep)
self.filters[policy_id] = get_filter(
self.observation_filter, new_policy.observation_space.shape)
@ -1301,12 +1281,27 @@ class RolloutWorker(ParallelIteratorWorker):
return func(self, *args)
def _build_policy_map(
self, policy_dict: MultiAgentPolicyConfigDict,
policy_config: TrainerConfigDict
self,
policy_dict: MultiAgentPolicyConfigDict,
policy_config: TrainerConfigDict,
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None,
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
policy_map = {}
preprocessors = {}
for name, (cls, obs_space, act_space,
ma_config = policy_config.get("multiagent", {})
self.policy_map = self.policy_map or PolicyMap(
worker_index=self.worker_index,
num_workers=self.num_workers,
capacity=ma_config.get("policy_map_capacity"),
path=ma_config.get("policy_map_cache"),
policy_config=policy_config,
session_creator=session_creator,
seed=seed,
)
self.preprocessors = self.preprocessors or {}
for name, (orig_cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
logger.debug("Creating policy for {}".format(name))
merged_conf = merge_dicts(policy_config, conf or {})
@ -1315,43 +1310,23 @@ class RolloutWorker(ParallelIteratorWorker):
if self.preprocessing_enabled:
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
preprocessors[name] = preprocessor
self.preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
else:
preprocessors[name] = NoPreprocessor(obs_space)
self.preprocessors[name] = NoPreprocessor(obs_space)
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
raise ValueError(
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
# Tf.
framework = policy_config.get("framework", "tf")
if framework in ["tf2", "tf", "tfe"]:
assert tf1
if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()
if hasattr(cls, "as_eager"):
cls = cls.as_eager()
if policy_config.get("eager_tracing"):
cls = cls.with_tracing()
elif not issubclass(cls, TFPolicy):
pass # could be some other type of policy
else:
raise ValueError("This policy does not support eager "
"execution: {}".format(cls))
scope = name + (("_wk" + str(self.worker_index))
if self.worker_index else "")
with tf1.variable_scope(scope):
policy_map[name] = cls(obs_space, act_space, merged_conf)
# non-tf.
else:
policy_map[name] = cls(obs_space, act_space, merged_conf)
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
conf, merged_conf)
if self.worker_index == 0:
logger.info("Built policy map: {}".format(policy_map))
logger.info("Built preprocessor map: {}".format(preprocessors))
return policy_map, preprocessors
logger.info(f"Built policy map: {self.policy_map}")
logger.info(f"Built preprocessor map: {self.preprocessors}")
def setup_torch_data_parallel(self, url: str, world_rank: int,
world_size: int, backend: str) -> None:

View file

@ -205,8 +205,7 @@ class MultiAgentSampleBatchBuilder:
post_batches[agent_id] = pre_batch
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id],
getattr(policy, "_sess", None))
policy, post_batches[agent_id], policy.get_session())
post_batches[agent_id] = policy.postprocess_trajectory(
post_batches[agent_id], other_batches, episode)

View file

@ -30,7 +30,6 @@ from ray.rllib.utils.filter import Filter
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import clip_action, \
unsquash_action, unbatch
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
TensorStructType
@ -137,7 +136,6 @@ class SyncSampler(SamplerInput):
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
normalize_actions: bool = True,
clip_actions: bool = False,
soft_horizon: bool = False,
@ -150,6 +148,7 @@ class SyncSampler(SamplerInput):
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
tf_sess=None,
):
"""Initializes a SyncSampler object.
@ -168,8 +167,6 @@ class SyncSampler(SamplerInput):
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
framework=tf).
normalize_actions (bool): Whether to normalize actions to the
action space's bounds.
clip_actions (bool): Whether to clip actions according to the
@ -199,6 +196,8 @@ class SyncSampler(SamplerInput):
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
if tf_sess is not None:
deprecation_warning(old="tf_sess")
self.base_env = BaseEnv.to_base_env(env)
self.rollout_fragment_length = rollout_fragment_length
@ -221,7 +220,7 @@ class SyncSampler(SamplerInput):
worker, self.base_env, self.extra_batches.put,
self.rollout_fragment_length, self.horizon, clip_rewards,
normalize_actions, clip_actions, multiple_episodes_in_batch,
callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end,
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
observation_fn, self.sample_collector, self.render)
self.metrics_queue = queue.Queue()
@ -275,7 +274,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
callbacks: "DefaultCallbacks",
horizon: int = None,
multiple_episodes_in_batch: bool = False,
tf_sess=None,
normalize_actions: bool = True,
clip_actions: bool = False,
blackhole_outputs: bool = False,
@ -289,6 +287,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
tf_sess=None,
):
"""Initializes a AsyncSampler object.
@ -309,8 +308,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
multiple_episodes_in_batch (bool): Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
framework=tf).
normalize_actions (bool): Whether to normalize actions to the
action space's bounds.
clip_actions (bool): Whether to clip actions according to the
@ -342,6 +339,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
if tf_sess is not None:
deprecation_warning(old="tf_sess")
self.worker = worker
@ -359,7 +358,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.clip_rewards = clip_rewards
self.daemon = True
self.multiple_episodes_in_batch = multiple_episodes_in_batch
self.tf_sess = tf_sess
self.callbacks = callbacks
self.normalize_actions = normalize_actions
self.clip_actions = clip_actions
@ -400,9 +398,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.worker, self.base_env, extra_batches_putter,
self.rollout_fragment_length, self.horizon, self.clip_rewards,
self.normalize_actions, self.clip_actions,
self.multiple_episodes_in_batch, self.callbacks, self.tf_sess,
self.perf_stats, self.soft_horizon, self.no_done_at_end,
self.observation_fn, self.sample_collector, self.render)
self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
self.soft_horizon, self.no_done_at_end, self.observation_fn,
self.sample_collector, self.render)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@ -458,7 +456,6 @@ def _env_runner(
clip_actions: bool,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
tf_sess: Optional["tf.Session"],
perf_stats: _PerfStats,
soft_horizon: bool,
no_done_at_end: bool,
@ -484,8 +481,6 @@ def _env_runner(
space's bounds.
clip_actions (bool): Whether to clip actions to the space range.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
perf_stats (_PerfStats): Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
@ -566,7 +561,7 @@ def _env_runner(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
tf_sess=p.get_session())
callbacks.on_episode_start(
worker=worker,
base_env=base_env,
@ -627,7 +622,6 @@ def _env_runner(
policy_mapping_fn=worker.policy_mapping_fn,
sample_collector=sample_collector,
active_episodes=active_episodes,
tf_sess=tf_sess,
)
perf_stats.inference_time += time.time() - t2
@ -915,7 +909,7 @@ def _process_observations(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
tf_sess=p.get_session())
# Call custom on_episode_end callback.
callbacks.on_episode_end(
worker=worker,
@ -986,7 +980,6 @@ def _do_policy_eval(
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
sample_collector,
active_episodes: Dict[str, MultiAgentEpisode],
tf_sess: Optional["tf.Session"] = None,
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action.
@ -997,8 +990,6 @@ def _do_policy_eval(
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
obj.
sample_collector (SampleCollector): The SampleCollector object to use.
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
batching TF policy evaluations.
Returns:
eval_results: dict of policy to compute_action() outputs.
@ -1006,12 +997,6 @@ def _do_policy_eval(
eval_results: Dict[PolicyID, TensorStructType] = {}
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
pending_fetches: Dict[PolicyID, Any] = {}
else:
builder = None
if log_once("compute_actions_input"):
logger.info("Inputs to compute_actions():\n\n{}\n".format(
summarize(to_eval)))
@ -1033,11 +1018,6 @@ def _do_policy_eval(
timestep=policy.global_timestep,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
# types: PolicyID, Tuple[TensorStructType, StateBatch, dict]
for pid, v in pending_fetches.items():
eval_results[pid] = builder.get(v)
if log_once("compute_actions_result"):
logger.info("Outputs of compute_actions():\n\n{}\n".format(
summarize(eval_results)))

View file

@ -182,7 +182,7 @@ class TestRolloutWorker(unittest.TestCase):
0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
lr = policy.cur_lr
if fw == "tf":
lr = policy._sess.run(lr)
lr = policy.get_session().run(lr)
check(lr, expected_lr, rtol=0.05)
agent.stop()

View file

@ -78,13 +78,14 @@ class TFMultiGPULearner(LearnerThread):
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
tf_session = self.policy.get_session()
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.par_opt = []
with self.local_worker.tf_sess.graph.as_default():
with self.local_worker.tf_sess.as_default():
with tf_session.graph.as_default():
with tf_session.as_default():
with tf1.variable_scope(
DEFAULT_POLICY_ID, reuse=tf1.AUTO_REUSE):
if self.policy._state_inputs:
@ -106,7 +107,7 @@ class TFMultiGPULearner(LearnerThread):
999999, # it will get rounded down
self.policy.copy))
self.sess = self.local_worker.tf_sess
self.sess = tf_session
self.sess.run(tf1.global_variables_initializer())
self.idle_optimizers = queue.Queue()

View file

@ -148,14 +148,9 @@ class TrainTFMultiGPU:
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.optimizers = {}
with self.workers.local_worker().tf_sess.graph.as_default():
with self.workers.local_worker().tf_sess.as_default():
for policy_id in (self.policies
or self.local_worker.policies_to_train):
self.add_optimizer(policy_id)
self.sess = self.workers.local_worker().tf_sess
self.sess.run(tf1.global_variables_initializer())
for policy_id in (self.policies
or self.local_worker.policies_to_train):
self.add_optimizer(policy_id)
def __call__(self,
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
@ -181,10 +176,7 @@ class TrainTFMultiGPU:
# Policy seems to be new and doesn't have an optimizer yet.
# Add it here and continue.
elif policy_id not in self.optimizers:
with self.workers.local_worker().tf_sess.graph.as_default(
):
with self.workers.local_worker().tf_sess.as_default():
self.add_optimizer(policy_id)
self.add_optimizer(policy_id)
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
@ -200,13 +192,14 @@ class TrainTFMultiGPU:
state_keys = []
num_loaded_tuples[policy_id] = (
self.optimizers[policy_id].load_data(
self.sess, [tuples[k] for k in data_keys],
policy.get_session(), [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys]))
# Execute minibatch SGD on loaded data.
with learn_timer:
fetches = {}
for policy_id, tuples_per_device in num_loaded_tuples.items():
policy = self.workers.local_worker().get_policy(policy_id)
optimizer = self.optimizers[policy_id]
num_batches = max(
1,
@ -217,7 +210,7 @@ class TrainTFMultiGPU:
batch_fetches_all_towers = []
for batch_index in range(num_batches):
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
policy.get_session(), permutation[batch_index] *
self.per_device_batch_size)
batch_fetches_all_towers.append(
@ -250,15 +243,20 @@ class TrainTFMultiGPU:
def add_optimizer(self, policy_id):
policy = self.workers.local_worker().get_policy(policy_id)
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
if policy._state_inputs:
rnn_inputs = policy._state_inputs + [policy._seq_lens]
else:
rnn_inputs = []
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
policy._optimizer, self.devices,
list(policy._loss_input_dict_no_rnn.values()), rnn_inputs,
self.per_device_batch_size, policy.copy))
tf_session = policy.get_session()
with tf_session.graph.as_default():
with tf_session.as_default():
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
if policy._state_inputs:
rnn_inputs = policy._state_inputs + [policy._seq_lens]
else:
rnn_inputs = []
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
policy._optimizer, self.devices,
list(policy._loss_input_dict_no_rnn.values()),
rnn_inputs, self.per_device_batch_size, policy.copy))
tf_session.run(tf1.global_variables_initializer())
def all_tower_reduce(path, *tower_data):

View file

@ -166,8 +166,7 @@ class JsonReader(InputReader):
self.ioctx.worker.policy_map[pid].action_space_struct)
# Re-normalize actions (from env's bounds to 0.0 centered), if
# necessary.
if "actions_in_input_normalized" in cfg and \
cfg["actions_in_input_normalized"] is False:
if cfg.get("actions_in_input_normalized") is False:
if isinstance(batch, SampleBatch):
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], self.ioctx.worker.policy_map[

View file

@ -467,7 +467,7 @@ class DynamicTFPolicy(TFPolicy):
self._optimizer = self.optimizer()
# Test calls depend on variable init, so initialize model first.
self._sess.run(tf1.global_variables_initializer())
self.get_session().run(tf1.global_variables_initializer())
logger.info("Testing `compute_actions` w/ dummy batch.")
actions, state_outs, extra_fetches = \
@ -486,7 +486,8 @@ class DynamicTFPolicy(TFPolicy):
dummy_batch = self._dummy_batch
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
self.exploration.postprocess_trajectory(self, dummy_batch,
self.get_session())
_ = self.postprocess_trajectory(dummy_batch)
# Add new columns automatically to (loss) input_dict.
for key in dummy_batch.added_keys:
@ -592,7 +593,7 @@ class DynamicTFPolicy(TFPolicy):
}
# Initialize again after loss init.
self._sess.run(tf1.global_variables_initializer())
self.get_session().run(tf1.global_variables_initializer())
def _do_loss_init(self, train_batch: SampleBatch):
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)

View file

@ -255,7 +255,6 @@ def build_eager_tf_policy(
self._is_training = False
self._loss_initialized = False
self._sess = None
self._loss = loss_fn
self.batch_divisibility_req = get_batch_divisibility_req(self) if \

View file

@ -461,6 +461,9 @@ class Policy(metaclass=ABCMeta):
def get_weights(self) -> ModelWeights:
"""Returns model weights.
Note: The return value of this method will reside under the "weights"
key in the return value of Policy.get_state().
Returns:
ModelWeights: Serializable copy or view of model weights.
"""
@ -468,7 +471,7 @@ class Policy(metaclass=ABCMeta):
@DeveloperAPI
def set_weights(self, weights: ModelWeights) -> None:
"""Sets model weights.
"""Sets this Policy's model's weights.
Args:
weights (ModelWeights): Serializable copy or view of model weights.
@ -523,12 +526,16 @@ class Policy(metaclass=ABCMeta):
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
"""Returns all local state.
Note: Not to be confused with an RNN model's internal state.
Returns:
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
state.
"""
state = {
# All the policy's weights.
"weights": self.get_weights(),
# The current global timestep.
"global_timestep": self.global_timestep,
}
return state
@ -581,6 +588,16 @@ class Policy(metaclass=ABCMeta):
"""
raise NotImplementedError
@DeveloperAPI
def get_session(self) -> Optional["tf1.Session"]:
"""Returns tf.Session object to use for computing actions or None.
Returns:
Optional[tf1.Session]: The tf Session to use for computing actions
and losses with this policy.
"""
return None
def _create_exploration(self) -> Exploration:
"""Creates the Policy's Exploration object.

292
rllib/policy/policy_map.py Normal file
View file

@ -0,0 +1,292 @@
from collections import deque
import gym
import os
import pickle
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import PartialTrainerConfigDict, \
PolicyID, TrainerConfigDict
from ray.tune.utils.util import merge_dicts
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
tf1, tf, tfv = try_import_tf()
class PolicyMap(dict):
"""Maps policy IDs to Policy objects.
Thereby, keeps n policies in memory and - when capacity is reached -
writes the least recently used to disk. This allows adding 100s of
policies to a Trainer for league-based setups w/o running out of memory.
"""
def __init__(
self,
worker_index: int,
num_workers: int,
capacity: Optional[int] = None,
path: Optional[str] = None,
policy_config: Optional[TrainerConfigDict] = None,
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None,
):
"""Initializes a PolicyMap instance.
Args:
maxlen (int): The maximum number of policies to hold in memory.
The least used ones are written to disk/S3 and retrieved
when needed.
path (str):
"""
super().__init__()
self.worker_index = worker_index
self.num_workers = num_workers
self.session_creator = session_creator
self.seed = seed
# The file extension for stashed policies (that are no longer available
# in-memory but can be reinstated any time from storage).
self.extension = ".policy.pkl"
# Dictionary of keys that may be looked up (cached or not).
self.valid_keys = set()
# The actual cache with the in-memory policy objects.
self.cache = {}
# The doubly-linked list holding the currently in-memory objects.
self.deque = deque(maxlen=capacity or 10)
# The file path where to store overflowing policies.
self.path = path or "."
# The core config to use. Each single policy's config override is
# added on top of this.
self.policy_config = policy_config or {}
# The orig classes/obs+act spaces, and config overrides of the
# Policies.
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
observation_space: gym.Space, action_space: gym.Space,
config_override: PartialTrainerConfigDict,
merged_config: TrainerConfigDict) -> None:
"""Creates a new policy and stores it to the cache.
Args:
policy_id (PolicyID): The policy ID. This is the key under which
the created policy will be stored in this map.
policy_cls (Type[Policy]): The (original) policy class to use.
This may still be altered in case tf-eager (and tracing)
is used.
observation_space (gym.Space): The observation space of the
policy.
action_space (gym.Space): The action space of the policy.
config_override (PartialTrainerConfigDict): The config override
dict for this policy. This is the partial dict provided by
the user.
merged_config (TrainerConfigDict): The entire config (merged
default config + `config_override`).
"""
framework = merged_config.get("framework", "tf")
class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
# Tf.
if framework in ["tf2", "tf", "tfe"]:
var_scope = policy_id + (("_wk" + str(self.worker_index))
if self.worker_index else "")
# For tf static graph, build every policy in its own graph
# and create a new session for it.
if framework == "tf":
with tf1.Graph().as_default():
if self.session_creator:
sess = self.session_creator()
else:
sess = tf1.Session(
config=tf1.ConfigProto(
gpu_options=tf1.GPUOptions(allow_growth=True)))
with sess.as_default():
# Set graph-level seed.
if self.seed is not None:
tf1.set_random_seed(self.seed)
with tf1.variable_scope(var_scope):
self[policy_id] = class_(
observation_space, action_space, merged_config)
# For tf-eager: no graph, no session.
else:
with tf1.variable_scope(var_scope):
self[policy_id] = \
class_(observation_space, action_space, merged_config)
# Non-tf: No graph, no session.
else:
class_ = policy_cls
self[policy_id] = class_(observation_space, action_space,
merged_config)
# Store spec (class, obs-space, act-space, and config overrides) such
# that the map will be able to reproduce on-the-fly added policies
# from disk.
self.policy_specs[policy_id] = PolicySpec(
policy_class=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config_override)
@override(dict)
def __getitem__(self, item):
# Never seen this key -> Error.
if item not in self.valid_keys:
raise KeyError(f"'{item}' not a valid key!")
# Item already in cache -> Rearrange deque (least recently used) and
# return.
if item in self.cache:
self.deque.remove(item)
self.deque.append(item)
# Item not currently in cache -> Get from disk and - if at capacity -
# remove leftmost one.
else:
self._read_from_disk(policy_id=item)
return self.cache[item]
@override(dict)
def __setitem__(self, key, value):
# Item already in cache -> Rearrange deque (least recently used).
if key in self.cache:
self.deque.remove(key)
self.deque.append(key)
self.cache[key] = value
# Item not currently in cache -> store new value and - if at capacity -
# remove leftmost one.
else:
# Cache at capacity -> Drop leftmost item.
if len(self.deque) == self.deque.maxlen:
self._stash_to_disk()
self.deque.append(key)
self.cache[key] = value
self.valid_keys.add(key)
@override(dict)
def __delitem__(self, key):
# Make key invalid.
self.valid_keys.remove(key)
# Remove policy from memory if currently cached.
if key in self.cache:
policy = self.cache[key]
self._close_session(policy)
del self.cache[key]
# Remove file associated with the policy, if it exists.
filename = self.path + "/" + key + self.extension
if os.path.isfile(filename):
os.remove(filename)
@override(dict)
def __iter__(self):
return self.keys()
@override(dict)
def items(self):
"""Iterates over all policies, even the stashed-to-disk ones."""
def gen():
for key in self.valid_keys:
yield (key, self[key])
return gen()
@override(dict)
def keys(self):
def gen():
for key in self.valid_keys:
yield key
return gen()
@override(dict)
def values(self):
def gen():
for key in self.valid_keys:
yield self[key]
return gen()
@override(dict)
def update(self, __m, **kwargs):
for k, v in __m.items():
self[k] = v
for k, v in kwargs.items():
self[k] = v
@override(dict)
def get(self, key):
if key not in self.valid_keys:
return None
return self[key]
@override(dict)
def __len__(self):
"""Returns number of all policies, including the stashed-to-disk ones.
"""
return len(self.valid_keys)
@override(dict)
def __contains__(self, item):
return item in self.valid_keys
def _stash_to_disk(self):
"""Writes the least-recently used policy to disk and rearranges cache.
Also closes the session - if applicable - of the stashed policy.
"""
# Get least recently used policy (all the way on the left in deque).
delkey = self.deque.popleft()
policy = self.cache[delkey]
# Get its state for writing to disk.
policy_state = policy.get_state()
# Closes policy's tf session, if any.
self._close_session(policy)
# Remove from memory.
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
del self.cache[delkey]
# Write state to disk.
with open(self.path + "/" + delkey + self.extension, "wb") as f:
pickle.dump(policy_state, file=f)
def _read_from_disk(self, policy_id):
"""Reads a policy ID from disk and re-adds it to the cache.
"""
# Make sure this policy ID is not in the cache right now.
assert policy_id not in self.cache
# Read policy state from disk.
with open(self.path + "/" + policy_id + self.extension, "rb") as f:
policy_state = pickle.load(f)
# Get class and config override.
merged_conf = merge_dicts(self.policy_config,
self.policy_specs[policy_id].config)
# Create policy object (from its spec: cls, obs-space, act-space,
# config).
self.create_policy(
policy_id,
self.policy_specs[policy_id].policy_class,
self.policy_specs[policy_id].observation_space,
self.policy_specs[policy_id].action_space,
self.policy_specs[policy_id].config,
merged_conf,
)
# Restore policy's state.
policy = self[policy_id]
policy.set_state(policy_state)
def _close_session(self, policy):
sess = policy.get_session()
# Closes the tf session, if any.
if sess is not None:
sess.close()

View file

@ -263,7 +263,8 @@ class TFPolicy(Policy):
"`get_placeholder()` can be called"
return self._loss_input_dict[name]
def get_session(self) -> "tf1.Session":
@override(Policy)
def get_session(self) -> Optional["tf1.Session"]:
"""Returns a reference to the TF session for this policy."""
return self._sess
@ -305,7 +306,7 @@ class TFPolicy(Policy):
if self.model:
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
[], self._sess, self.variables())
[], self.get_session(), self.variables())
# gather update ops for any batch norm layers
if not self._update_ops:
@ -323,12 +324,12 @@ class TFPolicy(Policy):
"These tensors were used in the loss_fn:\n\n{}\n".format(
summarize(self._loss_input_dict)))
self._sess.run(tf1.global_variables_initializer())
self.get_session().run(tf1.global_variables_initializer())
self._optimizer_variables = None
if self._optimizer:
self._optimizer_variables = \
ray.experimental.tf_utils.TensorFlowVariables(
self._optimizer.variables(), self._sess)
self._optimizer.variables(), self.get_session())
@override(Policy)
def compute_actions(
@ -346,7 +347,7 @@ class TFPolicy(Policy):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
builder = TFRunBuilder(self._sess, "compute_actions")
builder = TFRunBuilder(self.get_session(), "compute_actions")
to_fetch = self._build_compute_actions(
builder,
obs_batch=obs_batch,
@ -378,7 +379,8 @@ class TFPolicy(Policy):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
builder = TFRunBuilder(self.get_session(),
"compute_actions_from_input_dict")
obs_batch = input_dict[SampleBatch.OBS]
to_fetch = self._build_compute_actions(
builder, input_dict=input_dict, explore=explore, timestep=timestep)
@ -413,7 +415,7 @@ class TFPolicy(Policy):
self.exploration.before_compute_actions(
explore=False, tf_sess=self.get_session())
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods")
# Normalize actions if necessary.
if actions_normalized is False and self.config["normalize_actions"]:
@ -451,7 +453,7 @@ class TFPolicy(Policy):
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "learn_on_batch")
builder = TFRunBuilder(self.get_session(), "learn_on_batch")
# Callback handling.
learn_stats = {}
@ -470,7 +472,7 @@ class TFPolicy(Policy):
postprocessed_batch: SampleBatch) -> \
Tuple[ModelGradients, Dict[str, TensorType]]:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "compute_gradients")
builder = TFRunBuilder(self.get_session(), "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@ -478,7 +480,7 @@ class TFPolicy(Policy):
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "apply_gradients")
builder = TFRunBuilder(self.get_session(), "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@ -510,7 +512,7 @@ class TFPolicy(Policy):
if self._optimizer_variables and \
len(self._optimizer_variables.variables) > 0:
state["_optimizer_variables"] = \
self._sess.run(self._optimizer_variables.variables)
self.get_session().run(self._optimizer_variables.variables)
# Add exploration state.
state["_exploration_state"] = \
self.exploration.get_state(self.get_session())
@ -546,7 +548,7 @@ class TFPolicy(Policy):
"`tf2onnx` to be installed. Install with "
"`pip install tf2onnx`.") from e
with self._sess.graph.as_default():
with self.get_session().graph.as_default():
signature_def_map = self._build_signature_def()
sd = signature_def_map[tf1.saved_model.signature_constants.
@ -558,7 +560,7 @@ class TFPolicy(Policy):
frozen_graph_def = tf_loader.freeze_session(
self._sess, input_names=inputs, output_names=outputs)
with tf.compat.v1.Session(graph=tf.Graph()) as session:
with tf1.Session(graph=tf.Graph()) as session:
tf.import_graph_def(frozen_graph_def, name="")
g = tf2onnx.tfonnx.process_tf_graph(
@ -574,14 +576,15 @@ class TFPolicy(Policy):
feed_dict={},
model_proto=model_proto)
else:
with self._sess.graph.as_default():
with self.get_session().graph.as_default():
signature_def_map = self._build_signature_def()
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
self._sess, [tf1.saved_model.tag_constants.SERVING],
self.get_session(),
[tf1.saved_model.tag_constants.SERVING],
signature_def_map=signature_def_map,
saver=tf1.summary.FileWriter(export_dir).add_graph(
graph=self._sess.graph))
graph=self.get_session().graph))
builder.save()
# TODO: (sven) Deprecate this in favor of `save()`.
@ -599,17 +602,17 @@ class TFPolicy(Policy):
if e.errno != errno.EEXIST:
raise
save_path = os.path.join(export_dir, filename_prefix)
with self._sess.graph.as_default():
with self.get_session().graph.as_default():
saver = tf1.train.Saver()
saver.save(self._sess, save_path)
saver.save(self.get_session(), save_path)
@override(Policy)
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
"""Imports weights into tf model."""
# Make sure the session is the right one (see issue #7046).
with self._sess.graph.as_default():
with self._sess.as_default():
with self.get_session().graph.as_default():
with self.get_session().as_default():
return self.model.import_from_h5(import_file)
@DeveloperAPI
@ -1026,7 +1029,7 @@ class LearningRateSchedule:
if self._lr_schedule is not None:
new_val = self._lr_schedule.value(global_vars["timestep"])
if self.framework == "tf":
self._sess.run(
self.get_session().run(
self._lr_update, feed_dict={self._lr_placeholder: new_val})
else:
self.cur_lr.assign(new_val, read_value=False)
@ -1082,7 +1085,7 @@ class EntropyCoeffSchedule:
new_val = self._entropy_coeff_schedule.value(
global_vars["timestep"])
if self.framework == "tf":
self._sess.run(
self.get_session().run(
self._entropy_coeff_update,
feed_dict={self._entropy_coeff_placeholder: new_val})
else:

View file

@ -45,7 +45,7 @@ class TestParameterNoise(unittest.TestCase):
trainer = trainer_cls(config=config, env=env)
policy = trainer.get_policy()
pol_sess = getattr(policy, "_sess", None)
pol_sess = policy.get_session()
# Remove noise that has been added during policy initialization
# (exploration.postprocess_trajectory does add noise to measure
# the delta).
@ -110,7 +110,7 @@ class TestParameterNoise(unittest.TestCase):
config["explore"] = False
trainer = trainer_cls(config=config, env=env)
policy = trainer.get_policy()
pol_sess = getattr(policy, "_sess", None)
pol_sess = policy.get_session()
# Remove noise that has been added during policy initialization
# (exploration.postprocess_trajectory does add noise to measure
# the delta).

View file

@ -60,6 +60,31 @@ def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
)
def get_tf_eager_cls_if_necessary(orig_cls, config):
cls = orig_cls
framework = config.get("framework", "tf")
if framework in ["tf2", "tf", "tfe"]:
if not tf1:
raise ImportError("Could not import tensorflow!")
if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()
from ray.rllib.policy.tf_policy import TFPolicy
# Create eager-class.
if hasattr(orig_cls, "as_eager"):
cls = orig_cls.as_eager()
if config.get("eager_tracing"):
cls = cls.with_tracing()
# Could be some other type of policy.
elif not issubclass(orig_cls, TFPolicy):
pass
else:
raise ValueError("This policy does not support eager "
"execution: {}".format(orig_cls))
return cls
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(