2019-07-03 15:59:47 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2019-07-25 11:02:53 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI
|
2019-07-03 15:59:47 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
|
|
|
|
|
|
|
|
2019-07-25 11:02:53 -07:00
|
|
|
@PublicAPI
|
2019-07-03 15:59:47 -07:00
|
|
|
class TFModelV2(ModelV2):
|
2019-07-25 11:02:53 -07:00
|
|
|
"""TF version of ModelV2.
|
|
|
|
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
|
|
implement forward() in a subclass."""
|
2019-07-03 15:59:47 -07:00
|
|
|
|
2019-07-24 13:09:41 -07:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
2019-07-03 15:59:47 -07:00
|
|
|
name):
|
|
|
|
ModelV2.__init__(
|
|
|
|
self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
2019-07-24 13:09:41 -07:00
|
|
|
num_outputs,
|
2019-07-03 15:59:47 -07:00
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
framework="tf")
|
2019-07-25 11:02:53 -07:00
|
|
|
self.var_list = []
|
2019-07-24 13:55:55 -07:00
|
|
|
|
|
|
|
def update_ops(self):
|
|
|
|
"""Return the list of update ops for this model.
|
|
|
|
|
|
|
|
For example, this should include any BatchNorm update ops."""
|
|
|
|
return []
|
2019-07-25 11:02:53 -07:00
|
|
|
|
|
|
|
def register_variables(self, variables):
|
|
|
|
"""Register the given list of variables with this model."""
|
|
|
|
self.var_list.extend(variables)
|
|
|
|
|
|
|
|
def variables(self):
|
|
|
|
"""Returns the list of variables for this model."""
|
|
|
|
return list(self.var_list)
|
|
|
|
|
|
|
|
def trainable_variables(self):
|
|
|
|
"""Returns the list of trainable variables for this model."""
|
|
|
|
return [v for v in self.variables() if v.trainable]
|