mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] ConvTranspose2D module (#11231)
This commit is contained in:
parent
d1579819e9
commit
d3bc20b727
12 changed files with 621 additions and 453 deletions
21
rllib/BUILD
21
rllib/BUILD
|
@ -1040,13 +1040,6 @@ py_test(
|
|||
# Tag: models
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_distributions",
|
||||
tags = ["models"],
|
||||
size = "medium",
|
||||
srcs = ["models/tests/test_distributions.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_attention_nets",
|
||||
tags = ["models"],
|
||||
|
@ -1054,6 +1047,20 @@ py_test(
|
|||
srcs = ["models/tests/test_attention_nets.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_convtranspose2d_stack",
|
||||
tags = ["models"],
|
||||
size = "small",
|
||||
data = glob(["tests/data/images/obstacle_tower.png"]),
|
||||
srcs = ["models/tests/test_convtranspose2d_stack.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_distributions",
|
||||
tags = ["models"],
|
||||
size = "medium",
|
||||
srcs = ["models/tests/test_distributions.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Evaluation components
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
from typing import Any, List, Tuple
|
||||
from ray.rllib.models.torch.modules.reshape import Reshape
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.framework import TensorType
|
||||
|
@ -12,203 +13,187 @@ if torch:
|
|||
|
||||
ActFunc = Any
|
||||
|
||||
|
||||
# Encoder, part of PlaNET
|
||||
if torch:
|
||||
class ConvEncoder(nn.Module):
|
||||
"""Standard Convolutional Encoder for Dreamer. This encoder is used
|
||||
to encode images frm an enviornment into a latent state for the
|
||||
RSSM model in PlaNET.
|
||||
"""
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
"""Standard Convolutional Encoder for Dreamer. This encoder is used
|
||||
to encode images frm an enviornment into a latent state for the
|
||||
RSSM model in PlaNET.
|
||||
"""
|
||||
def __init__(self,
|
||||
depth: int = 32,
|
||||
act: ActFunc = None,
|
||||
shape: Tuple[int] = (3, 64, 64)):
|
||||
"""Initializes Conv Encoder
|
||||
|
||||
def __init__(self,
|
||||
depth: int = 32,
|
||||
act: ActFunc = None,
|
||||
shape: List = [3, 64, 64]):
|
||||
"""Initializes Conv Encoder
|
||||
|
||||
Args:
|
||||
Args:
|
||||
depth (int): Number of channels in the first conv layer
|
||||
act (Any): Activation for Encoder, default ReLU
|
||||
shape (List): Shape of observation input
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.depth = depth
|
||||
self.shape = shape
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.depth = depth
|
||||
self.shape = shape
|
||||
|
||||
init_channels = self.shape[0]
|
||||
self.layers = [
|
||||
Conv2d(init_channels, self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(self.depth, 2 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(2 * self.depth, 4 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(4 * self.depth, 8 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
]
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
init_channels = self.shape[0]
|
||||
self.layers = [
|
||||
Conv2d(init_channels, self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(self.depth, 2 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(2 * self.depth, 4 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
Conv2d(4 * self.depth, 8 * self.depth, 4, stride=2),
|
||||
self.act(),
|
||||
]
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
|
||||
def forward(self, x):
|
||||
# Flatten to [batch*horizon, 3, 64, 64] in loss function
|
||||
orig_shape = list(x.size())
|
||||
x = x.view(-1, *(orig_shape[-3:]))
|
||||
x = self.model(x)
|
||||
def forward(self, x):
|
||||
# Flatten to [batch*horizon, 3, 64, 64] in loss function
|
||||
orig_shape = list(x.size())
|
||||
x = x.view(-1, *(orig_shape[-3:]))
|
||||
x = self.model(x)
|
||||
|
||||
new_shape = orig_shape[:-3] + [32 * self.depth]
|
||||
x = x.view(*new_shape)
|
||||
return x
|
||||
|
||||
|
||||
if torch:
|
||||
|
||||
class Reshape(nn.Module):
|
||||
"""Standard module that reshapes/views a tensor
|
||||
"""
|
||||
|
||||
def __init__(self, shape: List):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(*self.shape)
|
||||
new_shape = orig_shape[:-3] + [32 * self.depth]
|
||||
x = x.view(*new_shape)
|
||||
return x
|
||||
|
||||
|
||||
# Decoder, part of PlaNET
|
||||
if torch:
|
||||
class ConvDecoder(nn.Module):
|
||||
"""Standard Convolutional Decoder for Dreamer.
|
||||
|
||||
class ConvDecoder(nn.Module):
|
||||
"""Standard Convolutional Decoder for Dreamer. This decoder is used
|
||||
to decoder images from the latent state generated by the transition
|
||||
dynamics model. This is used in calulating loss and logging gifs for
|
||||
imagine trajectories.
|
||||
"""
|
||||
This decoder is used to decode images from the latent state generated
|
||||
by the transition dynamics model. This is used in calculating loss and
|
||||
logging gifs for imagined trajectories.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
depth: int = 32,
|
||||
act: ActFunc = None,
|
||||
shape: List = [3, 64, 64]):
|
||||
"""Initializes Conv Decoder
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
depth: int = 32,
|
||||
act: ActFunc = None,
|
||||
shape: Tuple[int] = (3, 64, 64)):
|
||||
"""Initializes a ConvDecoder instance.
|
||||
|
||||
Args:
|
||||
input_size (int): Input size, usually feature size output from RSSM
|
||||
Args:
|
||||
input_size (int): Input size, usually feature size output from
|
||||
RSSM.
|
||||
depth (int): Number of channels in the first conv layer
|
||||
act (Any): Activation for Encoder, default ReLU
|
||||
shape (List): Shape of observation input
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.depth = depth
|
||||
self.shape = shape
|
||||
"""
|
||||
super().__init__()
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.depth = depth
|
||||
self.shape = shape
|
||||
|
||||
self.layers = [
|
||||
Linear(input_size, 32 * self.depth),
|
||||
Reshape((-1, 32 * self.depth, 1, 1)),
|
||||
ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
|
||||
]
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
self.layers = [
|
||||
Linear(input_size, 32 * self.depth),
|
||||
Reshape([-1, 32 * self.depth, 1, 1]),
|
||||
ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
|
||||
self.act(),
|
||||
ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
|
||||
]
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
|
||||
def forward(self, x):
|
||||
# x is [batch, hor_length, input_size]
|
||||
orig_shape = list(x.size())
|
||||
x = self.model(x)
|
||||
def forward(self, x):
|
||||
# x is [batch, hor_length, input_size]
|
||||
orig_shape = list(x.size())
|
||||
x = self.model(x)
|
||||
|
||||
reshape_size = orig_shape[:-1] + self.shape
|
||||
mean = x.view(*reshape_size)
|
||||
reshape_size = orig_shape[:-1] + self.shape
|
||||
mean = x.view(*reshape_size)
|
||||
|
||||
# Equivalent to making a multivariate diag
|
||||
return td.Independent(td.Normal(mean, 1), len(self.shape))
|
||||
# Equivalent to making a multivariate diag
|
||||
return td.Independent(td.Normal(mean, 1), len(self.shape))
|
||||
|
||||
|
||||
# Reward Model (PlaNET), and Value Function
|
||||
if torch:
|
||||
class DenseDecoder(nn.Module):
|
||||
"""FC network that outputs a distribution for calculating log_prob.
|
||||
|
||||
class DenseDecoder(nn.Module):
|
||||
"""Fully Connected network that outputs a distribution for calculating log_prob
|
||||
later in DreamerLoss
|
||||
"""
|
||||
Used later in DreamerLoss.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
layers: int,
|
||||
units: int,
|
||||
dist: str = "normal",
|
||||
act: ActFunc = None):
|
||||
"""Initializes FC network
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
layers: int,
|
||||
units: int,
|
||||
dist: str = "normal",
|
||||
act: ActFunc = None):
|
||||
"""Initializes FC network
|
||||
|
||||
Args:
|
||||
Args:
|
||||
input_size (int): Input size to network
|
||||
output_size (int): Output size to network
|
||||
layers (int): Number of layers in network
|
||||
units (int): Size of the hidden layers
|
||||
dist (str): Output distribution, parameterized by FC output logits
|
||||
dist (str): Output distribution, parameterized by FC output
|
||||
logits.
|
||||
act (Any): Activation function
|
||||
"""
|
||||
super().__init__()
|
||||
self.layrs = layers
|
||||
self.units = units
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ELU
|
||||
self.dist = dist
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.layers = []
|
||||
cur_size = input_size
|
||||
for _ in range(self.layrs):
|
||||
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
||||
cur_size = units
|
||||
self.layers.append(Linear(cur_size, output_size))
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
"""
|
||||
super().__init__()
|
||||
self.layrs = layers
|
||||
self.units = units
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ELU
|
||||
self.dist = dist
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.layers = []
|
||||
cur_size = input_size
|
||||
for _ in range(self.layrs):
|
||||
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
||||
cur_size = units
|
||||
self.layers.append(Linear(cur_size, output_size))
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
if self.output_size == 1:
|
||||
x = torch.squeeze(x)
|
||||
if self.dist == "normal":
|
||||
output_dist = td.Normal(x, 1)
|
||||
elif self.dist == "binary":
|
||||
output_dist = td.Bernoulli(logits=x)
|
||||
else:
|
||||
raise NotImplementedError("Distribution type not implemented!")
|
||||
return td.Independent(output_dist, 0)
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
if self.output_size == 1:
|
||||
x = torch.squeeze(x)
|
||||
if self.dist == "normal":
|
||||
output_dist = td.Normal(x, 1)
|
||||
elif self.dist == "binary":
|
||||
output_dist = td.Bernoulli(logits=x)
|
||||
else:
|
||||
raise NotImplementedError("Distribution type not implemented!")
|
||||
return td.Independent(output_dist, 0)
|
||||
|
||||
|
||||
# Represents dreamer policy
|
||||
if torch:
|
||||
class ActionDecoder(nn.Module):
|
||||
"""ActionDecoder is the policy module in Dreamer.
|
||||
|
||||
class ActionDecoder(nn.Module):
|
||||
"""ActionDecoder is the policy module in Dreamer. It outputs a distribution
|
||||
parameterized by mean and std, later to be transformed by a custom
|
||||
TanhBijector in utils.py for Dreamer.
|
||||
"""
|
||||
It outputs a distribution parameterized by mean and std, later to be
|
||||
transformed by a custom TanhBijector in utils.py for Dreamer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
action_size: int,
|
||||
layers: int,
|
||||
units: int,
|
||||
dist: str = "tanh_normal",
|
||||
act: ActFunc = None,
|
||||
min_std: float = 1e-4,
|
||||
init_std: float = 5.0,
|
||||
mean_scale: float = 5.0):
|
||||
"""Initializes Policy
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
action_size: int,
|
||||
layers: int,
|
||||
units: int,
|
||||
dist: str = "tanh_normal",
|
||||
act: ActFunc = None,
|
||||
min_std: float = 1e-4,
|
||||
init_std: float = 5.0,
|
||||
mean_scale: float = 5.0):
|
||||
"""Initializes Policy
|
||||
|
||||
Args:
|
||||
Args:
|
||||
input_size (int): Input size to network
|
||||
action_size (int): Action space size
|
||||
layers (int): Number of layers in network
|
||||
|
@ -218,342 +203,335 @@ if torch:
|
|||
min_std (float): Minimum std for output distribution
|
||||
init_std (float): Intitial std
|
||||
mean_scale (float): Augmenting mean output from FC network
|
||||
"""
|
||||
super().__init__()
|
||||
self.layrs = layers
|
||||
self.units = units
|
||||
self.dist = dist
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.min_std = min_std
|
||||
self.init_std = init_std
|
||||
self.mean_scale = mean_scale
|
||||
self.action_size = action_size
|
||||
"""
|
||||
super().__init__()
|
||||
self.layrs = layers
|
||||
self.units = units
|
||||
self.dist = dist
|
||||
self.act = act
|
||||
if not act:
|
||||
self.act = nn.ReLU
|
||||
self.min_std = min_std
|
||||
self.init_std = init_std
|
||||
self.mean_scale = mean_scale
|
||||
self.action_size = action_size
|
||||
|
||||
self.layers = []
|
||||
self.softplus = nn.Softplus()
|
||||
self.layers = []
|
||||
self.softplus = nn.Softplus()
|
||||
|
||||
# MLP Construction
|
||||
cur_size = input_size
|
||||
for _ in range(self.layrs):
|
||||
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
||||
cur_size = self.units
|
||||
if self.dist == "tanh_normal":
|
||||
self.layers.append(Linear(cur_size, 2 * action_size))
|
||||
elif self.dist == "onehot":
|
||||
self.layers.append(Linear(cur_size, action_size))
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
# MLP Construction
|
||||
cur_size = input_size
|
||||
for _ in range(self.layrs):
|
||||
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
||||
cur_size = self.units
|
||||
if self.dist == "tanh_normal":
|
||||
self.layers.append(Linear(cur_size, 2 * action_size))
|
||||
elif self.dist == "onehot":
|
||||
self.layers.append(Linear(cur_size, action_size))
|
||||
self.model = nn.Sequential(*self.layers)
|
||||
|
||||
# Returns distribution
|
||||
def forward(self, x):
|
||||
raw_init_std = np.log(np.exp(self.init_std) - 1)
|
||||
x = self.model(x)
|
||||
if self.dist == "tanh_normal":
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
|
||||
std = self.softplus(std + raw_init_std) + self.min_std
|
||||
dist = td.Normal(mean, std)
|
||||
transforms = [TanhBijector()]
|
||||
dist = td.transformed_distribution.TransformedDistribution(
|
||||
dist, transforms)
|
||||
dist = td.Independent(dist, 1)
|
||||
elif self.dist == "onehot":
|
||||
dist = td.OneHotCategorical(logits=x)
|
||||
raise NotImplementedError("Atari not implemented yet!")
|
||||
return dist
|
||||
# Returns distribution
|
||||
def forward(self, x):
|
||||
raw_init_std = np.log(np.exp(self.init_std) - 1)
|
||||
x = self.model(x)
|
||||
if self.dist == "tanh_normal":
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
|
||||
std = self.softplus(std + raw_init_std) + self.min_std
|
||||
dist = td.Normal(mean, std)
|
||||
transforms = [TanhBijector()]
|
||||
dist = td.transformed_distribution.TransformedDistribution(
|
||||
dist, transforms)
|
||||
dist = td.Independent(dist, 1)
|
||||
elif self.dist == "onehot":
|
||||
dist = td.OneHotCategorical(logits=x)
|
||||
raise NotImplementedError("Atari not implemented yet!")
|
||||
return dist
|
||||
|
||||
|
||||
# Represents TD model in PlaNET
|
||||
if torch:
|
||||
class RSSM(nn.Module):
|
||||
"""RSSM is the core recurrent part of the PlaNET module. It consists of
|
||||
two networks, one (obs) to calculate posterior beliefs and states and
|
||||
the second (img) to calculate prior beliefs and states. The prior network
|
||||
takes in the previous state and action, while the posterior network takes
|
||||
in the previous state, action, and a latent embedding of the most recent
|
||||
observation.
|
||||
"""
|
||||
|
||||
class RSSM(nn.Module):
|
||||
"""RSSM is the core recurrent part of the PlaNET module. It consists of
|
||||
two networks, one (obs) to calculate posterior beliefs and states and
|
||||
the second (img) to calculate prior beliefs and states. The prior network
|
||||
takes in the previous state and action, while the posterior network takes
|
||||
in the previous state, action, and a latent embedding of the most recent
|
||||
observation.
|
||||
"""
|
||||
def __init__(self,
|
||||
action_size: int,
|
||||
embed_size: int,
|
||||
stoch: int = 30,
|
||||
deter: int = 200,
|
||||
hidden: int = 200,
|
||||
act: ActFunc = None):
|
||||
"""Initializes RSSM
|
||||
|
||||
def __init__(self,
|
||||
action_size: int,
|
||||
embed_size: int,
|
||||
stoch: int = 30,
|
||||
deter: int = 200,
|
||||
hidden: int = 200,
|
||||
act: ActFunc = None):
|
||||
"""Initializes RSSM
|
||||
|
||||
Args:
|
||||
Args:
|
||||
action_size (int): Action space size
|
||||
embed_size (int): Size of ConvEncoder embedding
|
||||
stoch (int): Size of the distributional hidden state
|
||||
deter (int): Size of the deterministic hidden state
|
||||
hidden (int): General size of hidden layers
|
||||
act (Any): Activation function
|
||||
"""
|
||||
super().__init__()
|
||||
self.stoch_size = stoch
|
||||
self.deter_size = deter
|
||||
self.hidden_size = hidden
|
||||
self.act = act
|
||||
if act is None:
|
||||
self.act = nn.ELU
|
||||
"""
|
||||
super().__init__()
|
||||
self.stoch_size = stoch
|
||||
self.deter_size = deter
|
||||
self.hidden_size = hidden
|
||||
self.act = act
|
||||
if act is None:
|
||||
self.act = nn.ELU
|
||||
|
||||
self.obs1 = Linear(embed_size + deter, hidden)
|
||||
self.obs2 = Linear(hidden, 2 * stoch)
|
||||
self.obs1 = Linear(embed_size + deter, hidden)
|
||||
self.obs2 = Linear(hidden, 2 * stoch)
|
||||
|
||||
self.cell = GRUCell(self.hidden_size, hidden_size=self.deter_size)
|
||||
self.img1 = Linear(stoch + action_size, hidden)
|
||||
self.img2 = Linear(deter, hidden)
|
||||
self.img3 = Linear(hidden, 2 * stoch)
|
||||
self.cell = GRUCell(self.hidden_size, hidden_size=self.deter_size)
|
||||
self.img1 = Linear(stoch + action_size, hidden)
|
||||
self.img2 = Linear(deter, hidden)
|
||||
self.img3 = Linear(hidden, 2 * stoch)
|
||||
|
||||
self.softplus = nn.Softplus
|
||||
self.softplus = nn.Softplus
|
||||
|
||||
self.device = (torch.device("cuda") if torch.cuda.is_available()
|
||||
else torch.device("cpu"))
|
||||
self.device = (torch.device("cuda")
|
||||
if torch.cuda.is_available() else torch.device("cpu"))
|
||||
|
||||
def get_initial_state(self, batch_size: int) -> List[TensorType]:
|
||||
"""Returns the inital state for the RSSM, which consists of mean, std
|
||||
for the stochastic state, the sampled stochastic hidden state
|
||||
(from mean, std), and the deterministic hidden state, which is pushed
|
||||
through the GRUCell.
|
||||
def get_initial_state(self, batch_size: int) -> List[TensorType]:
|
||||
"""Returns the inital state for the RSSM, which consists of mean,
|
||||
std for the stochastic state, the sampled stochastic hidden state
|
||||
(from mean, std), and the deterministic hidden state, which is
|
||||
pushed through the GRUCell.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
batch_size (int): Batch size for initial state
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
List of tensors
|
||||
"""
|
||||
return [
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.deter_size).to(self.device),
|
||||
]
|
||||
"""
|
||||
return [
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
||||
torch.zeros(batch_size, self.deter_size).to(self.device),
|
||||
]
|
||||
|
||||
def observe(self,
|
||||
embed: TensorType,
|
||||
action: TensorType,
|
||||
state: List[TensorType] = None
|
||||
) -> Tuple[List[TensorType], List[TensorType]]:
|
||||
"""Returns the corresponding states from the embedding from ConvEncoder
|
||||
and actions. This is accomplished by rolling out the RNN from the
|
||||
starting state through eacn index of embed and action, saving all
|
||||
intermediate states between.
|
||||
def observe(self,
|
||||
embed: TensorType,
|
||||
action: TensorType,
|
||||
state: List[TensorType] = None
|
||||
) -> Tuple[List[TensorType], List[TensorType]]:
|
||||
"""Returns the corresponding states from the embedding from ConvEncoder
|
||||
and actions. This is accomplished by rolling out the RNN from the
|
||||
starting state through eacn index of embed and action, saving all
|
||||
intermediate states between.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
embed (TensorType): ConvEncoder embedding
|
||||
action (TensorType): Actions
|
||||
state (List[TensorType]): Initial state before rollout
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
Posterior states and prior states (both List[TensorType])
|
||||
"""
|
||||
if state is None:
|
||||
state = self.get_initial_state(action.size()[0])
|
||||
"""
|
||||
if state is None:
|
||||
state = self.get_initial_state(action.size()[0])
|
||||
|
||||
embed = embed.permute(1, 0, 2)
|
||||
action = action.permute(1, 0, 2)
|
||||
embed = embed.permute(1, 0, 2)
|
||||
action = action.permute(1, 0, 2)
|
||||
|
||||
priors = [[] for i in range(len(state))]
|
||||
posts = [[] for i in range(len(state))]
|
||||
last = (state, state)
|
||||
for index in range(len(action)):
|
||||
# Tuple of post and prior
|
||||
last = self.obs_step(last[0], action[index], embed[index])
|
||||
[o.append(l) for l, o in zip(last[0], posts)]
|
||||
[o.append(l) for l, o in zip(last[1], priors)]
|
||||
priors = [[] for i in range(len(state))]
|
||||
posts = [[] for i in range(len(state))]
|
||||
last = (state, state)
|
||||
for index in range(len(action)):
|
||||
# Tuple of post and prior
|
||||
last = self.obs_step(last[0], action[index], embed[index])
|
||||
[o.append(l) for l, o in zip(last[0], posts)]
|
||||
[o.append(l) for l, o in zip(last[1], priors)]
|
||||
|
||||
prior = [torch.stack(x, dim=0) for x in priors]
|
||||
post = [torch.stack(x, dim=0) for x in posts]
|
||||
prior = [torch.stack(x, dim=0) for x in priors]
|
||||
post = [torch.stack(x, dim=0) for x in posts]
|
||||
|
||||
prior = [e.permute(1, 0, 2) for e in prior]
|
||||
post = [e.permute(1, 0, 2) for e in post]
|
||||
prior = [e.permute(1, 0, 2) for e in prior]
|
||||
post = [e.permute(1, 0, 2) for e in post]
|
||||
|
||||
return post, prior
|
||||
return post, prior
|
||||
|
||||
def imagine(self, action: TensorType,
|
||||
state: List[TensorType] = None) -> List[TensorType]:
|
||||
"""Imagines the trajectory starting from state through a list of actions.
|
||||
Similar to observe(), requires rolling out the RNN for each timestep.
|
||||
def imagine(self, action: TensorType,
|
||||
state: List[TensorType] = None) -> List[TensorType]:
|
||||
"""Imagines the trajectory starting from state through a list of actions.
|
||||
Similar to observe(), requires rolling out the RNN for each timestep.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
action (TensorType): Actions
|
||||
state (List[TensorType]): Starting state before rollout
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
Prior states
|
||||
"""
|
||||
if state is None:
|
||||
state = self.get_initial_state(action.size()[0])
|
||||
"""
|
||||
if state is None:
|
||||
state = self.get_initial_state(action.size()[0])
|
||||
|
||||
action = action.permute(1, 0, 2)
|
||||
action = action.permute(1, 0, 2)
|
||||
|
||||
indices = range(len(action))
|
||||
priors = [[] for _ in range(len(state))]
|
||||
last = state
|
||||
for index in indices:
|
||||
last = self.img_step(last, action[index])
|
||||
[o.append(l) for l, o in zip(last, priors)]
|
||||
indices = range(len(action))
|
||||
priors = [[] for _ in range(len(state))]
|
||||
last = state
|
||||
for index in indices:
|
||||
last = self.img_step(last, action[index])
|
||||
[o.append(l) for l, o in zip(last, priors)]
|
||||
|
||||
prior = [torch.stack(x, dim=0) for x in priors]
|
||||
prior = [e.permute(1, 0, 2) for e in prior]
|
||||
return prior
|
||||
prior = [torch.stack(x, dim=0) for x in priors]
|
||||
prior = [e.permute(1, 0, 2) for e in prior]
|
||||
return prior
|
||||
|
||||
def obs_step(self, prev_state: TensorType, prev_action: TensorType,
|
||||
embed: TensorType
|
||||
) -> Tuple[List[TensorType], List[TensorType]]:
|
||||
"""Runs through the posterior model and returns the posterior state
|
||||
def obs_step(
|
||||
self, prev_state: TensorType, prev_action: TensorType,
|
||||
embed: TensorType) -> Tuple[List[TensorType], List[TensorType]]:
|
||||
"""Runs through the posterior model and returns the posterior state
|
||||
|
||||
Args:
|
||||
Args:
|
||||
prev_state (TensorType): The previous state
|
||||
prev_action (TensorType): The previous action
|
||||
embed (TensorType): Embedding from ConvEncoder
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
Post and Prior state
|
||||
"""
|
||||
prior = self.img_step(prev_state, prev_action)
|
||||
x = torch.cat([prior[3], embed], dim=-1)
|
||||
x = self.obs1(x)
|
||||
x = self.act()(x)
|
||||
x = self.obs2(x)
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
std = self.softplus()(std) + 0.1
|
||||
stoch = self.get_dist(mean, std).rsample()
|
||||
post = [mean, std, stoch, prior[3]]
|
||||
return post, prior
|
||||
"""
|
||||
prior = self.img_step(prev_state, prev_action)
|
||||
x = torch.cat([prior[3], embed], dim=-1)
|
||||
x = self.obs1(x)
|
||||
x = self.act()(x)
|
||||
x = self.obs2(x)
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
std = self.softplus()(std) + 0.1
|
||||
stoch = self.get_dist(mean, std).rsample()
|
||||
post = [mean, std, stoch, prior[3]]
|
||||
return post, prior
|
||||
|
||||
def img_step(self, prev_state: TensorType,
|
||||
prev_action: TensorType) -> List[TensorType]:
|
||||
"""Runs through the prior model and returns the prior state
|
||||
def img_step(self, prev_state: TensorType,
|
||||
prev_action: TensorType) -> List[TensorType]:
|
||||
"""Runs through the prior model and returns the prior state
|
||||
|
||||
Args:
|
||||
Args:
|
||||
prev_state (TensorType): The previous state
|
||||
prev_action (TensorType): The previous action
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
Prior state
|
||||
"""
|
||||
x = torch.cat([prev_state[2], prev_action], dim=-1)
|
||||
x = self.img1(x)
|
||||
x = self.act()(x)
|
||||
deter = self.cell(x, prev_state[3])
|
||||
x = deter
|
||||
x = self.img2(x)
|
||||
x = self.act()(x)
|
||||
x = self.img3(x)
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
std = self.softplus()(std) + 0.1
|
||||
stoch = self.get_dist(mean, std).rsample()
|
||||
return [mean, std, stoch, deter]
|
||||
"""
|
||||
x = torch.cat([prev_state[2], prev_action], dim=-1)
|
||||
x = self.img1(x)
|
||||
x = self.act()(x)
|
||||
deter = self.cell(x, prev_state[3])
|
||||
x = deter
|
||||
x = self.img2(x)
|
||||
x = self.act()(x)
|
||||
x = self.img3(x)
|
||||
mean, std = torch.chunk(x, 2, dim=-1)
|
||||
std = self.softplus()(std) + 0.1
|
||||
stoch = self.get_dist(mean, std).rsample()
|
||||
return [mean, std, stoch, deter]
|
||||
|
||||
def get_feature(self, state: List[TensorType]) -> TensorType:
|
||||
# Constructs feature for input to reward, decoder, actor, critic
|
||||
return torch.cat([state[2], state[3]], dim=-1)
|
||||
def get_feature(self, state: List[TensorType]) -> TensorType:
|
||||
# Constructs feature for input to reward, decoder, actor, critic
|
||||
return torch.cat([state[2], state[3]], dim=-1)
|
||||
|
||||
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
|
||||
return td.Normal(mean, std)
|
||||
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
|
||||
return td.Normal(mean, std)
|
||||
|
||||
|
||||
# Represents all models in Dreamer, unifies them all into a single interface
|
||||
if torch:
|
||||
class DreamerModel(TorchModelV2, nn.Module):
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs, model_config,
|
||||
name)
|
||||
|
||||
class DreamerModel(TorchModelV2, nn.Module):
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name):
|
||||
super().__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
nn.Module.__init__(self)
|
||||
self.depth = model_config["depth_size"]
|
||||
self.deter_size = model_config["deter_size"]
|
||||
self.stoch_size = model_config["stoch_size"]
|
||||
self.hidden_size = model_config["hidden_size"]
|
||||
|
||||
nn.Module.__init__(self)
|
||||
self.depth = model_config["depth_size"]
|
||||
self.deter_size = model_config["deter_size"]
|
||||
self.stoch_size = model_config["stoch_size"]
|
||||
self.hidden_size = model_config["hidden_size"]
|
||||
self.action_size = action_space.shape[0]
|
||||
|
||||
self.action_size = action_space.shape[0]
|
||||
self.encoder = ConvEncoder(self.depth)
|
||||
self.decoder = ConvDecoder(
|
||||
self.stoch_size + self.deter_size, depth=self.depth)
|
||||
self.reward = DenseDecoder(self.stoch_size + self.deter_size, 1, 2,
|
||||
self.hidden_size)
|
||||
self.dynamics = RSSM(
|
||||
self.action_size,
|
||||
32 * self.depth,
|
||||
stoch=self.stoch_size,
|
||||
deter=self.deter_size)
|
||||
self.actor = ActionDecoder(self.stoch_size + self.deter_size,
|
||||
self.action_size, 4, self.hidden_size)
|
||||
self.value = DenseDecoder(self.stoch_size + self.deter_size, 1, 3,
|
||||
self.hidden_size)
|
||||
self.state = None
|
||||
|
||||
self.encoder = ConvEncoder(self.depth)
|
||||
self.decoder = ConvDecoder(
|
||||
self.stoch_size + self.deter_size, depth=self.depth)
|
||||
self.reward = DenseDecoder(self.stoch_size + self.deter_size, 1, 2,
|
||||
self.hidden_size)
|
||||
self.dynamics = RSSM(
|
||||
self.action_size,
|
||||
32 * self.depth,
|
||||
stoch=self.stoch_size,
|
||||
deter=self.deter_size)
|
||||
self.actor = ActionDecoder(self.stoch_size + self.deter_size,
|
||||
self.action_size, 4, self.hidden_size)
|
||||
self.value = DenseDecoder(self.stoch_size + self.deter_size, 1, 3,
|
||||
self.hidden_size)
|
||||
self.state = None
|
||||
self.device = (torch.device("cuda")
|
||||
if torch.cuda.is_available() else torch.device("cpu"))
|
||||
|
||||
self.device = (torch.device("cuda") if torch.cuda.is_available()
|
||||
else torch.device("cpu"))
|
||||
def policy(self, obs: TensorType, state: List[TensorType], explore=True
|
||||
) -> Tuple[TensorType, List[float], List[TensorType]]:
|
||||
"""Returns the action. Runs through the encoder, recurrent model,
|
||||
and policy to obtain action.
|
||||
"""
|
||||
if state is None:
|
||||
self.initial_state()
|
||||
else:
|
||||
self.state = state
|
||||
post = self.state[:4]
|
||||
action = self.state[4]
|
||||
|
||||
def policy(self,
|
||||
obs: TensorType,
|
||||
state: List[TensorType],
|
||||
explore=True
|
||||
) -> Tuple[TensorType, List[float], List[TensorType]]:
|
||||
"""Returns the action. Runs through the encoder, recurrent model,
|
||||
and policy to obtain action.
|
||||
"""
|
||||
if state is None:
|
||||
self.initial_state()
|
||||
else:
|
||||
self.state = state
|
||||
post = self.state[:4]
|
||||
action = self.state[4]
|
||||
embed = self.encoder(obs)
|
||||
post, _ = self.dynamics.obs_step(post, action, embed)
|
||||
feat = self.dynamics.get_feature(post)
|
||||
|
||||
embed = self.encoder(obs)
|
||||
post, _ = self.dynamics.obs_step(post, action, embed)
|
||||
feat = self.dynamics.get_feature(post)
|
||||
action_dist = self.actor(feat)
|
||||
if explore:
|
||||
action = action_dist.sample()
|
||||
else:
|
||||
action = action_dist.mean
|
||||
logp = action_dist.log_prob(action)
|
||||
|
||||
action_dist = self.actor(feat)
|
||||
if explore:
|
||||
action = action_dist.sample()
|
||||
else:
|
||||
action = action_dist.mean
|
||||
logp = action_dist.log_prob(action)
|
||||
self.state = post + [action]
|
||||
return action, logp, self.state
|
||||
|
||||
self.state = post + [action]
|
||||
return action, logp, self.state
|
||||
def imagine_ahead(self, state: List[TensorType],
|
||||
horizon: int) -> TensorType:
|
||||
"""Given a batch of states, rolls out more state of length horizon.
|
||||
"""
|
||||
start = []
|
||||
for s in state:
|
||||
s = s.contiguous().detach()
|
||||
shpe = [-1] + list(s.size())[2:]
|
||||
start.append(s.view(*shpe))
|
||||
|
||||
def imagine_ahead(self, state: List[TensorType],
|
||||
horizon: int) -> TensorType:
|
||||
"""Given a batch of states, rolls out more state of length horizon.
|
||||
"""
|
||||
start = []
|
||||
for s in state:
|
||||
s = s.contiguous().detach()
|
||||
shpe = [-1] + list(s.size())[2:]
|
||||
start.append(s.view(*shpe))
|
||||
def next_state(state):
|
||||
feature = self.dynamics.get_feature(state).detach()
|
||||
action = self.actor(feature).rsample()
|
||||
next_state = self.dynamics.img_step(state, action)
|
||||
return next_state
|
||||
|
||||
def next_state(state):
|
||||
feature = self.dynamics.get_feature(state).detach()
|
||||
action = self.actor(feature).rsample()
|
||||
next_state = self.dynamics.img_step(state, action)
|
||||
return next_state
|
||||
last = start
|
||||
outputs = [[] for i in range(len(start))]
|
||||
for _ in range(horizon):
|
||||
last = next_state(last)
|
||||
[o.append(l) for l, o in zip(last, outputs)]
|
||||
outputs = [torch.stack(x, dim=0) for x in outputs]
|
||||
|
||||
last = start
|
||||
outputs = [[] for i in range(len(start))]
|
||||
for _ in range(horizon):
|
||||
last = next_state(last)
|
||||
[o.append(l) for l, o in zip(last, outputs)]
|
||||
outputs = [torch.stack(x, dim=0) for x in outputs]
|
||||
imag_feat = self.dynamics.get_feature(outputs)
|
||||
return imag_feat
|
||||
|
||||
imag_feat = self.dynamics.get_feature(outputs)
|
||||
return imag_feat
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
self.state = self.dynamics.get_initial_state(1) + [
|
||||
torch.zeros(1, self.action_space.shape[0]).to(self.device)
|
||||
]
|
||||
return self.state
|
||||
|
||||
def get_initial_state(self) -> List[TensorType]:
|
||||
self.state = self.dynamics.get_initial_state(1) + [
|
||||
torch.zeros(1, self.action_space.shape[0]).to(self.device)
|
||||
]
|
||||
return self.state
|
||||
|
||||
def value_function(self) -> TensorType:
|
||||
return None
|
||||
def value_function(self) -> TensorType:
|
||||
return None
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from ray.rllib.utils.framework import try_import_torch
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
# Custom initialization for different types of layers
|
||||
|
@ -15,9 +16,6 @@ if torch:
|
|||
if self.bias is not None:
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
if torch:
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -27,9 +25,6 @@ if torch:
|
|||
if self.bias is not None:
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
if torch:
|
||||
|
||||
class ConvTranspose2d(nn.ConvTranspose2d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -39,9 +34,6 @@ if torch:
|
|||
if self.bias is not None:
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
if torch:
|
||||
|
||||
class GRUCell(nn.GRUCell):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -52,10 +44,7 @@ if torch:
|
|||
nn.init.zeros_(self.bias_ih)
|
||||
nn.init.zeros_(self.bias_hh)
|
||||
|
||||
|
||||
# Custom Tanh Bijector due to big gradients through Dreamer Actor
|
||||
if torch:
|
||||
|
||||
# Custom Tanh Bijector due to big gradients through Dreamer Actor
|
||||
class TanhBijector(torch.distributions.Transform):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -16,7 +16,7 @@ torch, nn = try_import_torch()
|
|||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class TestModules(unittest.TestCase):
|
||||
class TestAttentionNets(unittest.TestCase):
|
||||
"""Tests various torch/modules and tf/layers required for AttentionNet"""
|
||||
|
||||
def train_torch_full_model(self,
|
||||
|
|
64
rllib/models/tests/test_convtranspose2d_stack.py
Normal file
64
rllib/models/tests/test_convtranspose2d_stack.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
from ray.rllib.models.preprocessors import GenericPixelPreprocessor
|
||||
from ray.rllib.models.torch.modules.convtranspose2d_stack import \
|
||||
ConvTranspose2DStack
|
||||
from ray.rllib.utils.framework import try_import_torch, try_import_tf
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
class TestConvTranspose2DStack(unittest.TestCase):
|
||||
"""Tests our ConvTranspose2D Stack modules/layers."""
|
||||
|
||||
def test_convtranspose2d_stack(self):
|
||||
"""Tests, whether the conv2d stack can be trained to predict an image.
|
||||
"""
|
||||
batch_size = 128
|
||||
input_size = 1
|
||||
module = ConvTranspose2DStack(input_size=input_size)
|
||||
preprocessor = GenericPixelPreprocessor(
|
||||
gym.spaces.Box(0, 255, (64, 64, 3), np.uint8), options={"dim": 64})
|
||||
optim = torch.optim.Adam(module.parameters(), lr=0.0001)
|
||||
|
||||
rllib_dir = Path(__file__).parent.parent.parent
|
||||
img_file = os.path.join(rllib_dir,
|
||||
"tests/data/images/obstacle_tower.png")
|
||||
img = cv2.imread(img_file).astype(np.float32)
|
||||
# Preprocess.
|
||||
img = preprocessor.transform(img)
|
||||
# Make channels first.
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
# Add batch rank and repeat.
|
||||
imgs = np.reshape(img, (1, ) + img.shape)
|
||||
imgs = np.repeat(imgs, batch_size, axis=0)
|
||||
# Move to torch.
|
||||
imgs = torch.from_numpy(imgs)
|
||||
init_loss = loss = None
|
||||
for _ in range(10):
|
||||
# Random inputs.
|
||||
inputs = torch.from_numpy(
|
||||
np.random.normal(0.0, 1.0, (batch_size, input_size))).float()
|
||||
distribution = module(inputs)
|
||||
# Construct a loss.
|
||||
loss = -torch.mean(distribution.log_prob(imgs))
|
||||
if init_loss is None:
|
||||
init_loss = loss
|
||||
print("loss={}".format(loss))
|
||||
# Minimize loss.
|
||||
loss.backward()
|
||||
optim.step()
|
||||
self.assertLess(loss, init_loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -7,13 +7,15 @@ tf1, tf, tfv = try_import_tf()
|
|||
|
||||
|
||||
class NoisyLayer(tf.keras.layers.Layer if tf else object):
|
||||
"""A Layer that adds learnable Noise
|
||||
a common dense layer: y = w^{T}x + b
|
||||
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
"""A Layer that adds learnable Noise to some previous layer's outputs.
|
||||
|
||||
Consists of:
|
||||
- a common dense layer: y = w^{T}x + b
|
||||
- a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
(b+\\epsilon_b*\\sigma_b)
|
||||
where \epsilon are random variables sampled from factorized normal
|
||||
, where \epsilon are random variables sampled from factorized normal
|
||||
distributions and \\sigma are trainable variables which are expected to
|
||||
vanish along the training procedure
|
||||
vanish along the training procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, out_size, sigma0, activation="relu"):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
""" Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
|
||||
|
@ -138,3 +139,15 @@ class AppendBiasLayer(nn.Module):
|
|||
out = torch.cat(
|
||||
[x, self.log_std.unsqueeze(0).repeat([len(x), 1])], axis=1)
|
||||
return out
|
||||
|
||||
|
||||
class Reshape(nn.Module):
|
||||
"""Standard module that reshapes/views a tensor
|
||||
"""
|
||||
|
||||
def __init__(self, shape: List):
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(*self.shape)
|
||||
|
|
77
rllib/models/torch/modules/convtranspose2d_stack.py
Normal file
77
rllib/models/torch/modules/convtranspose2d_stack.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from typing import Tuple
|
||||
|
||||
from ray.rllib.models.torch.misc import Reshape
|
||||
from ray.rllib.models.utils import get_initializer
|
||||
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
import torch.distributions as td
|
||||
|
||||
|
||||
class ConvTranspose2DStack(nn.Module):
|
||||
"""ConvTranspose2D decoder generating an image distribution from a vector.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
input_size: int,
|
||||
filters: Tuple[Tuple[int]] = ((1024, 5, 2), (128, 5, 2),
|
||||
(64, 6, 2), (32, 6, 2)),
|
||||
initializer="default",
|
||||
bias_init=0,
|
||||
activation_fn: str = "relu",
|
||||
output_shape: Tuple[int] = (3, 64, 64)):
|
||||
"""Initializes a TransposedConv2DStack instance.
|
||||
|
||||
Args:
|
||||
input_size (int): The size of the 1D input vector, from which to
|
||||
generate the image distribution.
|
||||
filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each
|
||||
ConvTranspose2D layer): [in_channels, kernel, stride].
|
||||
initializer (Union[str]):
|
||||
bias_init (float): The initial bias values to use.
|
||||
activation_fn (str): Activation function descriptor (str).
|
||||
output_shape (Tuple[int]): Shape of the final output image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.activation = get_activation_fn(activation_fn, framework="torch")
|
||||
self.output_shape = output_shape
|
||||
initializer = get_initializer(initializer, framework="torch")
|
||||
|
||||
in_channels = filters[0][0]
|
||||
self.layers = [
|
||||
# Map from 1D-input vector to correct initial size for the
|
||||
# Conv2DTransposed stack.
|
||||
nn.Linear(input_size, in_channels),
|
||||
# Reshape from the incoming 1D vector (input_size) to 1x1 image
|
||||
# format (channels first).
|
||||
Reshape([-1, in_channels, 1, 1]),
|
||||
]
|
||||
for i, (_, kernel, stride) in enumerate(filters):
|
||||
out_channels = filters[i + 1][0] if i < len(filters) - 1 else \
|
||||
output_shape[0]
|
||||
conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel,
|
||||
stride)
|
||||
# Apply initializer.
|
||||
initializer(conv_transp.weight)
|
||||
nn.init.constant_(conv_transp.bias, bias_init)
|
||||
self.layers.append(conv_transp)
|
||||
# Apply activation function, if provided and if not last layer.
|
||||
if self.activation is not None and i < len(filters) - 1:
|
||||
self.layers.append(self.activation())
|
||||
|
||||
# num-outputs == num-inputs for next layer.
|
||||
in_channels = out_channels
|
||||
|
||||
self._model = nn.Sequential(*self.layers)
|
||||
|
||||
def forward(self, x):
|
||||
# x is [batch, hor_length, input_size]
|
||||
batch_dims = x.shape[:-1]
|
||||
model_out = self._model(x)
|
||||
|
||||
# Equivalent to making a multivariate diag.
|
||||
reshape_size = batch_dims + self.output_shape
|
||||
mean = model_out.view(*reshape_size)
|
||||
return td.Independent(td.Normal(mean, 1.0), len(self.output_shape))
|
|
@ -7,13 +7,15 @@ torch, nn = try_import_torch()
|
|||
|
||||
|
||||
class NoisyLayer(nn.Module):
|
||||
"""A Layer that adds learnable Noise
|
||||
a common dense layer: y = w^{T}x + b
|
||||
a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
"""A Layer that adds learnable Noise to some previous layer's outputs.
|
||||
|
||||
Consists of:
|
||||
- a common dense layer: y = w^{T}x + b
|
||||
- a noisy layer: y = (w + \\epsilon_w*\\sigma_w)^{T}x +
|
||||
(b+\\epsilon_b*\\sigma_b)
|
||||
where \epsilon are random variables sampled from factorized normal
|
||||
, where \epsilon are random variables sampled from factorized normal
|
||||
distributions and \\sigma are trainable variables which are expected to
|
||||
vanish along the training procedure
|
||||
vanish along the training procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, in_size, out_size, sigma0, activation="relu"):
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
|
||||
|
||||
def get_filter_config(shape):
|
||||
"""Returns a default Conv2D filter config (list) for a given image shape.
|
||||
|
||||
|
@ -30,3 +33,35 @@ def get_filter_config(shape):
|
|||
"Default configurations are only available for inputs of shape "
|
||||
"[42, 42, K] and [84, 84, K]. You may alternatively want "
|
||||
"to use a custom model or preprocessor.")
|
||||
|
||||
|
||||
def get_initializer(name, framework="tf"):
|
||||
"""Returns a framework specific initializer, given a name string.
|
||||
|
||||
Args:
|
||||
name (str): One of "xavier_uniform" (default), "xavier_normal".
|
||||
framework (str): One of "tf" or "torch".
|
||||
|
||||
Returns:
|
||||
A framework-specific initializer function, e.g.
|
||||
tf.keras.initializers.GlorotUniform or
|
||||
torch.nn.init.xavier_uniform_.
|
||||
|
||||
Raises:
|
||||
ValueError: If name is an unknown initializer.
|
||||
"""
|
||||
if framework == "torch":
|
||||
_, nn = try_import_torch()
|
||||
if name in [None, "default", "xavier_uniform"]:
|
||||
return nn.init.xavier_uniform_
|
||||
elif name == "xavier_normal":
|
||||
return nn.init.xavier_normal_
|
||||
else:
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
if name in [None, "default", "xavier_uniform"]:
|
||||
return tf.keras.initializers.GlorotUniform
|
||||
elif name == "xavier_normal":
|
||||
return tf.keras.initializers.GlorotNormal
|
||||
|
||||
raise ValueError("Unknown activation ({}) for framework={}!".format(
|
||||
name, framework))
|
||||
|
|
BIN
rllib/tests/data/images/obstacle_tower.png
Normal file
BIN
rllib/tests/data/images/obstacle_tower.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.5 KiB |
|
@ -219,11 +219,12 @@ def get_variable(value,
|
|||
return value
|
||||
|
||||
|
||||
# TODO: (sven) move to models/utils.py
|
||||
def get_activation_fn(name, framework="tf"):
|
||||
"""Returns a framework specific activation function, given a name string.
|
||||
|
||||
Args:
|
||||
name (str): One of "relu" (default), "tanh", or "linear".
|
||||
name (str): One of "relu" (default), "tanh", "swish", or "linear".
|
||||
framework (str): One of "tf" or "torch".
|
||||
|
||||
Returns:
|
||||
|
|
Loading…
Add table
Reference in a new issue