mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* WIP. * Fixes. * LINT. * WIP. * WIP. * Fixes. * Fixes. * Fixes. * Fixes. * WIP. * Fixes. * Test * Fix. * Fixes and LINT. * Fixes and LINT. * LINT.
162 lines
6.9 KiB
Python
162 lines
6.9 KiB
Python
import copy
|
|
import logging
|
|
import numpy as np
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.models.tf.misc import linear, normc_initializer
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.tf_ops import scope_vars
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def make_v1_wrapper(legacy_model_cls):
|
|
class ModelV1Wrapper(TFModelV2):
|
|
"""Wrapper that allows V1 models to be used as ModelV2."""
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
TFModelV2.__init__(self, obs_space, action_space, num_outputs,
|
|
model_config, name)
|
|
self.legacy_model_cls = legacy_model_cls
|
|
|
|
# Tracks the last v1 model created by the call to forward
|
|
self.cur_instance = None
|
|
|
|
# XXX: Try to guess the initial state size. Since the size of the
|
|
# state is known only after forward() for V1 models, it might be
|
|
# wrong.
|
|
if model_config.get("state_shape"):
|
|
self.initial_state = [
|
|
np.zeros(s, np.float32)
|
|
for s in model_config["state_shape"]
|
|
]
|
|
elif model_config.get("use_lstm"):
|
|
cell_size = model_config.get("lstm_cell_size", 256)
|
|
self.initial_state = [
|
|
np.zeros(cell_size, np.float32),
|
|
np.zeros(cell_size, np.float32),
|
|
]
|
|
else:
|
|
self.initial_state = []
|
|
|
|
# Tracks update ops
|
|
self._update_ops = None
|
|
|
|
with tf1.variable_scope(self.name) as scope:
|
|
self.variable_scope = scope
|
|
|
|
@override(ModelV2)
|
|
def get_initial_state(self):
|
|
return self.initial_state
|
|
|
|
@override(ModelV2)
|
|
def __call__(self, input_dict, state, seq_lens):
|
|
if self.cur_instance:
|
|
# create a weight-sharing model copy
|
|
with tf1.variable_scope(self.cur_instance.scope, reuse=True):
|
|
new_instance = self.legacy_model_cls(
|
|
input_dict, self.obs_space, self.action_space,
|
|
self.num_outputs, self.model_config, state, seq_lens)
|
|
else:
|
|
# create a new model instance
|
|
with tf1.variable_scope(self.name):
|
|
prev_update_ops = set(
|
|
tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
|
|
new_instance = self.legacy_model_cls(
|
|
input_dict, self.obs_space, self.action_space,
|
|
self.num_outputs, self.model_config, state, seq_lens)
|
|
self._update_ops = list(
|
|
set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS)) -
|
|
prev_update_ops)
|
|
if len(new_instance.state_init) != len(self.get_initial_state()):
|
|
raise ValueError(
|
|
"When using a custom recurrent ModelV1 model, you should "
|
|
"declare the state_shape in the model options. For "
|
|
"example, set 'state_shape': [256, 256] for a lstm with "
|
|
"cell size 256. The guessed state shape was {} which "
|
|
"appears to be incorrect.".format(
|
|
[s.shape[0] for s in self.get_initial_state()]))
|
|
self.cur_instance = new_instance
|
|
self.variable_scope = new_instance.scope
|
|
return new_instance.outputs, new_instance.state_out
|
|
|
|
@override(TFModelV2)
|
|
def update_ops(self):
|
|
if self._update_ops is None:
|
|
raise ValueError(
|
|
"Cannot get update ops before wrapped v1 model init")
|
|
return list(self._update_ops)
|
|
|
|
@override(TFModelV2)
|
|
def variables(self):
|
|
var_list = super(ModelV1Wrapper, self).variables()
|
|
for v in scope_vars(self.variable_scope):
|
|
if v not in var_list:
|
|
var_list.append(v)
|
|
return var_list
|
|
|
|
@override(ModelV2)
|
|
def custom_loss(self, policy_loss, loss_inputs):
|
|
return self.cur_instance.custom_loss(policy_loss, loss_inputs)
|
|
|
|
@override(ModelV2)
|
|
def metrics(self):
|
|
return self.cur_instance.custom_stats()
|
|
|
|
@override(ModelV2)
|
|
def value_function(self):
|
|
assert self.cur_instance is not None, "must call forward first"
|
|
|
|
with tf1.variable_scope(self.variable_scope):
|
|
with tf1.variable_scope(
|
|
"value_function", reuse=tf1.AUTO_REUSE):
|
|
# Simple case: sharing the feature layer
|
|
if self.model_config["vf_share_layers"]:
|
|
return tf.reshape(
|
|
linear(self.cur_instance.last_layer, 1,
|
|
"value_function", normc_initializer(1.0)),
|
|
[-1])
|
|
|
|
# Create a new separate model with no RNN state, etc.
|
|
branch_model_config = self.model_config.copy()
|
|
branch_model_config["free_log_std"] = False
|
|
obs_space_vf = self.obs_space
|
|
|
|
if branch_model_config["use_lstm"]:
|
|
branch_model_config["use_lstm"] = False
|
|
logger.warning(
|
|
"It is not recommended to use an LSTM model "
|
|
"with the `vf_share_layers=False` option. "
|
|
"If you want to use separate policy- and vf-"
|
|
"networks with LSTMs, you can implement a custom "
|
|
"LSTM model that overrides the value_function() "
|
|
"method. "
|
|
"NOTE: Your policy- and vf-NNs will use the same "
|
|
"shared LSTM!")
|
|
# Remove original space from obs-space not to trigger
|
|
# preprocessing (input to vf-NN is already vectorized
|
|
# LSTM output).
|
|
obs_space_vf = copy.copy(self.obs_space)
|
|
if hasattr(obs_space_vf, "original_space"):
|
|
delattr(obs_space_vf, "original_space")
|
|
|
|
branch_instance = self.legacy_model_cls(
|
|
self.cur_instance.input_dict,
|
|
obs_space_vf,
|
|
self.action_space,
|
|
1,
|
|
branch_model_config,
|
|
state_in=None,
|
|
seq_lens=None)
|
|
return tf.reshape(branch_instance.outputs, [-1])
|
|
|
|
@override(ModelV2)
|
|
def last_output(self):
|
|
return self.cur_instance.outputs
|
|
|
|
return ModelV1Wrapper
|