[rllib] Fix some eager execution regressions with 1.13 (#5537)

* fix bugs with 1.13

* allow disable
This commit is contained in:
Eric Liang 2019-08-26 23:23:35 -07:00 committed by GitHub
parent 948b1b09e8
commit 03a1b75852
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 7 deletions

View file

@ -72,6 +72,9 @@ COMMON_CONFIG = {
"log_sys_usage": True, "log_sys_usage": True,
# Enable TF eager execution (TF policies only) # Enable TF eager execution (TF policies only)
"eager": False, "eager": False,
# Disable eager execution on workers (but allow it on the driver). This
# only has an effect is eager is enabled.
"no_eager_on_workers": False,
# === Policy === # === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the # Arguments to pass to model. See models/catalog.py for a full list of the

View file

@ -239,7 +239,8 @@ class RolloutWorker(EvaluatorInterface):
_global_worker = self _global_worker = self
policy_config = policy_config or {} policy_config = policy_config or {}
if tf and policy_config.get("eager"): if (tf and policy_config.get("eager")
and not policy_config.get("no_eager_on_workers")):
tf.enable_eager_execution() tf.enable_eager_execution()
if log_level: if log_level:

View file

@ -65,7 +65,8 @@ class VisionNetwork(TFModelV2):
# Build the value layers # Build the value layers
if vf_share_layers: if vf_share_layers:
last_layer = tf.squeeze(last_layer, axis=[1, 2]) last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
value_out = tf.keras.layers.Dense( value_out = tf.keras.layers.Dense(
1, 1,
name="value_out", name="value_out",
@ -95,7 +96,8 @@ class VisionNetwork(TFModelV2):
activation=None, activation=None,
padding="same", padding="same",
name="conv_value_out")(last_layer) name="conv_value_out")(last_layer)
value_out = tf.squeeze(last_layer, axis=[1, 2]) value_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
self.base_model = tf.keras.Model(inputs, [conv_out, value_out]) self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
self.register_variables(self.base_model.variables) self.register_variables(self.base_model.variables)

View file

@ -204,13 +204,16 @@ def build_eager_tf_policy(name,
@override(Policy) @override(Policy)
def get_weights(self): def get_weights(self):
return tf.nest.map_structure(lambda var: var.numpy(), variables = self.model.variables()
self.model.variables()) return [v.numpy() for v in variables]
@override(Policy) @override(Policy)
def set_weights(self, weights): def set_weights(self, weights):
tf.nest.map_structure(lambda var, value: var.assign(value), variables = self.model.variables()
self.model.variables(), weights) assert len(weights) == len(variables), (len(weights),
len(variables))
for v, w in zip(variables, weights):
v.assign(w)
def is_recurrent(self): def is_recurrent(self):
return len(self._state_in) > 0 return len(self._state_in) > 0