mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Tracing for eager tensorflow policies with tf.function
(#5705)
* Added tracing of eager policies with `tf.function` * lint * add config option * add docs * wip * tracing now works with a3c * typo * none * file doc * returns * syntax error * syntax error
This commit is contained in:
parent
f74aaf2619
commit
8903bcd0c3
11 changed files with 204 additions and 37 deletions
|
@ -418,9 +418,9 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
|
||||||
Building Policies in TensorFlow Eager
|
Building Policies in TensorFlow Eager
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||||
|
|
||||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
|
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
|
||||||
|
|
||||||
You can also selectively leverage eager operations within graph mode execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager ops embedded `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.
|
You can also selectively leverage eager operations within graph mode execution with `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager ops embedded `within a loss function <https://github.com/ray-project/ray/blob/master/rllib/examples/eager_execution.py>`__.
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ You can train a simple DQN trainer with the following command:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
rllib train --run DQN --env CartPole-v0 # add --eager for eager execution
|
rllib train --run DQN --env CartPole-v0 # --eager [--trace] for eager execution
|
||||||
|
|
||||||
By default, the results will be logged to a subdirectory of ``~/ray_results``.
|
By default, the results will be logged to a subdirectory of ``~/ray_results``.
|
||||||
This subdirectory will contain a file ``params.json`` which contains the
|
This subdirectory will contain a file ``params.json`` which contains the
|
||||||
|
@ -544,9 +544,9 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res
|
||||||
Eager Mode
|
Eager Mode
|
||||||
~~~~~~~~~~
|
~~~~~~~~~~
|
||||||
|
|
||||||
Policies built with ``build_tf_policy`` can be also run in eager mode by setting the ``"eager": True`` config option or using ``rllib train --eager``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
Policies built with ``build_tf_policy`` (most of the reference algorithms are) can be run in eager mode by setting the ``"eager": True`` / ``"eager_tracing": True`` config options or using ``rllib train --eager [--trace]``. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.
|
||||||
|
|
||||||
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it is slower than graph mode.
|
Eager mode makes debugging much easier, since you can now use normal Python functions such as ``print()`` to inspect intermediate tensor values. However, it can be slower than graph mode unless tracing is enabled.
|
||||||
|
|
||||||
Episode Traces
|
Episode Traces
|
||||||
~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~
|
||||||
|
|
|
@ -25,7 +25,7 @@ Then, you can try out training in the following equivalent ways:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
rllib train --run=PPO --env=CartPole-v0 # add --eager for eager execution
|
rllib train --run=PPO --env=CartPole-v0 # --eager [--trace] for eager execution
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
|
@ -70,8 +70,12 @@ COMMON_CONFIG = {
|
||||||
"ignore_worker_failures": False,
|
"ignore_worker_failures": False,
|
||||||
# Log system resource metrics to results.
|
# Log system resource metrics to results.
|
||||||
"log_sys_usage": True,
|
"log_sys_usage": True,
|
||||||
# Enable TF eager execution (TF policies only)
|
# Enable TF eager execution (TF policies only).
|
||||||
"eager": False,
|
"eager": False,
|
||||||
|
# Enable tracing in eager mode. This greatly improves performance, but
|
||||||
|
# makes it slightly harder to debug since Python code won't be evaluated
|
||||||
|
# after the initial eager pass.
|
||||||
|
"eager_tracing": False,
|
||||||
# Disable eager execution on workers (but allow it on the driver). This
|
# Disable eager execution on workers (but allow it on the driver). This
|
||||||
# only has an effect is eager is enabled.
|
# only has an effect is eager is enabled.
|
||||||
"no_eager_on_workers": False,
|
"no_eager_on_workers": False,
|
||||||
|
@ -333,7 +337,8 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
if tf and config.get("eager"):
|
if tf and config.get("eager"):
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
logger.info("Executing eagerly")
|
logger.info("Executing eagerly, with eager_tracing={}".format(
|
||||||
|
"True" if config.get("eager_tracing") else "False"))
|
||||||
|
|
||||||
if tf and not tf.executing_eagerly():
|
if tf and not tf.executing_eagerly():
|
||||||
logger.info("Tip: set 'eager': true or the --eager flag to enable "
|
logger.info("Tip: set 'eager': true or the --eager flag to enable "
|
||||||
|
|
|
@ -752,6 +752,8 @@ class RolloutWorker(EvaluatorInterface):
|
||||||
if tf and tf.executing_eagerly():
|
if tf and tf.executing_eagerly():
|
||||||
if hasattr(cls, "as_eager"):
|
if hasattr(cls, "as_eager"):
|
||||||
cls = cls.as_eager()
|
cls = cls.as_eager()
|
||||||
|
if policy_config["eager_tracing"]:
|
||||||
|
cls = cls.with_tracing()
|
||||||
elif not issubclass(cls, TFPolicy):
|
elif not issubclass(cls, TFPolicy):
|
||||||
pass # could be some other type of policy
|
pass # could be some other type of policy
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -21,14 +21,14 @@ def policy_gradient_loss(policy, model, dist_class, train_batch):
|
||||||
logits, _ = model.from_batch(train_batch)
|
logits, _ = model.from_batch(train_batch)
|
||||||
action_dist = dist_class(logits, model)
|
action_dist = dist_class(logits, model)
|
||||||
return -tf.reduce_mean(
|
return -tf.reduce_mean(
|
||||||
action_dist.logp(train_batch["actions"]) * train_batch["advantages"])
|
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
|
||||||
|
|
||||||
|
|
||||||
def calculate_advantages(policy,
|
def calculate_advantages(policy,
|
||||||
sample_batch,
|
sample_batch,
|
||||||
other_agent_batches=None,
|
other_agent_batches=None,
|
||||||
episode=None):
|
episode=None):
|
||||||
sample_batch["advantages"] = discount(sample_batch["rewards"], 0.99)
|
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
|
||||||
return sample_batch
|
return sample_batch
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
"""Graph mode TF policy built using build_tf_policy()."""
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
"""Eager mode TF policy built using build_tf_policy().
|
||||||
|
|
||||||
|
It supports both traced and non-traced eager execution modes."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import functools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ray.rllib.evaluation.episode import _flatten_action
|
from ray.rllib.evaluation.episode import _flatten_action
|
||||||
|
@ -19,6 +23,56 @@ tf = try_import_tf()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_tf(x):
|
||||||
|
if isinstance(x, SampleBatch):
|
||||||
|
x = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
|
||||||
|
return tf.nest.map_structure(_convert_to_tf, x)
|
||||||
|
if isinstance(x, Policy):
|
||||||
|
return x
|
||||||
|
|
||||||
|
if x is not None:
|
||||||
|
x = tf.nest.map_structure(tf.convert_to_tensor, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_numpy(x):
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return x.numpy()
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(
|
||||||
|
("Object of type {} has no method to convert to numpy.").format(
|
||||||
|
type(x)))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_eager_inputs(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _func(*args, **kwargs):
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
args = [_convert_to_tf(x) for x in args]
|
||||||
|
# TODO(gehring): find a way to remove specific hacks
|
||||||
|
kwargs = {
|
||||||
|
k: _convert_to_tf(v)
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k not in {"info_batch", "episodes"}
|
||||||
|
}
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return _func
|
||||||
|
|
||||||
|
|
||||||
|
def convert_eager_outputs(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _func(*args, **kwargs):
|
||||||
|
out = func(*args, **kwargs)
|
||||||
|
if tf.executing_eagerly():
|
||||||
|
out = tf.nest.map_structure(_convert_to_numpy, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return _func
|
||||||
|
|
||||||
|
|
||||||
def _disallow_var_creation(next_creator, **kw):
|
def _disallow_var_creation(next_creator, **kw):
|
||||||
v = next_creator(**kw)
|
v = next_creator(**kw)
|
||||||
raise ValueError("Detected a variable being created during an eager "
|
raise ValueError("Detected a variable being created during an eager "
|
||||||
|
@ -26,6 +80,86 @@ def _disallow_var_creation(next_creator, **kw):
|
||||||
"model initialization: {}".format(v.name))
|
"model initialization: {}".format(v.name))
|
||||||
|
|
||||||
|
|
||||||
|
def traced_eager_policy(eager_policy_cls):
|
||||||
|
"""Wrapper that enables tracing for all eager policy methods.
|
||||||
|
|
||||||
|
This is enabled by the --trace / "eager_tracing" config."""
|
||||||
|
|
||||||
|
class TracedEagerPolicy(eager_policy_cls):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._traced_learn_on_batch = None
|
||||||
|
self._traced_compute_actions = None
|
||||||
|
self._traced_compute_gradients = None
|
||||||
|
self._traced_apply_gradients = None
|
||||||
|
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
|
def learn_on_batch(self, samples):
|
||||||
|
|
||||||
|
if self._traced_learn_on_batch is None:
|
||||||
|
self._traced_learn_on_batch = tf.function(
|
||||||
|
super(TracedEagerPolicy, self).learn_on_batch,
|
||||||
|
autograph=False)
|
||||||
|
|
||||||
|
return self._traced_learn_on_batch(samples)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
|
def compute_actions(self,
|
||||||
|
obs_batch,
|
||||||
|
state_batches,
|
||||||
|
prev_action_batch=None,
|
||||||
|
prev_reward_batch=None,
|
||||||
|
info_batch=None,
|
||||||
|
episodes=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
obs_batch = tf.convert_to_tensor(obs_batch)
|
||||||
|
state_batches = _convert_to_tf(state_batches)
|
||||||
|
prev_action_batch = _convert_to_tf(prev_action_batch)
|
||||||
|
prev_reward_batch = _convert_to_tf(prev_reward_batch)
|
||||||
|
|
||||||
|
if self._traced_compute_actions is None:
|
||||||
|
self._traced_compute_actions = tf.function(
|
||||||
|
super(TracedEagerPolicy, self).compute_actions,
|
||||||
|
autograph=False)
|
||||||
|
|
||||||
|
return self._traced_compute_actions(
|
||||||
|
obs_batch, state_batches, prev_action_batch, prev_reward_batch,
|
||||||
|
info_batch, episodes, **kwargs)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
|
def compute_gradients(self, samples):
|
||||||
|
|
||||||
|
if self._traced_compute_gradients is None:
|
||||||
|
self._traced_compute_gradients = tf.function(
|
||||||
|
super(TracedEagerPolicy, self).compute_gradients,
|
||||||
|
autograph=False)
|
||||||
|
|
||||||
|
return self._traced_compute_gradients(samples)
|
||||||
|
|
||||||
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
|
def apply_gradients(self, grads):
|
||||||
|
|
||||||
|
if self._traced_apply_gradients is None:
|
||||||
|
self._traced_apply_gradients = tf.function(
|
||||||
|
super(TracedEagerPolicy, self).apply_gradients,
|
||||||
|
autograph=False)
|
||||||
|
|
||||||
|
return self._traced_apply_gradients(grads)
|
||||||
|
|
||||||
|
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
|
||||||
|
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
|
||||||
|
return TracedEagerPolicy
|
||||||
|
|
||||||
|
|
||||||
def build_eager_tf_policy(name,
|
def build_eager_tf_policy(name,
|
||||||
loss_fn,
|
loss_fn,
|
||||||
get_default_config=None,
|
get_default_config=None,
|
||||||
|
@ -133,6 +267,8 @@ def build_eager_tf_policy(name,
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
def learn_on_batch(self, samples):
|
def learn_on_batch(self, samples):
|
||||||
with tf.variable_creator_scope(_disallow_var_creation):
|
with tf.variable_creator_scope(_disallow_var_creation):
|
||||||
grads_and_vars, stats = self._compute_gradients(samples)
|
grads_and_vars, stats = self._compute_gradients(samples)
|
||||||
|
@ -140,14 +276,17 @@ def build_eager_tf_policy(name,
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
def compute_gradients(self, samples):
|
def compute_gradients(self, samples):
|
||||||
with tf.variable_creator_scope(_disallow_var_creation):
|
with tf.variable_creator_scope(_disallow_var_creation):
|
||||||
grads_and_vars, stats = self._compute_gradients(samples)
|
grads_and_vars, stats = self._compute_gradients(samples)
|
||||||
grads = [g for g, v in grads_and_vars]
|
grads = [g for g, v in grads_and_vars]
|
||||||
grads = [(g.numpy() if g is not None else None) for g in grads]
|
|
||||||
return grads, stats
|
return grads, stats
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
|
@convert_eager_inputs
|
||||||
|
@convert_eager_outputs
|
||||||
def compute_actions(self,
|
def compute_actions(self,
|
||||||
obs_batch,
|
obs_batch,
|
||||||
state_batches,
|
state_batches,
|
||||||
|
@ -157,41 +296,46 @@ def build_eager_tf_policy(name,
|
||||||
episodes=None,
|
episodes=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
assert tf.executing_eagerly()
|
# TODO: remove python side effect to cull sources of bugs.
|
||||||
self._is_training = False
|
self._is_training = False
|
||||||
|
self._state_in = state_batches
|
||||||
|
|
||||||
self._seq_lens = tf.ones(len(obs_batch))
|
if tf.executing_eagerly():
|
||||||
self._input_dict = {
|
n = len(obs_batch)
|
||||||
|
else:
|
||||||
|
n = obs_batch.shape[0]
|
||||||
|
|
||||||
|
seq_lens = tf.ones(n)
|
||||||
|
input_dict = {
|
||||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
|
||||||
"is_training": tf.convert_to_tensor(False),
|
"is_training": tf.constant(False),
|
||||||
}
|
}
|
||||||
if obs_include_prev_action_reward:
|
if obs_include_prev_action_reward:
|
||||||
self._input_dict.update({
|
input_dict.update({
|
||||||
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
|
||||||
prev_action_batch),
|
prev_action_batch),
|
||||||
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
|
||||||
prev_reward_batch),
|
prev_reward_batch),
|
||||||
})
|
})
|
||||||
self._state_in = state_batches
|
|
||||||
with tf.variable_creator_scope(_disallow_var_creation):
|
with tf.variable_creator_scope(_disallow_var_creation):
|
||||||
model_out, state_out = self.model(
|
model_out, state_out = self.model(input_dict, state_batches,
|
||||||
self._input_dict, state_batches, self._seq_lens)
|
seq_lens)
|
||||||
|
|
||||||
if self.dist_class:
|
if self.dist_class:
|
||||||
action_dist = self.dist_class(model_out, self.model)
|
action_dist = self.dist_class(model_out, self.model)
|
||||||
action = action_dist.sample().numpy()
|
action = action_dist.sample()
|
||||||
logp = action_dist.sampled_action_logp()
|
logp = action_dist.sampled_action_logp()
|
||||||
else:
|
else:
|
||||||
action, logp = action_sampler_fn(
|
action, logp = action_sampler_fn(
|
||||||
self, self.model, self._input_dict, self.observation_space,
|
self, self.model, input_dict, self.observation_space,
|
||||||
self.action_space, self.config)
|
self.action_space, self.config)
|
||||||
action = action.numpy()
|
|
||||||
|
|
||||||
fetches = {}
|
fetches = {}
|
||||||
if logp is not None:
|
if logp is not None:
|
||||||
fetches.update({
|
fetches.update({
|
||||||
ACTION_PROB: tf.exp(logp).numpy(),
|
ACTION_PROB: tf.exp(logp),
|
||||||
ACTION_LOGP: logp.numpy(),
|
ACTION_LOGP: logp,
|
||||||
})
|
})
|
||||||
if extra_action_fetches_fn:
|
if extra_action_fetches_fn:
|
||||||
fetches.update(extra_action_fetches_fn(self))
|
fetches.update(extra_action_fetches_fn(self))
|
||||||
|
@ -248,14 +392,9 @@ def build_eager_tf_policy(name,
|
||||||
|
|
||||||
self._is_training = True
|
self._is_training = True
|
||||||
|
|
||||||
samples = {
|
|
||||||
k: tf.convert_to_tensor(v)
|
|
||||||
for k, v in samples.items() if v.dtype != np.object
|
|
||||||
}
|
|
||||||
|
|
||||||
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
||||||
# TODO: set seq len and state in properly
|
# TODO: set seq len and state in properly
|
||||||
self._seq_lens = tf.ones(len(samples[SampleBatch.CUR_OBS]))
|
self._seq_lens = tf.ones(samples[SampleBatch.CUR_OBS].shape[0])
|
||||||
self._state_in = []
|
self._state_in = []
|
||||||
model_out, _ = self.model(samples, self._state_in,
|
model_out, _ = self.model(samples, self._state_in,
|
||||||
self._seq_lens)
|
self._seq_lens)
|
||||||
|
@ -288,23 +427,22 @@ def build_eager_tf_policy(name,
|
||||||
return grads_and_vars, stats
|
return grads_and_vars, stats
|
||||||
|
|
||||||
def _stats(self, outputs, samples, grads):
|
def _stats(self, outputs, samples, grads):
|
||||||
assert tf.executing_eagerly()
|
|
||||||
fetches = {}
|
fetches = {}
|
||||||
if stats_fn:
|
if stats_fn:
|
||||||
fetches[LEARNER_STATS_KEY] = {
|
fetches[LEARNER_STATS_KEY] = {
|
||||||
k: v.numpy()
|
k: v
|
||||||
for k, v in stats_fn(outputs, samples).items()
|
for k, v in stats_fn(outputs, samples).items()
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
fetches[LEARNER_STATS_KEY] = {}
|
fetches[LEARNER_STATS_KEY] = {}
|
||||||
if extra_learn_fetches_fn:
|
if extra_learn_fetches_fn:
|
||||||
fetches.update({
|
fetches.update(
|
||||||
k: v.numpy()
|
{k: v
|
||||||
for k, v in extra_learn_fetches_fn(self).items()
|
for k, v in extra_learn_fetches_fn(self).items()})
|
||||||
})
|
|
||||||
if grad_stats_fn:
|
if grad_stats_fn:
|
||||||
fetches.update({
|
fetches.update({
|
||||||
k: v.numpy()
|
k: v
|
||||||
for k, v in grad_stats_fn(self, samples, grads).items()
|
for k, v in grad_stats_fn(self, samples, grads).items()
|
||||||
})
|
})
|
||||||
return fetches
|
return fetches
|
||||||
|
@ -380,6 +518,10 @@ def build_eager_tf_policy(name,
|
||||||
if stats_fn:
|
if stats_fn:
|
||||||
stats_fn(self, postprocessed_batch)
|
stats_fn(self, postprocessed_batch)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def with_tracing(cls):
|
||||||
|
return traced_eager_policy(cls)
|
||||||
|
|
||||||
eager_policy_cls.__name__ = name + "_eager"
|
eager_policy_cls.__name__ = name + "_eager"
|
||||||
eager_policy_cls.__qualname__ = name + "_eager"
|
eager_policy_cls.__qualname__ = name + "_eager"
|
||||||
return eager_policy_cls
|
return eager_policy_cls
|
||||||
|
|
|
@ -184,6 +184,10 @@ class SampleBatch(object):
|
||||||
def items(self):
|
def items(self):
|
||||||
return self.data.items()
|
return self.data.items()
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
def get(self, key):
|
||||||
|
return self.data.get(key)
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.data[key]
|
return self.data[key]
|
||||||
|
|
|
@ -12,6 +12,11 @@ def check_support(alg, config):
|
||||||
else:
|
else:
|
||||||
config["env"] = "CartPole-v0"
|
config["env"] = "CartPole-v0"
|
||||||
a = get_agent_class(alg)
|
a = get_agent_class(alg)
|
||||||
|
|
||||||
|
config["eager_tracing"] = False
|
||||||
|
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||||
|
|
||||||
|
config["eager_tracing"] = True
|
||||||
tune.run(a, config=config, stop={"training_iteration": 0})
|
tune.run(a, config=config, stop={"training_iteration": 0})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,10 @@ def create_parser(parser_creator=None):
|
||||||
"--eager",
|
"--eager",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to attempt to enable TF eager execution.")
|
help="Whether to attempt to enable TF eager execution.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--trace",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to attempt to enable tracing for eager mode.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--env", default=None, type=str, help="The gym environment to use.")
|
"--env", default=None, type=str, help="The gym environment to use.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -146,6 +150,10 @@ def run(args, parser):
|
||||||
parser.error("the following arguments are required: --env")
|
parser.error("the following arguments are required: --env")
|
||||||
if args.eager:
|
if args.eager:
|
||||||
exp["config"]["eager"] = True
|
exp["config"]["eager"] = True
|
||||||
|
if args.trace:
|
||||||
|
if not exp["config"].get("eager"):
|
||||||
|
raise ValueError("Must enable --eager to enable tracing.")
|
||||||
|
exp["config"]["eager_tracing"] = True
|
||||||
|
|
||||||
if args.ray_num_nodes:
|
if args.ray_num_nodes:
|
||||||
cluster = Cluster()
|
cluster = Cluster()
|
||||||
|
|
Loading…
Add table
Reference in a new issue