mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] Tracing for eager tensorflow policies with tf.function
(#5705)
* Added tracing of eager policies with `tf.function` * lint * add config option * add docs * wip * tracing now works with a3c * typo * none * file doc * returns * syntax error * syntax error
This commit is contained in:
parent
f74aaf2619
commit
8903bcd0c3
11 changed files with 204 additions and 37 deletions
|
@ -418,9 +418,9 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
|
|||
Building Policies in TensorFlow Eager
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||
|
||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
|
||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
|
||||
|
||||
You can also selectively leverage eager operations within graph mode execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager ops embedded `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ You can train a simple DQN trainer with the following command:
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
rllib train --run DQN --env CartPole-v0 # add --eager for eager execution
|
||||
rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution
|
||||
|
||||
By default, the results will be logged to a subdirectory of ``~/ray_results``.
|
||||
This subdirectory will contain a file ``params.json`` which contains the
|
||||
|
@ -544,9 +544,9 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res
|
|||
Eager Mode
|
||||
~~~~~~~~~~
|
||||
|
||||
Policies built with ``build_tf_policy`` can be also run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||
|
||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
|
||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
|
||||
|
||||
Episode Traces
|
||||
~~~~~~~~~~~~~~
|
||||
|
|
|
@ -25,7 +25,7 @@ Then, you can try out training in the following equivalent ways:
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
rllib train --run=PPO --env=CartPole-v0 # add --eager for eager execution
|
||||
rllib train --run=PPO --env=CartPole-v0 # --eager [--trace] for eager execution
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -70,8 +70,12 @@ COMMON_CONFIG = {
|
|||
"ignore_worker_failures": False,
|
||||
# Log system resource metrics to results.
|
||||
"log_sys_usage": True,
|
||||
# Enable TF eager execution (TF policies only)
|
||||
# Enable TF eager execution (TF policies only).
|
||||
"eager": False,
|
||||
# Enable tracing in eager mode. This greatly improves performance, but
|
||||
# makes it slightly harder to debug since Python code won't be evaluated
|
||||
# after the initial eager pass.
|
||||
"eager_tracing": False,
|
||||
# Disable eager execution on workers (but allow it on the driver). This
|
||||
# only has an effect is eager is enabled.
|
||||
"no_eager_on_workers": False,
|
||||
|
@ -333,7 +337,8 @@ class Trainer(Trainable):
|
|||
|
||||
if tf and config.get("eager"):
|
||||
tf.enable_eager_execution()
|
||||
logger.info("Executing eagerly")
|
||||
logger.info("Executing eagerly, with eager_tracing={}".format(
|
||||
"True" if config.get("eager_tracing") else "False"))
|
||||
|
||||
if tf and not tf.executing_eagerly():
|
||||
logger.info("Tip: set 'eager': true or the --eager flag to enable "
|
||||
|
|
|
@ -752,6 +752,8 @@ class RolloutWorker(EvaluatorInterface):
|
|||
if tf and tf.executing_eagerly():
|
||||
if hasattr(cls, "as_eager"):
|
||||
cls = cls.as_eager()
|
||||
if policy_config["eager_tracing"]:
|
||||
cls = cls.with_tracing()
|
||||
elif not issubclass(cls, TFPolicy):
|
||||
pass # could be some other type of policy
|
||||
else:
|
||||
|
|
|
@ -21,14 +21,14 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
|
|||
logits, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
return -tf.reduce_mean(
|
||||
action_dist.logp(train_batch["actions"]) * train_batch["advantages"])
|
||||
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
|
||||
|
||||
|
||||
def calculate_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99)
|
||||
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
|
||||
return sample_batch
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Graph mode TF policy built using build_tf_policy()."""
|
||||
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Eager mode TF policy built using build_tf_policy().
|
||||
|
||||
It supports both traced and non-traced eager execution modes."""
|
||||
|
||||
import logging
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
|
@ -19,6 +23,56 @@ tf = try_import_tf()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_to_tf(x):
|
||||
if isinstance(x, SampleBatch):
|
||||
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
||||
return tf.nest.map_structure(_convert_to_tf, x)
|
||||
if isinstance(x, Policy):
|
||||
return x
|
||||
|
||||
if x is not None:
|
||||
x = tf.nest.map_structure(tf.convert_to_tensor, x)
|
||||
return x
|
||||
|
||||
|
||||
def _convert_to_numpy(x):
|
||||
if x is None:
|
||||
return None
|
||||
try:
|
||||
return x.numpy()
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
("Object of type {} has no method to convert to numpy.").format(
|
||||
type(x)))
|
||||
|
||||
|
||||
def convert_eager_inputs(func):
|
||||
@functools.wraps(func)
|
||||
def _func(*args, **kwargs):
|
||||
if tf.executing_eagerly():
|
||||
args = [_convert_to_tf(x) for x in args]
|
||||
# TODO(gehring): find a way to remove specific hacks
|
||||
kwargs = {
|
||||
k: _convert_to_tf(v)
|
||||
for k, v in kwargs.items()
|
||||
if k not in {"info_batch", "episodes"}
|
||||
}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
|
||||
def convert_eager_outputs(func):
|
||||
@functools.wraps(func)
|
||||
def _func(*args, **kwargs):
|
||||
out = func(*args, **kwargs)
|
||||
if tf.executing_eagerly():
|
||||
out = tf.nest.map_structure(_convert_to_numpy, out)
|
||||
return out
|
||||
|
||||
return _func
|
||||
|
||||
|
||||
def _disallow_var_creation(next_creator, **kw):
|
||||
v = next_creator(**kw)
|
||||
raise ValueError("Detected a variable being created during an eager "
|
||||
|
@ -26,6 +80,86 @@ def _disallow_var_creation(next_creator, **kw):
|
|||
"model initialization: {}".format(v.name))
|
||||
|
||||
|
||||
def traced_eager_policy(eager_policy_cls):
|
||||
"""Wrapper that enables tracing for all eager policy methods.
|
||||
|
||||
This is enabled by the --trace / "eager_tracing" config."""
|
||||
|
||||
class TracedEagerPolicy(eager_policy_cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._traced_learn_on_batch = None
|
||||
self._traced_compute_actions = None
|
||||
self._traced_compute_gradients = None
|
||||
self._traced_apply_gradients = None
|
||||
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def learn_on_batch(self, samples):
|
||||
|
||||
if self._traced_learn_on_batch is None:
|
||||
self._traced_learn_on_batch = tf.function(
|
||||
super(TracedEagerPolicy, self).learn_on_batch,
|
||||
autograph=False)
|
||||
|
||||
return self._traced_learn_on_batch(samples)
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
|
||||
obs_batch = tf.convert_to_tensor(obs_batch)
|
||||
state_batches = _convert_to_tf(state_batches)
|
||||
prev_action_batch = _convert_to_tf(prev_action_batch)
|
||||
prev_reward_batch = _convert_to_tf(prev_reward_batch)
|
||||
|
||||
if self._traced_compute_actions is None:
|
||||
self._traced_compute_actions = tf.function(
|
||||
super(TracedEagerPolicy, self).compute_actions,
|
||||
autograph=False)
|
||||
|
||||
return self._traced_compute_actions(
|
||||
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
|
||||
info_batch, episodes, **kwargs)
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def compute_gradients(self, samples):
|
||||
|
||||
if self._traced_compute_gradients is None:
|
||||
self._traced_compute_gradients = tf.function(
|
||||
super(TracedEagerPolicy, self).compute_gradients,
|
||||
autograph=False)
|
||||
|
||||
return self._traced_compute_gradients(samples)
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def apply_gradients(self, grads):
|
||||
|
||||
if self._traced_apply_gradients is None:
|
||||
self._traced_apply_gradients = tf.function(
|
||||
super(TracedEagerPolicy, self).apply_gradients,
|
||||
autograph=False)
|
||||
|
||||
return self._traced_apply_gradients(grads)
|
||||
|
||||
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
|
||||
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
|
||||
return TracedEagerPolicy
|
||||
|
||||
|
||||
def build_eager_tf_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
|
@ -133,6 +267,8 @@ def build_eager_tf_policy(name,
|
|||
return samples
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def learn_on_batch(self, samples):
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
grads_and_vars, stats = self._compute_gradients(samples)
|
||||
|
@ -140,14 +276,17 @@ def build_eager_tf_policy(name,
|
|||
return stats
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def compute_gradients(self, samples):
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
grads_and_vars, stats = self._compute_gradients(samples)
|
||||
grads = [g for g, v in grads_and_vars]
|
||||
grads = [(g.numpy() if g is not None else None) for g in grads]
|
||||
return grads, stats
|
||||
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
|
@ -157,41 +296,46 @@ def build_eager_tf_policy(name,
|
|||
episodes=None,
|
||||
**kwargs):
|
||||
|
||||
assert tf.executing_eagerly()
|
||||
# TODO: remove python side effect to cull sources of bugs.
|
||||
self._is_training = False
|
||||
self._state_in = state_batches
|
||||
|
||||
self._seq_lens = tf.ones(len(obs_batch))
|
||||
self._input_dict = {
|
||||
if tf.executing_eagerly():
|
||||
n = len(obs_batch)
|
||||
else:
|
||||
n = obs_batch.shape[0]
|
||||
|
||||
seq_lens = tf.ones(n)
|
||||
input_dict = {
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||
"is_training": tf.convert_to_tensor(False),
|
||||
"is_training": tf.constant(False),
|
||||
}
|
||||
if obs_include_prev_action_reward:
|
||||
self._input_dict.update({
|
||||
input_dict.update({
|
||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||
prev_action_batch),
|
||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||
prev_reward_batch),
|
||||
})
|
||||
self._state_in = state_batches
|
||||
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
model_out, state_out = self.model(
|
||||
self._input_dict, state_batches, self._seq_lens)
|
||||
model_out, state_out = self.model(input_dict, state_batches,
|
||||
seq_lens)
|
||||
|
||||
if self.dist_class:
|
||||
action_dist = self.dist_class(model_out, self.model)
|
||||
action = action_dist.sample().numpy()
|
||||
action = action_dist.sample()
|
||||
logp = action_dist.sampled_action_logp()
|
||||
else:
|
||||
action, logp = action_sampler_fn(
|
||||
self, self.model, self._input_dict, self.observation_space,
|
||||
self, self.model, input_dict, self.observation_space,
|
||||
self.action_space, self.config)
|
||||
action = action.numpy()
|
||||
|
||||
fetches = {}
|
||||
if logp is not None:
|
||||
fetches.update({
|
||||
ACTION_PROB: tf.exp(logp).numpy(),
|
||||
ACTION_LOGP: logp.numpy(),
|
||||
ACTION_PROB: tf.exp(logp),
|
||||
ACTION_LOGP: logp,
|
||||
})
|
||||
if extra_action_fetches_fn:
|
||||
fetches.update(extra_action_fetches_fn(self))
|
||||
|
@ -248,14 +392,9 @@ def build_eager_tf_policy(name,
|
|||
|
||||
self._is_training = True
|
||||
|
||||
samples = {
|
||||
k: tf.convert_to_tensor(v)
|
||||
for k, v in samples.items() if v.dtype != np.object
|
||||
}
|
||||
|
||||
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
||||
# TODO: set seq len and state in properly
|
||||
self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS]))
|
||||
self._seq_lens = tf.ones(samples[SampleBatch.CUR_OBS].shape[0])
|
||||
self._state_in = []
|
||||
model_out, _ = self.model(samples, self._state_in,
|
||||
self._seq_lens)
|
||||
|
@ -288,23 +427,22 @@ def build_eager_tf_policy(name,
|
|||
return grads_and_vars, stats
|
||||
|
||||
def _stats(self, outputs, samples, grads):
|
||||
assert tf.executing_eagerly()
|
||||
|
||||
fetches = {}
|
||||
if stats_fn:
|
||||
fetches[LEARNER_STATS_KEY] = {
|
||||
k: v.numpy()
|
||||
k: v
|
||||
for k, v in stats_fn(outputs, samples).items()
|
||||
}
|
||||
else:
|
||||
fetches[LEARNER_STATS_KEY] = {}
|
||||
if extra_learn_fetches_fn:
|
||||
fetches.update({
|
||||
k: v.numpy()
|
||||
for k, v in extra_learn_fetches_fn(self).items()
|
||||
})
|
||||
fetches.update(
|
||||
{k: v
|
||||
for k, v in extra_learn_fetches_fn(self).items()})
|
||||
if grad_stats_fn:
|
||||
fetches.update({
|
||||
k: v.numpy()
|
||||
k: v
|
||||
for k, v in grad_stats_fn(self, samples, grads).items()
|
||||
})
|
||||
return fetches
|
||||
|
@ -380,6 +518,10 @@ def build_eager_tf_policy(name,
|
|||
if stats_fn:
|
||||
stats_fn(self, postprocessed_batch)
|
||||
|
||||
@classmethod
|
||||
def with_tracing(cls):
|
||||
return traced_eager_policy(cls)
|
||||
|
||||
eager_policy_cls.__name__ = name + "_eager"
|
||||
eager_policy_cls.__qualname__ = name + "_eager"
|
||||
return eager_policy_cls
|
||||
|
|
|
@ -184,6 +184,10 @@ class SampleBatch(object):
|
|||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
@PublicAPI
|
||||
def get(self, key):
|
||||
return self.data.get(key)
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
|
|
@ -12,6 +12,11 @@ def check_support(alg, config):
|
|||
else:
|
||||
config["env"] = "CartPole-v0"
|
||||
a = get_agent_class(alg)
|
||||
|
||||
config["eager_tracing"] = False
|
||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||
|
||||
config["eager_tracing"] = True
|
||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||
|
||||
|
||||
|
|
|
@ -96,6 +96,10 @@ def create_parser(parser_creator=None):
|
|||
"--eager",
|
||||
action="store_true",
|
||||
help="Whether to attempt to enable TF eager execution.")
|
||||
parser.add_argument(
|
||||
"--trace",
|
||||
action="store_true",
|
||||
help="Whether to attempt to enable tracing for eager mode.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
|
@ -146,6 +150,10 @@ def run(args, parser):
|
|||
parser.error("the following arguments are required: --env")
|
||||
if args.eager:
|
||||
exp["config"]["eager"] = True
|
||||
if args.trace:
|
||||
if not exp["config"].get("eager"):
|
||||
raise ValueError("Must enable --eager to enable tracing.")
|
||||
exp["config"]["eager_tracing"] = True
|
||||
|
||||
if args.ray_num_nodes:
|
||||
cluster = Cluster()
|
||||
|
|
Loading…
Add table
Reference in a new issue