diff --git a/rllib/BUILD b/rllib/BUILD index 0335a24bb..eb65e8123 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -148,7 +148,7 @@ py_test( ) py_test( - name = "run_regression_tests_frozenlake_appo", + name = "learning_frozenlake_appo", main = "tests/run_regression_tests.py", tags = ["team:ml", "learning_tests", "learning_tests_discrete"], size = "large", diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 4c8a25924..7d2547a52 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -24,8 +24,8 @@ class TestA2C(unittest.TestCase): num_iterations = 1 # Test against all frameworks. - for _ in framework_iterator(config): - for env in ["PongDeterministic-v0"]: + for _ in framework_iterator(config, with_eager_tracing=True): + for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]: trainer = a3c.A2CTrainer(config=config, env=env) for i in range(num_iterations): results = trainer.train() diff --git a/rllib/agents/a3c/tests/test_a3c.py b/rllib/agents/a3c/tests/test_a3c.py index 5803c1ba1..e6b9f7a22 100644 --- a/rllib/agents/a3c/tests/test_a3c.py +++ b/rllib/agents/a3c/tests/test_a3c.py @@ -27,7 +27,7 @@ class TestA3C(unittest.TestCase): num_iterations = 1 # Test against all frameworks. - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): for env in ["CartPole-v1", "Pendulum-v1", "PongDeterministic-v0"]: print("env={}".format(env)) config["model"]["use_lstm"] = env == "CartPole-v1" diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 67b19bcb7..87c744ca7 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -66,7 +66,7 @@ class TestCQL(unittest.TestCase): num_iterations = 4 # Test for tf/torch frameworks. - for fw in framework_iterator(config): + for fw in framework_iterator(config, with_eager_tracing=True): trainer = cql.CQLTrainer(config=config) for i in range(num_iterations): results = trainer.train() diff --git a/rllib/agents/ddpg/tests/test_apex_ddpg.py b/rllib/agents/ddpg/tests/test_apex_ddpg.py index 53f44f3bf..f40799392 100644 --- a/rllib/agents/ddpg/tests/test_apex_ddpg.py +++ b/rllib/agents/ddpg/tests/test_apex_ddpg.py @@ -24,7 +24,7 @@ class TestApexDDPG(unittest.TestCase): config["learning_starts"] = 0 config["optimizer"]["num_replay_buffer_shards"] = 1 num_iterations = 1 - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): plain_config = config.copy() trainer = apex_ddpg.ApexDDPGTrainer( config=plain_config, env="Pendulum-v1") diff --git a/rllib/agents/ddpg/tests/test_ddpg.py b/rllib/agents/ddpg/tests/test_ddpg.py index fe2f0ac21..03025a29c 100644 --- a/rllib/agents/ddpg/tests/test_ddpg.py +++ b/rllib/agents/ddpg/tests/test_ddpg.py @@ -41,7 +41,7 @@ class TestDDPG(unittest.TestCase): num_iterations = 1 # Test against all frameworks. - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1") for i in range(num_iterations): results = trainer.train() diff --git a/rllib/agents/dqn/tests/test_apex_dqn.py b/rllib/agents/dqn/tests/test_apex_dqn.py index 46d475608..1f4f3f749 100644 --- a/rllib/agents/dqn/tests/test_apex_dqn.py +++ b/rllib/agents/dqn/tests/test_apex_dqn.py @@ -44,7 +44,7 @@ class TestApexDQN(unittest.TestCase): config["min_iter_time_s"] = 1 config["optimizer"]["num_replay_buffer_shards"] = 1 - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): plain_config = config.copy() trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0") diff --git a/rllib/agents/dqn/tests/test_simple_q.py b/rllib/agents/dqn/tests/test_simple_q.py index 299bf39f6..4277f530f 100644 --- a/rllib/agents/dqn/tests/test_simple_q.py +++ b/rllib/agents/dqn/tests/test_simple_q.py @@ -34,7 +34,7 @@ class TestSimpleQ(unittest.TestCase): num_iterations = 2 - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0") rw = trainer.workers.local_worker() for i in range(num_iterations): diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py index 13e22240e..1af4ab95e 100644 --- a/rllib/agents/impala/tests/test_impala.py +++ b/rllib/agents/impala/tests/test_impala.py @@ -30,7 +30,7 @@ class TestIMPALA(unittest.TestCase): num_iterations = 1 env = "CartPole-v0" - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): local_cfg = config.copy() for lstm in [False, True]: local_cfg["num_aggregation_workers"] = 0 if not lstm else 1 diff --git a/rllib/agents/ppo/tests/test_appo.py b/rllib/agents/ppo/tests/test_appo.py index 9e2950568..3ae6ae033 100644 --- a/rllib/agents/ppo/tests/test_appo.py +++ b/rllib/agents/ppo/tests/test_appo.py @@ -24,7 +24,7 @@ class TestAPPO(unittest.TestCase): config["num_workers"] = 1 num_iterations = 2 - for _ in framework_iterator(config): + for _ in framework_iterator(config, with_eager_tracing=True): print("w/o v-trace") _config = config.copy() _config["vtrace"] = False diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 9d9faa1b2..94242e4ab 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -106,7 +106,7 @@ class TestPPO(unittest.TestCase): config["compress_observations"] = True num_iterations = 2 - for fw in framework_iterator(config): + for fw in framework_iterator(config, with_eager_tracing=True): for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]: print("Env={}".format(env)) for lstm in [True, False]: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 81405f330..d3f668565 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -230,14 +230,23 @@ COMMON_CONFIG: TrainerConfigDict = { # === Deep Learning Framework Settings === # tf: TensorFlow (static-graph) - # tf2: TensorFlow 2.x (eager) - # tfe: TensorFlow eager + # tf2: TensorFlow 2.x (eager or traced, if eager_tracing=True) + # tfe: TensorFlow eager (or traced, if eager_tracing=True) # torch: PyTorch "framework": "tf", - # Enable tracing in eager mode. This greatly improves performance, but - # makes it slightly harder to debug since Python code won't be evaluated - # after the initial eager pass. Only possible if framework=tfe. + # Enable tracing in eager mode. This greatly improves performance + # (speedup ~2x), but makes it slightly harder to debug since Python + # code won't be evaluated after the initial eager pass. + # Only possible if framework=[tf2|tfe]. "eager_tracing": False, + # Maximum number of tf.function re-traces before a runtime error is raised. + # This is to prevent unnoticed retraces of methods inside the + # `..._eager_traced` Policy, which could slow down execution by a + # factor of 4, without the user noticing what the root cause for this + # slowdown could be. + # Only necessary for framework=[tf2|tfe]. + # Set to None to ignore the re-trace count and never throw an error. + "eager_max_retraces": 20, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into diff --git a/rllib/examples/env/random_env.py b/rllib/examples/env/random_env.py index ceeca2342..dabf2d483 100644 --- a/rllib/examples/env/random_env.py +++ b/rllib/examples/env/random_env.py @@ -10,8 +10,8 @@ class RandomEnv(gym.Env): Can be instantiated with arbitrary action-, observation-, and reward spaces. Observations and rewards are generated by simply sampling from the - observation/reward spaces. The probability of a `done=True` can be - configured as well. + observation/reward spaces. The probability of a `done=True` after each + action can be configured, as well as the max episode length. """ def __init__(self, config=None): @@ -26,8 +26,13 @@ class RandomEnv(gym.Env): "reward_space", gym.spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32)) # Chance that an episode ends at any step. + # Note that a max episode length can be specified via + # `max_episode_len`. self.p_done = config.get("p_done", 0.1) - # A max episode length. + # A max episode length. Even if the `p_done` sampling does not lead + # to a terminus, the episode will end after at most this many + # timesteps. + # Set to 0 or None for using no limit on the episode length. self.max_episode_len = config.get("max_episode_len", None) # Whether to check action bounds. self.check_action_bounds = config.get("check_action_bounds", False) @@ -49,11 +54,10 @@ class RandomEnv(gym.Env): self.steps += 1 done = False - # We are done as per our max-episode-len. - if self.max_episode_len is not None and \ - self.steps >= self.max_episode_len: + # We are `done` as per our max-episode-len. + if self.max_episode_len and self.steps >= self.max_episode_len: done = True - # Max not reached yet -> Sample done via p_done. + # Max episode length not reached yet -> Sample `done` via `p_done`. elif self.p_done > 0.0: done = bool( np.random.choice( diff --git a/rllib/examples/models/batch_norm_model.py b/rllib/examples/models/batch_norm_model.py index 70f4274f3..aec7fba6e 100644 --- a/rllib/examples/models/batch_norm_model.py +++ b/rllib/examples/models/batch_norm_model.py @@ -196,7 +196,8 @@ class TorchBatchNormModel(TorchModelV2, nn.Module): def forward(self, input_dict, state, seq_lens): # Set the correct train-mode for our hidden module (only important # b/c we have some batch-norm layers). - self._hidden_layers.train(mode=input_dict.get("is_training", False)) + self._hidden_layers.train( + mode=bool(input_dict.get("is_training", False))) self._hidden_out = self._hidden_layers(input_dict["obs"]) logits = self._logits(self._hidden_out) return logits, [] diff --git a/rllib/execution/multi_gpu_impl.py b/rllib/execution/multi_gpu_impl.py index c648c8b32..eb74f0c55 100644 --- a/rllib/execution/multi_gpu_impl.py +++ b/rllib/execution/multi_gpu_impl.py @@ -1,5 +1,11 @@ from ray.rllib.policy.dynamic_tf_policy import TFMultiGPUTowerStack from ray.rllib.utils.deprecation import deprecation_warning -deprecation_warning("LocalSyncParallelOptimizer", "TFMultiGPUTowerStack") +# Backward compatibility. +deprecation_warning( + old="ray.rllib.execution.multi_gpu_impl.LocalSyncParallelOptimizer", + new="ray.rllib.policy.dynamic_tf_policy.TFMultiGPUTowerStack", + error=False, +) +# Old name. LocalSyncParallelOptimizer = TFMultiGPUTowerStack diff --git a/rllib/execution/multi_gpu_learner.py b/rllib/execution/multi_gpu_learner.py index c5cd0e496..b9101fbbe 100644 --- a/rllib/execution/multi_gpu_learner.py +++ b/rllib/execution/multi_gpu_learner.py @@ -2,6 +2,12 @@ from ray.rllib.execution.multi_gpu_learner_thread import \ MultiGPULearnerThread, _MultiGPULoaderThread from ray.rllib.utils.deprecation import deprecation_warning -deprecation_warning("multi_gpu_learner.py", "multi_gpu_learner_thread.py") +# Backward compatibility. +deprecation_warning( + old="ray.rllib.execution.multi_gpu_learner.py", + new="ray.rllib.execution.multi_gpu_learner_thread.py", + error=False, +) +# Old names. TFMultiGPULearner = MultiGPULearnerThread _LoaderThread = _MultiGPULoaderThread diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 971fc952c..4e72cd0ea 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -243,6 +243,11 @@ class ModelV2: with self.context(): res = self.forward(restored, state or [], seq_lens) + if isinstance(input_dict, SampleBatch): + input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"} + input_dict.deleted_keys = restored.deleted_keys + input_dict.added_keys = restored.added_keys - {"obs_flat"} + if ((not isinstance(res, list) and not isinstance(res, tuple)) or len(res) != 2): raise ValueError( diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index 3f5ac3044..5856e7112 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -250,7 +250,7 @@ class DynamicTFPolicy(TFPolicy): True, (), name="is_exploring") # Placeholder for `is_training` flag. - self._input_dict.is_training = self._get_is_training_placeholder() + self._input_dict.set_training(self._get_is_training_placeholder()) # Multi-GPU towers do not need any action computing/exploration # graphs. @@ -464,7 +464,7 @@ class DynamicTFPolicy(TFPolicy): buffer_index: int = 0, ) -> int: # Set the is_training flag of the batch. - batch.is_training = True + batch.set_training(True) # Shortcut for 1 CPU only: Store batch in # `self._loaded_single_cpu_batch`. diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 562b0c1d3..bf262b81e 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -44,9 +44,11 @@ def _convert_to_tf(x, dtype=None): if x is not None: d = dtype - x = tf.nest.map_structure( + return tf.nest.map_structure( lambda f: _convert_to_tf(f, d) if isinstance(f, RepeatedValues) - else tf.convert_to_tensor(f, d) if f is not None else None, x) + else tf.convert_to_tensor(f, d) if + f is not None and not tf.is_tensor(f) else f, x) + return x @@ -101,32 +103,39 @@ def _disallow_var_creation(next_creator, **kw): "model initialization: {}".format(v.name)) -def traced_eager_policy(eager_policy_cls): - """Wrapper that enables tracing for all eager policy methods. +def check_too_many_retraces(obj): + """Asserts that a given number of re-traces is not breached.""" - This is enabled by the --trace / "eager_tracing" config.""" + def _func(self_, *args, **kwargs): + if self_.config.get("eager_max_retraces") is not None and \ + self_._re_trace_counter > self_.config["eager_max_retraces"]: + raise RuntimeError( + "Too many tf-eager re-traces detected! This could lead to" + " significant slow-downs (even slower than running in " + "tf-eager mode w/ `eager_tracing=False`). To switch off " + "these re-trace counting checks, set `eager_max_retraces`" + " in your config to None.") + return obj(self_, *args, **kwargs) + + return _func + + +def traced_eager_policy(eager_policy_cls): + """Wrapper class that enables tracing for all eager policy methods. + + This is enabled by the `--trace`/`eager_tracing=True` config when + framework=[tf2|tfe]. + """ class TracedEagerPolicy(eager_policy_cls): def __init__(self, *args, **kwargs): - self._traced_learn_on_batch = None - self._traced_compute_action_helper = False - self._traced_compute_gradients = None - self._traced_apply_gradients = None + self._traced_learn_on_batch_helper = False + self._traced_compute_actions_helper = False + self._traced_compute_gradients_helper = False + self._traced_apply_gradients_helper = False super(TracedEagerPolicy, self).__init__(*args, **kwargs) - @override(eager_policy_cls) - @convert_eager_inputs - @convert_eager_outputs - def _learn_on_batch_eager(self, samples): - - if self._traced_learn_on_batch is None: - self._traced_learn_on_batch = tf.function( - super(TracedEagerPolicy, self)._learn_on_batch_eager, - autograph=False, - experimental_relax_shapes=True) - - return self._traced_learn_on_batch(samples) - + @check_too_many_retraces @override(Policy) def compute_actions_from_input_dict( self, @@ -136,16 +145,20 @@ def traced_eager_policy(eager_policy_cls): episodes: Optional[List[Episode]] = None, **kwargs ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: + """Traced version of Policy.compute_actions_from_input_dict.""" - # Create a traced version of `self._compute_action_helper`. - if self._traced_compute_action_helper is False: - self._compute_action_helper = convert_eager_inputs( + # Create a traced version of `self._compute_actions_helper`. + if self._traced_compute_actions_helper is False and \ + not self._no_tracing: + self._compute_actions_helper = convert_eager_inputs( tf.function( - super(TracedEagerPolicy, self)._compute_action_helper, + super(TracedEagerPolicy, self)._compute_actions_helper, autograph=False, experimental_relax_shapes=True)) - self._traced_compute_action_helper = True + self._traced_compute_actions_helper = True + # Now that the helper method is traced, call super's + # apply_gradients (which will call the traced helper). return super(TracedEagerPolicy, self).\ compute_actions_from_input_dict( input_dict=input_dict, @@ -155,55 +168,67 @@ def traced_eager_policy(eager_policy_cls): **kwargs, ) + @check_too_many_retraces @override(eager_policy_cls) - @convert_eager_outputs - def _compute_gradients_eager(self, samples: SampleBatch) -> \ - ModelGradients: - """Traced version of EagerTFPolicy's `_compute_gradients_eager`. + def learn_on_batch(self, samples): + """Traced version of Policy.learn_on_batch.""" - Note that `samples` is already zero-padded and has the is_training - flag set. We convert this sample batch into tensors here and do the - actual computation. - - Args: - samples: The SampleBatch to compute gradients for. - - Returns: - The computed model gradients. - """ - # Have we traced the `_compute_gradients_eager` function yet? - if self._traced_compute_gradients is None: - self._traced_compute_gradients = convert_eager_inputs( + # Create a traced version of `self._learn_on_batch_helper`. + if self._traced_learn_on_batch_helper is False and \ + not self._no_tracing: + self._learn_on_batch_helper = convert_eager_inputs( tf.function( - super()._compute_gradients_eager, + super(TracedEagerPolicy, self)._learn_on_batch_helper, autograph=False, experimental_relax_shapes=True)) + self._traced_learn_on_batch_helper = True - # Call the only-once compiled traced function with the SampleBatch - # (will be converted to tensors beforehand). - return self._traced_compute_gradients(samples) + # Now that the helper method is traced, call super's + # apply_gradients (which will call the traced helper). + return super(TracedEagerPolicy, self).learn_on_batch(samples) + @check_too_many_retraces + @override(eager_policy_cls) + def compute_gradients(self, samples: SampleBatch) -> \ + ModelGradients: + """Traced version of Policy.compute_gradients.""" + + # Create a traced version of `self._compute_gradients_helper`. + if self._traced_compute_gradients_helper is False and \ + not self._no_tracing: + self._compute_gradients_helper = convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, + self)._compute_gradients_helper, + autograph=False, + experimental_relax_shapes=True)) + self._traced_compute_gradients_helper = True + + # Now that the helper method is traced, call super's + # apply_gradients (which will call the traced helper). + return super(TracedEagerPolicy, self).compute_gradients(samples) + + @check_too_many_retraces @override(Policy) - @convert_eager_outputs def apply_gradients(self, grads: ModelGradients) -> None: - """Traced version of EagerTFPolicy's `apply_gradients` method. + """Traced version of Policy.apply_gradients.""" - Args: - grads: The ModelGradients to apply using our optimizer(s). - """ - # Have we traced the `apply_gradients` function yet? - if self._traced_apply_gradients is None: - self._traced_apply_gradients = tf.function( - super().apply_gradients, - autograph=False, - experimental_relax_shapes=True) + # Create a traced version of `self._apply_gradients_helper`. + if self._traced_apply_gradients_helper is False and \ + not self._no_tracing: + self._apply_gradients_helper = convert_eager_inputs( + tf.function( + super(TracedEagerPolicy, self)._apply_gradients_helper, + autograph=False, + experimental_relax_shapes=True)) + self._traced_apply_gradients_helper = True - # Call the only-once compiled traced function with the grads - # (will be converted to tensors beforehand). - return self._traced_apply_gradients(grads) + # Now that the helper method is traced, call super's + # apply_gradients (which will call the traced helper). + return super(TracedEagerPolicy, self).apply_gradients(grads) - TracedEagerPolicy.__name__ = eager_policy_cls.__name__ - TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced" + TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced" return TracedEagerPolicy @@ -289,9 +314,20 @@ def build_eager_tf_policy( worker_idx if worker_idx > 0 else "local")) self._is_training = False - self._loss_initialized = False + # Only for `config.eager_tracing=True`: A counter to keep track of + # how many times an eager-traced method (e.g. + # `self._compute_actions_helper`) has been re-traced by tensorflow. + # We will raise an error if more than n re-tracings have been + # detected, since this would considerably slow down execution. + # The variable below should only get incremented during the + # tf.function trace operations, never when calling the already + # traced function after that. + self._re_trace_counter = 0 + + self._loss_initialized = False self._loss = loss_fn + self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) @@ -376,107 +412,6 @@ def build_eager_tf_policy( # Got to reset global_timestep again after fake run-throughs. self.global_timestep = 0 - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - assert tf.executing_eagerly() - # Call super's postprocess_trajectory first. - sample_batch = Policy.postprocess_trajectory(self, sample_batch) - if postprocess_fn: - return postprocess_fn(self, sample_batch, other_agent_batches, - episode) - return sample_batch - - @with_lock - @override(Policy) - def learn_on_batch(self, postprocessed_batch): - # Callback handling. - learn_stats = {} - self.callbacks.on_learn_on_batch( - policy=self, - train_batch=postprocessed_batch, - result=learn_stats) - - if not isinstance(postprocessed_batch, SampleBatch) or \ - not postprocessed_batch.zero_padded: - pad_batch_to_sequences_of_same_size( - postprocessed_batch, - max_seq_len=self._max_seq_len, - shuffle=False, - batch_divisibility_req=self.batch_divisibility_req, - view_requirements=self.view_requirements, - ) - - self._is_training = True - postprocessed_batch.is_training = True - stats = self._learn_on_batch_eager(postprocessed_batch) - stats.update({"custom_metrics": learn_stats}) - return stats - - @convert_eager_inputs - @convert_eager_outputs - def _learn_on_batch_eager(self, samples): - with tf.variable_creator_scope(_disallow_var_creation): - grads_and_vars, stats = self._compute_gradients(samples) - self._apply_gradients(grads_and_vars) - return stats - - @override(Policy) - def compute_gradients(self, samples): - pad_batch_to_sequences_of_same_size( - samples, - shuffle=False, - max_seq_len=self._max_seq_len, - batch_divisibility_req=self.batch_divisibility_req, - view_requirements=self.view_requirements, - ) - - self._is_training = True - samples.is_training = True - return self._compute_gradients_eager(samples) - - @convert_eager_inputs - @convert_eager_outputs - def _compute_gradients_eager(self, samples): - with tf.variable_creator_scope(_disallow_var_creation): - grads_and_vars, stats = self._compute_gradients(samples) - grads = [g for g, v in grads_and_vars] - return grads, stats - - @override(Policy) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - explore=None, - timestep=None, - **kwargs): - - # Create input dict to simply pass the entire call to - # self.compute_actions_from_input_dict(). - input_dict = SampleBatch( - { - SampleBatch.CUR_OBS: obs_batch, - }, - _is_training=tf.constant(False)) - if prev_action_batch is not None: - input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch - if prev_reward_batch is not None: - input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch - - return self.compute_actions_from_input_dict( - input_dict=input_dict, - explore=explore, - timestep=timestep, - episodes=episodes, - **kwargs, - ) - @override(Policy) def compute_actions_from_input_dict( self, @@ -502,7 +437,8 @@ def build_eager_tf_policy( # Pass lazy (eager) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) - input_dict.is_training = False + input_dict.set_training(False) + # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] @@ -514,7 +450,7 @@ def build_eager_tf_policy( self.exploration.before_compute_actions( timestep=timestep, explore=explore, tf_sess=self.get_session()) - ret = self._compute_action_helper( + ret = self._compute_actions_helper( input_dict, state_batches, # TODO: Passing episodes into a traced method does not work. @@ -525,9 +461,254 @@ def build_eager_tf_policy( self.global_timestep += int(tree.flatten(ret[0])[0].shape[0]) return convert_to_numpy(ret) + @override(Policy) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + explore=None, + timestep=None, + **kwargs): + + # Create input dict to simply pass the entire call to + # self.compute_actions_from_input_dict(). + input_dict = SampleBatch( + { + SampleBatch.CUR_OBS: obs_batch, + }, + _is_training=tf.constant(False)) + if state_batches is not None: + for s in enumerate(state_batches): + input_dict["state_in_{i}"] = s + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch + if info_batch is not None: + input_dict[SampleBatch.INFOS] = info_batch + + return self.compute_actions_from_input_dict( + input_dict=input_dict, + explore=explore, + timestep=timestep, + episodes=episodes, + **kwargs, + ) + @with_lock - def _compute_action_helper(self, input_dict, state_batches, episodes, - explore, timestep): + @override(Policy) + def compute_log_likelihoods( + self, + actions, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + actions_normalized=True, + ): + if action_sampler_fn and action_distribution_fn is None: + raise ValueError("Cannot compute log-prob/likelihood w/o an " + "`action_distribution_fn` and a provided " + "`action_sampler_fn`!") + + seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) + input_batch = SampleBatch( + { + SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch) + }, + _is_training=False) + if prev_action_batch is not None: + input_batch[SampleBatch.PREV_ACTIONS] = \ + tf.convert_to_tensor(prev_action_batch) + if prev_reward_batch is not None: + input_batch[SampleBatch.PREV_REWARDS] = \ + tf.convert_to_tensor(prev_reward_batch) + + # Exploration hook before each forward pass. + self.exploration.before_compute_actions(explore=False) + + # Action dist class and inputs are generated via custom function. + if action_distribution_fn: + dist_inputs, dist_class, _ = action_distribution_fn( + self, + self.model, + input_batch, + explore=False, + is_training=False) + # Default log-likelihood calculation. + else: + dist_inputs, _ = self.model(input_batch, state_batches, + seq_lens) + dist_class = self.dist_class + + action_dist = dist_class(dist_inputs, self.model) + + # Normalize actions if necessary. + if not actions_normalized and self.config["normalize_actions"]: + actions = normalize_action(actions, self.action_space_struct) + + log_likelihoods = action_dist.logp(actions) + + return log_likelihoods + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + assert tf.executing_eagerly() + # Call super's postprocess_trajectory first. + sample_batch = Policy.postprocess_trajectory(self, sample_batch) + if postprocess_fn: + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + return sample_batch + + @with_lock + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, + train_batch=postprocessed_batch, + result=learn_stats) + + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + max_seq_len=self._max_seq_len, + shuffle=False, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + stats = self._learn_on_batch_helper(postprocessed_batch) + stats.update({"custom_metrics": learn_stats}) + return convert_to_numpy(stats) + + @override(Policy) + def compute_gradients(self, postprocessed_batch: SampleBatch) -> \ + Tuple[ModelGradients, Dict[str, TensorType]]: + + pad_batch_to_sequences_of_same_size( + postprocessed_batch, + shuffle=False, + max_seq_len=self._max_seq_len, + batch_divisibility_req=self.batch_divisibility_req, + view_requirements=self.view_requirements, + ) + + self._is_training = True + self._lazy_tensor_dict(postprocessed_batch) + postprocessed_batch.set_training(True) + grads_and_vars, grads, stats = self._compute_gradients_helper( + postprocessed_batch) + return convert_to_numpy((grads, stats)) + + @override(Policy) + def apply_gradients(self, gradients: ModelGradients) -> None: + self._apply_gradients_helper( + list( + zip([(tf.convert_to_tensor(g) + if g is not None else None) for g in gradients], + self.model.trainable_variables()))) + + @override(Policy) + def get_weights(self, as_dict=False): + variables = self.variables() + if as_dict: + return {v.name: v.numpy() for v in variables} + return [v.numpy() for v in variables] + + @override(Policy) + def set_weights(self, weights): + variables = self.variables() + assert len(weights) == len(variables), (len(weights), + len(variables)) + for v, w in zip(variables, weights): + v.assign(w) + + @override(Policy) + def get_exploration_state(self): + return convert_to_numpy(self.exploration.get_state()) + + @override(Policy) + def is_recurrent(self): + return self._is_recurrent + + @override(Policy) + def num_state_tensors(self): + return len(self._state_inputs) + + @override(Policy) + def get_initial_state(self): + if hasattr(self, "model"): + return self.model.get_initial_state() + return [] + + @override(Policy) + def get_state(self): + state = super().get_state() + if self._optimizer and \ + len(self._optimizer.variables()) > 0: + state["_optimizer_variables"] = \ + self._optimizer.variables() + # Add exploration state. + state["_exploration_state"] = self.exploration.get_state() + return state + + @override(Policy) + def set_state(self, state): + state = state.copy() # shallow copy + # Set optimizer vars first. + optimizer_vars = state.get("_optimizer_variables", None) + if optimizer_vars and self._optimizer.variables(): + logger.warning( + "Cannot restore an optimizer's state for tf eager! Keras " + "is not able to save the v1.x optimizers (from " + "tf.compat.v1.train) since they aren't compatible with " + "checkpoints.") + for opt_var, value in zip(self._optimizer.variables(), + optimizer_vars): + opt_var.assign(value) + # Set exploration's state. + if hasattr(self, "exploration") and "_exploration_state" in state: + self.exploration.set_state(state=state["_exploration_state"]) + # Then the Policy's (NN) weights. + super().set_state(state) + + @override(Policy) + def export_checkpoint(self, export_dir): + raise NotImplementedError # TODO: implement this + + @override(Policy) + def export_model(self, export_dir): + raise NotImplementedError # TODO: implement this + + def variables(self): + """Return the list of all savable variables for this policy.""" + if isinstance(self.model, tf.keras.Model): + return self.model.variables + else: + return self.model.variables() + + def loss_initialized(self): + return self._loss_initialized + + @with_lock + def _compute_actions_helper(self, input_dict, state_batches, episodes, + explore, timestep): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 # Calculate RNN sequence lengths. batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0] @@ -612,166 +793,32 @@ def build_eager_tf_policy( return actions, state_out, extra_fetches - @with_lock - @override(Policy) - def compute_log_likelihoods( - self, - actions, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - actions_normalized=True, - ): - if action_sampler_fn and action_distribution_fn is None: - raise ValueError("Cannot compute log-prob/likelihood w/o an " - "`action_distribution_fn` and a provided " - "`action_sampler_fn`!") + def _learn_on_batch_helper(self, samples): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 - seq_lens = tf.ones(len(obs_batch), dtype=tf.int32) - input_batch = SampleBatch( - { - SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch) - }, - _is_training=False) - if prev_action_batch is not None: - input_batch[SampleBatch.PREV_ACTIONS] = \ - tf.convert_to_tensor(prev_action_batch) - if prev_reward_batch is not None: - input_batch[SampleBatch.PREV_REWARDS] = \ - tf.convert_to_tensor(prev_reward_batch) - - # Exploration hook before each forward pass. - self.exploration.before_compute_actions(explore=False) - - # Action dist class and inputs are generated via custom function. - if action_distribution_fn: - dist_inputs, dist_class, _ = action_distribution_fn( - self, - self.model, - input_batch, - explore=False, - is_training=False) - # Default log-likelihood calculation. - else: - dist_inputs, _ = self.model(input_batch, state_batches, - seq_lens) - dist_class = self.dist_class - - action_dist = dist_class(dist_inputs, self.model) - - # Normalize actions if necessary. - if not actions_normalized and self.config["normalize_actions"]: - actions = normalize_action(actions, self.action_space_struct) - - log_likelihoods = action_dist.logp(actions) - - return log_likelihoods - - @override(Policy) - def apply_gradients(self, gradients): - self._apply_gradients( - zip([(tf.convert_to_tensor(g) if g is not None else None) - for g in gradients], self.model.trainable_variables())) - - @override(Policy) - def get_exploration_state(self): - return convert_to_numpy(self.exploration.get_state()) - - @override(Policy) - def get_weights(self, as_dict=False): - variables = self.variables() - if as_dict: - return {v.name: v.numpy() for v in variables} - return [v.numpy() for v in variables] - - @override(Policy) - def set_weights(self, weights): - variables = self.variables() - assert len(weights) == len(variables), (len(weights), - len(variables)) - for v, w in zip(variables, weights): - v.assign(w) - - @override(Policy) - def get_state(self): - state = super().get_state() - if self._optimizer and \ - len(self._optimizer.variables()) > 0: - state["_optimizer_variables"] = \ - self._optimizer.variables() - # Add exploration state. - state["_exploration_state"] = self.exploration.get_state() - return state - - @override(Policy) - def set_state(self, state): - state = state.copy() # shallow copy - # Set optimizer vars first. - optimizer_vars = state.get("_optimizer_variables", None) - if optimizer_vars and self._optimizer.variables(): - logger.warning( - "Cannot restore an optimizer's state for tf eager! Keras " - "is not able to save the v1.x optimizers (from " - "tf.compat.v1.train) since they aren't compatible with " - "checkpoints.") - for opt_var, value in zip(self._optimizer.variables(), - optimizer_vars): - opt_var.assign(value) - # Set exploration's state. - if hasattr(self, "exploration") and "_exploration_state" in state: - self.exploration.set_state(state=state["_exploration_state"]) - # Then the Policy's (NN) weights. - super().set_state(state) - - def variables(self): - """Return the list of all savable variables for this policy.""" - if isinstance(self.model, tf.keras.Model): - return self.model.variables - else: - return self.model.variables() - - @override(Policy) - def is_recurrent(self): - return self._is_recurrent - - @override(Policy) - def num_state_tensors(self): - return len(self._state_inputs) - - @override(Policy) - def get_initial_state(self): - if hasattr(self, "model"): - return self.model.get_initial_state() - return [] - - def get_session(self): - return None # None implies eager - - def get_placeholder(self, ph): - raise ValueError( - "get_placeholder() is not allowed in eager mode. Try using " - "rllib.utils.tf_utils.make_tf_callable() to write " - "functions that work in both graph and eager mode.") - - def loss_initialized(self): - return self._loss_initialized - - @override(Policy) - def export_model(self, export_dir): - pass - - @override(Policy) - def export_checkpoint(self, export_dir): - pass + with tf.variable_creator_scope(_disallow_var_creation): + grads_and_vars, _, stats = self._compute_gradients_helper( + samples) + self._apply_gradients_helper(grads_and_vars) + return stats def _get_is_training_placeholder(self): return tf.convert_to_tensor(self._is_training) @with_lock - def _compute_gradients(self, samples): + def _compute_gradients_helper(self, samples): """Computes and returns grads as eager tensors.""" + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 + # Gather all variables for which to calculate losses. if isinstance(self.model, tf.keras.Model): variables = self.model.trainable_variables @@ -823,9 +870,15 @@ def build_eager_tf_policy( grads = [g for g, _ in grads_and_vars] stats = self._stats(self, samples, grads) - return grads_and_vars, stats + return grads_and_vars, grads, stats + + def _apply_gradients_helper(self, grads_and_vars): + # Increase the tracing counter to make sure we don't re-trace too + # often. If eager_tracing=True, this counter should only get + # incremented during the @tf.function trace operations, never when + # calling the already traced function after that. + self._re_trace_counter += 1 - def _apply_gradients(self, grads_and_vars): if apply_gradients_fn: if self.config["_tf_policy_handles_more_than_one_loss"]: apply_gradients_fn(self, self._optimizers, grads_and_vars) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index e96c5c9f3..7069a2401 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -764,6 +764,12 @@ class Policy(metaclass=ABCMeta): TensorType]]]): An optional stats function to be called after the loss. """ + # Signal Policy that currently we do not like to eager/jit trace + # any function calls. This is to be able to track, which columns + # in the dummy batch are accessed by the different function (e.g. + # loss) such that we can then adjust our view requirements. + self._no_tracing = True + sample_batch_size = max(self.batch_divisibility_req * 4, 32) self._dummy_batch = self._get_dummy_batch_from_view_requirements( sample_batch_size) @@ -771,6 +777,9 @@ class Policy(metaclass=ABCMeta): actions, state_outs, extra_outs = \ self.compute_actions_from_input_dict( self._dummy_batch, explore=False) + for key, view_req in self.view_requirements.items(): + if key not in self._dummy_batch.accessed_keys: + view_req.used_for_compute_actions = False # Add all extra action outputs to view reqirements (these may be # filtered out later again, if not needed for postprocessing or loss). for key, value in extra_outs.items(): @@ -806,7 +815,7 @@ class Policy(metaclass=ABCMeta): # Switch on lazy to-tensor conversion on `postprocessed_batch`. train_batch = self._lazy_tensor_dict(postprocessed_batch) # Calling loss, so set `is_training` to True. - train_batch.is_training = True + train_batch.set_training(True) if seq_lens is not None: train_batch[SampleBatch.SEQ_LENS] = seq_lens train_batch.count = self._dummy_batch.count @@ -817,6 +826,9 @@ class Policy(metaclass=ABCMeta): if stats_fn is not None: stats_fn(self, train_batch) + # Re-enable tracing. + self._no_tracing = False + # Add new columns automatically to view-reqs. if auto_remove_unneeded_view_reqs: # Add those needed for postprocessing and training. diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 1e89151ab..61f702e51 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -46,21 +46,22 @@ def pad_batch_to_sequences_of_same_size( Padding depends on episodes found in batch and `max_seq_len`. Args: - batch (SampleBatch): The SampleBatch object. All values in here have + batch: The SampleBatch object. All values in here have the shape [B, ...]. - max_seq_len (int): The max. sequence length to use for chopping. - shuffle (bool): Whether to shuffle batch sequences. Shuffle may + max_seq_len: The max. sequence length to use for chopping. + shuffle: Whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. - batch_divisibility_req (int): The int by which the batch dimension + batch_divisibility_req: The int by which the batch dimension must be dividable. - feature_keys (Optional[List[str]]): An optional list of keys to apply - sequence-chopping to. If None, use all keys in batch that are not + feature_keys: An optional list of keys to apply sequence-chopping + to. If None, use all keys in batch that are not "state_in/out_"-type keys. - view_requirements (Optional[ViewRequirementsDict]): An optional - Policy ViewRequirements dict to be able to infer whether - e.g. dynamic max'ing should be applied over the seq_lens. + view_requirements: An optional Policy ViewRequirements dict to + be able to infer whether e.g. dynamic max'ing should be + applied over the seq_lens. """ + # If already zero-padded, skip. if batch.zero_padded: return diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index 07b108cc5..ef977b0d6 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -91,7 +91,7 @@ class SampleBatch(dict): # Is alredy right-zero-padded? self.zero_padded = kwargs.pop("_zero_padded", False) # Whether this batch is used for training (vs inference). - self.is_training = kwargs.pop("_is_training", None) + self._is_training = kwargs.pop("_is_training", None) # Call super constructor. This will make the actual data accessible # by column name (str) via e.g. self["some-col"]. @@ -118,8 +118,8 @@ class SampleBatch(dict): len(seq_lens_) > 0: self.max_seq_len = max(seq_lens_) - if self.is_training is None: - self.is_training = self.pop("is_training", False) + if self._is_training is None: + self._is_training = self.pop("is_training", False) lengths = [] copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS} @@ -271,6 +271,9 @@ class SampleBatch(dict): ) copy_ = SampleBatch(data) copy_.set_get_interceptor(self.get_interceptor) + copy_.added_keys = self.added_keys + copy_.deleted_keys = self.deleted_keys + copy_.accessed_keys = self.accessed_keys return copy_ @PublicAPI @@ -501,6 +504,7 @@ class SampleBatch(dict): return SampleBatch( data, seq_lens=seq_lens, + _is_training=self.is_training, _time_major=self.time_major, ) else: @@ -731,7 +735,7 @@ class SampleBatch(dict): old="SampleBatch['is_training']", new="SampleBatch.is_training", error=False) - self.is_training = item + self._is_training = item return if key not in self: @@ -741,6 +745,20 @@ class SampleBatch(dict): if key in self.intercepted_values: self.intercepted_values[key] = item + @property + def is_training(self): + if self.get_interceptor is not None and \ + isinstance(self._is_training, bool): + if "_is_training" not in self.intercepted_values: + self.intercepted_values["_is_training"] = \ + self.get_interceptor(self._is_training) + return self.intercepted_values["_is_training"] + return self._is_training + + def set_training(self, training: Union[bool, "tf1.placeholder"] = True): + self._is_training = training + self.intercepted_values.pop("_is_training", None) + @PublicAPI def __delitem__(self, key): self.deleted_keys.add(key) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 4c7f12bc1..efea97a1a 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -403,7 +403,7 @@ class TFPolicy(Policy): assert self.loss_initialized() # Switch on is_training flag in our batch. - postprocessed_batch.is_training = True + postprocessed_batch.set_training(True) builder = TFRunBuilder(self.get_session(), "learn_on_batch") @@ -425,7 +425,7 @@ class TFPolicy(Policy): Tuple[ModelGradients, Dict[str, TensorType]]: assert self.loss_initialized() # Switch on is_training flag in our batch. - postprocessed_batch.is_training = True + postprocessed_batch.set_training(True) builder = TFRunBuilder(self.get_session(), "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @@ -1124,7 +1124,7 @@ class TFPolicy(Policy): # Mark the batch as "is_training" so the Model can use this # information. - train_batch.is_training = True + train_batch.set_training(True) # Build the feed dict from the batch. feed_dict = {} diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 618069997..6fa6ab4fb 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -256,7 +256,7 @@ class TorchPolicy(Policy): with torch.no_grad(): # Pass lazy (torch) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) - input_dict.is_training = False + input_dict.set_training(True) # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] @@ -424,7 +424,7 @@ class TorchPolicy(Policy): buffer_index: int = 0, ) -> int: # Set the is_training flag of the batch. - batch.is_training = True + batch.set_training(True) # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`. if len(self.devices) == 1 and self.devices[0].type == "cpu": @@ -572,7 +572,7 @@ class TorchPolicy(Policy): view_requirements=self.view_requirements, ) - postprocessed_batch.is_training = True + postprocessed_batch.set_training(True) self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0]) # Do the (maybe parallelized) gradient calculation step. diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index f440da06d..3ee6a8737 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -7,12 +7,13 @@ import random import re import time import tree # pip install dm_tree -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml import ray from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch +from ray.rllib.utils.typing import PartialTrainerConfigDict from ray.tune import run_experiments jax, _ = try_import_jax() @@ -29,32 +30,37 @@ torch, _ = try_import_torch() logger = logging.getLogger(__name__) -def framework_iterator(config=None, - frameworks=("tf2", "tf", "tfe", "torch"), - session=False, - with_eager_tracing=False): +def framework_iterator( + config: Optional[PartialTrainerConfigDict] = None, + frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"), + session: bool = False, + with_eager_tracing: bool = False, + time_iterations: Optional[dict] = None, +) -> Union[str, Tuple[str, Optional["tf1.Session"]]]: """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well as the correct eager/non-eager contexts for tfe/tf. Args: - config (Optional[dict]): An optional config dict to alter in place - depending on the iteration. - frameworks (Tuple[str]): A list/tuple of the frameworks to be tested. + config: An optional config dict to alter in place depending on the + iteration. + frameworks: A list/tuple of the frameworks to be tested. Allowed are: "tf2", "tf", "tfe", "torch", and None. - session (bool): If True and only in the tf-case: Enter a tf.Session() + session: If True and only in the tf-case: Enter a tf.Session() and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. with_eager_tracing: Include `eager_tracing=True` in the returned configs, when framework=[tfe|tf2]. + time_iterations: If provided, will write to the given dict (by + framework key) the times in seconds that each (framework's) + iteration takes. Yields: - str: If enter_session is False: - The current framework ("tf2", "tf", "tfe", "torch") used. - Tuple(str, Union[None,tf.Session]: If enter_session is True: - A tuple of the current fw and the tf.Session if fw="tf". + If `session` is False: The current framework [tf2|tf|tfe|torch] used. + If `session` is True: A tuple consisting of the current framework + string and the tf1.Session (if fw="tf", otherwise None). """ config = config or {} frameworks = [frameworks] if isinstance(frameworks, str) else \ @@ -111,12 +117,24 @@ def framework_iterator(config=None, for tracing in [True, False]: config["eager_tracing"] = tracing print(f"framework={fw} (eager-tracing={tracing})") + time_started = time.time() yield fw if session is False else (fw, sess) + if time_iterations is not None: + time_total = time.time() - time_started + time_iterations[fw + ("+tracing" if tracing else "")] = \ + time_total + print(f".. took {time_total}sec") config["eager_tracing"] = False # Yield current framework + tf-session (if necessary). else: print(f"framework={fw}") + time_started = time.time() yield fw if session is False else (fw, sess) + if time_iterations is not None: + time_total = time.time() - time_started + time_iterations[fw + ("+tracing" if tracing else "")] = \ + time_total + print(f".. took {time_total}sec") # Exit any context we may have entered. if eager_ctx: