From 8903bcd0c325f76f2642eb542140bdde5a94f7ac Mon Sep 17 00:00:00 2001 From: gehring Date: Tue, 17 Sep 2019 04:44:20 -0400 Subject: [PATCH] [rllib] Tracing for eager tensorflow policies with `tf.function` (#5705) * Added tracing of eager policies with `tf.function` * lint * add config option * add docs * wip * tracing now works with a3c * typo * none * file doc * returns * syntax error * syntax error --- doc/source/rllib-concepts.rst | 4 +- doc/source/rllib-training.rst | 6 +- doc/source/rllib.rst | 2 +- rllib/agents/trainer.py | 9 +- rllib/evaluation/rollout_worker.py | 2 + rllib/examples/custom_tf_policy.py | 4 +- rllib/policy/dynamic_tf_policy.py | 1 + rllib/policy/eager_tf_policy.py | 196 +++++++++++++++++++++++++---- rllib/policy/sample_batch.py | 4 + rllib/tests/test_eager_support.py | 5 + rllib/train.py | 8 ++ 11 files changed, 204 insertions(+), 37 deletions(-) diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index fc18ae056..47a13ec33 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -418,9 +418,9 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor Building Policies in TensorFlow Eager ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode. +Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode. -Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode. +Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled. You can also selectively leverage eager operations within graph mode execution with `tf.py_function `__. Here's an example of using eager ops embedded `within a loss function `__. diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index ad8d5201e..973aaa889 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -14,7 +14,7 @@ You can train a simple DQN trainer with the following command: .. code-block:: bash - rllib train --run DQN --env CartPole-v0 # add --eager for eager execution + rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution By default, the results will be logged to a subdirectory of ``~/ray_results``. This subdirectory will contain a file ``params.json`` which contains the @@ -544,9 +544,9 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res Eager Mode ~~~~~~~~~~ -Policies built with ``build_tf_policy`` can be also run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode. +Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode. -Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode. +Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled. Episode Traces ~~~~~~~~~~~~~~ diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 6dedc4a67..bdc47a108 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -25,7 +25,7 @@ Then, you can try out training in the following equivalent ways: .. code-block:: bash - rllib train --run=PPO --env=CartPole-v0 # add --eager for eager execution + rllib train --run=PPO --env=CartPole-v0 # --eager [--trace] for eager execution .. code-block:: python diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9d5cf84f0..a73938bbd 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -70,8 +70,12 @@ COMMON_CONFIG = { "ignore_worker_failures": False, # Log system resource metrics to results. "log_sys_usage": True, - # Enable TF eager execution (TF policies only) + # Enable TF eager execution (TF policies only). "eager": False, + # Enable tracing in eager mode. This greatly improves performance, but + # makes it slightly harder to debug since Python code won't be evaluated + # after the initial eager pass. + "eager_tracing": False, # Disable eager execution on workers (but allow it on the driver). This # only has an effect is eager is enabled. "no_eager_on_workers": False, @@ -333,7 +337,8 @@ class Trainer(Trainable): if tf and config.get("eager"): tf.enable_eager_execution() - logger.info("Executing eagerly") + logger.info("Executing eagerly, with eager_tracing={}".format( + "True" if config.get("eager_tracing") else "False")) if tf and not tf.executing_eagerly(): logger.info("Tip: set 'eager': true or the --eager flag to enable " diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index d42fb0d3b..aa04ae274 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -752,6 +752,8 @@ class RolloutWorker(EvaluatorInterface): if tf and tf.executing_eagerly(): if hasattr(cls, "as_eager"): cls = cls.as_eager() + if policy_config["eager_tracing"]: + cls = cls.with_tracing() elif not issubclass(cls, TFPolicy): pass # could be some other type of policy else: diff --git a/rllib/examples/custom_tf_policy.py b/rllib/examples/custom_tf_policy.py index fbde9201f..abeddd187 100644 --- a/rllib/examples/custom_tf_policy.py +++ b/rllib/examples/custom_tf_policy.py @@ -21,14 +21,14 @@ def policy_gradient_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) action_dist = dist_class(logits, model) return -tf.reduce_mean( - action_dist.logp(train_batch["actions"]) * train_batch["advantages"]) + action_dist.logp(train_batch["actions"]) * train_batch["returns"]) def calculate_advantages(policy, sample_batch, other_agent_batches=None, episode=None): - sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99) + sample_batch["returns"] = discount(sample_batch["rewards"], 0.99) return sample_batch diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 7911a5ae6..3a98e08f1 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +"""Graph mode TF policy built using build_tf_policy().""" from collections import OrderedDict import logging diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index b8150d3c4..63b96378e 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -1,8 +1,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +"""Eager mode TF policy built using build_tf_policy(). + +It supports both traced and non-traced eager execution modes.""" import logging +import functools import numpy as np from ray.rllib.evaluation.episode import _flatten_action @@ -19,6 +23,56 @@ tf = try_import_tf() logger = logging.getLogger(__name__) +def _convert_to_tf(x): + if isinstance(x, SampleBatch): + x = {k: v for k, v in x.items() if k != SampleBatch.INFOS} + return tf.nest.map_structure(_convert_to_tf, x) + if isinstance(x, Policy): + return x + + if x is not None: + x = tf.nest.map_structure(tf.convert_to_tensor, x) + return x + + +def _convert_to_numpy(x): + if x is None: + return None + try: + return x.numpy() + except AttributeError: + raise TypeError( + ("Object of type {} has no method to convert to numpy.").format( + type(x))) + + +def convert_eager_inputs(func): + @functools.wraps(func) + def _func(*args, **kwargs): + if tf.executing_eagerly(): + args = [_convert_to_tf(x) for x in args] + # TODO(gehring): find a way to remove specific hacks + kwargs = { + k: _convert_to_tf(v) + for k, v in kwargs.items() + if k not in {"info_batch", "episodes"} + } + return func(*args, **kwargs) + + return _func + + +def convert_eager_outputs(func): + @functools.wraps(func) + def _func(*args, **kwargs): + out = func(*args, **kwargs) + if tf.executing_eagerly(): + out = tf.nest.map_structure(_convert_to_numpy, out) + return out + + return _func + + def _disallow_var_creation(next_creator, **kw): v = next_creator(**kw) raise ValueError("Detected a variable being created during an eager " @@ -26,6 +80,86 @@ def _disallow_var_creation(next_creator, **kw): "model initialization: {}".format(v.name)) +def traced_eager_policy(eager_policy_cls): + """Wrapper that enables tracing for all eager policy methods. + + This is enabled by the --trace / "eager_tracing" config.""" + + class TracedEagerPolicy(eager_policy_cls): + def __init__(self, *args, **kwargs): + self._traced_learn_on_batch = None + self._traced_compute_actions = None + self._traced_compute_gradients = None + self._traced_apply_gradients = None + super(TracedEagerPolicy, self).__init__(*args, **kwargs) + + @override(Policy) + @convert_eager_inputs + @convert_eager_outputs + def learn_on_batch(self, samples): + + if self._traced_learn_on_batch is None: + self._traced_learn_on_batch = tf.function( + super(TracedEagerPolicy, self).learn_on_batch, + autograph=False) + + return self._traced_learn_on_batch(samples) + + @override(Policy) + @convert_eager_inputs + @convert_eager_outputs + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + + obs_batch = tf.convert_to_tensor(obs_batch) + state_batches = _convert_to_tf(state_batches) + prev_action_batch = _convert_to_tf(prev_action_batch) + prev_reward_batch = _convert_to_tf(prev_reward_batch) + + if self._traced_compute_actions is None: + self._traced_compute_actions = tf.function( + super(TracedEagerPolicy, self).compute_actions, + autograph=False) + + return self._traced_compute_actions( + obs_batch, state_batches, prev_action_batch, prev_reward_batch, + info_batch, episodes, **kwargs) + + @override(Policy) + @convert_eager_inputs + @convert_eager_outputs + def compute_gradients(self, samples): + + if self._traced_compute_gradients is None: + self._traced_compute_gradients = tf.function( + super(TracedEagerPolicy, self).compute_gradients, + autograph=False) + + return self._traced_compute_gradients(samples) + + @override(Policy) + @convert_eager_inputs + @convert_eager_outputs + def apply_gradients(self, grads): + + if self._traced_apply_gradients is None: + self._traced_apply_gradients = tf.function( + super(TracedEagerPolicy, self).apply_gradients, + autograph=False) + + return self._traced_apply_gradients(grads) + + TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + return TracedEagerPolicy + + def build_eager_tf_policy(name, loss_fn, get_default_config=None, @@ -133,6 +267,8 @@ def build_eager_tf_policy(name, return samples @override(Policy) + @convert_eager_inputs + @convert_eager_outputs def learn_on_batch(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) @@ -140,14 +276,17 @@ def build_eager_tf_policy(name, return stats @override(Policy) + @convert_eager_inputs + @convert_eager_outputs def compute_gradients(self, samples): with tf.variable_creator_scope(_disallow_var_creation): grads_and_vars, stats = self._compute_gradients(samples) grads = [g for g, v in grads_and_vars] - grads = [(g.numpy() if g is not None else None) for g in grads] return grads, stats @override(Policy) + @convert_eager_inputs + @convert_eager_outputs def compute_actions(self, obs_batch, state_batches, @@ -157,41 +296,46 @@ def build_eager_tf_policy(name, episodes=None, **kwargs): - assert tf.executing_eagerly() + # TODO: remove python side effect to cull sources of bugs. self._is_training = False + self._state_in = state_batches - self._seq_lens = tf.ones(len(obs_batch)) - self._input_dict = { + if tf.executing_eagerly(): + n = len(obs_batch) + else: + n = obs_batch.shape[0] + + seq_lens = tf.ones(n) + input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch), - "is_training": tf.convert_to_tensor(False), + "is_training": tf.constant(False), } if obs_include_prev_action_reward: - self._input_dict.update({ + input_dict.update({ SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( prev_action_batch), SampleBatch.PREV_REWARDS: tf.convert_to_tensor( prev_reward_batch), }) - self._state_in = state_batches + with tf.variable_creator_scope(_disallow_var_creation): - model_out, state_out = self.model( - self._input_dict, state_batches, self._seq_lens) + model_out, state_out = self.model(input_dict, state_batches, + seq_lens) if self.dist_class: action_dist = self.dist_class(model_out, self.model) - action = action_dist.sample().numpy() + action = action_dist.sample() logp = action_dist.sampled_action_logp() else: action, logp = action_sampler_fn( - self, self.model, self._input_dict, self.observation_space, + self, self.model, input_dict, self.observation_space, self.action_space, self.config) - action = action.numpy() fetches = {} if logp is not None: fetches.update({ - ACTION_PROB: tf.exp(logp).numpy(), - ACTION_LOGP: logp.numpy(), + ACTION_PROB: tf.exp(logp), + ACTION_LOGP: logp, }) if extra_action_fetches_fn: fetches.update(extra_action_fetches_fn(self)) @@ -248,14 +392,9 @@ def build_eager_tf_policy(name, self._is_training = True - samples = { - k: tf.convert_to_tensor(v) - for k, v in samples.items() if v.dtype != np.object - } - with tf.GradientTape(persistent=gradients_fn is not None) as tape: # TODO: set seq len and state in properly - self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS])) + self._seq_lens = tf.ones(samples[SampleBatch.CUR_OBS].shape[0]) self._state_in = [] model_out, _ = self.model(samples, self._state_in, self._seq_lens) @@ -288,23 +427,22 @@ def build_eager_tf_policy(name, return grads_and_vars, stats def _stats(self, outputs, samples, grads): - assert tf.executing_eagerly() + fetches = {} if stats_fn: fetches[LEARNER_STATS_KEY] = { - k: v.numpy() + k: v for k, v in stats_fn(outputs, samples).items() } else: fetches[LEARNER_STATS_KEY] = {} if extra_learn_fetches_fn: - fetches.update({ - k: v.numpy() - for k, v in extra_learn_fetches_fn(self).items() - }) + fetches.update( + {k: v + for k, v in extra_learn_fetches_fn(self).items()}) if grad_stats_fn: fetches.update({ - k: v.numpy() + k: v for k, v in grad_stats_fn(self, samples, grads).items() }) return fetches @@ -380,6 +518,10 @@ def build_eager_tf_policy(name, if stats_fn: stats_fn(self, postprocessed_batch) + @classmethod + def with_tracing(cls): + return traced_eager_policy(cls) + eager_policy_cls.__name__ = name + "_eager" eager_policy_cls.__qualname__ = name + "_eager" return eager_policy_cls diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 1f36ed6a2..fbd4b7e37 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -184,6 +184,10 @@ class SampleBatch(object): def items(self): return self.data.items() + @PublicAPI + def get(self, key): + return self.data.get(key) + @PublicAPI def __getitem__(self, key): return self.data[key] diff --git a/rllib/tests/test_eager_support.py b/rllib/tests/test_eager_support.py index 359926a2a..c7d0641d5 100644 --- a/rllib/tests/test_eager_support.py +++ b/rllib/tests/test_eager_support.py @@ -12,6 +12,11 @@ def check_support(alg, config): else: config["env"] = "CartPole-v0" a = get_agent_class(alg) + + config["eager_tracing"] = False + tune.run(a, config=config, stop={"training_iteration": 0}) + + config["eager_tracing"] = True tune.run(a, config=config, stop={"training_iteration": 0}) diff --git a/rllib/train.py b/rllib/train.py index 8e6b2b3bb..2e717d863 100755 --- a/rllib/train.py +++ b/rllib/train.py @@ -96,6 +96,10 @@ def create_parser(parser_creator=None): "--eager", action="store_true", help="Whether to attempt to enable TF eager execution.") + parser.add_argument( + "--trace", + action="store_true", + help="Whether to attempt to enable tracing for eager mode.") parser.add_argument( "--env", default=None, type=str, help="The gym environment to use.") parser.add_argument( @@ -146,6 +150,10 @@ def run(args, parser): parser.error("the following arguments are required: --env") if args.eager: exp["config"]["eager"] = True + if args.trace: + if not exp["config"].get("eager"): + raise ValueError("Must enable --eager to enable tracing.") + exp["config"]["eager_tracing"] = True if args.ray_num_nodes: cluster = Cluster()