ray/rllib/models/torch/complex_input_net.py

175 lines
6.9 KiB
Python

from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import tree # pip install dm_tree
# TODO (sven): add IMPALA-style option.
# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet
from ray.rllib.models.torch.misc import normc_initializer as \
torch_normc_initializer, SlimFC
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.utils import get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.torch_ops import one_hot
torch, nn = try_import_torch()
class ComplexInputNetwork(TorchModelV2, nn.Module):
"""TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s).
Note: This model should be used for complex (Dict or Tuple) observation
spaces that have one or more image components.
The data flow is as follows:
`obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT`
`CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out`
`out` -> (optional) FC-stack -> `out2`
`out2` -> action (logits) and vaulue heads.
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
self.original_space = obs_space.original_space if \
hasattr(obs_space, "original_space") else obs_space
assert isinstance(self.original_space, (Dict, Tuple)), \
"`obs_space.original_space` must be [Dict|Tuple]!"
nn.Module.__init__(self)
TorchModelV2.__init__(self, self.original_space, action_space,
num_outputs, model_config, name)
self.flattened_input_space = flatten_space(self.original_space)
# Atari type CNNs or IMPALA type CNNs (with residual layers)?
# self.cnn_type = self.model_config["custom_model_config"].get(
# "conv_type", "atari")
# Build the CNN(s) given obs_space's image components.
self.cnns = {}
self.one_hot = {}
self.flatten = {}
concat_size = 0
for i, component in enumerate(self.flattened_input_space):
# Image space.
if len(component.shape) == 3:
config = {
"conv_filters": model_config["conv_filters"]
if "conv_filters" in model_config else
get_filter_config(obs_space.shape),
"conv_activation": model_config.get("conv_activation"),
"post_fcnet_hiddens": [],
}
# if self.cnn_type == "atari":
cnn = ModelCatalog.get_model_v2(
component,
action_space,
num_outputs=None,
model_config=config,
framework="torch",
name="cnn_{}".format(i))
# TODO (sven): add IMPALA-style option.
# else:
# cnn = TorchImpalaVisionNet(
# component,
# action_space,
# num_outputs=None,
# model_config=config,
# name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
self.add_module("cnn_{}".format(i), cnn)
# Discrete|MultiDiscrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
self.one_hot[i] = True
concat_size += component.n
elif isinstance(component, MultiDiscrete):
self.one_hot[i] = True
concat_size += sum(component.nvec)
# Everything else (1D Box).
else:
self.flatten[i] = int(np.product(component.shape))
concat_size += self.flatten[i]
# Optional post-concat FC-stack.
post_fc_stack_config = {
"fcnet_hiddens": model_config.get("post_fcnet_hiddens", []),
"fcnet_activation": model_config.get("post_fcnet_activation",
"relu")
}
self.post_fc_stack = ModelCatalog.get_model_v2(
Box(float("-inf"),
float("inf"),
shape=(concat_size, ),
dtype=np.float32),
self.action_space,
None,
post_fc_stack_config,
framework="torch",
name="post_fc_stack")
# Actions and value heads.
self.logits_layer = None
self.value_layer = None
self._value_out = None
if num_outputs:
# Action-distribution head.
self.logits_layer = SlimFC(
in_size=self.post_fc_stack.num_outputs,
out_size=num_outputs,
activation_fn=None,
)
# Create the value branch model.
self.value_layer = SlimFC(
in_size=self.post_fc_stack.num_outputs,
out_size=1,
activation_fn=None,
initializer=torch_normc_initializer(0.01))
else:
self.num_outputs = concat_size
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
if SampleBatch.OBS in input_dict and "obs_flat" in input_dict:
orig_obs = input_dict[SampleBatch.OBS]
else:
orig_obs = restore_original_dimensions(input_dict[SampleBatch.OBS],
self.obs_space, "tf")
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(tree.flatten(orig_obs)):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [torch.int32, torch.int64, torch.uint8]:
outs.append(
one_hot(component, self.flattened_input_space[i]))
else:
outs.append(component)
else:
outs.append(torch.reshape(component, [-1, self.flatten[i]]))
# Concat all outputs and the non-image inputs.
out = torch.cat(outs, dim=1)
# Push through (optional) FC-stack (this may be an empty stack).
out, _ = self.post_fc_stack({SampleBatch.OBS: out}, [], None)
# No logits/value branches.
if self.logits_layer is None:
return out, []
# Logits- and value branches.
logits, values = self.logits_layer(out), self.value_layer(out)
self._value_out = torch.reshape(values, [-1])
return logits, []
@override(ModelV2)
def value_function(self):
return self._value_out