ray/rllib/algorithms/dreamer/dreamer_model.py

556 lines
18 KiB
Python

import numpy as np
from typing import Any, List, Tuple
from ray.rllib.models.torch.misc import Reshape
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.framework import TensorType
torch, nn = try_import_torch()
if torch:
from torch import distributions as td
from ray.rllib.algorithms.dreamer.utils import (
Linear,
Conv2d,
ConvTranspose2d,
GRUCell,
TanhBijector,
)
ActFunc = Any
# Encoder, part of PlaNET
class ConvEncoder(nn.Module):
"""Standard Convolutional Encoder for Dreamer. This encoder is used
to encode images frm an enviornment into a latent state for the
RSSM model in PlaNET.
"""
def __init__(
self, depth: int = 32, act: ActFunc = None, shape: Tuple[int] = (3, 64, 64)
):
"""Initializes Conv Encoder
Args:
depth: Number of channels in the first conv layer
act: Activation for Encoder, default ReLU
shape: Shape of observation input
"""
super().__init__()
self.act = act
if not act:
self.act = nn.ReLU
self.depth = depth
self.shape = shape
init_channels = self.shape[0]
self.layers = [
Conv2d(init_channels, self.depth, 4, stride=2),
self.act(),
Conv2d(self.depth, 2 * self.depth, 4, stride=2),
self.act(),
Conv2d(2 * self.depth, 4 * self.depth, 4, stride=2),
self.act(),
Conv2d(4 * self.depth, 8 * self.depth, 4, stride=2),
self.act(),
]
self.model = nn.Sequential(*self.layers)
def forward(self, x):
# Flatten to [batch*horizon, 3, 64, 64] in loss function
orig_shape = list(x.size())
x = x.view(-1, *(orig_shape[-3:]))
x = self.model(x)
new_shape = orig_shape[:-3] + [32 * self.depth]
x = x.view(*new_shape)
return x
# Decoder, part of PlaNET
class ConvDecoder(nn.Module):
"""Standard Convolutional Decoder for Dreamer.
This decoder is used to decode images from the latent state generated
by the transition dynamics model. This is used in calculating loss and
logging gifs for imagined trajectories.
"""
def __init__(
self,
input_size: int,
depth: int = 32,
act: ActFunc = None,
shape: Tuple[int] = (3, 64, 64),
):
"""Initializes a ConvDecoder instance.
Args:
input_size: Input size, usually feature size output from
RSSM.
depth: Number of channels in the first conv layer
act: Activation for Encoder, default ReLU
shape: Shape of observation input
"""
super().__init__()
self.act = act
if not act:
self.act = nn.ReLU
self.depth = depth
self.shape = shape
self.layers = [
Linear(input_size, 32 * self.depth),
Reshape([-1, 32 * self.depth, 1, 1]),
ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
self.act(),
ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
self.act(),
ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
self.act(),
ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
]
self.model = nn.Sequential(*self.layers)
def forward(self, x):
# x is [batch, hor_length, input_size]
orig_shape = list(x.size())
x = self.model(x)
reshape_size = orig_shape[:-1] + list(self.shape)
mean = x.view(*reshape_size)
# Equivalent to making a multivariate diag
return td.Independent(td.Normal(mean, 1), len(self.shape))
# Reward Model (PlaNET), and Value Function
class DenseDecoder(nn.Module):
"""FC network that outputs a distribution for calculating log_prob.
Used later in DreamerLoss.
"""
def __init__(
self,
input_size: int,
output_size: int,
layers: int,
units: int,
dist: str = "normal",
act: ActFunc = None,
):
"""Initializes FC network
Args:
input_size: Input size to network
output_size: Output size to network
layers: Number of layers in network
units: Size of the hidden layers
dist: Output distribution, parameterized by FC output
logits.
act: Activation function
"""
super().__init__()
self.layrs = layers
self.units = units
self.act = act
if not act:
self.act = nn.ELU
self.dist = dist
self.input_size = input_size
self.output_size = output_size
self.layers = []
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([Linear(cur_size, self.units), self.act()])
cur_size = units
self.layers.append(Linear(cur_size, output_size))
self.model = nn.Sequential(*self.layers)
def forward(self, x):
x = self.model(x)
if self.output_size == 1:
x = torch.squeeze(x)
if self.dist == "normal":
output_dist = td.Normal(x, 1)
elif self.dist == "binary":
output_dist = td.Bernoulli(logits=x)
else:
raise NotImplementedError("Distribution type not implemented!")
return td.Independent(output_dist, 0)
# Represents dreamer policy
class ActionDecoder(nn.Module):
"""ActionDecoder is the policy module in Dreamer.
It outputs a distribution parameterized by mean and std, later to be
transformed by a custom TanhBijector in utils.py for Dreamer.
"""
def __init__(
self,
input_size: int,
action_size: int,
layers: int,
units: int,
dist: str = "tanh_normal",
act: ActFunc = None,
min_std: float = 1e-4,
init_std: float = 5.0,
mean_scale: float = 5.0,
):
"""Initializes Policy
Args:
input_size: Input size to network
action_size: Action space size
layers: Number of layers in network
units: Size of the hidden layers
dist: Output distribution, with tanh_normal implemented
act: Activation function
min_std: Minimum std for output distribution
init_std: Intitial std
mean_scale: Augmenting mean output from FC network
"""
super().__init__()
self.layrs = layers
self.units = units
self.dist = dist
self.act = act
if not act:
self.act = nn.ReLU
self.min_std = min_std
self.init_std = init_std
self.mean_scale = mean_scale
self.action_size = action_size
self.layers = []
self.softplus = nn.Softplus()
# MLP Construction
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([Linear(cur_size, self.units), self.act()])
cur_size = self.units
if self.dist == "tanh_normal":
self.layers.append(Linear(cur_size, 2 * action_size))
elif self.dist == "onehot":
self.layers.append(Linear(cur_size, action_size))
self.model = nn.Sequential(*self.layers)
# Returns distribution
def forward(self, x):
raw_init_std = np.log(np.exp(self.init_std) - 1)
x = self.model(x)
if self.dist == "tanh_normal":
mean, std = torch.chunk(x, 2, dim=-1)
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
std = self.softplus(std + raw_init_std) + self.min_std
dist = td.Normal(mean, std)
transforms = [TanhBijector()]
dist = td.transformed_distribution.TransformedDistribution(dist, transforms)
dist = td.Independent(dist, 1)
elif self.dist == "onehot":
dist = td.OneHotCategorical(logits=x)
raise NotImplementedError("Atari not implemented yet!")
return dist
# Represents TD model in PlaNET
class RSSM(nn.Module):
"""RSSM is the core recurrent part of the PlaNET module. It consists of
two networks, one (obs) to calculate posterior beliefs and states and
the second (img) to calculate prior beliefs and states. The prior network
takes in the previous state and action, while the posterior network takes
in the previous state, action, and a latent embedding of the most recent
observation.
"""
def __init__(
self,
action_size: int,
embed_size: int,
stoch: int = 30,
deter: int = 200,
hidden: int = 200,
act: ActFunc = None,
):
"""Initializes RSSM
Args:
action_size: Action space size
embed_size: Size of ConvEncoder embedding
stoch: Size of the distributional hidden state
deter: Size of the deterministic hidden state
hidden: General size of hidden layers
act: Activation function
"""
super().__init__()
self.stoch_size = stoch
self.deter_size = deter
self.hidden_size = hidden
self.act = act
if act is None:
self.act = nn.ELU
self.obs1 = Linear(embed_size + deter, hidden)
self.obs2 = Linear(hidden, 2 * stoch)
self.cell = GRUCell(self.hidden_size, hidden_size=self.deter_size)
self.img1 = Linear(stoch + action_size, hidden)
self.img2 = Linear(deter, hidden)
self.img3 = Linear(hidden, 2 * stoch)
self.softplus = nn.Softplus
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
def get_initial_state(self, batch_size: int) -> List[TensorType]:
"""Returns the inital state for the RSSM, which consists of mean,
std for the stochastic state, the sampled stochastic hidden state
(from mean, std), and the deterministic hidden state, which is
pushed through the GRUCell.
Args:
batch_size: Batch size for initial state
Returns:
List of tensors
"""
return [
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.deter_size).to(self.device),
]
def observe(
self, embed: TensorType, action: TensorType, state: List[TensorType] = None
) -> Tuple[List[TensorType], List[TensorType]]:
"""Returns the corresponding states from the embedding from ConvEncoder
and actions. This is accomplished by rolling out the RNN from the
starting state through each index of embed and action, saving all
intermediate states between.
Args:
embed: ConvEncoder embedding
action: Actions
state (List[TensorType]): Initial state before rollout
Returns:
Posterior states and prior states (both List[TensorType])
"""
if state is None:
state = self.get_initial_state(action.size()[0])
if embed.dim() <= 2:
embed = torch.unsqueeze(embed, 1)
if action.dim() <= 2:
action = torch.unsqueeze(action, 1)
embed = embed.permute(1, 0, 2)
action = action.permute(1, 0, 2)
priors = [[] for i in range(len(state))]
posts = [[] for i in range(len(state))]
last = (state, state)
for index in range(len(action)):
# Tuple of post and prior
last = self.obs_step(last[0], action[index], embed[index])
[o.append(s) for s, o in zip(last[0], posts)]
[o.append(s) for s, o in zip(last[1], priors)]
prior = [torch.stack(x, dim=0) for x in priors]
post = [torch.stack(x, dim=0) for x in posts]
prior = [e.permute(1, 0, 2) for e in prior]
post = [e.permute(1, 0, 2) for e in post]
return post, prior
def imagine(
self, action: TensorType, state: List[TensorType] = None
) -> List[TensorType]:
"""Imagines the trajectory starting from state through a list of actions.
Similar to observe(), requires rolling out the RNN for each timestep.
Args:
action: Actions
state (List[TensorType]): Starting state before rollout
Returns:
Prior states
"""
if state is None:
state = self.get_initial_state(action.size()[0])
action = action.permute(1, 0, 2)
indices = range(len(action))
priors = [[] for _ in range(len(state))]
last = state
for index in indices:
last = self.img_step(last, action[index])
[o.append(s) for s, o in zip(last, priors)]
prior = [torch.stack(x, dim=0) for x in priors]
prior = [e.permute(1, 0, 2) for e in prior]
return prior
def obs_step(
self, prev_state: TensorType, prev_action: TensorType, embed: TensorType
) -> Tuple[List[TensorType], List[TensorType]]:
"""Runs through the posterior model and returns the posterior state
Args:
prev_state: The previous state
prev_action: The previous action
embed: Embedding from ConvEncoder
Returns:
Post and Prior state
"""
prior = self.img_step(prev_state, prev_action)
x = torch.cat([prior[3], embed], dim=-1)
x = self.obs1(x)
x = self.act()(x)
x = self.obs2(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
post = [mean, std, stoch, prior[3]]
return post, prior
def img_step(
self, prev_state: TensorType, prev_action: TensorType
) -> List[TensorType]:
"""Runs through the prior model and returns the prior state
Args:
prev_state: The previous state
prev_action: The previous action
Returns:
Prior state
"""
x = torch.cat([prev_state[2], prev_action], dim=-1)
x = self.img1(x)
x = self.act()(x)
deter = self.cell(x, prev_state[3])
x = deter
x = self.img2(x)
x = self.act()(x)
x = self.img3(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
return [mean, std, stoch, deter]
def get_feature(self, state: List[TensorType]) -> TensorType:
# Constructs feature for input to reward, decoder, actor, critic
return torch.cat([state[2], state[3]], dim=-1)
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
return td.Normal(mean, std)
# Represents all models in Dreamer, unifies them all into a single interface
class DreamerModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super().__init__(obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.depth = model_config["depth_size"]
self.deter_size = model_config["deter_size"]
self.stoch_size = model_config["stoch_size"]
self.hidden_size = model_config["hidden_size"]
self.action_size = action_space.shape[0]
self.encoder = ConvEncoder(self.depth)
self.decoder = ConvDecoder(self.stoch_size + self.deter_size, depth=self.depth)
self.reward = DenseDecoder(
self.stoch_size + self.deter_size, 1, 2, self.hidden_size
)
self.dynamics = RSSM(
self.action_size,
32 * self.depth,
stoch=self.stoch_size,
deter=self.deter_size,
)
self.actor = ActionDecoder(
self.stoch_size + self.deter_size, self.action_size, 4, self.hidden_size
)
self.value = DenseDecoder(
self.stoch_size + self.deter_size, 1, 3, self.hidden_size
)
self.state = None
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
def policy(
self, obs: TensorType, state: List[TensorType], explore=True
) -> Tuple[TensorType, List[float], List[TensorType]]:
"""Returns the action. Runs through the encoder, recurrent model,
and policy to obtain action.
"""
if state is None:
self.state = self.get_initial_state(batch_size=obs.shape[0])
else:
self.state = state
post = self.state[:4]
action = self.state[4]
embed = self.encoder(obs)
post, _ = self.dynamics.obs_step(post, action, embed)
feat = self.dynamics.get_feature(post)
action_dist = self.actor(feat)
if explore:
action = action_dist.sample()
else:
action = action_dist.mean
logp = action_dist.log_prob(action)
self.state = post + [action]
return action, logp, self.state
def imagine_ahead(self, state: List[TensorType], horizon: int) -> TensorType:
"""Given a batch of states, rolls out more state of length horizon."""
start = []
for s in state:
s = s.contiguous().detach()
shpe = [-1] + list(s.size())[2:]
start.append(s.view(*shpe))
def next_state(state):
feature = self.dynamics.get_feature(state).detach()
action = self.actor(feature).rsample()
next_state = self.dynamics.img_step(state, action)
return next_state
last = start
outputs = [[] for i in range(len(start))]
for _ in range(horizon):
last = next_state(last)
[o.append(s) for s, o in zip(last, outputs)]
outputs = [torch.stack(x, dim=0) for x in outputs]
imag_feat = self.dynamics.get_feature(outputs)
return imag_feat
def get_initial_state(self) -> List[TensorType]:
self.state = self.dynamics.get_initial_state(1) + [
torch.zeros(1, self.action_space.shape[0]).to(self.device)
]
return self.state
def value_function(self) -> TensorType:
return None