ray/rllib/models/tf/modelv1_compat.py
Sven Mika 43043ee4d5
[RLlib] Tf2x preparation; part 2 (upgrading try_import_tf()). (#9136)
* WIP.

* Fixes.

* LINT.

* WIP.

* WIP.

* Fixes.

* Fixes.

* Fixes.

* Fixes.

* WIP.

* Fixes.

* Test

* Fix.

* Fixes and LINT.

* Fixes and LINT.

* LINT.
2020-06-30 10:13:20 +02:00

162 lines
6.9 KiB
Python

import copy
import logging
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import linear, normc_initializer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import scope_vars
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
def make_v1_wrapper(legacy_model_cls):
class ModelV1Wrapper(TFModelV2):
"""Wrapper that allows V1 models to be used as ModelV2."""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
TFModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
self.legacy_model_cls = legacy_model_cls
# Tracks the last v1 model created by the call to forward
self.cur_instance = None
# XXX: Try to guess the initial state size. Since the size of the
# state is known only after forward() for V1 models, it might be
# wrong.
if model_config.get("state_shape"):
self.initial_state = [
np.zeros(s, np.float32)
for s in model_config["state_shape"]
]
elif model_config.get("use_lstm"):
cell_size = model_config.get("lstm_cell_size", 256)
self.initial_state = [
np.zeros(cell_size, np.float32),
np.zeros(cell_size, np.float32),
]
else:
self.initial_state = []
# Tracks update ops
self._update_ops = None
with tf1.variable_scope(self.name) as scope:
self.variable_scope = scope
@override(ModelV2)
def get_initial_state(self):
return self.initial_state
@override(ModelV2)
def __call__(self, input_dict, state, seq_lens):
if self.cur_instance:
# create a weight-sharing model copy
with tf1.variable_scope(self.cur_instance.scope, reuse=True):
new_instance = self.legacy_model_cls(
input_dict, self.obs_space, self.action_space,
self.num_outputs, self.model_config, state, seq_lens)
else:
# create a new model instance
with tf1.variable_scope(self.name):
prev_update_ops = set(
tf1.get_collection(tf1.GraphKeys.UPDATE_OPS))
new_instance = self.legacy_model_cls(
input_dict, self.obs_space, self.action_space,
self.num_outputs, self.model_config, state, seq_lens)
self._update_ops = list(
set(tf1.get_collection(tf1.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
if len(new_instance.state_init) != len(self.get_initial_state()):
raise ValueError(
"When using a custom recurrent ModelV1 model, you should "
"declare the state_shape in the model options. For "
"example, set 'state_shape': [256, 256] for a lstm with "
"cell size 256. The guessed state shape was {} which "
"appears to be incorrect.".format(
[s.shape[0] for s in self.get_initial_state()]))
self.cur_instance = new_instance
self.variable_scope = new_instance.scope
return new_instance.outputs, new_instance.state_out
@override(TFModelV2)
def update_ops(self):
if self._update_ops is None:
raise ValueError(
"Cannot get update ops before wrapped v1 model init")
return list(self._update_ops)
@override(TFModelV2)
def variables(self):
var_list = super(ModelV1Wrapper, self).variables()
for v in scope_vars(self.variable_scope):
if v not in var_list:
var_list.append(v)
return var_list
@override(ModelV2)
def custom_loss(self, policy_loss, loss_inputs):
return self.cur_instance.custom_loss(policy_loss, loss_inputs)
@override(ModelV2)
def metrics(self):
return self.cur_instance.custom_stats()
@override(ModelV2)
def value_function(self):
assert self.cur_instance is not None, "must call forward first"
with tf1.variable_scope(self.variable_scope):
with tf1.variable_scope(
"value_function", reuse=tf1.AUTO_REUSE):
# Simple case: sharing the feature layer
if self.model_config["vf_share_layers"]:
return tf.reshape(
linear(self.cur_instance.last_layer, 1,
"value_function", normc_initializer(1.0)),
[-1])
# Create a new separate model with no RNN state, etc.
branch_model_config = self.model_config.copy()
branch_model_config["free_log_std"] = False
obs_space_vf = self.obs_space
if branch_model_config["use_lstm"]:
branch_model_config["use_lstm"] = False
logger.warning(
"It is not recommended to use an LSTM model "
"with the `vf_share_layers=False` option. "
"If you want to use separate policy- and vf-"
"networks with LSTMs, you can implement a custom "
"LSTM model that overrides the value_function() "
"method. "
"NOTE: Your policy- and vf-NNs will use the same "
"shared LSTM!")
# Remove original space from obs-space not to trigger
# preprocessing (input to vf-NN is already vectorized
# LSTM output).
obs_space_vf = copy.copy(self.obs_space)
if hasattr(obs_space_vf, "original_space"):
delattr(obs_space_vf, "original_space")
branch_instance = self.legacy_model_cls(
self.cur_instance.input_dict,
obs_space_vf,
self.action_space,
1,
branch_model_config,
state_in=None,
seq_lens=None)
return tf.reshape(branch_instance.outputs, [-1])
@override(ModelV2)
def last_output(self):
return self.cur_instance.outputs
return ModelV1Wrapper