ray/rllib/agents/dqn/simple_q_model.py
Sven Mika 43043ee4d5
[RLlib] Tf2x preparation; part 2 (upgrading try_import_tf()). (#9136)
* WIP.

* Fixes.

* LINT.

* WIP.

* WIP.

* Fixes.

* Fixes.

* Fixes.

* Fixes.

* WIP.

* Fixes.

* Test

* Fix.

* Fixes and LINT.

* Fixes and LINT.

* LINT.
2020-06-30 10:13:20 +02:00

69 lines
2.4 KiB
Python

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class SimpleQModel(TFModelV2):
"""Extension of standard TFModel to provide Q values.
Data flow:
obs -> forward() -> model_out
model_out -> get_q_values() -> Q(s, a)
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
q_hiddens=(256, )):
"""Initialize variables of this model.
Extra model kwargs:
q_hiddens (list): defines size of hidden layers for the q head.
These will be used to postprocess the model output for the
purposes of computing Q values.
Note that the core layers for forward() are not defined here, this
only defines the layers for the Q head. Those layers for forward()
should be defined in subclasses of SimpleQModel.
"""
super(SimpleQModel, self).__init__(obs_space, action_space,
num_outputs, model_config, name)
# setup the Q head output (i.e., model for get_q_values)
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
if q_hiddens:
last_layer = self.model_out
for i, n in enumerate(q_hiddens):
last_layer = tf.keras.layers.Dense(
n, name="q_hidden_{}".format(i),
activation=tf.nn.relu)(last_layer)
q_out = tf.keras.layers.Dense(
action_space.n, activation=None, name="q_out")(last_layer)
else:
q_out = self.model_out
self.q_value_head = tf.keras.Model(self.model_out, q_out)
self.register_variables(self.q_value_head.variables)
def get_q_values(self, model_out):
"""Returns Q(s, a) given a feature tensor for the state.
Override this in your custom model to customize the Q output head.
Arguments:
model_out (Tensor): embedding from the model layers
Returns:
action scores Q(s, a) for each action, shape [None, action_space.n]
"""
return self.q_value_head(model_out)