ray/rllib/examples/models/cnn_plus_fc_concat_model.py

218 lines
8.1 KiB
Python

from gym.spaces import Discrete, Tuple
from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.misc import normc_initializer as \
torch_normc_initializer, SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
# __sphinx_doc_begin__
class CNNPlusFCConcatModel(TFModelV2):
"""TFModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
super().__init__(obs_space, action_space, num_outputs, model_config,
name)
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="tf",
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_and_value_model = None
self._value_out = None
if num_outputs:
# Action-distribution head.
concat_layer = tf.keras.layers.Input((concat_size, ))
logits_layer = tf.keras.layers.Dense(
num_outputs,
activation=tf.keras.activations.linear,
name="logits")(concat_layer)
# Create the value branch model.
value_layer = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(concat_layer)
self.logits_and_value_model = tf.keras.models.Model(
concat_layer, [logits_layer, value_layer])
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
if not self.logits_and_value_model:
return out, []
# Value branch.
logits, values = self.logits_and_value_model(out)
self._value_out = tf.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out
# __sphinx_doc_end__
class TorchCNNPlusFCConcatModel(TorchModelV2, nn.Module):
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
assert isinstance(obs_space.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
nn.Module.__init__(self)
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
self.cnn_type = self.model_config["custom_model_config"].get(
"conv_type", "atari")
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
concat_size = 0
for i, component in enumerate(obs_space.original_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config.get(
"conv_filters", get_filter_config(component.shape)),
"conv_activation": model_config.get("conv_activation"),
}
if self.cnn_type == "atari":
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="torch",
name="cnn_{}".format(i))
else:
cnn = TorchImpalaVisionNet(
component,
action_space,
num_outputs=None,
model_config=config,
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.add_module("cnn_{}".format(i), cnn)
# Discrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
# Everything else (1D Box).
else:
assert len(component.shape) == 1, \
"Only input Box 1D or 3D spaces allowed!"
concat_size += component.shape[-1]
self.logits_layer = None
self.value_layer = None
self._value_out = None
if num_outputs:
# Action-distribution head.
self.logits_layer = SlimFC(
in_size=concat_size,
out_size=num_outputs,
activation_fn=None,
)
# Create the value branch model.
self.value_layer = SlimFC(
in_size=concat_size,
out_size=1,
activation_fn=None,
initializer=torch_normc_initializer(0.01))
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(input_dict["obs"]):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({"obs": component})
outs.append(cnn_out)
else:
outs.append(component)
# Concat all outputs and the non-image inputs.
out = torch.cat(outs, dim=1)
if self.logits_layer is None:
return out, []
# Value branch.
logits, values = self.logits_layer(out), self.value_layer(out)
self._value_out = torch.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out