[rllib] Tracing for eager tensorflow policies with tf.function (#5705)

* Added tracing of eager policies with `tf.function`

* lint

* add config option

* add docs

* wip

* tracing now works with a3c

* typo

* none

* file doc

* returns

* syntax error

* syntax error
This commit is contained in:
gehring 2019-09-17 04:44:20 -04:00 committed by Eric Liang
parent f74aaf2619
commit 8903bcd0c3
11 changed files with 204 additions and 37 deletions

View file

@ -418,9 +418,9 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
Building Policies in TensorFlow Eager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
You can also selectively leverage eager operations within graph mode execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager ops embedded `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.

View file

@ -14,7 +14,7 @@ You can train a simple DQN trainer with the following command:
.. code-block:: bash
rllib train --run DQN --env CartPole-v0 # add --eager for eager execution
rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution
By default, the results will be logged to a subdirectory of ``~/ray_results``.
This subdirectory will contain a file ``params.json`` which contains the
@ -544,9 +544,9 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res
Eager Mode
~~~~~~~~~~
Policies built with ``build_tf_policy`` can be also run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
Episode Traces
~~~~~~~~~~~~~~

View file

@ -25,7 +25,7 @@ Then, you can try out training in the following equivalent ways:
.. code-block:: bash
rllib train --run=PPO --env=CartPole-v0 # add --eager for eager execution
rllib train --run=PPO --env=CartPole-v0 # --eager [--trace] for eager execution
.. code-block:: python

View file

@ -70,8 +70,12 @@ COMMON_CONFIG = {
"ignore_worker_failures": False,
# Log system resource metrics to results.
"log_sys_usage": True,
# Enable TF eager execution (TF policies only)
# Enable TF eager execution (TF policies only).
"eager": False,
# Enable tracing in eager mode. This greatly improves performance, but
# makes it slightly harder to debug since Python code won't be evaluated
# after the initial eager pass.
"eager_tracing": False,
# Disable eager execution on workers (but allow it on the driver). This
# only has an effect is eager is enabled.
"no_eager_on_workers": False,
@ -333,7 +337,8 @@ class Trainer(Trainable):
if tf and config.get("eager"):
tf.enable_eager_execution()
logger.info("Executing eagerly")
logger.info("Executing eagerly, with eager_tracing={}".format(
"True" if config.get("eager_tracing") else "False"))
if tf and not tf.executing_eagerly():
logger.info("Tip: set 'eager': true or the --eager flag to enable "

View file

@ -752,6 +752,8 @@ class RolloutWorker(EvaluatorInterface):
if tf and tf.executing_eagerly():
if hasattr(cls, "as_eager"):
cls = cls.as_eager()
if policy_config["eager_tracing"]:
cls = cls.with_tracing()
elif not issubclass(cls, TFPolicy):
pass # could be some other type of policy
else:

View file

@ -21,14 +21,14 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["advantages"])
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
def calculate_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99)
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
return sample_batch

View file

@ -1,6 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Graph mode TF policy built using build_tf_policy()."""
from collections import OrderedDict
import logging

View file

@ -1,8 +1,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Eager mode TF policy built using build_tf_policy().
It supports both traced and non-traced eager execution modes."""
import logging
import functools
import numpy as np
from ray.rllib.evaluation.episode import _flatten_action
@ -19,6 +23,56 @@ tf = try_import_tf()
logger = logging.getLogger(__name__)
def _convert_to_tf(x):
if isinstance(x, SampleBatch):
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
return tf.nest.map_structure(_convert_to_tf, x)
if isinstance(x, Policy):
return x
if x is not None:
x = tf.nest.map_structure(tf.convert_to_tensor, x)
return x
def _convert_to_numpy(x):
if x is None:
return None
try:
return x.numpy()
except AttributeError:
raise TypeError(
("Object of type {} has no method to convert to numpy.").format(
type(x)))
def convert_eager_inputs(func):
@functools.wraps(func)
def _func(*args, **kwargs):
if tf.executing_eagerly():
args = [_convert_to_tf(x) for x in args]
# TODO(gehring): find a way to remove specific hacks
kwargs = {
k: _convert_to_tf(v)
for k, v in kwargs.items()
if k not in {"info_batch", "episodes"}
}
return func(*args, **kwargs)
return _func
def convert_eager_outputs(func):
@functools.wraps(func)
def _func(*args, **kwargs):
out = func(*args, **kwargs)
if tf.executing_eagerly():
out = tf.nest.map_structure(_convert_to_numpy, out)
return out
return _func
def _disallow_var_creation(next_creator, **kw):
v = next_creator(**kw)
raise ValueError("Detected a variable being created during an eager "
@ -26,6 +80,86 @@ def _disallow_var_creation(next_creator, **kw):
"model initialization: {}".format(v.name))
def traced_eager_policy(eager_policy_cls):
"""Wrapper that enables tracing for all eager policy methods.
This is enabled by the --trace / "eager_tracing" config."""
class TracedEagerPolicy(eager_policy_cls):
def __init__(self, *args, **kwargs):
self._traced_learn_on_batch = None
self._traced_compute_actions = None
self._traced_compute_gradients = None
self._traced_apply_gradients = None
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
if self._traced_learn_on_batch is None:
self._traced_learn_on_batch = tf.function(
super(TracedEagerPolicy, self).learn_on_batch,
autograph=False)
return self._traced_learn_on_batch(samples)
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
obs_batch = tf.convert_to_tensor(obs_batch)
state_batches = _convert_to_tf(state_batches)
prev_action_batch = _convert_to_tf(prev_action_batch)
prev_reward_batch = _convert_to_tf(prev_reward_batch)
if self._traced_compute_actions is None:
self._traced_compute_actions = tf.function(
super(TracedEagerPolicy, self).compute_actions,
autograph=False)
return self._traced_compute_actions(
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
info_batch, episodes, **kwargs)
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
if self._traced_compute_gradients is None:
self._traced_compute_gradients = tf.function(
super(TracedEagerPolicy, self).compute_gradients,
autograph=False)
return self._traced_compute_gradients(samples)
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def apply_gradients(self, grads):
if self._traced_apply_gradients is None:
self._traced_apply_gradients = tf.function(
super(TracedEagerPolicy, self).apply_gradients,
autograph=False)
return self._traced_apply_gradients(grads)
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
return TracedEagerPolicy
def build_eager_tf_policy(name,
loss_fn,
get_default_config=None,
@ -133,6 +267,8 @@ def build_eager_tf_policy(name,
return samples
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def learn_on_batch(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
@ -140,14 +276,17 @@ def build_eager_tf_policy(name,
return stats
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def compute_gradients(self, samples):
with tf.variable_creator_scope(_disallow_var_creation):
grads_and_vars, stats = self._compute_gradients(samples)
grads = [g for g, v in grads_and_vars]
grads = [(g.numpy() if g is not None else None) for g in grads]
return grads, stats
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
def compute_actions(self,
obs_batch,
state_batches,
@ -157,41 +296,46 @@ def build_eager_tf_policy(name,
episodes=None,
**kwargs):
assert tf.executing_eagerly()
# TODO: remove python side effect to cull sources of bugs.
self._is_training = False
self._state_in = state_batches
self._seq_lens = tf.ones(len(obs_batch))
self._input_dict = {
if tf.executing_eagerly():
n = len(obs_batch)
else:
n = obs_batch.shape[0]
seq_lens = tf.ones(n)
input_dict = {
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
"is_training": tf.convert_to_tensor(False),
"is_training": tf.constant(False),
}
if obs_include_prev_action_reward:
self._input_dict.update({
input_dict.update({
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
prev_action_batch),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
prev_reward_batch),
})
self._state_in = state_batches
with tf.variable_creator_scope(_disallow_var_creation):
model_out, state_out = self.model(
self._input_dict, state_batches, self._seq_lens)
model_out, state_out = self.model(input_dict, state_batches,
seq_lens)
if self.dist_class:
action_dist = self.dist_class(model_out, self.model)
action = action_dist.sample().numpy()
action = action_dist.sample()
logp = action_dist.sampled_action_logp()
else:
action, logp = action_sampler_fn(
self, self.model, self._input_dict, self.observation_space,
self, self.model, input_dict, self.observation_space,
self.action_space, self.config)
action = action.numpy()
fetches = {}
if logp is not None:
fetches.update({
ACTION_PROB: tf.exp(logp).numpy(),
ACTION_LOGP: logp.numpy(),
ACTION_PROB: tf.exp(logp),
ACTION_LOGP: logp,
})
if extra_action_fetches_fn:
fetches.update(extra_action_fetches_fn(self))
@ -248,14 +392,9 @@ def build_eager_tf_policy(name,
self._is_training = True
samples = {
k: tf.convert_to_tensor(v)
for k, v in samples.items() if v.dtype != np.object
}
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
# TODO: set seq len and state in properly
self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS]))
self._seq_lens = tf.ones(samples[SampleBatch.CUR_OBS].shape[0])
self._state_in = []
model_out, _ = self.model(samples, self._state_in,
self._seq_lens)
@ -288,23 +427,22 @@ def build_eager_tf_policy(name,
return grads_and_vars, stats
def _stats(self, outputs, samples, grads):
assert tf.executing_eagerly()
fetches = {}
if stats_fn:
fetches[LEARNER_STATS_KEY] = {
k: v.numpy()
k: v
for k, v in stats_fn(outputs, samples).items()
}
else:
fetches[LEARNER_STATS_KEY] = {}
if extra_learn_fetches_fn:
fetches.update({
k: v.numpy()
for k, v in extra_learn_fetches_fn(self).items()
})
fetches.update(
{k: v
for k, v in extra_learn_fetches_fn(self).items()})
if grad_stats_fn:
fetches.update({
k: v.numpy()
k: v
for k, v in grad_stats_fn(self, samples, grads).items()
})
return fetches
@ -380,6 +518,10 @@ def build_eager_tf_policy(name,
if stats_fn:
stats_fn(self, postprocessed_batch)
@classmethod
def with_tracing(cls):
return traced_eager_policy(cls)
eager_policy_cls.__name__ = name + "_eager"
eager_policy_cls.__qualname__ = name + "_eager"
return eager_policy_cls

View file

@ -184,6 +184,10 @@ class SampleBatch(object):
def items(self):
return self.data.items()
@PublicAPI
def get(self, key):
return self.data.get(key)
@PublicAPI
def __getitem__(self, key):
return self.data[key]

View file

@ -12,6 +12,11 @@ def check_support(alg, config):
else:
config["env"] = "CartPole-v0"
a = get_agent_class(alg)
config["eager_tracing"] = False
tune.run(a, config=config, stop={"training_iteration": 0})
config["eager_tracing"] = True
tune.run(a, config=config, stop={"training_iteration": 0})

View file

@ -96,6 +96,10 @@ def create_parser(parser_creator=None):
"--eager",
action="store_true",
help="Whether to attempt to enable TF eager execution.")
parser.add_argument(
"--trace",
action="store_true",
help="Whether to attempt to enable tracing for eager mode.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
@ -146,6 +150,10 @@ def run(args, parser):
parser.error("the following arguments are required: --env")
if args.eager:
exp["config"]["eager"] = True
if args.trace:
if not exp["config"].get("eager"):
raise ValueError("Must enable --eager to enable tracing.")
exp["config"]["eager_tracing"] = True
if args.ray_num_nodes:
cluster = Cluster()