2020-07-24 12:01:46 -07:00
|
|
|
import contextlib
|
|
|
|
import gym
|
|
|
|
from typing import List
|
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-07-24 12:01:46 -07:00
|
|
|
from ray.rllib.utils.types import ModelConfigDict, TensorType
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-07-03 15:59:47 -07:00
|
|
|
|
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
@PublicAPI
|
2019-07-03 15:59:47 -07:00
|
|
|
class TFModelV2(ModelV2):
|
2019-07-25 11:02:53 -07:00
|
|
|
"""TF version of ModelV2.
|
|
|
|
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
|
|
implement forward() in a subclass."""
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def __init__(self, obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space, num_outputs: int,
|
|
|
|
model_config: ModelConfigDict, name: str):
|
2019-07-27 02:08:16 -07:00
|
|
|
"""Initialize a TFModelV2.
|
|
|
|
|
|
|
|
Here is an example implementation for a subclass
|
|
|
|
``MyModelClass(TFModelV2)``::
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(MyModelClass, self).__init__(*args, **kwargs)
|
|
|
|
input_layer = tf.keras.layers.Input(...)
|
|
|
|
hidden_layer = tf.keras.layers.Dense(...)(input_layer)
|
|
|
|
output_layer = tf.keras.layers.Dense(...)(hidden_layer)
|
|
|
|
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
|
|
|
|
self.base_model = tf.keras.Model(
|
|
|
|
input_layer, [output_layer, value_layer])
|
|
|
|
self.register_variables(self.base_model.variables)
|
|
|
|
"""
|
|
|
|
|
2019-07-03 15:59:47 -07:00
|
|
|
ModelV2.__init__(
|
|
|
|
self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
2019-07-24 13:09:41 -07:00
|
|
|
num_outputs,
|
2019-07-03 15:59:47 -07:00
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
framework="tf")
|
2019-07-25 11:02:53 -07:00
|
|
|
self.var_list = []
|
2020-06-30 10:13:20 +02:00
|
|
|
if tf1.executing_eagerly():
|
2019-09-07 11:50:18 -07:00
|
|
|
self.graph = None
|
|
|
|
else:
|
2020-06-30 10:13:20 +02:00
|
|
|
self.graph = tf1.get_default_graph()
|
2019-09-07 11:50:18 -07:00
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def context(self) -> contextlib.AbstractContextManager:
|
2019-09-07 11:50:18 -07:00
|
|
|
"""Returns a contextmanager for the current TF graph."""
|
|
|
|
if self.graph:
|
|
|
|
return self.graph.as_default()
|
|
|
|
else:
|
|
|
|
return ModelV2.context(self)
|
2019-07-24 13:55:55 -07:00
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def update_ops(self) -> List[TensorType]:
|
2019-07-24 13:55:55 -07:00
|
|
|
"""Return the list of update ops for this model.
|
|
|
|
|
|
|
|
For example, this should include any BatchNorm update ops."""
|
|
|
|
return []
|
2019-07-25 11:02:53 -07:00
|
|
|
|
2020-07-24 12:01:46 -07:00
|
|
|
def register_variables(self, variables: List[TensorType]) -> None:
|
2019-07-25 11:02:53 -07:00
|
|
|
"""Register the given list of variables with this model."""
|
|
|
|
self.var_list.extend(variables)
|
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
@override(ModelV2)
|
2020-07-24 12:01:46 -07:00
|
|
|
def variables(self, as_dict: bool = False) -> List[TensorType]:
|
2020-04-06 20:56:16 +02:00
|
|
|
if as_dict:
|
|
|
|
return {v.name: v for v in self.var_list}
|
2019-07-25 11:02:53 -07:00
|
|
|
return list(self.var_list)
|
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
@override(ModelV2)
|
2020-07-24 12:01:46 -07:00
|
|
|
def trainable_variables(self, as_dict: bool = False) -> List[TensorType]:
|
2020-04-06 20:56:16 +02:00
|
|
|
if as_dict:
|
|
|
|
return {
|
|
|
|
k: v
|
|
|
|
for k, v in self.variables(as_dict=True).items() if v.trainable
|
|
|
|
}
|
2019-07-25 11:02:53 -07:00
|
|
|
return [v for v in self.variables() if v.trainable]
|