ray/rllib/algorithms/dreamer/dreamer_torch_policy.py
2022-06-09 17:14:15 +02:00

279 lines
10 KiB
Python

from typing import (
List,
Tuple,
Union,
)
import logging
import ray
import numpy as np
from typing import Dict, Optional
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.typing import AgentID, TensorType
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
torch, nn = try_import_torch()
if torch:
from torch import distributions as td
logger = logging.getLogger(__name__)
class DreamerTorchPolicy(TorchPolicyV2):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.algorithms.dreamer.DreamerConfig().to_dict(), **config)
TorchPolicyV2.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"],
)
# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()
@override(TorchPolicyV2)
def loss(
self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch
) -> Union[TensorType, List[TensorType]]:
log_gif = False
if "log_gif" in train_batch:
log_gif = True
# This is the computation graph for workers (inner adaptation steps)
encoder_weights = list(self.model.encoder.parameters())
decoder_weights = list(self.model.decoder.parameters())
reward_weights = list(self.model.reward.parameters())
dynamics_weights = list(self.model.dynamics.parameters())
critic_weights = list(self.model.value.parameters())
model_weights = list(
encoder_weights + decoder_weights + reward_weights + dynamics_weights
)
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
# PlaNET Model Loss
latent = self.model.encoder(train_batch["obs"])
post, prior = self.model.dynamics.observe(latent, train_batch["actions"])
features = self.model.dynamics.get_feature(post)
image_pred = self.model.decoder(features)
reward_pred = self.model.reward(features)
image_loss = -torch.mean(image_pred.log_prob(train_batch["obs"]))
reward_loss = -torch.mean(reward_pred.log_prob(train_batch["rewards"]))
prior_dist = self.model.dynamics.get_dist(prior[0], prior[1])
post_dist = self.model.dynamics.get_dist(post[0], post[1])
div = torch.mean(
torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2)
)
div = torch.clamp(div, min=(self.config["free_nats"]))
model_loss = self.config["kl_coeff"] * div + reward_loss + image_loss
# Actor Loss
# [imagine_horizon, batch_length*batch_size, feature_size]
with torch.no_grad():
actor_states = [v.detach() for v in post]
with FreezeParameters(model_weights):
imag_feat = self.model.imagine_ahead(
actor_states, self.config["imagine_horizon"]
)
with FreezeParameters(model_weights + critic_weights):
reward = self.model.reward(imag_feat).mean
value = self.model.value(imag_feat).mean
pcont = self.config["gamma"] * torch.ones_like(reward)
# Similar to GAE-Lambda, calculate value targets
next_values = torch.cat([value[:-1][1:], value[-1][None]], dim=0)
inputs = reward[:-1] + pcont[:-1] * next_values * (1 - self.config["lambda"])
def agg_fn(x, y):
return y[0] + y[1] * self.config["lambda"] * x
last = value[-1]
returns = []
for i in reversed(range(len(inputs))):
last = agg_fn(last, [inputs[i], pcont[:-1][i]])
returns.append(last)
returns = list(reversed(returns))
returns = torch.stack(returns, dim=0)
discount_shape = pcont[:1].size()
discount = torch.cumprod(
torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0),
dim=0,
)
actor_loss = -torch.mean(discount * returns)
# Critic Loss
with torch.no_grad():
val_feat = imag_feat.detach()[:-1]
target = returns.detach()
val_discount = discount.detach()
val_pred = self.model.value(val_feat)
critic_loss = -torch.mean(val_discount * val_pred.log_prob(target))
# Logging purposes
prior_ent = torch.mean(prior_dist.entropy())
post_ent = torch.mean(post_dist.entropy())
gif = None
if log_gif:
gif = log_summary(
train_batch["obs"],
train_batch["actions"],
latent,
image_pred,
self.model,
)
return_dict = {
"model_loss": model_loss,
"reward_loss": reward_loss,
"image_loss": image_loss,
"divergence": div,
"actor_loss": actor_loss,
"critic_loss": critic_loss,
"prior_ent": prior_ent,
"post_ent": post_ent,
}
if gif is not None:
return_dict["log_gif"] = gif
self.stats_dict = return_dict
loss_dict = self.stats_dict
return (
loss_dict["model_loss"],
loss_dict["actor_loss"],
loss_dict["critic_loss"],
)
@override(TorchPolicyV2)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[
Dict[AgentID, Tuple["Policy", SampleBatch]]
] = None,
episode: Optional["Episode"] = None,
) -> SampleBatch:
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.
"""
obs = sample_batch[SampleBatch.OBS]
new_obs = sample_batch[SampleBatch.NEXT_OBS]
action = sample_batch[SampleBatch.ACTIONS]
reward = sample_batch[SampleBatch.REWARDS]
eps_ids = sample_batch[SampleBatch.EPS_ID]
act_shape = action.shape
act_reset = np.array([0.0] * act_shape[-1])[None]
rew_reset = np.array(0.0)[None]
obs_end = np.array(new_obs[act_shape[0] - 1])[None]
batch_obs = np.concatenate([obs, obs_end], axis=0)
batch_action = np.concatenate([act_reset, action], axis=0)
batch_rew = np.concatenate([rew_reset, reward], axis=0)
batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0)
new_batch = {
SampleBatch.OBS: batch_obs,
SampleBatch.REWARDS: batch_rew,
SampleBatch.ACTIONS: batch_action,
SampleBatch.EPS_ID: batch_eps_ids,
}
return SampleBatch(new_batch)
def stats_fn(self, train_batch):
return self.stats_dict
@override(TorchPolicyV2)
def optimizer(self):
model = self.model
encoder_weights = list(model.encoder.parameters())
decoder_weights = list(model.decoder.parameters())
reward_weights = list(model.reward.parameters())
dynamics_weights = list(model.dynamics.parameters())
actor_weights = list(model.actor.parameters())
critic_weights = list(model.value.parameters())
model_opt = torch.optim.Adam(
encoder_weights + decoder_weights + reward_weights + dynamics_weights,
lr=self.config["td_model_lr"],
)
actor_opt = torch.optim.Adam(actor_weights, lr=self.config["actor_lr"])
critic_opt = torch.optim.Adam(critic_weights, lr=self.config["critic_lr"])
return (model_opt, actor_opt, critic_opt)
def action_sampler_fn(policy, model, obs_batch, state_batches, explore, timestep):
"""Action sampler function has two phases. During the prefill phase,
actions are sampled uniformly [-1, 1]. During training phase, actions
are evaluated through DreamerPolicy and an additive gaussian is added
to incentivize exploration.
"""
obs = obs_batch["obs"]
# Custom Exploration
if timestep <= policy.config["prefill_timesteps"]:
logp = None
# Random action in space [-1.0, 1.0]
action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
state_batches = model.get_initial_state()
else:
# Weird RLlib Handling, this happens when env rests
if len(state_batches[0].size()) == 3:
# Very hacky, but works on all envs
state_batches = model.get_initial_state()
action, logp, state_batches = model.policy(obs, state_batches, explore)
action = td.Normal(action, policy.config["explore_noise"]).sample()
action = torch.clamp(action, min=-1.0, max=1.0)
policy.global_timestep += policy.config["action_repeat"]
return action, logp, state_batches
def make_model(self):
model = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
1,
self.config["dreamer_model"],
name="DreamerModel",
framework="torch",
)
self.model_variables = model.variables()
return model
def extra_grad_process(
self, optimizer: "torch.optim.Optimizer", loss: TensorType
) -> Dict[str, TensorType]:
return apply_grad_clipping(self, optimizer, loss)
# Creates gif
def log_summary(obs, action, embed, image_pred, model):
truth = obs[:6] + 0.5
recon = image_pred.mean[:6]
init, _ = model.dynamics.observe(embed[:6, :5], action[:6, :5])
init = [itm[:, -1] for itm in init]
prior = model.dynamics.imagine(action[:6, 5:], init)
openl = model.decoder(model.dynamics.get_feature(prior)).mean
mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (mod - truth + 1.0) / 2.0
return torch.cat([truth, mod, error], 3)