ray/rllib/models/torch/modules/convtranspose2d_stack.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

83 lines
3.1 KiB
Python
Raw Normal View History

from typing import Tuple
from ray.rllib.models.torch.misc import Reshape
from ray.rllib.models.utils import get_activation_fn, get_initializer
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
if torch:
import torch.distributions as td
class ConvTranspose2DStack(nn.Module):
"""ConvTranspose2D decoder generating an image distribution from a vector."""
def __init__(
self,
*,
input_size: int,
filters: Tuple[Tuple[int]] = (
(1024, 5, 2),
(128, 5, 2),
(64, 6, 2),
(32, 6, 2),
),
initializer="default",
bias_init=0,
activation_fn: str = "relu",
output_shape: Tuple[int] = (3, 64, 64)
):
"""Initializes a TransposedConv2DStack instance.
Args:
input_size (int): The size of the 1D input vector, from which to
generate the image distribution.
filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each
ConvTranspose2D layer): [in_channels, kernel, stride].
initializer (Union[str]):
bias_init (float): The initial bias values to use.
activation_fn (str): Activation function descriptor (str).
output_shape (Tuple[int]): Shape of the final output image.
"""
super().__init__()
self.activation = get_activation_fn(activation_fn, framework="torch")
self.output_shape = output_shape
initializer = get_initializer(initializer, framework="torch")
in_channels = filters[0][0]
self.layers = [
# Map from 1D-input vector to correct initial size for the
# Conv2DTransposed stack.
nn.Linear(input_size, in_channels),
# Reshape from the incoming 1D vector (input_size) to 1x1 image
# format (channels first).
Reshape([-1, in_channels, 1, 1]),
]
for i, (_, kernel, stride) in enumerate(filters):
out_channels = (
filters[i + 1][0] if i < len(filters) - 1 else output_shape[0]
)
conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride)
# Apply initializer.
initializer(conv_transp.weight)
nn.init.constant_(conv_transp.bias, bias_init)
self.layers.append(conv_transp)
# Apply activation function, if provided and if not last layer.
if self.activation is not None and i < len(filters) - 1:
self.layers.append(self.activation())
# num-outputs == num-inputs for next layer.
in_channels = out_channels
self._model = nn.Sequential(*self.layers)
def forward(self, x):
# x is [batch, hor_length, input_size]
batch_dims = x.shape[:-1]
model_out = self._model(x)
# Equivalent to making a multivariate diag.
reshape_size = batch_dims + self.output_shape
mean = model_out.view(*reshape_size)
return td.Independent(td.Normal(mean, 1.0), len(self.output_shape))