mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix some eager execution regressions with 1.13 (#5537)
* fix bugs with 1.13 * allow disable
This commit is contained in:
parent
948b1b09e8
commit
03a1b75852
4 changed files with 16 additions and 7 deletions
|
@ -72,6 +72,9 @@ COMMON_CONFIG = {
|
|||
"log_sys_usage": True,
|
||||
# Enable TF eager execution (TF policies only)
|
||||
"eager": False,
|
||||
# Disable eager execution on workers (but allow it on the driver). This
|
||||
# only has an effect is eager is enabled.
|
||||
"no_eager_on_workers": False,
|
||||
|
||||
# === Policy ===
|
||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||
|
|
|
@ -239,7 +239,8 @@ class RolloutWorker(EvaluatorInterface):
|
|||
_global_worker = self
|
||||
|
||||
policy_config = policy_config or {}
|
||||
if tf and policy_config.get("eager"):
|
||||
if (tf and policy_config.get("eager")
|
||||
and not policy_config.get("no_eager_on_workers")):
|
||||
tf.enable_eager_execution()
|
||||
|
||||
if log_level:
|
||||
|
|
|
@ -65,7 +65,8 @@ class VisionNetwork(TFModelV2):
|
|||
|
||||
# Build the value layers
|
||||
if vf_share_layers:
|
||||
last_layer = tf.squeeze(last_layer, axis=[1, 2])
|
||||
last_layer = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
value_out = tf.keras.layers.Dense(
|
||||
1,
|
||||
name="value_out",
|
||||
|
@ -95,7 +96,8 @@ class VisionNetwork(TFModelV2):
|
|||
activation=None,
|
||||
padding="same",
|
||||
name="conv_value_out")(last_layer)
|
||||
value_out = tf.squeeze(last_layer, axis=[1, 2])
|
||||
value_out = tf.keras.layers.Lambda(
|
||||
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||
|
||||
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
||||
self.register_variables(self.base_model.variables)
|
||||
|
|
|
@ -204,13 +204,16 @@ def build_eager_tf_policy(name,
|
|||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
return tf.nest.map_structure(lambda var: var.numpy(),
|
||||
self.model.variables())
|
||||
variables = self.model.variables()
|
||||
return [v.numpy() for v in variables]
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
tf.nest.map_structure(lambda var, value: var.assign(value),
|
||||
self.model.variables(), weights)
|
||||
variables = self.model.variables()
|
||||
assert len(weights) == len(variables), (len(weights),
|
||||
len(variables))
|
||||
for v, w in zip(variables, weights):
|
||||
v.assign(w)
|
||||
|
||||
def is_recurrent(self):
|
||||
return len(self._state_in) > 0
|
||||
|
|
Loading…
Add table
Reference in a new issue