2020-08-26 04:24:05 -07:00
|
|
|
import numpy as np
|
|
|
|
from typing import Any, List, Tuple
|
2021-01-21 16:30:26 +01:00
|
|
|
from ray.rllib.models.torch.misc import Reshape
|
2020-08-26 04:24:05 -07:00
|
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
from ray.rllib.utils.framework import TensorType
|
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
if torch:
|
|
|
|
from torch import distributions as td
|
2022-05-16 00:45:32 -07:00
|
|
|
from ray.rllib.algorithms.dreamer.utils import (
|
2022-01-29 18:41:57 -08:00
|
|
|
Linear,
|
|
|
|
Conv2d,
|
|
|
|
ConvTranspose2d,
|
|
|
|
GRUCell,
|
|
|
|
TanhBijector,
|
|
|
|
)
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
ActFunc = Any
|
|
|
|
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
# Encoder, part of PlaNET
|
|
|
|
class ConvEncoder(nn.Module):
|
|
|
|
"""Standard Convolutional Encoder for Dreamer. This encoder is used
|
|
|
|
to encode images frm an enviornment into a latent state for the
|
|
|
|
RSSM model in PlaNET.
|
|
|
|
"""
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self, depth: int = 32, act: ActFunc = None, shape: Tuple[int] = (3, 64, 64)
|
|
|
|
):
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Initializes Conv Encoder
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
depth: Number of channels in the first conv layer
|
|
|
|
act: Activation for Encoder, default ReLU
|
|
|
|
shape: Shape of observation input
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.act = act
|
|
|
|
if not act:
|
|
|
|
self.act = nn.ReLU
|
|
|
|
self.depth = depth
|
|
|
|
self.shape = shape
|
|
|
|
|
|
|
|
init_channels = self.shape[0]
|
|
|
|
self.layers = [
|
|
|
|
Conv2d(init_channels, self.depth, 4, stride=2),
|
|
|
|
self.act(),
|
|
|
|
Conv2d(self.depth, 2 * self.depth, 4, stride=2),
|
|
|
|
self.act(),
|
|
|
|
Conv2d(2 * self.depth, 4 * self.depth, 4, stride=2),
|
|
|
|
self.act(),
|
|
|
|
Conv2d(4 * self.depth, 8 * self.depth, 4, stride=2),
|
|
|
|
self.act(),
|
|
|
|
]
|
|
|
|
self.model = nn.Sequential(*self.layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# Flatten to [batch*horizon, 3, 64, 64] in loss function
|
|
|
|
orig_shape = list(x.size())
|
|
|
|
x = x.view(-1, *(orig_shape[-3:]))
|
|
|
|
x = self.model(x)
|
|
|
|
|
|
|
|
new_shape = orig_shape[:-3] + [32 * self.depth]
|
|
|
|
x = x.view(*new_shape)
|
|
|
|
return x
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Decoder, part of PlaNET
|
2020-10-12 15:00:42 +02:00
|
|
|
class ConvDecoder(nn.Module):
|
|
|
|
"""Standard Convolutional Decoder for Dreamer.
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
This decoder is used to decode images from the latent state generated
|
|
|
|
by the transition dynamics model. This is used in calculating loss and
|
|
|
|
logging gifs for imagined trajectories.
|
|
|
|
"""
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_size: int,
|
|
|
|
depth: int = 32,
|
|
|
|
act: ActFunc = None,
|
|
|
|
shape: Tuple[int] = (3, 64, 64),
|
|
|
|
):
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Initializes a ConvDecoder instance.
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
input_size: Input size, usually feature size output from
|
2020-10-12 15:00:42 +02:00
|
|
|
RSSM.
|
2022-06-01 11:27:54 -07:00
|
|
|
depth: Number of channels in the first conv layer
|
|
|
|
act: Activation for Encoder, default ReLU
|
|
|
|
shape: Shape of observation input
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.act = act
|
|
|
|
if not act:
|
|
|
|
self.act = nn.ReLU
|
|
|
|
self.depth = depth
|
|
|
|
self.shape = shape
|
|
|
|
|
|
|
|
self.layers = [
|
|
|
|
Linear(input_size, 32 * self.depth),
|
|
|
|
Reshape([-1, 32 * self.depth, 1, 1]),
|
|
|
|
ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
|
|
|
|
self.act(),
|
|
|
|
ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
|
|
|
|
self.act(),
|
|
|
|
ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
|
|
|
|
self.act(),
|
|
|
|
ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
|
|
|
|
]
|
|
|
|
self.model = nn.Sequential(*self.layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# x is [batch, hor_length, input_size]
|
|
|
|
orig_shape = list(x.size())
|
|
|
|
x = self.model(x)
|
|
|
|
|
2021-08-18 18:47:08 +02:00
|
|
|
reshape_size = orig_shape[:-1] + list(self.shape)
|
2020-10-12 15:00:42 +02:00
|
|
|
mean = x.view(*reshape_size)
|
|
|
|
|
|
|
|
# Equivalent to making a multivariate diag
|
|
|
|
return td.Independent(td.Normal(mean, 1), len(self.shape))
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Reward Model (PlaNET), and Value Function
|
2020-10-12 15:00:42 +02:00
|
|
|
class DenseDecoder(nn.Module):
|
|
|
|
"""FC network that outputs a distribution for calculating log_prob.
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Used later in DreamerLoss.
|
|
|
|
"""
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_size: int,
|
|
|
|
output_size: int,
|
|
|
|
layers: int,
|
|
|
|
units: int,
|
|
|
|
dist: str = "normal",
|
|
|
|
act: ActFunc = None,
|
|
|
|
):
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Initializes FC network
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
input_size: Input size to network
|
|
|
|
output_size: Output size to network
|
|
|
|
layers: Number of layers in network
|
|
|
|
units: Size of the hidden layers
|
|
|
|
dist: Output distribution, parameterized by FC output
|
2020-10-12 15:00:42 +02:00
|
|
|
logits.
|
2022-06-01 11:27:54 -07:00
|
|
|
act: Activation function
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.layrs = layers
|
|
|
|
self.units = units
|
|
|
|
self.act = act
|
|
|
|
if not act:
|
|
|
|
self.act = nn.ELU
|
|
|
|
self.dist = dist
|
|
|
|
self.input_size = input_size
|
|
|
|
self.output_size = output_size
|
|
|
|
self.layers = []
|
|
|
|
cur_size = input_size
|
|
|
|
for _ in range(self.layrs):
|
|
|
|
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
|
|
|
cur_size = units
|
|
|
|
self.layers.append(Linear(cur_size, output_size))
|
|
|
|
self.model = nn.Sequential(*self.layers)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.model(x)
|
|
|
|
if self.output_size == 1:
|
|
|
|
x = torch.squeeze(x)
|
|
|
|
if self.dist == "normal":
|
|
|
|
output_dist = td.Normal(x, 1)
|
|
|
|
elif self.dist == "binary":
|
|
|
|
output_dist = td.Bernoulli(logits=x)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Distribution type not implemented!")
|
|
|
|
return td.Independent(output_dist, 0)
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Represents dreamer policy
|
2020-10-12 15:00:42 +02:00
|
|
|
class ActionDecoder(nn.Module):
|
|
|
|
"""ActionDecoder is the policy module in Dreamer.
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
It outputs a distribution parameterized by mean and std, later to be
|
|
|
|
transformed by a custom TanhBijector in utils.py for Dreamer.
|
|
|
|
"""
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_size: int,
|
|
|
|
action_size: int,
|
|
|
|
layers: int,
|
|
|
|
units: int,
|
|
|
|
dist: str = "tanh_normal",
|
|
|
|
act: ActFunc = None,
|
|
|
|
min_std: float = 1e-4,
|
|
|
|
init_std: float = 5.0,
|
|
|
|
mean_scale: float = 5.0,
|
|
|
|
):
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Initializes Policy
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
input_size: Input size to network
|
|
|
|
action_size: Action space size
|
|
|
|
layers: Number of layers in network
|
|
|
|
units: Size of the hidden layers
|
|
|
|
dist: Output distribution, with tanh_normal implemented
|
|
|
|
act: Activation function
|
|
|
|
min_std: Minimum std for output distribution
|
|
|
|
init_std: Intitial std
|
|
|
|
mean_scale: Augmenting mean output from FC network
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.layrs = layers
|
|
|
|
self.units = units
|
|
|
|
self.dist = dist
|
|
|
|
self.act = act
|
|
|
|
if not act:
|
|
|
|
self.act = nn.ReLU
|
|
|
|
self.min_std = min_std
|
|
|
|
self.init_std = init_std
|
|
|
|
self.mean_scale = mean_scale
|
|
|
|
self.action_size = action_size
|
|
|
|
|
|
|
|
self.layers = []
|
|
|
|
self.softplus = nn.Softplus()
|
|
|
|
|
|
|
|
# MLP Construction
|
|
|
|
cur_size = input_size
|
|
|
|
for _ in range(self.layrs):
|
|
|
|
self.layers.extend([Linear(cur_size, self.units), self.act()])
|
|
|
|
cur_size = self.units
|
|
|
|
if self.dist == "tanh_normal":
|
|
|
|
self.layers.append(Linear(cur_size, 2 * action_size))
|
|
|
|
elif self.dist == "onehot":
|
|
|
|
self.layers.append(Linear(cur_size, action_size))
|
|
|
|
self.model = nn.Sequential(*self.layers)
|
|
|
|
|
|
|
|
# Returns distribution
|
|
|
|
def forward(self, x):
|
|
|
|
raw_init_std = np.log(np.exp(self.init_std) - 1)
|
|
|
|
x = self.model(x)
|
|
|
|
if self.dist == "tanh_normal":
|
|
|
|
mean, std = torch.chunk(x, 2, dim=-1)
|
|
|
|
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
|
|
|
|
std = self.softplus(std + raw_init_std) + self.min_std
|
|
|
|
dist = td.Normal(mean, std)
|
|
|
|
transforms = [TanhBijector()]
|
2022-01-29 18:41:57 -08:00
|
|
|
dist = td.transformed_distribution.TransformedDistribution(dist, transforms)
|
2020-10-12 15:00:42 +02:00
|
|
|
dist = td.Independent(dist, 1)
|
|
|
|
elif self.dist == "onehot":
|
|
|
|
dist = td.OneHotCategorical(logits=x)
|
|
|
|
raise NotImplementedError("Atari not implemented yet!")
|
|
|
|
return dist
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Represents TD model in PlaNET
|
2020-10-12 15:00:42 +02:00
|
|
|
class RSSM(nn.Module):
|
|
|
|
"""RSSM is the core recurrent part of the PlaNET module. It consists of
|
|
|
|
two networks, one (obs) to calculate posterior beliefs and states and
|
|
|
|
the second (img) to calculate prior beliefs and states. The prior network
|
|
|
|
takes in the previous state and action, while the posterior network takes
|
|
|
|
in the previous state, action, and a latent embedding of the most recent
|
|
|
|
observation.
|
|
|
|
"""
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
action_size: int,
|
|
|
|
embed_size: int,
|
|
|
|
stoch: int = 30,
|
|
|
|
deter: int = 200,
|
|
|
|
hidden: int = 200,
|
|
|
|
act: ActFunc = None,
|
|
|
|
):
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Initializes RSSM
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
action_size: Action space size
|
|
|
|
embed_size: Size of ConvEncoder embedding
|
|
|
|
stoch: Size of the distributional hidden state
|
|
|
|
deter: Size of the deterministic hidden state
|
|
|
|
hidden: General size of hidden layers
|
|
|
|
act: Activation function
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.stoch_size = stoch
|
|
|
|
self.deter_size = deter
|
|
|
|
self.hidden_size = hidden
|
|
|
|
self.act = act
|
|
|
|
if act is None:
|
|
|
|
self.act = nn.ELU
|
|
|
|
|
|
|
|
self.obs1 = Linear(embed_size + deter, hidden)
|
|
|
|
self.obs2 = Linear(hidden, 2 * stoch)
|
|
|
|
|
|
|
|
self.cell = GRUCell(self.hidden_size, hidden_size=self.deter_size)
|
|
|
|
self.img1 = Linear(stoch + action_size, hidden)
|
|
|
|
self.img2 = Linear(deter, hidden)
|
|
|
|
self.img3 = Linear(hidden, 2 * stoch)
|
|
|
|
|
|
|
|
self.softplus = nn.Softplus
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
self.device = (
|
|
|
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
)
|
2020-10-12 15:00:42 +02:00
|
|
|
|
|
|
|
def get_initial_state(self, batch_size: int) -> List[TensorType]:
|
|
|
|
"""Returns the inital state for the RSSM, which consists of mean,
|
|
|
|
std for the stochastic state, the sampled stochastic hidden state
|
|
|
|
(from mean, std), and the deterministic hidden state, which is
|
|
|
|
pushed through the GRUCell.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
batch_size: Batch size for initial state
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Returns:
|
2020-08-26 04:24:05 -07:00
|
|
|
List of tensors
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
return [
|
|
|
|
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
|
|
|
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
|
|
|
torch.zeros(batch_size, self.stoch_size).to(self.device),
|
|
|
|
torch.zeros(batch_size, self.deter_size).to(self.device),
|
|
|
|
]
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def observe(
|
|
|
|
self, embed: TensorType, action: TensorType, state: List[TensorType] = None
|
|
|
|
) -> Tuple[List[TensorType], List[TensorType]]:
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Returns the corresponding states from the embedding from ConvEncoder
|
|
|
|
and actions. This is accomplished by rolling out the RNN from the
|
2021-08-18 18:47:08 +02:00
|
|
|
starting state through each index of embed and action, saving all
|
2020-10-12 15:00:42 +02:00
|
|
|
intermediate states between.
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
embed: ConvEncoder embedding
|
|
|
|
action: Actions
|
2020-08-26 04:24:05 -07:00
|
|
|
state (List[TensorType]): Initial state before rollout
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Returns:
|
2020-08-26 04:24:05 -07:00
|
|
|
Posterior states and prior states (both List[TensorType])
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
if state is None:
|
|
|
|
state = self.get_initial_state(action.size()[0])
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2021-08-18 18:47:08 +02:00
|
|
|
if embed.dim() <= 2:
|
|
|
|
embed = torch.unsqueeze(embed, 1)
|
|
|
|
|
|
|
|
if action.dim() <= 2:
|
|
|
|
action = torch.unsqueeze(action, 1)
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
embed = embed.permute(1, 0, 2)
|
|
|
|
action = action.permute(1, 0, 2)
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
priors = [[] for i in range(len(state))]
|
|
|
|
posts = [[] for i in range(len(state))]
|
|
|
|
last = (state, state)
|
|
|
|
for index in range(len(action)):
|
|
|
|
# Tuple of post and prior
|
|
|
|
last = self.obs_step(last[0], action[index], embed[index])
|
2021-05-03 14:23:28 -07:00
|
|
|
[o.append(s) for s, o in zip(last[0], posts)]
|
|
|
|
[o.append(s) for s, o in zip(last[1], priors)]
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
prior = [torch.stack(x, dim=0) for x in priors]
|
|
|
|
post = [torch.stack(x, dim=0) for x in posts]
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
prior = [e.permute(1, 0, 2) for e in prior]
|
|
|
|
post = [e.permute(1, 0, 2) for e in post]
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
return post, prior
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def imagine(
|
|
|
|
self, action: TensorType, state: List[TensorType] = None
|
|
|
|
) -> List[TensorType]:
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Imagines the trajectory starting from state through a list of actions.
|
|
|
|
Similar to observe(), requires rolling out the RNN for each timestep.
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
action: Actions
|
2020-08-26 04:24:05 -07:00
|
|
|
state (List[TensorType]): Starting state before rollout
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Returns:
|
2020-08-26 04:24:05 -07:00
|
|
|
Prior states
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
if state is None:
|
|
|
|
state = self.get_initial_state(action.size()[0])
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
action = action.permute(1, 0, 2)
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
indices = range(len(action))
|
|
|
|
priors = [[] for _ in range(len(state))]
|
|
|
|
last = state
|
|
|
|
for index in indices:
|
|
|
|
last = self.img_step(last, action[index])
|
2021-05-03 14:23:28 -07:00
|
|
|
[o.append(s) for s, o in zip(last, priors)]
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
prior = [torch.stack(x, dim=0) for x in priors]
|
|
|
|
prior = [e.permute(1, 0, 2) for e in prior]
|
|
|
|
return prior
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
def obs_step(
|
2022-01-29 18:41:57 -08:00
|
|
|
self, prev_state: TensorType, prev_action: TensorType, embed: TensorType
|
|
|
|
) -> Tuple[List[TensorType], List[TensorType]]:
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Runs through the posterior model and returns the posterior state
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
prev_state: The previous state
|
|
|
|
prev_action: The previous action
|
|
|
|
embed: Embedding from ConvEncoder
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Returns:
|
2020-08-26 04:24:05 -07:00
|
|
|
Post and Prior state
|
2022-01-29 18:41:57 -08:00
|
|
|
"""
|
2020-10-12 15:00:42 +02:00
|
|
|
prior = self.img_step(prev_state, prev_action)
|
|
|
|
x = torch.cat([prior[3], embed], dim=-1)
|
|
|
|
x = self.obs1(x)
|
|
|
|
x = self.act()(x)
|
|
|
|
x = self.obs2(x)
|
|
|
|
mean, std = torch.chunk(x, 2, dim=-1)
|
|
|
|
std = self.softplus()(std) + 0.1
|
|
|
|
stoch = self.get_dist(mean, std).rsample()
|
|
|
|
post = [mean, std, stoch, prior[3]]
|
|
|
|
return post, prior
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def img_step(
|
|
|
|
self, prev_state: TensorType, prev_action: TensorType
|
|
|
|
) -> List[TensorType]:
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Runs through the prior model and returns the prior state
|
|
|
|
|
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
prev_state: The previous state
|
|
|
|
prev_action: The previous action
|
2020-08-26 04:24:05 -07:00
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
Returns:
|
2020-08-26 04:24:05 -07:00
|
|
|
Prior state
|
2020-10-12 15:00:42 +02:00
|
|
|
"""
|
|
|
|
x = torch.cat([prev_state[2], prev_action], dim=-1)
|
|
|
|
x = self.img1(x)
|
|
|
|
x = self.act()(x)
|
|
|
|
deter = self.cell(x, prev_state[3])
|
|
|
|
x = deter
|
|
|
|
x = self.img2(x)
|
|
|
|
x = self.act()(x)
|
|
|
|
x = self.img3(x)
|
|
|
|
mean, std = torch.chunk(x, 2, dim=-1)
|
|
|
|
std = self.softplus()(std) + 0.1
|
|
|
|
stoch = self.get_dist(mean, std).rsample()
|
|
|
|
return [mean, std, stoch, deter]
|
|
|
|
|
|
|
|
def get_feature(self, state: List[TensorType]) -> TensorType:
|
|
|
|
# Constructs feature for input to reward, decoder, actor, critic
|
|
|
|
return torch.cat([state[2], state[3]], dim=-1)
|
|
|
|
|
|
|
|
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
|
|
|
|
return td.Normal(mean, std)
|
2020-08-26 04:24:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Represents all models in Dreamer, unifies them all into a single interface
|
2020-10-12 15:00:42 +02:00
|
|
|
class DreamerModel(TorchModelV2, nn.Module):
|
2022-01-29 18:41:57 -08:00
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
|
|
|
|
super().__init__(obs_space, action_space, num_outputs, model_config, name)
|
2020-10-12 15:00:42 +02:00
|
|
|
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
self.depth = model_config["depth_size"]
|
|
|
|
self.deter_size = model_config["deter_size"]
|
|
|
|
self.stoch_size = model_config["stoch_size"]
|
|
|
|
self.hidden_size = model_config["hidden_size"]
|
|
|
|
|
|
|
|
self.action_size = action_space.shape[0]
|
|
|
|
|
|
|
|
self.encoder = ConvEncoder(self.depth)
|
2022-01-29 18:41:57 -08:00
|
|
|
self.decoder = ConvDecoder(self.stoch_size + self.deter_size, depth=self.depth)
|
|
|
|
self.reward = DenseDecoder(
|
|
|
|
self.stoch_size + self.deter_size, 1, 2, self.hidden_size
|
|
|
|
)
|
2020-10-12 15:00:42 +02:00
|
|
|
self.dynamics = RSSM(
|
|
|
|
self.action_size,
|
|
|
|
32 * self.depth,
|
|
|
|
stoch=self.stoch_size,
|
2022-01-29 18:41:57 -08:00
|
|
|
deter=self.deter_size,
|
|
|
|
)
|
|
|
|
self.actor = ActionDecoder(
|
|
|
|
self.stoch_size + self.deter_size, self.action_size, 4, self.hidden_size
|
|
|
|
)
|
|
|
|
self.value = DenseDecoder(
|
|
|
|
self.stoch_size + self.deter_size, 1, 3, self.hidden_size
|
|
|
|
)
|
2020-10-12 15:00:42 +02:00
|
|
|
self.state = None
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
self.device = (
|
|
|
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
)
|
2020-10-12 15:00:42 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def policy(
|
|
|
|
self, obs: TensorType, state: List[TensorType], explore=True
|
|
|
|
) -> Tuple[TensorType, List[float], List[TensorType]]:
|
2020-10-12 15:00:42 +02:00
|
|
|
"""Returns the action. Runs through the encoder, recurrent model,
|
|
|
|
and policy to obtain action.
|
|
|
|
"""
|
|
|
|
if state is None:
|
2021-08-18 18:47:08 +02:00
|
|
|
self.state = self.get_initial_state(batch_size=obs.shape[0])
|
2020-10-12 15:00:42 +02:00
|
|
|
else:
|
|
|
|
self.state = state
|
|
|
|
post = self.state[:4]
|
|
|
|
action = self.state[4]
|
|
|
|
|
|
|
|
embed = self.encoder(obs)
|
|
|
|
post, _ = self.dynamics.obs_step(post, action, embed)
|
|
|
|
feat = self.dynamics.get_feature(post)
|
|
|
|
|
|
|
|
action_dist = self.actor(feat)
|
|
|
|
if explore:
|
|
|
|
action = action_dist.sample()
|
|
|
|
else:
|
|
|
|
action = action_dist.mean
|
|
|
|
logp = action_dist.log_prob(action)
|
|
|
|
|
|
|
|
self.state = post + [action]
|
|
|
|
return action, logp, self.state
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def imagine_ahead(self, state: List[TensorType], horizon: int) -> TensorType:
|
|
|
|
"""Given a batch of states, rolls out more state of length horizon."""
|
2020-10-12 15:00:42 +02:00
|
|
|
start = []
|
|
|
|
for s in state:
|
|
|
|
s = s.contiguous().detach()
|
|
|
|
shpe = [-1] + list(s.size())[2:]
|
|
|
|
start.append(s.view(*shpe))
|
|
|
|
|
|
|
|
def next_state(state):
|
|
|
|
feature = self.dynamics.get_feature(state).detach()
|
|
|
|
action = self.actor(feature).rsample()
|
|
|
|
next_state = self.dynamics.img_step(state, action)
|
|
|
|
return next_state
|
|
|
|
|
|
|
|
last = start
|
|
|
|
outputs = [[] for i in range(len(start))]
|
|
|
|
for _ in range(horizon):
|
|
|
|
last = next_state(last)
|
2021-05-03 14:23:28 -07:00
|
|
|
[o.append(s) for s, o in zip(last, outputs)]
|
2020-10-12 15:00:42 +02:00
|
|
|
outputs = [torch.stack(x, dim=0) for x in outputs]
|
|
|
|
|
|
|
|
imag_feat = self.dynamics.get_feature(outputs)
|
|
|
|
return imag_feat
|
|
|
|
|
|
|
|
def get_initial_state(self) -> List[TensorType]:
|
|
|
|
self.state = self.dynamics.get_initial_state(1) + [
|
|
|
|
torch.zeros(1, self.action_space.shape[0]).to(self.device)
|
|
|
|
]
|
|
|
|
return self.state
|
|
|
|
|
|
|
|
def value_function(self) -> TensorType:
|
|
|
|
return None
|