From 03a1b758526b2699a21e44a932bb2abdfe636f2b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 26 Aug 2019 23:23:35 -0700 Subject: [PATCH] [rllib] Fix some eager execution regressions with 1.13 (#5537) * fix bugs with 1.13 * allow disable --- rllib/agents/trainer.py | 3 +++ rllib/evaluation/rollout_worker.py | 3 ++- rllib/models/tf/visionnet_v2.py | 6 ++++-- rllib/policy/eager_tf_policy.py | 11 +++++++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 7c224797e..2af10565f 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -72,6 +72,9 @@ COMMON_CONFIG = { "log_sys_usage": True, # Enable TF eager execution (TF policies only) "eager": False, + # Disable eager execution on workers (but allow it on the driver). This + # only has an effect is eager is enabled. + "no_eager_on_workers": False, # === Policy === # Arguments to pass to model. See models/catalog.py for a full list of the diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 870a4309e..e66e20700 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -239,7 +239,8 @@ class RolloutWorker(EvaluatorInterface): _global_worker = self policy_config = policy_config or {} - if tf and policy_config.get("eager"): + if (tf and policy_config.get("eager") + and not policy_config.get("no_eager_on_workers")): tf.enable_eager_execution() if log_level: diff --git a/rllib/models/tf/visionnet_v2.py b/rllib/models/tf/visionnet_v2.py index 730cbab54..21f01241e 100644 --- a/rllib/models/tf/visionnet_v2.py +++ b/rllib/models/tf/visionnet_v2.py @@ -65,7 +65,8 @@ class VisionNetwork(TFModelV2): # Build the value layers if vf_share_layers: - last_layer = tf.squeeze(last_layer, axis=[1, 2]) + last_layer = tf.keras.layers.Lambda( + lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) value_out = tf.keras.layers.Dense( 1, name="value_out", @@ -95,7 +96,8 @@ class VisionNetwork(TFModelV2): activation=None, padding="same", name="conv_value_out")(last_layer) - value_out = tf.squeeze(last_layer, axis=[1, 2]) + value_out = tf.keras.layers.Lambda( + lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) self.register_variables(self.base_model.variables) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index b69e3f0a8..105a95dee 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -204,13 +204,16 @@ def build_eager_tf_policy(name, @override(Policy) def get_weights(self): - return tf.nest.map_structure(lambda var: var.numpy(), - self.model.variables()) + variables = self.model.variables() + return [v.numpy() for v in variables] @override(Policy) def set_weights(self, weights): - tf.nest.map_structure(lambda var, value: var.assign(value), - self.model.variables(), weights) + variables = self.model.variables() + assert len(weights) == len(variables), (len(weights), + len(variables)) + for v, w in zip(variables, weights): + v.assign(w) def is_recurrent(self): return len(self._state_in) > 0