ray/rllib/agents/dreamer/dreamer_model.py

560 lines
20 KiB
Python
Raw Normal View History

2020-08-26 04:24:05 -07:00
import numpy as np
from typing import Any, List, Tuple
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.framework import TensorType
torch, nn = try_import_torch()
if torch:
from torch import distributions as td
from ray.rllib.agents.dreamer.utils import Linear, Conv2d, \
ConvTranspose2d, GRUCell, TanhBijector
ActFunc = Any
# Encoder, part of PlaNET
if torch:
class ConvEncoder(nn.Module):
"""Standard Convolutional Encoder for Dreamer. This encoder is used
to encode images frm an enviornment into a latent state for the
RSSM model in PlaNET.
"""
def __init__(self,
depth: int = 32,
act: ActFunc = None,
shape: List = [3, 64, 64]):
"""Initializes Conv Encoder
Args:
depth (int): Number of channels in the first conv layer
act (Any): Activation for Encoder, default ReLU
shape (List): Shape of observation input
"""
super().__init__()
self.act = act
if not act:
self.act = nn.ReLU
self.depth = depth
self.shape = shape
init_channels = self.shape[0]
self.layers = [
Conv2d(init_channels, self.depth, 4, stride=2),
self.act(),
Conv2d(self.depth, 2 * self.depth, 4, stride=2),
self.act(),
Conv2d(2 * self.depth, 4 * self.depth, 4, stride=2),
self.act(),
Conv2d(4 * self.depth, 8 * self.depth, 4, stride=2),
self.act(),
]
self.model = nn.Sequential(*self.layers)
def forward(self, x):
# Flatten to [batch*horizon, 3, 64, 64] in loss function
orig_shape = list(x.size())
x = x.view(-1, *(orig_shape[-3:]))
x = self.model(x)
new_shape = orig_shape[:-3] + [32 * self.depth]
x = x.view(*new_shape)
return x
if torch:
class Reshape(nn.Module):
"""Standard module that reshapes/views a tensor
"""
def __init__(self, shape: List):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
# Decoder, part of PlaNET
if torch:
class ConvDecoder(nn.Module):
"""Standard Convolutional Decoder for Dreamer. This decoder is used
to decoder images from the latent state generated by the transition
dynamics model. This is used in calulating loss and logging gifs for
imagine trajectories.
"""
def __init__(self,
input_size: int,
depth: int = 32,
act: ActFunc = None,
shape: List = [3, 64, 64]):
"""Initializes Conv Decoder
Args:
input_size (int): Input size, usually feature size output from RSSM
depth (int): Number of channels in the first conv layer
act (Any): Activation for Encoder, default ReLU
shape (List): Shape of observation input
"""
super().__init__()
self.act = act
if not act:
self.act = nn.ReLU
self.depth = depth
self.shape = shape
self.layers = [
Linear(input_size, 32 * self.depth),
Reshape((-1, 32 * self.depth, 1, 1)),
ConvTranspose2d(32 * self.depth, 4 * self.depth, 5, stride=2),
self.act(),
ConvTranspose2d(4 * self.depth, 2 * self.depth, 5, stride=2),
self.act(),
ConvTranspose2d(2 * self.depth, self.depth, 6, stride=2),
self.act(),
ConvTranspose2d(self.depth, self.shape[0], 6, stride=2),
]
self.model = nn.Sequential(*self.layers)
def forward(self, x):
# x is [batch, hor_length, input_size]
orig_shape = list(x.size())
x = self.model(x)
reshape_size = orig_shape[:-1] + self.shape
mean = x.view(*reshape_size)
# Equivalent to making a multivariate diag
return td.Independent(td.Normal(mean, 1), len(self.shape))
# Reward Model (PlaNET), and Value Function
if torch:
class DenseDecoder(nn.Module):
"""Fully Connected network that outputs a distribution for calculating log_prob
later in DreamerLoss
"""
def __init__(self,
input_size: int,
output_size: int,
layers: int,
units: int,
dist: str = "normal",
act: ActFunc = None):
"""Initializes FC network
Args:
input_size (int): Input size to network
output_size (int): Output size to network
layers (int): Number of layers in network
units (int): Size of the hidden layers
dist (str): Output distribution, parameterized by FC output logits
act (Any): Activation function
"""
super().__init__()
self.layrs = layers
self.units = units
self.act = act
if not act:
self.act = nn.ELU
self.dist = dist
self.input_size = input_size
self.output_size = output_size
self.layers = []
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([Linear(cur_size, self.units), self.act()])
cur_size = units
self.layers.append(Linear(cur_size, output_size))
self.model = nn.Sequential(*self.layers)
def forward(self, x):
x = self.model(x)
if self.output_size == 1:
x = torch.squeeze(x)
if self.dist == "normal":
output_dist = td.Normal(x, 1)
elif self.dist == "binary":
output_dist = td.Bernoulli(logits=x)
else:
raise NotImplementedError("Distribution type not implemented!")
return td.Independent(output_dist, 0)
# Represents dreamer policy
if torch:
class ActionDecoder(nn.Module):
"""ActionDecoder is the policy module in Dreamer. It outputs a distribution
parameterized by mean and std, later to be transformed by a custom
TanhBijector in utils.py for Dreamer.
"""
def __init__(self,
input_size: int,
action_size: int,
layers: int,
units: int,
dist: str = "tanh_normal",
act: ActFunc = None,
min_std: float = 1e-4,
init_std: float = 5.0,
mean_scale: float = 5.0):
"""Initializes Policy
Args:
input_size (int): Input size to network
action_size (int): Action space size
layers (int): Number of layers in network
units (int): Size of the hidden layers
dist (str): Output distribution, with tanh_normal implemented
act (Any): Activation function
min_std (float): Minimum std for output distribution
init_std (float): Intitial std
mean_scale (float): Augmenting mean output from FC network
"""
super().__init__()
self.layrs = layers
self.units = units
self.dist = dist
self.act = act
if not act:
self.act = nn.ReLU
self.min_std = min_std
self.init_std = init_std
self.mean_scale = mean_scale
self.action_size = action_size
self.layers = []
self.softplus = nn.Softplus()
# MLP Construction
cur_size = input_size
for _ in range(self.layrs):
self.layers.extend([Linear(cur_size, self.units), self.act()])
cur_size = self.units
if self.dist == "tanh_normal":
self.layers.append(Linear(cur_size, 2 * action_size))
elif self.dist == "onehot":
self.layers.append(Linear(cur_size, action_size))
self.model = nn.Sequential(*self.layers)
# Returns distribution
def forward(self, x):
raw_init_std = np.log(np.exp(self.init_std) - 1)
x = self.model(x)
if self.dist == "tanh_normal":
mean, std = torch.chunk(x, 2, dim=-1)
mean = self.mean_scale * torch.tanh(mean / self.mean_scale)
std = self.softplus(std + raw_init_std) + self.min_std
dist = td.Normal(mean, std)
transforms = [TanhBijector()]
dist = td.transformed_distribution.TransformedDistribution(
dist, transforms)
dist = td.Independent(dist, 1)
elif self.dist == "onehot":
dist = td.OneHotCategorical(logits=x)
raise NotImplementedError("Atari not implemented yet!")
return dist
# Represents TD model in PlaNET
if torch:
class RSSM(nn.Module):
"""RSSM is the core recurrent part of the PlaNET module. It consists of
two networks, one (obs) to calculate posterior beliefs and states and
the second (img) to calculate prior beliefs and states. The prior network
takes in the previous state and action, while the posterior network takes
in the previous state, action, and a latent embedding of the most recent
observation.
"""
def __init__(self,
action_size: int,
embed_size: int,
stoch: int = 30,
deter: int = 200,
hidden: int = 200,
act: ActFunc = None):
"""Initializes RSSM
Args:
action_size (int): Action space size
embed_size (int): Size of ConvEncoder embedding
stoch (int): Size of the distributional hidden state
deter (int): Size of the deterministic hidden state
hidden (int): General size of hidden layers
act (Any): Activation function
"""
super().__init__()
self.stoch_size = stoch
self.deter_size = deter
self.hidden_size = hidden
self.act = act
if act is None:
self.act = nn.ELU
self.obs1 = Linear(embed_size + deter, hidden)
self.obs2 = Linear(hidden, 2 * stoch)
self.cell = GRUCell(self.hidden_size, hidden_size=self.deter_size)
self.img1 = Linear(stoch + action_size, hidden)
self.img2 = Linear(deter, hidden)
self.img3 = Linear(hidden, 2 * stoch)
self.softplus = nn.Softplus
self.device = (torch.device("cuda") if torch.cuda.is_available()
else torch.device("cpu"))
def get_initial_state(self, batch_size: int) -> List[TensorType]:
"""Returns the inital state for the RSSM, which consists of mean, std
for the stochastic state, the sampled stochastic hidden state
(from mean, std), and the deterministic hidden state, which is pushed
through the GRUCell.
Args:
batch_size (int): Batch size for initial state
Returns:
List of tensors
"""
return [
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.stoch_size).to(self.device),
torch.zeros(batch_size, self.deter_size).to(self.device),
]
def observe(self,
embed: TensorType,
action: TensorType,
state: List[TensorType] = None
) -> Tuple[List[TensorType], List[TensorType]]:
"""Returns the corresponding states from the embedding from ConvEncoder
and actions. This is accomplished by rolling out the RNN from the
starting state through eacn index of embed and action, saving all
intermediate states between.
Args:
embed (TensorType): ConvEncoder embedding
action (TensorType): Actions
state (List[TensorType]): Initial state before rollout
Returns:
Posterior states and prior states (both List[TensorType])
"""
if state is None:
state = self.get_initial_state(action.size()[0])
embed = embed.permute(1, 0, 2)
action = action.permute(1, 0, 2)
priors = [[] for i in range(len(state))]
posts = [[] for i in range(len(state))]
last = (state, state)
for index in range(len(action)):
# Tuple of post and prior
last = self.obs_step(last[0], action[index], embed[index])
[o.append(l) for l, o in zip(last[0], posts)]
[o.append(l) for l, o in zip(last[1], priors)]
prior = [torch.stack(x, dim=0) for x in priors]
post = [torch.stack(x, dim=0) for x in posts]
prior = [e.permute(1, 0, 2) for e in prior]
post = [e.permute(1, 0, 2) for e in post]
return post, prior
def imagine(self, action: TensorType,
state: List[TensorType] = None) -> List[TensorType]:
"""Imagines the trajectory starting from state through a list of actions.
Similar to observe(), requires rolling out the RNN for each timestep.
Args:
action (TensorType): Actions
state (List[TensorType]): Starting state before rollout
Returns:
Prior states
"""
if state is None:
state = self.get_initial_state(action.size()[0])
action = action.permute(1, 0, 2)
indices = range(len(action))
priors = [[] for _ in range(len(state))]
last = state
for index in indices:
last = self.img_step(last, action[index])
[o.append(l) for l, o in zip(last, priors)]
prior = [torch.stack(x, dim=0) for x in priors]
prior = [e.permute(1, 0, 2) for e in prior]
return prior
def obs_step(self, prev_state: TensorType, prev_action: TensorType,
embed: TensorType
) -> Tuple[List[TensorType], List[TensorType]]:
"""Runs through the posterior model and returns the posterior state
Args:
prev_state (TensorType): The previous state
prev_action (TensorType): The previous action
embed (TensorType): Embedding from ConvEncoder
Returns:
Post and Prior state
"""
prior = self.img_step(prev_state, prev_action)
x = torch.cat([prior[3], embed], dim=-1)
x = self.obs1(x)
x = self.act()(x)
x = self.obs2(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
post = [mean, std, stoch, prior[3]]
return post, prior
def img_step(self, prev_state: TensorType,
prev_action: TensorType) -> List[TensorType]:
"""Runs through the prior model and returns the prior state
Args:
prev_state (TensorType): The previous state
prev_action (TensorType): The previous action
Returns:
Prior state
"""
x = torch.cat([prev_state[2], prev_action], dim=-1)
x = self.img1(x)
x = self.act()(x)
deter = self.cell(x, prev_state[3])
x = deter
x = self.img2(x)
x = self.act()(x)
x = self.img3(x)
mean, std = torch.chunk(x, 2, dim=-1)
std = self.softplus()(std) + 0.1
stoch = self.get_dist(mean, std).rsample()
return [mean, std, stoch, deter]
def get_feature(self, state: List[TensorType]) -> TensorType:
# Constructs feature for input to reward, decoder, actor, critic
return torch.cat([state[2], state[3]], dim=-1)
def get_dist(self, mean: TensorType, std: TensorType) -> TensorType:
return td.Normal(mean, std)
# Represents all models in Dreamer, unifies them all into a single interface
if torch:
class DreamerModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super().__init__(obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
self.depth = model_config["depth_size"]
self.deter_size = model_config["deter_size"]
self.stoch_size = model_config["stoch_size"]
self.hidden_size = model_config["hidden_size"]
self.action_size = action_space.shape[0]
self.encoder = ConvEncoder(self.depth)
self.decoder = ConvDecoder(
self.stoch_size + self.deter_size, depth=self.depth)
self.reward = DenseDecoder(self.stoch_size + self.deter_size, 1, 2,
self.hidden_size)
self.dynamics = RSSM(
self.action_size,
32 * self.depth,
stoch=self.stoch_size,
deter=self.deter_size)
self.actor = ActionDecoder(self.stoch_size + self.deter_size,
self.action_size, 4, self.hidden_size)
self.value = DenseDecoder(self.stoch_size + self.deter_size, 1, 3,
self.hidden_size)
self.state = None
self.device = (torch.device("cuda") if torch.cuda.is_available()
else torch.device("cpu"))
def policy(self,
obs: TensorType,
state: List[TensorType],
explore=True
) -> Tuple[TensorType, List[float], List[TensorType]]:
"""Returns the action. Runs through the encoder, recurrent model,
and policy to obtain action.
"""
if state is None:
self.initial_state()
else:
self.state = state
post = self.state[:4]
action = self.state[4]
embed = self.encoder(obs)
post, _ = self.dynamics.obs_step(post, action, embed)
feat = self.dynamics.get_feature(post)
action_dist = self.actor(feat)
if explore:
action = action_dist.sample()
else:
action = action_dist.mean
logp = action_dist.log_prob(action)
self.state = post + [action]
return action, logp, self.state
def imagine_ahead(self, state: List[TensorType],
horizon: int) -> TensorType:
"""Given a batch of states, rolls out more state of length horizon.
"""
start = []
for s in state:
s = s.contiguous().detach()
shpe = [-1] + list(s.size())[2:]
start.append(s.view(*shpe))
def next_state(state):
feature = self.dynamics.get_feature(state).detach()
action = self.actor(feature).rsample()
next_state = self.dynamics.img_step(state, action)
return next_state
last = start
outputs = [[] for i in range(len(start))]
for _ in range(horizon):
last = next_state(last)
[o.append(l) for l, o in zip(last, outputs)]
outputs = [torch.stack(x, dim=0) for x in outputs]
imag_feat = self.dynamics.get_feature(outputs)
return imag_feat
def get_initial_state(self) -> List[TensorType]:
self.state = self.dynamics.get_initial_state(1) + [
torch.zeros(1, self.action_space.shape[0]).to(self.device)
]
return self.state
def value_function(self) -> TensorType:
return None