mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix some eager execution regressions with 1.13 (#5537)
* fix bugs with 1.13 * allow disable
This commit is contained in:
parent
948b1b09e8
commit
03a1b75852
4 changed files with 16 additions and 7 deletions
|
@ -72,6 +72,9 @@ COMMON_CONFIG = {
|
||||||
"log_sys_usage": True,
|
"log_sys_usage": True,
|
||||||
# Enable TF eager execution (TF policies only)
|
# Enable TF eager execution (TF policies only)
|
||||||
"eager": False,
|
"eager": False,
|
||||||
|
# Disable eager execution on workers (but allow it on the driver). This
|
||||||
|
# only has an effect is eager is enabled.
|
||||||
|
"no_eager_on_workers": False,
|
||||||
|
|
||||||
# === Policy ===
|
# === Policy ===
|
||||||
# Arguments to pass to model. See models/catalog.py for a full list of the
|
# Arguments to pass to model. See models/catalog.py for a full list of the
|
||||||
|
|
|
@ -239,7 +239,8 @@ class RolloutWorker(EvaluatorInterface):
|
||||||
_global_worker = self
|
_global_worker = self
|
||||||
|
|
||||||
policy_config = policy_config or {}
|
policy_config = policy_config or {}
|
||||||
if tf and policy_config.get("eager"):
|
if (tf and policy_config.get("eager")
|
||||||
|
and not policy_config.get("no_eager_on_workers")):
|
||||||
tf.enable_eager_execution()
|
tf.enable_eager_execution()
|
||||||
|
|
||||||
if log_level:
|
if log_level:
|
||||||
|
|
|
@ -65,7 +65,8 @@ class VisionNetwork(TFModelV2):
|
||||||
|
|
||||||
# Build the value layers
|
# Build the value layers
|
||||||
if vf_share_layers:
|
if vf_share_layers:
|
||||||
last_layer = tf.squeeze(last_layer, axis=[1, 2])
|
last_layer = tf.keras.layers.Lambda(
|
||||||
|
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||||
value_out = tf.keras.layers.Dense(
|
value_out = tf.keras.layers.Dense(
|
||||||
1,
|
1,
|
||||||
name="value_out",
|
name="value_out",
|
||||||
|
@ -95,7 +96,8 @@ class VisionNetwork(TFModelV2):
|
||||||
activation=None,
|
activation=None,
|
||||||
padding="same",
|
padding="same",
|
||||||
name="conv_value_out")(last_layer)
|
name="conv_value_out")(last_layer)
|
||||||
value_out = tf.squeeze(last_layer, axis=[1, 2])
|
value_out = tf.keras.layers.Lambda(
|
||||||
|
lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
|
||||||
|
|
||||||
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
|
||||||
self.register_variables(self.base_model.variables)
|
self.register_variables(self.base_model.variables)
|
||||||
|
|
|
@ -204,13 +204,16 @@ def build_eager_tf_policy(name,
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def get_weights(self):
|
def get_weights(self):
|
||||||
return tf.nest.map_structure(lambda var: var.numpy(),
|
variables = self.model.variables()
|
||||||
self.model.variables())
|
return [v.numpy() for v in variables]
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def set_weights(self, weights):
|
def set_weights(self, weights):
|
||||||
tf.nest.map_structure(lambda var, value: var.assign(value),
|
variables = self.model.variables()
|
||||||
self.model.variables(), weights)
|
assert len(weights) == len(variables), (len(weights),
|
||||||
|
len(variables))
|
||||||
|
for v, w in zip(variables, weights):
|
||||||
|
v.assign(w)
|
||||||
|
|
||||||
def is_recurrent(self):
|
def is_recurrent(self):
|
||||||
return len(self._state_in) > 0
|
return len(self._state_in) > 0
|
||||||
|
|
Loading…
Add table
Reference in a new issue