ray/rllib/models/tf/tf_modelv2.py

83 lines
2.8 KiB
Python

import contextlib
import gym
from typing import List
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.types import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
@PublicAPI
class TFModelV2(ModelV2):
"""TF version of ModelV2.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name: str):
"""Initialize a TFModelV2.
Here is an example implementation for a subclass
``MyModelClass(TFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
input_layer = tf.keras.layers.Input(...)
hidden_layer = tf.keras.layers.Dense(...)(input_layer)
output_layer = tf.keras.layers.Dense(...)(hidden_layer)
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
self.base_model = tf.keras.Model(
input_layer, [output_layer, value_layer])
self.register_variables(self.base_model.variables)
"""
ModelV2.__init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
framework="tf")
self.var_list = []
if tf1.executing_eagerly():
self.graph = None
else:
self.graph = tf1.get_default_graph()
def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current TF graph."""
if self.graph:
return self.graph.as_default()
else:
return ModelV2.context(self)
def update_ops(self) -> List[TensorType]:
"""Return the list of update ops for this model.
For example, this should include any BatchNorm update ops."""
return []
def register_variables(self, variables: List[TensorType]) -> None:
"""Register the given list of variables with this model."""
self.var_list.extend(variables)
@override(ModelV2)
def variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {v.name: v for v in self.var_list}
return list(self.var_list)
@override(ModelV2)
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
if as_dict:
return {
k: v
for k, v in self.variables(as_dict=True).items() if v.trainable
}
return [v for v in self.variables() if v.trainable]