mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Tf2 + eager-tracing same speed as framework=tf; Add more test coverage for tf2+tracing. (#19981)
This commit is contained in:
parent
1341bb59bf
commit
a931076f59
25 changed files with 519 additions and 386 deletions
|
@ -148,7 +148,7 @@ py_test(
|
|||
)
|
||||
|
||||
py_test(
|
||||
name = "run_regression_tests_frozenlake_appo",
|
||||
name = "learning_frozenlake_appo",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
|
||||
size = "large",
|
||||
|
|
|
@ -24,8 +24,8 @@ class TestA2C(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config):
|
||||
for env in ["PongDeterministic-v0"]:
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]:
|
||||
trainer = a3c.A2CTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
|
@ -27,7 +27,7 @@ class TestA3C(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
for env in ["CartPole-v1", "Pendulum-v1", "PongDeterministic-v0"]:
|
||||
print("env={}".format(env))
|
||||
config["model"]["use_lstm"] = env == "CartPole-v1"
|
||||
|
|
|
@ -66,7 +66,7 @@ class TestCQL(unittest.TestCase):
|
|||
num_iterations = 4
|
||||
|
||||
# Test for tf/torch frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
for fw in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = cql.CQLTrainer(config=config)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
|
@ -24,7 +24,7 @@ class TestApexDDPG(unittest.TestCase):
|
|||
config["learning_starts"] = 0
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
num_iterations = 1
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
plain_config = config.copy()
|
||||
trainer = apex_ddpg.ApexDDPGTrainer(
|
||||
config=plain_config, env="Pendulum-v1")
|
||||
|
|
|
@ -41,7 +41,7 @@ class TestDDPG(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
|
@ -44,7 +44,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
config["min_iter_time_s"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
plain_config = config.copy()
|
||||
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class TestSimpleQ(unittest.TestCase):
|
|||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
|
||||
rw = trainer.workers.local_worker()
|
||||
for i in range(num_iterations):
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
env = "CartPole-v0"
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
local_cfg = config.copy()
|
||||
for lstm in [False, True]:
|
||||
local_cfg["num_aggregation_workers"] = 0 if not lstm else 1
|
||||
|
|
|
@ -24,7 +24,7 @@ class TestAPPO(unittest.TestCase):
|
|||
config["num_workers"] = 1
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
print("w/o v-trace")
|
||||
_config = config.copy()
|
||||
_config["vtrace"] = False
|
||||
|
|
|
@ -106,7 +106,7 @@ class TestPPO(unittest.TestCase):
|
|||
config["compress_observations"] = True
|
||||
num_iterations = 2
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
for fw in framework_iterator(config, with_eager_tracing=True):
|
||||
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
|
||||
print("Env={}".format(env))
|
||||
for lstm in [True, False]:
|
||||
|
|
|
@ -230,14 +230,23 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
|
||||
# === Deep Learning Framework Settings ===
|
||||
# tf: TensorFlow (static-graph)
|
||||
# tf2: TensorFlow 2.x (eager)
|
||||
# tfe: TensorFlow eager
|
||||
# tf2: TensorFlow 2.x (eager or traced, if eager_tracing=True)
|
||||
# tfe: TensorFlow eager (or traced, if eager_tracing=True)
|
||||
# torch: PyTorch
|
||||
"framework": "tf",
|
||||
# Enable tracing in eager mode. This greatly improves performance, but
|
||||
# makes it slightly harder to debug since Python code won't be evaluated
|
||||
# after the initial eager pass. Only possible if framework=tfe.
|
||||
# Enable tracing in eager mode. This greatly improves performance
|
||||
# (speedup ~2x), but makes it slightly harder to debug since Python
|
||||
# code won't be evaluated after the initial eager pass.
|
||||
# Only possible if framework=[tf2|tfe].
|
||||
"eager_tracing": False,
|
||||
# Maximum number of tf.function re-traces before a runtime error is raised.
|
||||
# This is to prevent unnoticed retraces of methods inside the
|
||||
# `..._eager_traced` Policy, which could slow down execution by a
|
||||
# factor of 4, without the user noticing what the root cause for this
|
||||
# slowdown could be.
|
||||
# Only necessary for framework=[tf2|tfe].
|
||||
# Set to None to ignore the re-trace count and never throw an error.
|
||||
"eager_max_retraces": 20,
|
||||
|
||||
# === Exploration Settings ===
|
||||
# Default exploration behavior, iff `explore`=None is passed into
|
||||
|
|
18
rllib/examples/env/random_env.py
vendored
18
rllib/examples/env/random_env.py
vendored
|
@ -10,8 +10,8 @@ class RandomEnv(gym.Env):
|
|||
|
||||
Can be instantiated with arbitrary action-, observation-, and reward
|
||||
spaces. Observations and rewards are generated by simply sampling from the
|
||||
observation/reward spaces. The probability of a `done=True` can be
|
||||
configured as well.
|
||||
observation/reward spaces. The probability of a `done=True` after each
|
||||
action can be configured, as well as the max episode length.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
|
@ -26,8 +26,13 @@ class RandomEnv(gym.Env):
|
|||
"reward_space",
|
||||
gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32))
|
||||
# Chance that an episode ends at any step.
|
||||
# Note that a max episode length can be specified via
|
||||
# `max_episode_len`.
|
||||
self.p_done = config.get("p_done", 0.1)
|
||||
# A max episode length.
|
||||
# A max episode length. Even if the `p_done` sampling does not lead
|
||||
# to a terminus, the episode will end after at most this many
|
||||
# timesteps.
|
||||
# Set to 0 or None for using no limit on the episode length.
|
||||
self.max_episode_len = config.get("max_episode_len", None)
|
||||
# Whether to check action bounds.
|
||||
self.check_action_bounds = config.get("check_action_bounds", False)
|
||||
|
@ -49,11 +54,10 @@ class RandomEnv(gym.Env):
|
|||
|
||||
self.steps += 1
|
||||
done = False
|
||||
# We are done as per our max-episode-len.
|
||||
if self.max_episode_len is not None and \
|
||||
self.steps >= self.max_episode_len:
|
||||
# We are `done` as per our max-episode-len.
|
||||
if self.max_episode_len and self.steps >= self.max_episode_len:
|
||||
done = True
|
||||
# Max not reached yet -> Sample done via p_done.
|
||||
# Max episode length not reached yet -> Sample `done` via `p_done`.
|
||||
elif self.p_done > 0.0:
|
||||
done = bool(
|
||||
np.random.choice(
|
||||
|
|
|
@ -196,7 +196,8 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
|
|||
def forward(self, input_dict, state, seq_lens):
|
||||
# Set the correct train-mode for our hidden module (only important
|
||||
# b/c we have some batch-norm layers).
|
||||
self._hidden_layers.train(mode=input_dict.get("is_training", False))
|
||||
self._hidden_layers.train(
|
||||
mode=bool(input_dict.get("is_training", False)))
|
||||
self._hidden_out = self._hidden_layers(input_dict["obs"])
|
||||
logits = self._logits(self._hidden_out)
|
||||
return logits, []
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning("LocalSyncParallelOptimizer", "TFMultiGPUTowerStack")
|
||||
# Backward compatibility.
|
||||
deprecation_warning(
|
||||
old="ray.rllib.execution.multi_gpu_impl.LocalSyncParallelOptimizer",
|
||||
new="ray.rllib.policy.dynamic_tf_policy.TFMultiGPUTowerStack",
|
||||
error=False,
|
||||
)
|
||||
# Old name.
|
||||
LocalSyncParallelOptimizer = TFMultiGPUTowerStack
|
||||
|
|
|
@ -2,6 +2,12 @@ from ray.rllib.execution.multi_gpu_learner_thread import \
|
|||
MultiGPULearnerThread, _MultiGPULoaderThread
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
deprecation_warning("multi_gpu_learner.py", "multi_gpu_learner_thread.py")
|
||||
# Backward compatibility.
|
||||
deprecation_warning(
|
||||
old="ray.rllib.execution.multi_gpu_learner.py",
|
||||
new="ray.rllib.execution.multi_gpu_learner_thread.py",
|
||||
error=False,
|
||||
)
|
||||
# Old names.
|
||||
TFMultiGPULearner = MultiGPULearnerThread
|
||||
_LoaderThread = _MultiGPULoaderThread
|
||||
|
|
|
@ -243,6 +243,11 @@ class ModelV2:
|
|||
with self.context():
|
||||
res = self.forward(restored, state or [], seq_lens)
|
||||
|
||||
if isinstance(input_dict, SampleBatch):
|
||||
input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}
|
||||
input_dict.deleted_keys = restored.deleted_keys
|
||||
input_dict.added_keys = restored.added_keys - {"obs_flat"}
|
||||
|
||||
if ((not isinstance(res, list) and not isinstance(res, tuple))
|
||||
or len(res) != 2):
|
||||
raise ValueError(
|
||||
|
|
|
@ -250,7 +250,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
True, (), name="is_exploring")
|
||||
|
||||
# Placeholder for `is_training` flag.
|
||||
self._input_dict.is_training = self._get_is_training_placeholder()
|
||||
self._input_dict.set_training(self._get_is_training_placeholder())
|
||||
|
||||
# Multi-GPU towers do not need any action computing/exploration
|
||||
# graphs.
|
||||
|
@ -464,7 +464,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
buffer_index: int = 0,
|
||||
) -> int:
|
||||
# Set the is_training flag of the batch.
|
||||
batch.is_training = True
|
||||
batch.set_training(True)
|
||||
|
||||
# Shortcut for 1 CPU only: Store batch in
|
||||
# `self._loaded_single_cpu_batch`.
|
||||
|
|
|
@ -44,9 +44,11 @@ def _convert_to_tf(x, dtype=None):
|
|||
|
||||
if x is not None:
|
||||
d = dtype
|
||||
x = tf.nest.map_structure(
|
||||
return tf.nest.map_structure(
|
||||
lambda f: _convert_to_tf(f, d) if isinstance(f, RepeatedValues)
|
||||
else tf.convert_to_tensor(f, d) if f is not None else None, x)
|
||||
else tf.convert_to_tensor(f, d) if
|
||||
f is not None and not tf.is_tensor(f) else f, x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -101,32 +103,39 @@ def _disallow_var_creation(next_creator, **kw):
|
|||
"model initialization: {}".format(v.name))
|
||||
|
||||
|
||||
def traced_eager_policy(eager_policy_cls):
|
||||
"""Wrapper that enables tracing for all eager policy methods.
|
||||
def check_too_many_retraces(obj):
|
||||
"""Asserts that a given number of re-traces is not breached."""
|
||||
|
||||
This is enabled by the --trace / "eager_tracing" config."""
|
||||
def _func(self_, *args, **kwargs):
|
||||
if self_.config.get("eager_max_retraces") is not None and \
|
||||
self_._re_trace_counter > self_.config["eager_max_retraces"]:
|
||||
raise RuntimeError(
|
||||
"Too many tf-eager re-traces detected! This could lead to"
|
||||
" significant slow-downs (even slower than running in "
|
||||
"tf-eager mode w/ `eager_tracing=False`). To switch off "
|
||||
"these re-trace counting checks, set `eager_max_retraces`"
|
||||
" in your config to None.")
|
||||
return obj(self_, *args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
|
||||
def traced_eager_policy(eager_policy_cls):
|
||||
"""Wrapper class that enables tracing for all eager policy methods.
|
||||
|
||||
This is enabled by the `--trace`/`eager_tracing=True` config when
|
||||
framework=[tf2|tfe].
|
||||
"""
|
||||
|
||||
class TracedEagerPolicy(eager_policy_cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._traced_learn_on_batch = None
|
||||
self._traced_compute_action_helper = False
|
||||
self._traced_compute_gradients = None
|
||||
self._traced_apply_gradients = None
|
||||
self._traced_learn_on_batch_helper = False
|
||||
self._traced_compute_actions_helper = False
|
||||
self._traced_compute_gradients_helper = False
|
||||
self._traced_apply_gradients_helper = False
|
||||
super(TracedEagerPolicy, self).__init__(*args, **kwargs)
|
||||
|
||||
@override(eager_policy_cls)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def _learn_on_batch_eager(self, samples):
|
||||
|
||||
if self._traced_learn_on_batch is None:
|
||||
self._traced_learn_on_batch = tf.function(
|
||||
super(TracedEagerPolicy, self)._learn_on_batch_eager,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True)
|
||||
|
||||
return self._traced_learn_on_batch(samples)
|
||||
|
||||
@check_too_many_retraces
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
|
@ -136,16 +145,20 @@ def traced_eager_policy(eager_policy_cls):
|
|||
episodes: Optional[List[Episode]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
|
||||
"""Traced version of Policy.compute_actions_from_input_dict."""
|
||||
|
||||
# Create a traced version of `self._compute_action_helper`.
|
||||
if self._traced_compute_action_helper is False:
|
||||
self._compute_action_helper = convert_eager_inputs(
|
||||
# Create a traced version of `self._compute_actions_helper`.
|
||||
if self._traced_compute_actions_helper is False and \
|
||||
not self._no_tracing:
|
||||
self._compute_actions_helper = convert_eager_inputs(
|
||||
tf.function(
|
||||
super(TracedEagerPolicy, self)._compute_action_helper,
|
||||
super(TracedEagerPolicy, self)._compute_actions_helper,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True))
|
||||
self._traced_compute_action_helper = True
|
||||
self._traced_compute_actions_helper = True
|
||||
|
||||
# Now that the helper method is traced, call super's
|
||||
# apply_gradients (which will call the traced helper).
|
||||
return super(TracedEagerPolicy, self).\
|
||||
compute_actions_from_input_dict(
|
||||
input_dict=input_dict,
|
||||
|
@ -155,55 +168,67 @@ def traced_eager_policy(eager_policy_cls):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
@check_too_many_retraces
|
||||
@override(eager_policy_cls)
|
||||
@convert_eager_outputs
|
||||
def _compute_gradients_eager(self, samples: SampleBatch) -> \
|
||||
ModelGradients:
|
||||
"""Traced version of EagerTFPolicy's `_compute_gradients_eager`.
|
||||
def learn_on_batch(self, samples):
|
||||
"""Traced version of Policy.learn_on_batch."""
|
||||
|
||||
Note that `samples` is already zero-padded and has the is_training
|
||||
flag set. We convert this sample batch into tensors here and do the
|
||||
actual computation.
|
||||
|
||||
Args:
|
||||
samples: The SampleBatch to compute gradients for.
|
||||
|
||||
Returns:
|
||||
The computed model gradients.
|
||||
"""
|
||||
# Have we traced the `_compute_gradients_eager` function yet?
|
||||
if self._traced_compute_gradients is None:
|
||||
self._traced_compute_gradients = convert_eager_inputs(
|
||||
# Create a traced version of `self._learn_on_batch_helper`.
|
||||
if self._traced_learn_on_batch_helper is False and \
|
||||
not self._no_tracing:
|
||||
self._learn_on_batch_helper = convert_eager_inputs(
|
||||
tf.function(
|
||||
super()._compute_gradients_eager,
|
||||
super(TracedEagerPolicy, self)._learn_on_batch_helper,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True))
|
||||
self._traced_learn_on_batch_helper = True
|
||||
|
||||
# Call the only-once compiled traced function with the SampleBatch
|
||||
# (will be converted to tensors beforehand).
|
||||
return self._traced_compute_gradients(samples)
|
||||
# Now that the helper method is traced, call super's
|
||||
# apply_gradients (which will call the traced helper).
|
||||
return super(TracedEagerPolicy, self).learn_on_batch(samples)
|
||||
|
||||
@check_too_many_retraces
|
||||
@override(eager_policy_cls)
|
||||
def compute_gradients(self, samples: SampleBatch) -> \
|
||||
ModelGradients:
|
||||
"""Traced version of Policy.compute_gradients."""
|
||||
|
||||
# Create a traced version of `self._compute_gradients_helper`.
|
||||
if self._traced_compute_gradients_helper is False and \
|
||||
not self._no_tracing:
|
||||
self._compute_gradients_helper = convert_eager_inputs(
|
||||
tf.function(
|
||||
super(TracedEagerPolicy,
|
||||
self)._compute_gradients_helper,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True))
|
||||
self._traced_compute_gradients_helper = True
|
||||
|
||||
# Now that the helper method is traced, call super's
|
||||
# apply_gradients (which will call the traced helper).
|
||||
return super(TracedEagerPolicy, self).compute_gradients(samples)
|
||||
|
||||
@check_too_many_retraces
|
||||
@override(Policy)
|
||||
@convert_eager_outputs
|
||||
def apply_gradients(self, grads: ModelGradients) -> None:
|
||||
"""Traced version of EagerTFPolicy's `apply_gradients` method.
|
||||
"""Traced version of Policy.apply_gradients."""
|
||||
|
||||
Args:
|
||||
grads: The ModelGradients to apply using our optimizer(s).
|
||||
"""
|
||||
# Have we traced the `apply_gradients` function yet?
|
||||
if self._traced_apply_gradients is None:
|
||||
self._traced_apply_gradients = tf.function(
|
||||
super().apply_gradients,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True)
|
||||
# Create a traced version of `self._apply_gradients_helper`.
|
||||
if self._traced_apply_gradients_helper is False and \
|
||||
not self._no_tracing:
|
||||
self._apply_gradients_helper = convert_eager_inputs(
|
||||
tf.function(
|
||||
super(TracedEagerPolicy, self)._apply_gradients_helper,
|
||||
autograph=False,
|
||||
experimental_relax_shapes=True))
|
||||
self._traced_apply_gradients_helper = True
|
||||
|
||||
# Call the only-once compiled traced function with the grads
|
||||
# (will be converted to tensors beforehand).
|
||||
return self._traced_apply_gradients(grads)
|
||||
# Now that the helper method is traced, call super's
|
||||
# apply_gradients (which will call the traced helper).
|
||||
return super(TracedEagerPolicy, self).apply_gradients(grads)
|
||||
|
||||
TracedEagerPolicy.__name__ = eager_policy_cls.__name__
|
||||
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__
|
||||
TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
|
||||
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
|
||||
return TracedEagerPolicy
|
||||
|
||||
|
||||
|
@ -289,9 +314,20 @@ def build_eager_tf_policy(
|
|||
worker_idx if worker_idx > 0 else "local"))
|
||||
|
||||
self._is_training = False
|
||||
self._loss_initialized = False
|
||||
|
||||
# Only for `config.eager_tracing=True`: A counter to keep track of
|
||||
# how many times an eager-traced method (e.g.
|
||||
# `self._compute_actions_helper`) has been re-traced by tensorflow.
|
||||
# We will raise an error if more than n re-tracings have been
|
||||
# detected, since this would considerably slow down execution.
|
||||
# The variable below should only get incremented during the
|
||||
# tf.function trace operations, never when calling the already
|
||||
# traced function after that.
|
||||
self._re_trace_counter = 0
|
||||
|
||||
self._loss_initialized = False
|
||||
self._loss = loss_fn
|
||||
|
||||
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||
callable(get_batch_divisibility_req) else \
|
||||
(get_batch_divisibility_req or 1)
|
||||
|
@ -376,107 +412,6 @@ def build_eager_tf_policy(
|
|||
# Got to reset global_timestep again after fake run-throughs.
|
||||
self.global_timestep = 0
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
assert tf.executing_eagerly()
|
||||
# Call super's postprocess_trajectory first.
|
||||
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
|
||||
if postprocess_fn:
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
return sample_batch
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
# Callback handling.
|
||||
learn_stats = {}
|
||||
self.callbacks.on_learn_on_batch(
|
||||
policy=self,
|
||||
train_batch=postprocessed_batch,
|
||||
result=learn_stats)
|
||||
|
||||
if not isinstance(postprocessed_batch, SampleBatch) or \
|
||||
not postprocessed_batch.zero_padded:
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
max_seq_len=self._max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
self._is_training = True
|
||||
postprocessed_batch.is_training = True
|
||||
stats = self._learn_on_batch_eager(postprocessed_batch)
|
||||
stats.update({"custom_metrics": learn_stats})
|
||||
return stats
|
||||
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def _learn_on_batch_eager(self, samples):
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
grads_and_vars, stats = self._compute_gradients(samples)
|
||||
self._apply_gradients(grads_and_vars)
|
||||
return stats
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, samples):
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
samples,
|
||||
shuffle=False,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
self._is_training = True
|
||||
samples.is_training = True
|
||||
return self._compute_gradients_eager(samples)
|
||||
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
def _compute_gradients_eager(self, samples):
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
grads_and_vars, stats = self._compute_gradients(samples)
|
||||
grads = [g for g, v in grads_and_vars]
|
||||
return grads, stats
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
|
||||
# Create input dict to simply pass the entire call to
|
||||
# self.compute_actions_from_input_dict().
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.CUR_OBS: obs_batch,
|
||||
},
|
||||
_is_training=tf.constant(False))
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
||||
if prev_reward_batch is not None:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
||||
|
||||
return self.compute_actions_from_input_dict(
|
||||
input_dict=input_dict,
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
episodes=episodes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(
|
||||
self,
|
||||
|
@ -502,7 +437,8 @@ def build_eager_tf_policy(
|
|||
|
||||
# Pass lazy (eager) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
input_dict.is_training = False
|
||||
input_dict.set_training(False)
|
||||
|
||||
# Pack internal state inputs into (separate) list.
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
||||
|
@ -514,7 +450,7 @@ def build_eager_tf_policy(
|
|||
self.exploration.before_compute_actions(
|
||||
timestep=timestep, explore=explore, tf_sess=self.get_session())
|
||||
|
||||
ret = self._compute_action_helper(
|
||||
ret = self._compute_actions_helper(
|
||||
input_dict,
|
||||
state_batches,
|
||||
# TODO: Passing episodes into a traced method does not work.
|
||||
|
@ -525,9 +461,254 @@ def build_eager_tf_policy(
|
|||
self.global_timestep += int(tree.flatten(ret[0])[0].shape[0])
|
||||
return convert_to_numpy(ret)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
|
||||
# Create input dict to simply pass the entire call to
|
||||
# self.compute_actions_from_input_dict().
|
||||
input_dict = SampleBatch(
|
||||
{
|
||||
SampleBatch.CUR_OBS: obs_batch,
|
||||
},
|
||||
_is_training=tf.constant(False))
|
||||
if state_batches is not None:
|
||||
for s in enumerate(state_batches):
|
||||
input_dict["state_in_{i}"] = s
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
||||
if prev_reward_batch is not None:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
||||
if info_batch is not None:
|
||||
input_dict[SampleBatch.INFOS] = info_batch
|
||||
|
||||
return self.compute_actions_from_input_dict(
|
||||
input_dict=input_dict,
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
episodes=episodes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@with_lock
|
||||
def _compute_action_helper(self, input_dict, state_batches, episodes,
|
||||
explore, timestep):
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
actions_normalized=True,
|
||||
):
|
||||
if action_sampler_fn and action_distribution_fn is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
||||
"`action_distribution_fn` and a provided "
|
||||
"`action_sampler_fn`!")
|
||||
|
||||
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
||||
input_batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)
|
||||
},
|
||||
_is_training=False)
|
||||
if prev_action_batch is not None:
|
||||
input_batch[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
if prev_reward_batch is not None:
|
||||
input_batch[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
|
||||
# Exploration hook before each forward pass.
|
||||
self.exploration.before_compute_actions(explore=False)
|
||||
|
||||
# Action dist class and inputs are generated via custom function.
|
||||
if action_distribution_fn:
|
||||
dist_inputs, dist_class, _ = action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_batch,
|
||||
explore=False,
|
||||
is_training=False)
|
||||
# Default log-likelihood calculation.
|
||||
else:
|
||||
dist_inputs, _ = self.model(input_batch, state_batches,
|
||||
seq_lens)
|
||||
dist_class = self.dist_class
|
||||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Normalize actions if necessary.
|
||||
if not actions_normalized and self.config["normalize_actions"]:
|
||||
actions = normalize_action(actions, self.action_space_struct)
|
||||
|
||||
log_likelihoods = action_dist.logp(actions)
|
||||
|
||||
return log_likelihoods
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
assert tf.executing_eagerly()
|
||||
# Call super's postprocess_trajectory first.
|
||||
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
|
||||
if postprocess_fn:
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
return sample_batch
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
# Callback handling.
|
||||
learn_stats = {}
|
||||
self.callbacks.on_learn_on_batch(
|
||||
policy=self,
|
||||
train_batch=postprocessed_batch,
|
||||
result=learn_stats)
|
||||
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
max_seq_len=self._max_seq_len,
|
||||
shuffle=False,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
self._is_training = True
|
||||
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
postprocessed_batch.set_training(True)
|
||||
stats = self._learn_on_batch_helper(postprocessed_batch)
|
||||
stats.update({"custom_metrics": learn_stats})
|
||||
return convert_to_numpy(stats)
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch: SampleBatch) -> \
|
||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
|
||||
pad_batch_to_sequences_of_same_size(
|
||||
postprocessed_batch,
|
||||
shuffle=False,
|
||||
max_seq_len=self._max_seq_len,
|
||||
batch_divisibility_req=self.batch_divisibility_req,
|
||||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
self._is_training = True
|
||||
self._lazy_tensor_dict(postprocessed_batch)
|
||||
postprocessed_batch.set_training(True)
|
||||
grads_and_vars, grads, stats = self._compute_gradients_helper(
|
||||
postprocessed_batch)
|
||||
return convert_to_numpy((grads, stats))
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||
self._apply_gradients_helper(
|
||||
list(
|
||||
zip([(tf.convert_to_tensor(g)
|
||||
if g is not None else None) for g in gradients],
|
||||
self.model.trainable_variables())))
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self, as_dict=False):
|
||||
variables = self.variables()
|
||||
if as_dict:
|
||||
return {v.name: v.numpy() for v in variables}
|
||||
return [v.numpy() for v in variables]
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
variables = self.variables()
|
||||
assert len(weights) == len(variables), (len(weights),
|
||||
len(variables))
|
||||
for v, w in zip(variables, weights):
|
||||
v.assign(w)
|
||||
|
||||
@override(Policy)
|
||||
def get_exploration_state(self):
|
||||
return convert_to_numpy(self.exploration.get_state())
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return self._is_recurrent
|
||||
|
||||
@override(Policy)
|
||||
def num_state_tensors(self):
|
||||
return len(self._state_inputs)
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
if hasattr(self, "model"):
|
||||
return self.model.get_initial_state()
|
||||
return []
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
if self._optimizer and \
|
||||
len(self._optimizer.variables()) > 0:
|
||||
state["_optimizer_variables"] = \
|
||||
self._optimizer.variables()
|
||||
# Add exploration state.
|
||||
state["_exploration_state"] = self.exploration.get_state()
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.get("_optimizer_variables", None)
|
||||
if optimizer_vars and self._optimizer.variables():
|
||||
logger.warning(
|
||||
"Cannot restore an optimizer's state for tf eager! Keras "
|
||||
"is not able to save the v1.x optimizers (from "
|
||||
"tf.compat.v1.train) since they aren't compatible with "
|
||||
"checkpoints.")
|
||||
for opt_var, value in zip(self._optimizer.variables(),
|
||||
optimizer_vars):
|
||||
opt_var.assign(value)
|
||||
# Set exploration's state.
|
||||
if hasattr(self, "exploration") and "_exploration_state" in state:
|
||||
self.exploration.set_state(state=state["_exploration_state"])
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir):
|
||||
raise NotImplementedError # TODO: implement this
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
raise NotImplementedError # TODO: implement this
|
||||
|
||||
def variables(self):
|
||||
"""Return the list of all savable variables for this policy."""
|
||||
if isinstance(self.model, tf.keras.Model):
|
||||
return self.model.variables
|
||||
else:
|
||||
return self.model.variables()
|
||||
|
||||
def loss_initialized(self):
|
||||
return self._loss_initialized
|
||||
|
||||
@with_lock
|
||||
def _compute_actions_helper(self, input_dict, state_batches, episodes,
|
||||
explore, timestep):
|
||||
# Increase the tracing counter to make sure we don't re-trace too
|
||||
# often. If eager_tracing=True, this counter should only get
|
||||
# incremented during the @tf.function trace operations, never when
|
||||
# calling the already traced function after that.
|
||||
self._re_trace_counter += 1
|
||||
|
||||
# Calculate RNN sequence lengths.
|
||||
batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
|
||||
|
@ -612,166 +793,32 @@ def build_eager_tf_policy(
|
|||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(
|
||||
self,
|
||||
actions,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
actions_normalized=True,
|
||||
):
|
||||
if action_sampler_fn and action_distribution_fn is None:
|
||||
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
||||
"`action_distribution_fn` and a provided "
|
||||
"`action_sampler_fn`!")
|
||||
def _learn_on_batch_helper(self, samples):
|
||||
# Increase the tracing counter to make sure we don't re-trace too
|
||||
# often. If eager_tracing=True, this counter should only get
|
||||
# incremented during the @tf.function trace operations, never when
|
||||
# calling the already traced function after that.
|
||||
self._re_trace_counter += 1
|
||||
|
||||
seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
|
||||
input_batch = SampleBatch(
|
||||
{
|
||||
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)
|
||||
},
|
||||
_is_training=False)
|
||||
if prev_action_batch is not None:
|
||||
input_batch[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
if prev_reward_batch is not None:
|
||||
input_batch[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
|
||||
# Exploration hook before each forward pass.
|
||||
self.exploration.before_compute_actions(explore=False)
|
||||
|
||||
# Action dist class and inputs are generated via custom function.
|
||||
if action_distribution_fn:
|
||||
dist_inputs, dist_class, _ = action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_batch,
|
||||
explore=False,
|
||||
is_training=False)
|
||||
# Default log-likelihood calculation.
|
||||
else:
|
||||
dist_inputs, _ = self.model(input_batch, state_batches,
|
||||
seq_lens)
|
||||
dist_class = self.dist_class
|
||||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Normalize actions if necessary.
|
||||
if not actions_normalized and self.config["normalize_actions"]:
|
||||
actions = normalize_action(actions, self.action_space_struct)
|
||||
|
||||
log_likelihoods = action_dist.logp(actions)
|
||||
|
||||
return log_likelihoods
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
self._apply_gradients(
|
||||
zip([(tf.convert_to_tensor(g) if g is not None else None)
|
||||
for g in gradients], self.model.trainable_variables()))
|
||||
|
||||
@override(Policy)
|
||||
def get_exploration_state(self):
|
||||
return convert_to_numpy(self.exploration.get_state())
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self, as_dict=False):
|
||||
variables = self.variables()
|
||||
if as_dict:
|
||||
return {v.name: v.numpy() for v in variables}
|
||||
return [v.numpy() for v in variables]
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
variables = self.variables()
|
||||
assert len(weights) == len(variables), (len(weights),
|
||||
len(variables))
|
||||
for v, w in zip(variables, weights):
|
||||
v.assign(w)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
if self._optimizer and \
|
||||
len(self._optimizer.variables()) > 0:
|
||||
state["_optimizer_variables"] = \
|
||||
self._optimizer.variables()
|
||||
# Add exploration state.
|
||||
state["_exploration_state"] = self.exploration.get_state()
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.get("_optimizer_variables", None)
|
||||
if optimizer_vars and self._optimizer.variables():
|
||||
logger.warning(
|
||||
"Cannot restore an optimizer's state for tf eager! Keras "
|
||||
"is not able to save the v1.x optimizers (from "
|
||||
"tf.compat.v1.train) since they aren't compatible with "
|
||||
"checkpoints.")
|
||||
for opt_var, value in zip(self._optimizer.variables(),
|
||||
optimizer_vars):
|
||||
opt_var.assign(value)
|
||||
# Set exploration's state.
|
||||
if hasattr(self, "exploration") and "_exploration_state" in state:
|
||||
self.exploration.set_state(state=state["_exploration_state"])
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
def variables(self):
|
||||
"""Return the list of all savable variables for this policy."""
|
||||
if isinstance(self.model, tf.keras.Model):
|
||||
return self.model.variables
|
||||
else:
|
||||
return self.model.variables()
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return self._is_recurrent
|
||||
|
||||
@override(Policy)
|
||||
def num_state_tensors(self):
|
||||
return len(self._state_inputs)
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
if hasattr(self, "model"):
|
||||
return self.model.get_initial_state()
|
||||
return []
|
||||
|
||||
def get_session(self):
|
||||
return None # None implies eager
|
||||
|
||||
def get_placeholder(self, ph):
|
||||
raise ValueError(
|
||||
"get_placeholder() is not allowed in eager mode. Try using "
|
||||
"rllib.utils.tf_utils.make_tf_callable() to write "
|
||||
"functions that work in both graph and eager mode.")
|
||||
|
||||
def loss_initialized(self):
|
||||
return self._loss_initialized
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
pass
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir):
|
||||
pass
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
grads_and_vars, _, stats = self._compute_gradients_helper(
|
||||
samples)
|
||||
self._apply_gradients_helper(grads_and_vars)
|
||||
return stats
|
||||
|
||||
def _get_is_training_placeholder(self):
|
||||
return tf.convert_to_tensor(self._is_training)
|
||||
|
||||
@with_lock
|
||||
def _compute_gradients(self, samples):
|
||||
def _compute_gradients_helper(self, samples):
|
||||
"""Computes and returns grads as eager tensors."""
|
||||
|
||||
# Increase the tracing counter to make sure we don't re-trace too
|
||||
# often. If eager_tracing=True, this counter should only get
|
||||
# incremented during the @tf.function trace operations, never when
|
||||
# calling the already traced function after that.
|
||||
self._re_trace_counter += 1
|
||||
|
||||
# Gather all variables for which to calculate losses.
|
||||
if isinstance(self.model, tf.keras.Model):
|
||||
variables = self.model.trainable_variables
|
||||
|
@ -823,9 +870,15 @@ def build_eager_tf_policy(
|
|||
grads = [g for g, _ in grads_and_vars]
|
||||
|
||||
stats = self._stats(self, samples, grads)
|
||||
return grads_and_vars, stats
|
||||
return grads_and_vars, grads, stats
|
||||
|
||||
def _apply_gradients_helper(self, grads_and_vars):
|
||||
# Increase the tracing counter to make sure we don't re-trace too
|
||||
# often. If eager_tracing=True, this counter should only get
|
||||
# incremented during the @tf.function trace operations, never when
|
||||
# calling the already traced function after that.
|
||||
self._re_trace_counter += 1
|
||||
|
||||
def _apply_gradients(self, grads_and_vars):
|
||||
if apply_gradients_fn:
|
||||
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
||||
apply_gradients_fn(self, self._optimizers, grads_and_vars)
|
||||
|
|
|
@ -764,6 +764,12 @@ class Policy(metaclass=ABCMeta):
|
|||
TensorType]]]): An optional stats function to be called after
|
||||
the loss.
|
||||
"""
|
||||
# Signal Policy that currently we do not like to eager/jit trace
|
||||
# any function calls. This is to be able to track, which columns
|
||||
# in the dummy batch are accessed by the different function (e.g.
|
||||
# loss) such that we can then adjust our view requirements.
|
||||
self._no_tracing = True
|
||||
|
||||
sample_batch_size = max(self.batch_divisibility_req * 4, 32)
|
||||
self._dummy_batch = self._get_dummy_batch_from_view_requirements(
|
||||
sample_batch_size)
|
||||
|
@ -771,6 +777,9 @@ class Policy(metaclass=ABCMeta):
|
|||
actions, state_outs, extra_outs = \
|
||||
self.compute_actions_from_input_dict(
|
||||
self._dummy_batch, explore=False)
|
||||
for key, view_req in self.view_requirements.items():
|
||||
if key not in self._dummy_batch.accessed_keys:
|
||||
view_req.used_for_compute_actions = False
|
||||
# Add all extra action outputs to view reqirements (these may be
|
||||
# filtered out later again, if not needed for postprocessing or loss).
|
||||
for key, value in extra_outs.items():
|
||||
|
@ -806,7 +815,7 @@ class Policy(metaclass=ABCMeta):
|
|||
# Switch on lazy to-tensor conversion on `postprocessed_batch`.
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
# Calling loss, so set `is_training` to True.
|
||||
train_batch.is_training = True
|
||||
train_batch.set_training(True)
|
||||
if seq_lens is not None:
|
||||
train_batch[SampleBatch.SEQ_LENS] = seq_lens
|
||||
train_batch.count = self._dummy_batch.count
|
||||
|
@ -817,6 +826,9 @@ class Policy(metaclass=ABCMeta):
|
|||
if stats_fn is not None:
|
||||
stats_fn(self, train_batch)
|
||||
|
||||
# Re-enable tracing.
|
||||
self._no_tracing = False
|
||||
|
||||
# Add new columns automatically to view-reqs.
|
||||
if auto_remove_unneeded_view_reqs:
|
||||
# Add those needed for postprocessing and training.
|
||||
|
|
|
@ -46,21 +46,22 @@ def pad_batch_to_sequences_of_same_size(
|
|||
Padding depends on episodes found in batch and `max_seq_len`.
|
||||
|
||||
Args:
|
||||
batch (SampleBatch): The SampleBatch object. All values in here have
|
||||
batch: The SampleBatch object. All values in here have
|
||||
the shape [B, ...].
|
||||
max_seq_len (int): The max. sequence length to use for chopping.
|
||||
shuffle (bool): Whether to shuffle batch sequences. Shuffle may
|
||||
max_seq_len: The max. sequence length to use for chopping.
|
||||
shuffle: Whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
applying minibatch SGD after getting the outputs.
|
||||
batch_divisibility_req (int): The int by which the batch dimension
|
||||
batch_divisibility_req: The int by which the batch dimension
|
||||
must be dividable.
|
||||
feature_keys (Optional[List[str]]): An optional list of keys to apply
|
||||
sequence-chopping to. If None, use all keys in batch that are not
|
||||
feature_keys: An optional list of keys to apply sequence-chopping
|
||||
to. If None, use all keys in batch that are not
|
||||
"state_in/out_"-type keys.
|
||||
view_requirements (Optional[ViewRequirementsDict]): An optional
|
||||
Policy ViewRequirements dict to be able to infer whether
|
||||
e.g. dynamic max'ing should be applied over the seq_lens.
|
||||
view_requirements: An optional Policy ViewRequirements dict to
|
||||
be able to infer whether e.g. dynamic max'ing should be
|
||||
applied over the seq_lens.
|
||||
"""
|
||||
# If already zero-padded, skip.
|
||||
if batch.zero_padded:
|
||||
return
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ class SampleBatch(dict):
|
|||
# Is alredy right-zero-padded?
|
||||
self.zero_padded = kwargs.pop("_zero_padded", False)
|
||||
# Whether this batch is used for training (vs inference).
|
||||
self.is_training = kwargs.pop("_is_training", None)
|
||||
self._is_training = kwargs.pop("_is_training", None)
|
||||
|
||||
# Call super constructor. This will make the actual data accessible
|
||||
# by column name (str) via e.g. self["some-col"].
|
||||
|
@ -118,8 +118,8 @@ class SampleBatch(dict):
|
|||
len(seq_lens_) > 0:
|
||||
self.max_seq_len = max(seq_lens_)
|
||||
|
||||
if self.is_training is None:
|
||||
self.is_training = self.pop("is_training", False)
|
||||
if self._is_training is None:
|
||||
self._is_training = self.pop("is_training", False)
|
||||
|
||||
lengths = []
|
||||
copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
|
||||
|
@ -271,6 +271,9 @@ class SampleBatch(dict):
|
|||
)
|
||||
copy_ = SampleBatch(data)
|
||||
copy_.set_get_interceptor(self.get_interceptor)
|
||||
copy_.added_keys = self.added_keys
|
||||
copy_.deleted_keys = self.deleted_keys
|
||||
copy_.accessed_keys = self.accessed_keys
|
||||
return copy_
|
||||
|
||||
@PublicAPI
|
||||
|
@ -501,6 +504,7 @@ class SampleBatch(dict):
|
|||
return SampleBatch(
|
||||
data,
|
||||
seq_lens=seq_lens,
|
||||
_is_training=self.is_training,
|
||||
_time_major=self.time_major,
|
||||
)
|
||||
else:
|
||||
|
@ -731,7 +735,7 @@ class SampleBatch(dict):
|
|||
old="SampleBatch['is_training']",
|
||||
new="SampleBatch.is_training",
|
||||
error=False)
|
||||
self.is_training = item
|
||||
self._is_training = item
|
||||
return
|
||||
|
||||
if key not in self:
|
||||
|
@ -741,6 +745,20 @@ class SampleBatch(dict):
|
|||
if key in self.intercepted_values:
|
||||
self.intercepted_values[key] = item
|
||||
|
||||
@property
|
||||
def is_training(self):
|
||||
if self.get_interceptor is not None and \
|
||||
isinstance(self._is_training, bool):
|
||||
if "_is_training" not in self.intercepted_values:
|
||||
self.intercepted_values["_is_training"] = \
|
||||
self.get_interceptor(self._is_training)
|
||||
return self.intercepted_values["_is_training"]
|
||||
return self._is_training
|
||||
|
||||
def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
|
||||
self._is_training = training
|
||||
self.intercepted_values.pop("_is_training", None)
|
||||
|
||||
@PublicAPI
|
||||
def __delitem__(self, key):
|
||||
self.deleted_keys.add(key)
|
||||
|
|
|
@ -403,7 +403,7 @@ class TFPolicy(Policy):
|
|||
assert self.loss_initialized()
|
||||
|
||||
# Switch on is_training flag in our batch.
|
||||
postprocessed_batch.is_training = True
|
||||
postprocessed_batch.set_training(True)
|
||||
|
||||
builder = TFRunBuilder(self.get_session(), "learn_on_batch")
|
||||
|
||||
|
@ -425,7 +425,7 @@ class TFPolicy(Policy):
|
|||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||
assert self.loss_initialized()
|
||||
# Switch on is_training flag in our batch.
|
||||
postprocessed_batch.is_training = True
|
||||
postprocessed_batch.set_training(True)
|
||||
builder = TFRunBuilder(self.get_session(), "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
@ -1124,7 +1124,7 @@ class TFPolicy(Policy):
|
|||
|
||||
# Mark the batch as "is_training" so the Model can use this
|
||||
# information.
|
||||
train_batch.is_training = True
|
||||
train_batch.set_training(True)
|
||||
|
||||
# Build the feed dict from the batch.
|
||||
feed_dict = {}
|
||||
|
|
|
@ -256,7 +256,7 @@ class TorchPolicy(Policy):
|
|||
with torch.no_grad():
|
||||
# Pass lazy (torch) tensor dict to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(input_dict)
|
||||
input_dict.is_training = False
|
||||
input_dict.set_training(True)
|
||||
# Pack internal state inputs into (separate) list.
|
||||
state_batches = [
|
||||
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
|
||||
|
@ -424,7 +424,7 @@ class TorchPolicy(Policy):
|
|||
buffer_index: int = 0,
|
||||
) -> int:
|
||||
# Set the is_training flag of the batch.
|
||||
batch.is_training = True
|
||||
batch.set_training(True)
|
||||
|
||||
# Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
|
||||
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
||||
|
@ -572,7 +572,7 @@ class TorchPolicy(Policy):
|
|||
view_requirements=self.view_requirements,
|
||||
)
|
||||
|
||||
postprocessed_batch.is_training = True
|
||||
postprocessed_batch.set_training(True)
|
||||
self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
|
||||
|
||||
# Do the (maybe parallelized) gradient calculation step.
|
||||
|
|
|
@ -7,12 +7,13 @@ import random
|
|||
import re
|
||||
import time
|
||||
import tree # pip install dm_tree
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
|
||||
try_import_torch
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||
from ray.tune import run_experiments
|
||||
|
||||
jax, _ = try_import_jax()
|
||||
|
@ -29,32 +30,37 @@ torch, _ = try_import_torch()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def framework_iterator(config=None,
|
||||
frameworks=("tf2", "tf", "tfe", "torch"),
|
||||
session=False,
|
||||
with_eager_tracing=False):
|
||||
def framework_iterator(
|
||||
config: Optional[PartialTrainerConfigDict] = None,
|
||||
frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
|
||||
session: bool = False,
|
||||
with_eager_tracing: bool = False,
|
||||
time_iterations: Optional[dict] = None,
|
||||
) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
|
||||
"""An generator that allows for looping through n frameworks for testing.
|
||||
|
||||
Provides the correct config entries ("framework") as well
|
||||
as the correct eager/non-eager contexts for tfe/tf.
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): An optional config dict to alter in place
|
||||
depending on the iteration.
|
||||
frameworks (Tuple[str]): A list/tuple of the frameworks to be tested.
|
||||
config: An optional config dict to alter in place depending on the
|
||||
iteration.
|
||||
frameworks: A list/tuple of the frameworks to be tested.
|
||||
Allowed are: "tf2", "tf", "tfe", "torch", and None.
|
||||
session (bool): If True and only in the tf-case: Enter a tf.Session()
|
||||
session: If True and only in the tf-case: Enter a tf.Session()
|
||||
and yield that as second return value (otherwise yield (fw, None)).
|
||||
Also sets a seed (42) on the session to make the test
|
||||
deterministic.
|
||||
with_eager_tracing: Include `eager_tracing=True` in the returned
|
||||
configs, when framework=[tfe|tf2].
|
||||
time_iterations: If provided, will write to the given dict (by
|
||||
framework key) the times in seconds that each (framework's)
|
||||
iteration takes.
|
||||
|
||||
Yields:
|
||||
str: If enter_session is False:
|
||||
The current framework ("tf2", "tf", "tfe", "torch") used.
|
||||
Tuple(str, Union[None,tf.Session]: If enter_session is True:
|
||||
A tuple of the current fw and the tf.Session if fw="tf".
|
||||
If `session` is False: The current framework [tf2|tf|tfe|torch] used.
|
||||
If `session` is True: A tuple consisting of the current framework
|
||||
string and the tf1.Session (if fw="tf", otherwise None).
|
||||
"""
|
||||
config = config or {}
|
||||
frameworks = [frameworks] if isinstance(frameworks, str) else \
|
||||
|
@ -111,12 +117,24 @@ def framework_iterator(config=None,
|
|||
for tracing in [True, False]:
|
||||
config["eager_tracing"] = tracing
|
||||
print(f"framework={fw} (eager-tracing={tracing})")
|
||||
time_started = time.time()
|
||||
yield fw if session is False else (fw, sess)
|
||||
if time_iterations is not None:
|
||||
time_total = time.time() - time_started
|
||||
time_iterations[fw + ("+tracing" if tracing else "")] = \
|
||||
time_total
|
||||
print(f".. took {time_total}sec")
|
||||
config["eager_tracing"] = False
|
||||
# Yield current framework + tf-session (if necessary).
|
||||
else:
|
||||
print(f"framework={fw}")
|
||||
time_started = time.time()
|
||||
yield fw if session is False else (fw, sess)
|
||||
if time_iterations is not None:
|
||||
time_total = time.time() - time_started
|
||||
time_iterations[fw + ("+tracing" if tracing else "")] = \
|
||||
time_total
|
||||
print(f".. took {time_total}sec")
|
||||
|
||||
# Exit any context we may have entered.
|
||||
if eager_ctx:
|
||||
|
|
Loading…
Add table
Reference in a new issue