[RLlib] Tf2 + eager-tracing same speed as framework=tf; Add more test coverage for tf2+tracing. (#19981)

This commit is contained in:
Sven Mika 2021-11-05 16:10:00 +01:00 committed by GitHub
parent 1341bb59bf
commit a931076f59
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 519 additions and 386 deletions

View file

@ -148,7 +148,7 @@ py_test(
)
py_test(
name = "run_regression_tests_frozenlake_appo",
name = "learning_frozenlake_appo",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
size = "large",

View file

@ -24,8 +24,8 @@ class TestA2C(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config):
for env in ["PongDeterministic-v0"]:
for _ in framework_iterator(config, with_eager_tracing=True):
for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]:
trainer = a3c.A2CTrainer(config=config, env=env)
for i in range(num_iterations):
results = trainer.train()

View file

@ -27,7 +27,7 @@ class TestA3C(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
for env in ["CartPole-v1", "Pendulum-v1", "PongDeterministic-v0"]:
print("env={}".format(env))
config["model"]["use_lstm"] = env == "CartPole-v1"

View file

@ -66,7 +66,7 @@ class TestCQL(unittest.TestCase):
num_iterations = 4
# Test for tf/torch frameworks.
for fw in framework_iterator(config):
for fw in framework_iterator(config, with_eager_tracing=True):
trainer = cql.CQLTrainer(config=config)
for i in range(num_iterations):
results = trainer.train()

View file

@ -24,7 +24,7 @@ class TestApexDDPG(unittest.TestCase):
config["learning_starts"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
num_iterations = 1
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
plain_config = config.copy()
trainer = apex_ddpg.ApexDDPGTrainer(
config=plain_config, env="Pendulum-v1")

View file

@ -41,7 +41,7 @@ class TestDDPG(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
for i in range(num_iterations):
results = trainer.train()

View file

@ -44,7 +44,7 @@ class TestApexDQN(unittest.TestCase):
config["min_iter_time_s"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
plain_config = config.copy()
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")

View file

@ -34,7 +34,7 @@ class TestSimpleQ(unittest.TestCase):
num_iterations = 2
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
rw = trainer.workers.local_worker()
for i in range(num_iterations):

View file

@ -30,7 +30,7 @@ class TestIMPALA(unittest.TestCase):
num_iterations = 1
env = "CartPole-v0"
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
local_cfg = config.copy()
for lstm in [False, True]:
local_cfg["num_aggregation_workers"] = 0 if not lstm else 1

View file

@ -24,7 +24,7 @@ class TestAPPO(unittest.TestCase):
config["num_workers"] = 1
num_iterations = 2
for _ in framework_iterator(config):
for _ in framework_iterator(config, with_eager_tracing=True):
print("w/o v-trace")
_config = config.copy()
_config["vtrace"] = False

View file

@ -106,7 +106,7 @@ class TestPPO(unittest.TestCase):
config["compress_observations"] = True
num_iterations = 2
for fw in framework_iterator(config):
for fw in framework_iterator(config, with_eager_tracing=True):
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:

View file

@ -230,14 +230,23 @@ COMMON_CONFIG: TrainerConfigDict = {
# === Deep Learning Framework Settings ===
# tf: TensorFlow (static-graph)
# tf2: TensorFlow 2.x (eager)
# tfe: TensorFlow eager
# tf2: TensorFlow 2.x (eager or traced, if eager_tracing=True)
# tfe: TensorFlow eager (or traced, if eager_tracing=True)
# torch: PyTorch
"framework": "tf",
# Enable tracing in eager mode. This greatly improves performance, but
# makes it slightly harder to debug since Python code won't be evaluated
# after the initial eager pass. Only possible if framework=tfe.
# Enable tracing in eager mode. This greatly improves performance
# (speedup ~2x), but makes it slightly harder to debug since Python
# code won't be evaluated after the initial eager pass.
# Only possible if framework=[tf2|tfe].
"eager_tracing": False,
# Maximum number of tf.function re-traces before a runtime error is raised.
# This is to prevent unnoticed retraces of methods inside the
# `..._eager_traced` Policy, which could slow down execution by a
# factor of 4, without the user noticing what the root cause for this
# slowdown could be.
# Only necessary for framework=[tf2|tfe].
# Set to None to ignore the re-trace count and never throw an error.
"eager_max_retraces": 20,
# === Exploration Settings ===
# Default exploration behavior, iff `explore`=None is passed into

View file

@ -10,8 +10,8 @@ class RandomEnv(gym.Env):
Can be instantiated with arbitrary action-, observation-, and reward
spaces. Observations and rewards are generated by simply sampling from the
observation/reward spaces. The probability of a `done=True` can be
configured as well.
observation/reward spaces. The probability of a `done=True` after each
action can be configured, as well as the max episode length.
"""
def __init__(self, config=None):
@ -26,8 +26,13 @@ class RandomEnv(gym.Env):
"reward_space",
gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32))
# Chance that an episode ends at any step.
# Note that a max episode length can be specified via
# `max_episode_len`.
self.p_done = config.get("p_done", 0.1)
# A max episode length.
# A max episode length. Even if the `p_done` sampling does not lead
# to a terminus, the episode will end after at most this many
# timesteps.
# Set to 0 or None for using no limit on the episode length.
self.max_episode_len = config.get("max_episode_len", None)
# Whether to check action bounds.
self.check_action_bounds = config.get("check_action_bounds", False)
@ -49,11 +54,10 @@ class RandomEnv(gym.Env):
self.steps += 1
done = False
# We are done as per our max-episode-len.
if self.max_episode_len is not None and \
self.steps >= self.max_episode_len:
# We are `done` as per our max-episode-len.
if self.max_episode_len and self.steps >= self.max_episode_len:
done = True
# Max not reached yet -> Sample done via p_done.
# Max episode length not reached yet -> Sample `done` via `p_done`.
elif self.p_done > 0.0:
done = bool(
np.random.choice(

View file

@ -196,7 +196,8 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
def forward(self, input_dict, state, seq_lens):
# Set the correct train-mode for our hidden module (only important
# b/c we have some batch-norm layers).
self._hidden_layers.train(mode=input_dict.get("is_training", False))
self._hidden_layers.train(
mode=bool(input_dict.get("is_training", False)))
self._hidden_out = self._hidden_layers(input_dict["obs"])
logits = self._logits(self._hidden_out)
return logits, []

View file

@ -1,5 +1,11 @@
from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("LocalSyncParallelOptimizer", "TFMultiGPUTowerStack")
# Backward compatibility.
deprecation_warning(
old="ray.rllib.execution.multi_gpu_impl.LocalSyncParallelOptimizer",
new="ray.rllib.policy.dynamic_tf_policy.TFMultiGPUTowerStack",
error=False,
)
# Old name.
LocalSyncParallelOptimizer = TFMultiGPUTowerStack

View file

@ -2,6 +2,12 @@ from ray.rllib.execution.multi_gpu_learner_thread import \
MultiGPULearnerThread, _MultiGPULoaderThread
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning("multi_gpu_learner.py", "multi_gpu_learner_thread.py")
# Backward compatibility.
deprecation_warning(
old="ray.rllib.execution.multi_gpu_learner.py",
new="ray.rllib.execution.multi_gpu_learner_thread.py",
error=False,
)
# Old names.
TFMultiGPULearner = MultiGPULearnerThread
_LoaderThread = _MultiGPULoaderThread

View file

@ -243,6 +243,11 @@ class ModelV2:
with self.context():
res = self.forward(restored, state or [], seq_lens)
if isinstance(input_dict, SampleBatch):
input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}
input_dict.deleted_keys = restored.deleted_keys
input_dict.added_keys = restored.added_keys - {"obs_flat"}
if ((not isinstance(res, list) and not isinstance(res, tuple))
or len(res) != 2):
raise ValueError(

View file

@ -250,7 +250,7 @@ class DynamicTFPolicy(TFPolicy):
True, (), name="is_exploring")
# Placeholder for `is_training` flag.
self._input_dict.is_training = self._get_is_training_placeholder()
self._input_dict.set_training(self._get_is_training_placeholder())
# Multi-GPU towers do not need any action computing/exploration
# graphs.
@ -464,7 +464,7 @@ class DynamicTFPolicy(TFPolicy):
buffer_index: int = 0,
) -> int:
# Set the is_training flag of the batch.
batch.is_training = True
batch.set_training(True)
# Shortcut for 1 CPU only: Store batch in
# `self._loaded_single_cpu_batch`.

View file

@ -44,9 +44,11 @@ def _convert_to_tf(x, dtype=None):
if x is not None:
d = dtype
x = tf.nest.map_structure(
return tf.nest.map_structure(
lambda f: _convert_to_tf(f, d) if isinstance(f, RepeatedValues)
else tf.convert_to_tensor(f, d) if f is not None else None, x)
else tf.convert_to_tensor(f, d) if
f is not None and not tf.is_tensor(f) else f, x)
return x
@ -101,32 +103,39 @@ def _disallow_var_creation(next_creator, **kw):
"model initialization: {}".format(v.name))
def traced_eager_policy(eager_policy_cls):
"""Wrapper that enables tracing for all eager policy methods.
def check_too_many_retraces(obj):
"""Asserts that a given number of re-traces is not breached."""
This is enabled by the --trace / "eager_tracing" config."""
def _func(self_, *args, **kwargs):
if self_.config.get("eager_max_retraces") is not None and \
self_._re_trace_counter > self_.config["eager_max_retraces"]:
raise RuntimeError(
"Too many tf-eager re-traces detected! This could lead to"
" significant slow-downs (even slower than running in "
"tf-eager mode w/ `eager_tracing=False`). To switch off "
"these re-trace counting checks, set `eager_max_retraces`"
" in your config to None.")
return obj(self_, *args, **kwargs)
return _func
def traced_eager_policy(eager_policy_cls):
"""Wrapper class that enables tracing for all eager policy methods.
This is enabled by the `--trace`/`eager_tracing=True` config when
framework=[tf2|tfe].
"""
class TracedEagerPolicy(eager_policy_cls):
def __init__(self, *args, **kwargs):
self._traced_learn_on_batch = None
self._traced_compute_action_helper = False
self._traced_compute_gradients = None
self._traced_apply_gradients = None
self._traced_learn_on_batch_helper = False
self._traced_compute_actions_helper = False
self._traced_compute_gradients_helper = False
self._traced_apply_gradients_helper = False
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
@override(eager_policy_cls)
@convert_eager_inputs
@convert_eager_outputs
def _learn_on_batch_eager(self, samples):
if self._traced_learn_on_batch is None:
self._traced_learn_on_batch = tf.function(
super(TracedEagerPolicy, self)._learn_on_batch_eager,
autograph=False,
experimental_relax_shapes=True)
return self._traced_learn_on_batch(samples)
@check_too_many_retraces
@override(Policy)
def compute_actions_from_input_dict(
self,
@ -136,16 +145,20 @@ def traced_eager_policy(eager_policy_cls):
episodes: Optional[List[Episode]] = None,
**kwargs
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
"""Traced version of Policy.compute_actions_from_input_dict."""
# Create a traced version of `self._compute_action_helper`.
if self._traced_compute_action_helper is False:
self._compute_action_helper = convert_eager_inputs(
# Create a traced version of `self._compute_actions_helper`.
if self._traced_compute_actions_helper is False and \
not self._no_tracing:
self._compute_actions_helper = convert_eager_inputs(
tf.function(
super(TracedEagerPolicy, self)._compute_action_helper,
super(TracedEagerPolicy, self)._compute_actions_helper,
autograph=False,
experimental_relax_shapes=True))
self._traced_compute_action_helper = True
self._traced_compute_actions_helper = True
# Now that the helper method is traced, call super's
# apply_gradients (which will call the traced helper).
return super(TracedEagerPolicy, self).\
compute_actions_from_input_dict(
input_dict=input_dict,
@ -155,55 +168,67 @@ def traced_eager_policy(eager_policy_cls):
**kwargs,
)
@check_too_many_retraces
@override(eager_policy_cls)
@convert_eager_outputs
def _compute_gradients_eager(self, samples: SampleBatch) -> \
ModelGradients:
"""Traced version of EagerTFPolicy's `_compute_gradients_eager`.
def learn_on_batch(self, samples):
"""Traced version of Policy.learn_on_batch."""
Note that `samples` is already zero-padded and has the is_training
flag set. We convert this sample batch into tensors here and do the
actual computation.
Args:
samples: The SampleBatch to compute gradients for.
Returns:
The computed model gradients.
"""
# Have we traced the `_compute_gradients_eager` function yet?
if self._traced_compute_gradients is None:
self._traced_compute_gradients = convert_eager_inputs(
# Create a traced version of `self._learn_on_batch_helper`.
if self._traced_learn_on_batch_helper is False and \
not self._no_tracing:
self._learn_on_batch_helper = convert_eager_inputs(
tf.function(
super()._compute_gradients_eager,
super(TracedEagerPolicy, self)._learn_on_batch_helper,
autograph=False,
experimental_relax_shapes=True))
self._traced_learn_on_batch_helper = True
# Call the only-once compiled traced function with the SampleBatch
# (will be converted to tensors beforehand).
return self._traced_compute_gradients(samples)
# Now that the helper method is traced, call super's
# apply_gradients (which will call the traced helper).
return super(TracedEagerPolicy, self).learn_on_batch(samples)
@check_too_many_retraces
@override(eager_policy_cls)
def compute_gradients(self, samples: SampleBatch) -> \
ModelGradients:
"""Traced version of Policy.compute_gradients."""
# Create a traced version of `self._compute_gradients_helper`.
if self._traced_compute_gradients_helper is False and \
not self._no_tracing:
self._compute_gradients_helper = convert_eager_inputs(
tf.function(
super(TracedEagerPolicy,
self)._compute_gradients_helper,
autograph=False,
experimental_relax_shapes=True))
self._traced_compute_gradients_helper = True
# Now that the helper method is traced, call super's
# apply_gradients (which will call the traced helper).
return super(TracedEagerPolicy, self).compute_gradients(samples)
@check_too_many_retraces
@override(Policy)
@convert_eager_outputs
def apply_gradients(self, grads: ModelGradients) -> None:
"""Traced version of EagerTFPolicy's `apply_gradients` method.
"""Traced version of Policy.apply_gradients."""
Args:
grads: The ModelGradients to apply using our optimizer(s).
"""
# Have we traced the `apply_gradients` function yet?
if self._traced_apply_gradients is None:
self._traced_apply_gradients = tf.function(
super().apply_gradients,
autograph=False,
experimental_relax_shapes=True)
# Create a traced version of `self._apply_gradients_helper`.
if self._traced_apply_gradients_helper is False and \
not self._no_tracing:
self._apply_gradients_helper = convert_eager_inputs(
tf.function(
super(TracedEagerPolicy, self)._apply_gradients_helper,
autograph=False,
experimental_relax_shapes=True))
self._traced_apply_gradients_helper = True
# Call the only-once compiled traced function with the grads
# (will be converted to tensors beforehand).
return self._traced_apply_gradients(grads)
# Now that the helper method is traced, call super's
# apply_gradients (which will call the traced helper).
return super(TracedEagerPolicy, self).apply_gradients(grads)
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
return TracedEagerPolicy
@ -289,9 +314,20 @@ def build_eager_tf_policy(
worker_idx if worker_idx > 0 else "local"))
self._is_training = False
self._loss_initialized = False
# Only for `config.eager_tracing=True`: A counter to keep track of
# how many times an eager-traced method (e.g.
# `self._compute_actions_helper`) has been re-traced by tensorflow.
# We will raise an error if more than n re-tracings have been
# detected, since this would considerably slow down execution.
# The variable below should only get incremented during the
# tf.function trace operations, never when calling the already
# traced function after that.
self._re_trace_counter = 0
self._loss_initialized = False
self._loss = loss_fn
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
callable(get_batch_divisibility_req) else \
(get_batch_divisibility_req or 1)
@ -376,107 +412,6 @@ def build_eager_tf_policy(
# Got to reset global_timestep again after fake run-throughs.
self.global_timestep = 0
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
return sample_batch
@with_lock
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self,
train_batch=postprocessed_batch,
result=learn_stats)
if not isinstance(postprocessed_batch, SampleBatch) or \
not postprocessed_batch.zero_padded:
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
max_seq_len=self._max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
self._is_training = True
postprocessed_batch.is_training = True
stats = self._learn_on_batch_eager(postprocessed_batch)
stats.update({"custom_metrics": learn_stats})
return stats
@convert_eager_inputs
@convert_eager_outputs
def _learn_on_batch_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
self._apply_gradients(grads_and_vars)
return stats
@override(Policy)
def compute_gradients(self, samples):
pad_batch_to_sequences_of_same_size(
samples,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
self._is_training = True
samples.is_training = True
return self._compute_gradients_eager(samples)
@convert_eager_inputs
@convert_eager_outputs
def _compute_gradients_eager(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
grads = [g for g, v in grads_and_vars]
return grads, stats
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
# Create input dict to simply pass the entire call to
# self.compute_actions_from_input_dict().
input_dict = SampleBatch(
{
SampleBatch.CUR_OBS: obs_batch,
},
_is_training=tf.constant(False))
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
return self.compute_actions_from_input_dict(
input_dict=input_dict,
explore=explore,
timestep=timestep,
episodes=episodes,
**kwargs,
)
@override(Policy)
def compute_actions_from_input_dict(
self,
@ -502,7 +437,8 @@ def build_eager_tf_policy(
# Pass lazy (eager) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
input_dict.is_training = False
input_dict.set_training(False)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
@ -514,7 +450,7 @@ def build_eager_tf_policy(
self.exploration.before_compute_actions(
timestep=timestep, explore=explore, tf_sess=self.get_session())
ret = self._compute_action_helper(
ret = self._compute_actions_helper(
input_dict,
state_batches,
# TODO: Passing episodes into a traced method does not work.
@ -525,9 +461,254 @@ def build_eager_tf_policy(
self.global_timestep += int(tree.flatten(ret[0])[0].shape[0])
return convert_to_numpy(ret)
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
# Create input dict to simply pass the entire call to
# self.compute_actions_from_input_dict().
input_dict = SampleBatch(
{
SampleBatch.CUR_OBS: obs_batch,
},
_is_training=tf.constant(False))
if state_batches is not None:
for s in enumerate(state_batches):
input_dict["state_in_{i}"] = s
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
if info_batch is not None:
input_dict[SampleBatch.INFOS] = info_batch
return self.compute_actions_from_input_dict(
input_dict=input_dict,
explore=explore,
timestep=timestep,
episodes=episodes,
**kwargs,
)
@with_lock
def _compute_action_helper(self, input_dict, state_batches, episodes,
explore, timestep):
@override(Policy)
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
actions_normalized=True,
):
if action_sampler_fn and action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
"`action_sampler_fn`!")
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
input_batch = SampleBatch(
{
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)
},
_is_training=False)
if prev_action_batch is not None:
input_batch[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
if prev_reward_batch is not None:
input_batch[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if action_distribution_fn:
dist_inputs, dist_class, _ = action_distribution_fn(
self,
self.model,
input_batch,
explore=False,
is_training=False)
# Default log-likelihood calculation.
else:
dist_inputs, _ = self.model(input_batch, state_batches,
seq_lens)
dist_class = self.dist_class
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
return sample_batch
@with_lock
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self,
train_batch=postprocessed_batch,
result=learn_stats)
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
max_seq_len=self._max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
self._is_training = True
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
postprocessed_batch.set_training(True)
stats = self._learn_on_batch_helper(postprocessed_batch)
stats.update({"custom_metrics": learn_stats})
return convert_to_numpy(stats)
@override(Policy)
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
Tuple[ModelGradients, Dict[str, TensorType]]:
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
shuffle=False,
max_seq_len=self._max_seq_len,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
self._is_training = True
self._lazy_tensor_dict(postprocessed_batch)
postprocessed_batch.set_training(True)
grads_and_vars, grads, stats = self._compute_gradients_helper(
postprocessed_batch)
return convert_to_numpy((grads, stats))
@override(Policy)
def apply_gradients(self, gradients: ModelGradients) -> None:
self._apply_gradients_helper(
list(
zip([(tf.convert_to_tensor(g)
if g is not None else None) for g in gradients],
self.model.trainable_variables())))
@override(Policy)
def get_weights(self, as_dict=False):
variables = self.variables()
if as_dict:
return {v.name: v.numpy() for v in variables}
return [v.numpy() for v in variables]
@override(Policy)
def set_weights(self, weights):
variables = self.variables()
assert len(weights) == len(variables), (len(weights),
len(variables))
for v, w in zip(variables, weights):
v.assign(w)
@override(Policy)
def get_exploration_state(self):
return convert_to_numpy(self.exploration.get_state())
@override(Policy)
def is_recurrent(self):
return self._is_recurrent
@override(Policy)
def num_state_tensors(self):
return len(self._state_inputs)
@override(Policy)
def get_initial_state(self):
if hasattr(self, "model"):
return self.model.get_initial_state()
return []
@override(Policy)
def get_state(self):
state = super().get_state()
if self._optimizer and \
len(self._optimizer.variables()) > 0:
state["_optimizer_variables"] = \
self._optimizer.variables()
# Add exploration state.
state["_exploration_state"] = self.exploration.get_state()
return state
@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.get("_optimizer_variables", None)
if optimizer_vars and self._optimizer.variables():
logger.warning(
"Cannot restore an optimizer's state for tf eager! Keras "
"is not able to save the v1.x optimizers (from "
"tf.compat.v1.train) since they aren't compatible with "
"checkpoints.")
for opt_var, value in zip(self._optimizer.variables(),
optimizer_vars):
opt_var.assign(value)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(state=state["_exploration_state"])
# Then the Policy's (NN) weights.
super().set_state(state)
@override(Policy)
def export_checkpoint(self, export_dir):
raise NotImplementedError # TODO: implement this
@override(Policy)
def export_model(self, export_dir):
raise NotImplementedError # TODO: implement this
def variables(self):
"""Return the list of all savable variables for this policy."""
if isinstance(self.model, tf.keras.Model):
return self.model.variables
else:
return self.model.variables()
def loss_initialized(self):
return self._loss_initialized
@with_lock
def _compute_actions_helper(self, input_dict, state_batches, episodes,
explore, timestep):
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when
# calling the already traced function after that.
self._re_trace_counter += 1
# Calculate RNN sequence lengths.
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
@ -612,166 +793,32 @@ def build_eager_tf_policy(
return actions, state_out, extra_fetches
@with_lock
@override(Policy)
def compute_log_likelihoods(
self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
actions_normalized=True,
):
if action_sampler_fn and action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
"`action_sampler_fn`!")
def _learn_on_batch_helper(self, samples):
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when
# calling the already traced function after that.
self._re_trace_counter += 1
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
input_batch = SampleBatch(
{
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)
},
_is_training=False)
if prev_action_batch is not None:
input_batch[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
if prev_reward_batch is not None:
input_batch[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if action_distribution_fn:
dist_inputs, dist_class, _ = action_distribution_fn(
self,
self.model,
input_batch,
explore=False,
is_training=False)
# Default log-likelihood calculation.
else:
dist_inputs, _ = self.model(input_batch, state_batches,
seq_lens)
dist_class = self.dist_class
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods
@override(Policy)
def apply_gradients(self, gradients):
self._apply_gradients(
zip([(tf.convert_to_tensor(g) if g is not None else None)
for g in gradients], self.model.trainable_variables()))
@override(Policy)
def get_exploration_state(self):
return convert_to_numpy(self.exploration.get_state())
@override(Policy)
def get_weights(self, as_dict=False):
variables = self.variables()
if as_dict:
return {v.name: v.numpy() for v in variables}
return [v.numpy() for v in variables]
@override(Policy)
def set_weights(self, weights):
variables = self.variables()
assert len(weights) == len(variables), (len(weights),
len(variables))
for v, w in zip(variables, weights):
v.assign(w)
@override(Policy)
def get_state(self):
state = super().get_state()
if self._optimizer and \
len(self._optimizer.variables()) > 0:
state["_optimizer_variables"] = \
self._optimizer.variables()
# Add exploration state.
state["_exploration_state"] = self.exploration.get_state()
return state
@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.get("_optimizer_variables", None)
if optimizer_vars and self._optimizer.variables():
logger.warning(
"Cannot restore an optimizer's state for tf eager! Keras "
"is not able to save the v1.x optimizers (from "
"tf.compat.v1.train) since they aren't compatible with "
"checkpoints.")
for opt_var, value in zip(self._optimizer.variables(),
optimizer_vars):
opt_var.assign(value)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(state=state["_exploration_state"])
# Then the Policy's (NN) weights.
super().set_state(state)
def variables(self):
"""Return the list of all savable variables for this policy."""
if isinstance(self.model, tf.keras.Model):
return self.model.variables
else:
return self.model.variables()
@override(Policy)
def is_recurrent(self):
return self._is_recurrent
@override(Policy)
def num_state_tensors(self):
return len(self._state_inputs)
@override(Policy)
def get_initial_state(self):
if hasattr(self, "model"):
return self.model.get_initial_state()
return []
def get_session(self):
return None # None implies eager
def get_placeholder(self, ph):
raise ValueError(
"get_placeholder() is not allowed in eager mode. Try using "
"rllib.utils.tf_utils.make_tf_callable() to write "
"functions that work in both graph and eager mode.")
def loss_initialized(self):
return self._loss_initialized
@override(Policy)
def export_model(self, export_dir):
pass
@override(Policy)
def export_checkpoint(self, export_dir):
pass
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, _, stats = self._compute_gradients_helper(
samples)
self._apply_gradients_helper(grads_and_vars)
return stats
def _get_is_training_placeholder(self):
return tf.convert_to_tensor(self._is_training)
@with_lock
def _compute_gradients(self, samples):
def _compute_gradients_helper(self, samples):
"""Computes and returns grads as eager tensors."""
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when
# calling the already traced function after that.
self._re_trace_counter += 1
# Gather all variables for which to calculate losses.
if isinstance(self.model, tf.keras.Model):
variables = self.model.trainable_variables
@ -823,9 +870,15 @@ def build_eager_tf_policy(
grads = [g for g, _ in grads_and_vars]
stats = self._stats(self, samples, grads)
return grads_and_vars, stats
return grads_and_vars, grads, stats
def _apply_gradients_helper(self, grads_and_vars):
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when
# calling the already traced function after that.
self._re_trace_counter += 1
def _apply_gradients(self, grads_and_vars):
if apply_gradients_fn:
if self.config["_tf_policy_handles_more_than_one_loss"]:
apply_gradients_fn(self, self._optimizers, grads_and_vars)

View file

@ -764,6 +764,12 @@ class Policy(metaclass=ABCMeta):
TensorType]]]): An optional stats function to be called after
the loss.
"""
# Signal Policy that currently we do not like to eager/jit trace
# any function calls. This is to be able to track, which columns
# in the dummy batch are accessed by the different function (e.g.
# loss) such that we can then adjust our view requirements.
self._no_tracing = True
sample_batch_size = max(self.batch_divisibility_req * 4, 32)
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
sample_batch_size)
@ -771,6 +777,9 @@ class Policy(metaclass=ABCMeta):
actions, state_outs, extra_outs = \
self.compute_actions_from_input_dict(
self._dummy_batch, explore=False)
for key, view_req in self.view_requirements.items():
if key not in self._dummy_batch.accessed_keys:
view_req.used_for_compute_actions = False
# Add all extra action outputs to view reqirements (these may be
# filtered out later again, if not needed for postprocessing or loss).
for key, value in extra_outs.items():
@ -806,7 +815,7 @@ class Policy(metaclass=ABCMeta):
# Switch on lazy to-tensor conversion on `postprocessed_batch`.
train_batch = self._lazy_tensor_dict(postprocessed_batch)
# Calling loss, so set `is_training` to True.
train_batch.is_training = True
train_batch.set_training(True)
if seq_lens is not None:
train_batch[SampleBatch.SEQ_LENS] = seq_lens
train_batch.count = self._dummy_batch.count
@ -817,6 +826,9 @@ class Policy(metaclass=ABCMeta):
if stats_fn is not None:
stats_fn(self, train_batch)
# Re-enable tracing.
self._no_tracing = False
# Add new columns automatically to view-reqs.
if auto_remove_unneeded_view_reqs:
# Add those needed for postprocessing and training.

View file

@ -46,21 +46,22 @@ def pad_batch_to_sequences_of_same_size(
Padding depends on episodes found in batch and `max_seq_len`.
Args:
batch (SampleBatch): The SampleBatch object. All values in here have
batch: The SampleBatch object. All values in here have
the shape [B, ...].
max_seq_len (int): The max. sequence length to use for chopping.
shuffle (bool): Whether to shuffle batch sequences. Shuffle may
max_seq_len: The max. sequence length to use for chopping.
shuffle: Whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
batch_divisibility_req (int): The int by which the batch dimension
batch_divisibility_req: The int by which the batch dimension
must be dividable.
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
feature_keys: An optional list of keys to apply sequence-chopping
to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
view_requirements (Optional[ViewRequirementsDict]): An optional
Policy ViewRequirements dict to be able to infer whether
e.g. dynamic max'ing should be applied over the seq_lens.
view_requirements: An optional Policy ViewRequirements dict to
be able to infer whether e.g. dynamic max'ing should be
applied over the seq_lens.
"""
# If already zero-padded, skip.
if batch.zero_padded:
return

View file

@ -91,7 +91,7 @@ class SampleBatch(dict):
# Is alredy right-zero-padded?
self.zero_padded = kwargs.pop("_zero_padded", False)
# Whether this batch is used for training (vs inference).
self.is_training = kwargs.pop("_is_training", None)
self._is_training = kwargs.pop("_is_training", None)
# Call super constructor. This will make the actual data accessible
# by column name (str) via e.g. self["some-col"].
@ -118,8 +118,8 @@ class SampleBatch(dict):
len(seq_lens_) > 0:
self.max_seq_len = max(seq_lens_)
if self.is_training is None:
self.is_training = self.pop("is_training", False)
if self._is_training is None:
self._is_training = self.pop("is_training", False)
lengths = []
copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
@ -271,6 +271,9 @@ class SampleBatch(dict):
)
copy_ = SampleBatch(data)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_
@PublicAPI
@ -501,6 +504,7 @@ class SampleBatch(dict):
return SampleBatch(
data,
seq_lens=seq_lens,
_is_training=self.is_training,
_time_major=self.time_major,
)
else:
@ -731,7 +735,7 @@ class SampleBatch(dict):
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False)
self.is_training = item
self._is_training = item
return
if key not in self:
@ -741,6 +745,20 @@ class SampleBatch(dict):
if key in self.intercepted_values:
self.intercepted_values[key] = item
@property
def is_training(self):
if self.get_interceptor is not None and \
isinstance(self._is_training, bool):
if "_is_training" not in self.intercepted_values:
self.intercepted_values["_is_training"] = \
self.get_interceptor(self._is_training)
return self.intercepted_values["_is_training"]
return self._is_training
def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
self._is_training = training
self.intercepted_values.pop("_is_training", None)
@PublicAPI
def __delitem__(self, key):
self.deleted_keys.add(key)

View file

@ -403,7 +403,7 @@ class TFPolicy(Policy):
assert self.loss_initialized()
# Switch on is_training flag in our batch.
postprocessed_batch.is_training = True
postprocessed_batch.set_training(True)
builder = TFRunBuilder(self.get_session(), "learn_on_batch")
@ -425,7 +425,7 @@ class TFPolicy(Policy):
Tuple[ModelGradients, Dict[str, TensorType]]:
assert self.loss_initialized()
# Switch on is_training flag in our batch.
postprocessed_batch.is_training = True
postprocessed_batch.set_training(True)
builder = TFRunBuilder(self.get_session(), "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@ -1124,7 +1124,7 @@ class TFPolicy(Policy):
# Mark the batch as "is_training" so the Model can use this
# information.
train_batch.is_training = True
train_batch.set_training(True)
# Build the feed dict from the batch.
feed_dict = {}

View file

@ -256,7 +256,7 @@ class TorchPolicy(Policy):
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
input_dict.is_training = False
input_dict.set_training(True)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
@ -424,7 +424,7 @@ class TorchPolicy(Policy):
buffer_index: int = 0,
) -> int:
# Set the is_training flag of the batch.
batch.is_training = True
batch.set_training(True)
# Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
if len(self.devices) == 1 and self.devices[0].type == "cpu":
@ -572,7 +572,7 @@ class TorchPolicy(Policy):
view_requirements=self.view_requirements,
)
postprocessed_batch.is_training = True
postprocessed_batch.set_training(True)
self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
# Do the (maybe parallelized) gradient calculation step.

View file

@ -7,12 +7,13 @@ import random
import re
import time
import tree # pip install dm_tree
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml
import ray
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.tune import run_experiments
jax, _ = try_import_jax()
@ -29,32 +30,37 @@ torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
def framework_iterator(config=None,
frameworks=("tf2", "tf", "tfe", "torch"),
session=False,
with_eager_tracing=False):
def framework_iterator(
config: Optional[PartialTrainerConfigDict] = None,
frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
session: bool = False,
with_eager_tracing: bool = False,
time_iterations: Optional[dict] = None,
) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
"""An generator that allows for looping through n frameworks for testing.
Provides the correct config entries ("framework") as well
as the correct eager/non-eager contexts for tfe/tf.
Args:
config (Optional[dict]): An optional config dict to alter in place
depending on the iteration.
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
config: An optional config dict to alter in place depending on the
iteration.
frameworks: A list/tuple of the frameworks to be tested.
Allowed are: "tf2", "tf", "tfe", "torch", and None.
session (bool): If True and only in the tf-case: Enter a tf.Session()
session: If True and only in the tf-case: Enter a tf.Session()
and yield that as second return value (otherwise yield (fw, None)).
Also sets a seed (42) on the session to make the test
deterministic.
with_eager_tracing: Include `eager_tracing=True` in the returned
configs, when framework=[tfe|tf2].
time_iterations: If provided, will write to the given dict (by
framework key) the times in seconds that each (framework's)
iteration takes.
Yields:
str: If enter_session is False:
The current framework ("tf2", "tf", "tfe", "torch") used.
Tuple(str, Union[None,tf.Session]: If enter_session is True:
A tuple of the current fw and the tf.Session if fw="tf".
If `session` is False: The current framework [tf2|tf|tfe|torch] used.
If `session` is True: A tuple consisting of the current framework
string and the tf1.Session (if fw="tf", otherwise None).
"""
config = config or {}
frameworks = [frameworks] if isinstance(frameworks, str) else \
@ -111,12 +117,24 @@ def framework_iterator(config=None,
for tracing in [True, False]:
config["eager_tracing"] = tracing
print(f"framework={fw} (eager-tracing={tracing})")
time_started = time.time()
yield fw if session is False else (fw, sess)
if time_iterations is not None:
time_total = time.time() - time_started
time_iterations[fw + ("+tracing" if tracing else "")] = \
time_total
print(f".. took {time_total}sec")
config["eager_tracing"] = False
# Yield current framework + tf-session (if necessary).
else:
print(f"framework={fw}")
time_started = time.time()
yield fw if session is False else (fw, sess)
if time_iterations is not None:
time_total = time.time() - time_started
time_iterations[fw + ("+tracing" if tracing else "")] = \
time_total
print(f".. took {time_total}sec")
# Exit any context we may have entered.
if eager_ctx: