mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

* WIP. * Fixes. * LINT. * WIP. * WIP. * Fixes. * Fixes. * Fixes. * Fixes. * WIP. * Fixes. * Test * Fix. * Fixes and LINT. * Fixes and LINT. * LINT.
69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class SimpleQModel(TFModelV2):
|
|
"""Extension of standard TFModel to provide Q values.
|
|
|
|
Data flow:
|
|
obs -> forward() -> model_out
|
|
model_out -> get_q_values() -> Q(s, a)
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
implement forward() in a subclass."""
|
|
|
|
def __init__(self,
|
|
obs_space,
|
|
action_space,
|
|
num_outputs,
|
|
model_config,
|
|
name,
|
|
q_hiddens=(256, )):
|
|
"""Initialize variables of this model.
|
|
|
|
Extra model kwargs:
|
|
q_hiddens (list): defines size of hidden layers for the q head.
|
|
These will be used to postprocess the model output for the
|
|
purposes of computing Q values.
|
|
|
|
Note that the core layers for forward() are not defined here, this
|
|
only defines the layers for the Q head. Those layers for forward()
|
|
should be defined in subclasses of SimpleQModel.
|
|
"""
|
|
|
|
super(SimpleQModel, self).__init__(obs_space, action_space,
|
|
num_outputs, model_config, name)
|
|
|
|
# setup the Q head output (i.e., model for get_q_values)
|
|
self.model_out = tf.keras.layers.Input(
|
|
shape=(num_outputs, ), name="model_out")
|
|
|
|
if q_hiddens:
|
|
last_layer = self.model_out
|
|
for i, n in enumerate(q_hiddens):
|
|
last_layer = tf.keras.layers.Dense(
|
|
n, name="q_hidden_{}".format(i),
|
|
activation=tf.nn.relu)(last_layer)
|
|
q_out = tf.keras.layers.Dense(
|
|
action_space.n, activation=None, name="q_out")(last_layer)
|
|
else:
|
|
q_out = self.model_out
|
|
|
|
self.q_value_head = tf.keras.Model(self.model_out, q_out)
|
|
self.register_variables(self.q_value_head.variables)
|
|
|
|
def get_q_values(self, model_out):
|
|
"""Returns Q(s, a) given a feature tensor for the state.
|
|
|
|
Override this in your custom model to customize the Q output head.
|
|
|
|
Arguments:
|
|
model_out (Tensor): embedding from the model layers
|
|
|
|
Returns:
|
|
action scores Q(s, a) for each action, shape [None, action_space.n]
|
|
"""
|
|
|
|
return self.q_value_head(model_out)
|