mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
import numpy as np
|
|
import pickle
|
|
|
|
import ray
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.misc import normc_initializer
|
|
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class SpyLayer(tf.keras.layers.Layer):
|
|
"""A keras Layer, which intercepts its inputs and stored them as pickled.
|
|
"""
|
|
|
|
output = np.array(0, dtype=np.int64)
|
|
|
|
def __init__(self, num_outputs, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self.dense = tf.keras.layers.Dense(
|
|
units=num_outputs, kernel_initializer=normc_initializer(0.01))
|
|
|
|
def call(self, inputs, **kwargs):
|
|
"""Does a forward pass through our Dense, but also intercepts inputs.
|
|
"""
|
|
|
|
del kwargs
|
|
spy_fn = tf1.py_func(
|
|
self.spy,
|
|
[
|
|
inputs[0], # observations
|
|
inputs[2], # seq_lens
|
|
inputs[3], # h_in
|
|
inputs[4], # c_in
|
|
inputs[5], # h_out
|
|
inputs[6], # c_out
|
|
],
|
|
tf.int64, # Must match SpyLayer.output's type.
|
|
stateful=True)
|
|
|
|
# Compute outputs
|
|
with tf1.control_dependencies([spy_fn]):
|
|
return self.dense(inputs[1])
|
|
|
|
@staticmethod
|
|
def spy(inputs, seq_lens, h_in, c_in, h_out, c_out):
|
|
"""The actual spy operation: Store inputs in internal_kv."""
|
|
|
|
if len(inputs) == 1:
|
|
# don't capture inference inputs
|
|
return SpyLayer.output
|
|
# TF runs this function in an isolated context, so we have to use
|
|
# redis to communicate back to our suite
|
|
ray.experimental.internal_kv._internal_kv_put(
|
|
"rnn_spy_in_{}".format(RNNSpyModel.capture_index),
|
|
pickle.dumps({
|
|
"sequences": inputs,
|
|
"seq_lens": seq_lens,
|
|
"state_in": [h_in, c_in],
|
|
"state_out": [h_out, c_out]
|
|
}),
|
|
overwrite=True)
|
|
RNNSpyModel.capture_index += 1
|
|
return SpyLayer.output
|
|
|
|
|
|
class RNNSpyModel(RecurrentNetwork):
|
|
capture_index = 0
|
|
cell_size = 3
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
super().__init__(obs_space, action_space, num_outputs, model_config,
|
|
name)
|
|
self.cell_size = RNNSpyModel.cell_size
|
|
|
|
# Create a keras LSTM model.
|
|
inputs = tf.keras.layers.Input(
|
|
shape=(None, ) + obs_space.shape, name="input")
|
|
state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h")
|
|
state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c")
|
|
seq_lens = tf.keras.layers.Input(
|
|
shape=(), name="seq_lens", dtype=tf.int32)
|
|
|
|
lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM(
|
|
self.cell_size,
|
|
return_sequences=True,
|
|
return_state=True,
|
|
name="lstm")(
|
|
inputs=inputs,
|
|
mask=tf.sequence_mask(seq_lens),
|
|
initial_state=[state_in_h, state_in_c])
|
|
|
|
logits = SpyLayer(num_outputs=self.num_outputs)([
|
|
inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h,
|
|
state_out_c
|
|
])
|
|
|
|
# Value branch.
|
|
value_out = tf.keras.layers.Dense(
|
|
units=1, kernel_initializer=normc_initializer(1.0))(lstm_out)
|
|
|
|
self.base_model = tf.keras.Model(
|
|
[inputs, seq_lens, state_in_h, state_in_c],
|
|
[logits, value_out, state_out_h, state_out_c])
|
|
self.base_model.summary()
|
|
|
|
@override(RecurrentNetwork)
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
# Previously, a new class object was created during
|
|
# deserialization and this `capture_index`
|
|
# variable would be refreshed between class instantiations.
|
|
# This behavior is no longer the case, so we manually refresh
|
|
# the variable.
|
|
RNNSpyModel.capture_index = 0
|
|
model_out, value_out, h, c = self.base_model(
|
|
[inputs, seq_lens, state[0], state[1]])
|
|
self._value_out = value_out
|
|
return model_out, [h, c]
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
@override(ModelV2)
|
|
def get_initial_state(self):
|
|
return [
|
|
np.zeros(self.cell_size, np.float32),
|
|
np.zeros(self.cell_size, np.float32)
|
|
]
|