mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
import random
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class EagerModel(TFModelV2):
|
|
"""Example of using embedded eager execution in a custom model.
|
|
|
|
This shows how to use tf.py_function() to execute a snippet of TF code
|
|
in eager mode. Here the `self.forward_eager` method just prints out
|
|
the intermediate tensor for debug purposes, but you can in general
|
|
perform any TF eager operation in tf.py_function().
|
|
"""
|
|
|
|
def __init__(
|
|
self, observation_space, action_space, num_outputs, model_config, name
|
|
):
|
|
super().__init__(
|
|
observation_space, action_space, num_outputs, model_config, name
|
|
)
|
|
|
|
inputs = tf.keras.layers.Input(shape=observation_space.shape)
|
|
self.fcnet = FullyConnectedNetwork(
|
|
obs_space=self.obs_space,
|
|
action_space=self.action_space,
|
|
num_outputs=self.num_outputs,
|
|
model_config=self.model_config,
|
|
name="fc1",
|
|
)
|
|
out, value_out = self.fcnet.base_model(inputs)
|
|
|
|
def lambda_(x):
|
|
eager_out = tf.py_function(self.forward_eager, [x], tf.float32)
|
|
with tf1.control_dependencies([eager_out]):
|
|
eager_out.set_shape(x.shape)
|
|
return eager_out
|
|
|
|
out = tf.keras.layers.Lambda(lambda_)(out)
|
|
self.base_model = tf.keras.models.Model(inputs, [out, value_out])
|
|
|
|
@override(ModelV2)
|
|
def forward(self, input_dict, state, seq_lens):
|
|
out, self._value_out = self.base_model(input_dict["obs"], state, seq_lens)
|
|
return out, []
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
def forward_eager(self, feature_layer):
|
|
assert tf.executing_eagerly()
|
|
if random.random() > 0.99:
|
|
print(
|
|
"Eagerly printing the feature layer mean value",
|
|
tf.reduce_mean(feature_layer),
|
|
)
|
|
return feature_layer
|