ray/rllib/models/tf/visionnet.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

504 lines
20 KiB
Python
Raw Normal View History

import gym
from typing import Dict, List, Optional, Sequence
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
# TODO: (sven) obsolete this class once we only support native keras models.
@DeveloperAPI
class VisionNetwork(TFModelV2):
"""Generic vision network implemented in ModelV2 API.
An additional post-conv fully connected stack can be added and configured
via the config keys:
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
`post_fcnet_activation`: Activation function to use for this FC stack.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
if not model_config.get("conv_filters"):
model_config["conv_filters"] = get_filter_config(obs_space.shape)
super(VisionNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name
)
activation = get_activation_fn(
self.model_config.get("conv_activation"), framework="tf"
)
filters = self.model_config["conv_filters"]
assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!"
# Post FC net config.
post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
post_fcnet_activation = get_activation_fn(
model_config.get("post_fcnet_activation"), framework="tf"
)
no_final_linear = self.model_config.get("no_final_linear")
vf_share_layers = self.model_config.get("vf_share_layers")
input_shape = obs_space.shape
self.data_format = "channels_last"
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
last_layer = inputs
# Whether the last layer is the output of a Flattened (rather than
# a n x (1,1) Conv2D).
self.last_layer_is_flattened = False
# Build the action layers
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv{}".format(i),
)(last_layer)
out_size, kernel, stride = filters[-1]
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
last_layer = tf.keras.layers.Conv2D(
out_size if post_fcnet_hiddens else num_outputs,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_out",
)(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + (
[num_outputs] if post_fcnet_hiddens else []
)
feature_out = last_layer
for i, out_size in enumerate(layer_sizes):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv{}".format(len(filters)),
)(last_layer)
# num_outputs defined. Use that to create an exact
# `num_output`-sized (1,1)-Conv2D.
if num_outputs:
if post_fcnet_hiddens:
last_cnn = last_layer = tf.keras.layers.Conv2D(
post_fcnet_hiddens[0],
[1, 1],
activation=post_fcnet_activation,
padding="same",
data_format="channels_last",
name="conv_out",
)(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(
post_fcnet_hiddens[1:] + [num_outputs]
):
feature_out = last_layer
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
activation=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1
else None,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
else:
feature_out = last_layer
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs,
[1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out",
)(last_layer)
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, 1, "
"1, {} (`num_outputs`)] shape (but in {})! Please "
"adjust your Conv2D stack such that the dims 1 and 2 "
"are both 1.".format(
self.model_config["conv_filters"],
self.num_outputs,
list(last_cnn.shape),
)
)
# num_outputs not known -> Flatten, then set self.num_outputs
# to the resulting number of nodes.
else:
self.last_layer_is_flattened = True
last_layer = tf.keras.layers.Flatten(data_format="channels_last")(
last_layer
)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
feature_out = last_layer
self.num_outputs = last_layer.shape[1]
logits_out = last_layer
# Build the value layers
if vf_share_layers:
if not self.last_layer_is_flattened:
feature_out = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2])
)(feature_out)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01),
)(feature_out)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=activation,
padding="same",
data_format="channels_last",
name="conv_value_{}".format(i),
)(last_layer)
out_size, kernel, stride = filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=activation,
padding="valid",
data_format="channels_last",
name="conv_value_{}".format(len(filters)),
)(last_layer)
last_layer = tf.keras.layers.Conv2D(
1,
[1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_value_out",
)(last_layer)
value_out = tf.keras.layers.Lambda(lambda x: tf.squeeze(x, axis=[1, 2]))(
last_layer
)
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> (TensorType, List[TensorType]):
obs = input_dict["obs"]
if self.data_format == "channels_first":
obs = tf.transpose(obs, [0, 2, 3, 1])
# Explicit cast to float32 needed in eager.
model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32))
# Our last layer is already flat.
if self.last_layer_is_flattened:
return model_out, state
# Last layer is a n x [1,1] Conv2D -> Flatten.
else:
return tf.squeeze(model_out, axis=[1, 2]), state
def value_function(self) -> TensorType:
return tf.reshape(self._value_out, [-1])
@DeveloperAPI
class Keras_VisionNetwork(tf.keras.Model if tf else object):
"""Generic vision network implemented in tf keras.
An additional post-conv fully connected stack can be added and configured
via the config keys:
`post_fcnet_hiddens`: Dense layer sizes after the Conv2D stack.
`post_fcnet_activation`: Activation function to use for this FC stack.
"""
def __init__(
self,
input_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int] = None,
*,
name: str = "",
conv_filters: Optional[Sequence[Sequence[int]]] = None,
conv_activation: Optional[str] = None,
post_fcnet_hiddens: Optional[Sequence[int]] = (),
post_fcnet_activation: Optional[str] = None,
no_final_linear: bool = False,
vf_share_layers: bool = False,
free_log_std: bool = False,
**kwargs,
):
super().__init__(name=name)
if not conv_filters:
conv_filters = get_filter_config(input_space.shape)
assert len(conv_filters) > 0, "Must provide at least 1 entry in `conv_filters`!"
conv_activation = get_activation_fn(conv_activation, framework="tf")
post_fcnet_activation = get_activation_fn(post_fcnet_activation, framework="tf")
input_shape = input_space.shape
self.data_format = "channels_last"
inputs = tf.keras.layers.Input(shape=input_shape, name="observations")
last_layer = inputs
# Whether the last layer is the output of a Flattened (rather than
# a n x (1,1) Conv2D).
self.last_layer_is_flattened = False
# Build the action layers
for i, (out_size, kernel, stride) in enumerate(conv_filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=conv_activation,
padding="same",
data_format="channels_last",
name="conv{}".format(i),
)(last_layer)
out_size, kernel, stride = conv_filters[-1]
# No final linear: Last layer has activation function and exits with
# num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
# on `post_fcnet_...` settings).
if no_final_linear and num_outputs:
last_layer = tf.keras.layers.Conv2D(
out_size if post_fcnet_hiddens else num_outputs,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv_out",
)(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
layer_sizes = post_fcnet_hiddens[:-1] + (
[num_outputs] if post_fcnet_hiddens else []
)
for i, out_size in enumerate(layer_sizes):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
# Finish network normally (w/o overriding last layer size with
# `num_outputs`), then add another linear one of size `num_outputs`.
else:
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv{}".format(len(conv_filters)),
)(last_layer)
# num_outputs defined. Use that to create an exact
# `num_output`-sized (1,1)-Conv2D.
if num_outputs:
if post_fcnet_hiddens:
last_cnn = last_layer = tf.keras.layers.Conv2D(
post_fcnet_hiddens[0],
[1, 1],
activation=post_fcnet_activation,
padding="same",
data_format="channels_last",
name="conv_out",
)(last_layer)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(
post_fcnet_hiddens[1:] + [num_outputs]
):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i + 1),
activation=post_fcnet_activation
if i < len(post_fcnet_hiddens) - 1
else None,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
else:
last_cnn = last_layer = tf.keras.layers.Conv2D(
num_outputs,
[1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_out",
)(last_layer)
if last_cnn.shape[1] != 1 or last_cnn.shape[2] != 1:
raise ValueError(
"Given `conv_filters` ({}) do not result in a [B, 1, "
"1, {} (`num_outputs`)] shape (but in {})! Please "
"adjust your Conv2D stack such that the dims 1 and 2 "
"are both 1.".format(
self.model_config["conv_filters"],
num_outputs,
list(last_cnn.shape),
)
)
# num_outputs not known -> Flatten.
else:
self.last_layer_is_flattened = True
last_layer = tf.keras.layers.Flatten(data_format="channels_last")(
last_layer
)
# Add (optional) post-fc-stack after last Conv2D layer.
for i, out_size in enumerate(post_fcnet_hiddens):
last_layer = tf.keras.layers.Dense(
out_size,
name="post_fcnet_{}".format(i),
activation=post_fcnet_activation,
kernel_initializer=normc_initializer(1.0),
)(last_layer)
logits_out = last_layer
# Build the value layers
if vf_share_layers:
if not self.last_layer_is_flattened:
last_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, axis=[1, 2])
)(last_layer)
value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01),
)(last_layer)
else:
# build a parallel set of hidden layers for the value net
last_layer = inputs
for i, (out_size, kernel, stride) in enumerate(conv_filters[:-1], 1):
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=conv_activation,
padding="same",
data_format="channels_last",
name="conv_value_{}".format(i),
)(last_layer)
out_size, kernel, stride = conv_filters[-1]
last_layer = tf.keras.layers.Conv2D(
out_size,
kernel,
strides=stride
if isinstance(stride, (list, tuple))
else (stride, stride),
activation=conv_activation,
padding="valid",
data_format="channels_last",
name="conv_value_{}".format(len(conv_filters)),
)(last_layer)
last_layer = tf.keras.layers.Conv2D(
1,
[1, 1],
activation=None,
padding="same",
data_format="channels_last",
name="conv_value_out",
)(last_layer)
value_out = tf.keras.layers.Lambda(lambda x: tf.squeeze(x, axis=[1, 2]))(
last_layer
)
self.base_model = tf.keras.Model(inputs, [logits_out, value_out])
def call(
self, input_dict: SampleBatch
) -> (TensorType, List[TensorType], Dict[str, TensorType]):
obs = input_dict["obs"]
if self.data_format == "channels_first":
obs = tf.transpose(obs, [0, 2, 3, 1])
# Explicit cast to float32 needed in eager.
model_out, self._value_out = self.base_model(tf.cast(obs, tf.float32))
state = [v for k, v in input_dict.items() if k.startswith("state_in_")]
extra_outs = {SampleBatch.VF_PREDS: tf.reshape(self._value_out, [-1])}
# Our last layer is already flat.
if self.last_layer_is_flattened:
return model_out, state, extra_outs
# Last layer is a n x [1,1] Conv2D -> Flatten.
else:
return tf.squeeze(model_out, axis=[1, 2]), state, extra_outs