ray/rllib/algorithms/dreamer/dreamer_torch_policy.py

295 lines
9.8 KiB
Python

import logging
import numpy as np
from typing import Dict, Optional
import ray
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.typing import AgentID, TensorType
torch, nn = try_import_torch()
if torch:
from torch import distributions as td
logger = logging.getLogger(__name__)
# This is the computation graph for workers (inner adaptation steps)
def compute_dreamer_loss(
obs: TensorType,
action: TensorType,
reward: TensorType,
model: TorchModelV2,
imagine_horizon: int,
gamma: float = 0.99,
lambda_: float = 0.95,
kl_coeff: float = 1.0,
free_nats: float = 3.0,
log: bool = False,
):
"""Constructs loss for the Dreamer objective.
Args:
obs: Observations (o_t).
action: Actions (a_(t-1)).
reward: Rewards (r_(t-1)).
model: DreamerModel, encompassing all other models.
imagine_horizon: Imagine horizon for actor and critic loss.
gamma: Discount factor gamma.
lambda_: Lambda, like in GAE.
kl_coeff: KL Coefficient for Divergence loss in model loss.
free_nats: Threshold for minimum divergence in model loss.
log: If log, generate gifs.
"""
encoder_weights = list(model.encoder.parameters())
decoder_weights = list(model.decoder.parameters())
reward_weights = list(model.reward.parameters())
dynamics_weights = list(model.dynamics.parameters())
critic_weights = list(model.value.parameters())
model_weights = list(
encoder_weights + decoder_weights + reward_weights + dynamics_weights
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# PlaNET Model Loss
latent = model.encoder(obs)
post, prior = model.dynamics.observe(latent, action)
features = model.dynamics.get_feature(post)
image_pred = model.decoder(features)
reward_pred = model.reward(features)
image_loss = -torch.mean(image_pred.log_prob(obs))
reward_loss = -torch.mean(reward_pred.log_prob(reward))
prior_dist = model.dynamics.get_dist(prior[0], prior[1])
post_dist = model.dynamics.get_dist(post[0], post[1])
div = torch.mean(
torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2)
)
div = torch.clamp(div, min=free_nats)
model_loss = kl_coeff * div + reward_loss + image_loss
# Actor Loss
# [imagine_horizon, batch_length*batch_size, feature_size]
with torch.no_grad():
actor_states = [v.detach() for v in post]
with FreezeParameters(model_weights):
imag_feat = model.imagine_ahead(actor_states, imagine_horizon)
with FreezeParameters(model_weights + critic_weights):
reward = model.reward(imag_feat).mean
value = model.value(imag_feat).mean
pcont = gamma * torch.ones_like(reward)
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
discount_shape = pcont[:1].size()
discount = torch.cumprod(
torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0), dim=0
)
actor_loss = -torch.mean(discount * returns)
# Critic Loss
with torch.no_grad():
val_feat = imag_feat.detach()[:-1]
target = returns.detach()
val_discount = discount.detach()
val_pred = model.value(val_feat)
critic_loss = -torch.mean(val_discount * val_pred.log_prob(target))
# Logging purposes
prior_ent = torch.mean(prior_dist.entropy())
post_ent = torch.mean(post_dist.entropy())
log_gif = None
if log:
log_gif = log_summary(obs, action, latent, image_pred, model)
return_dict = {
"model_loss": model_loss,
"reward_loss": reward_loss,
"image_loss": image_loss,
"divergence": div,
"actor_loss": actor_loss,
"critic_loss": critic_loss,
"prior_ent": prior_ent,
"post_ent": post_ent,
}
if log_gif is not None:
return_dict["log_gif"] = log_gif
return return_dict
# Similar to GAE-Lambda, calculate value targets
def lambda_return(reward, value, pcont, bootstrap, lambda_):
def agg_fn(x, y):
return y[0] + y[1] * lambda_ * x
next_values = torch.cat([value[1:], bootstrap[None]], dim=0)
inputs = reward + pcont * next_values * (1 - lambda_)
last = bootstrap
returns = []
for i in reversed(range(len(inputs))):
last = agg_fn(last, [inputs[i], pcont[i]])
returns.append(last)
returns = list(reversed(returns))
returns = torch.stack(returns, dim=0)
return returns
# Creates gif
def log_summary(obs, action, embed, image_pred, model):
truth = obs[:6] + 0.5
recon = image_pred.mean[:6]
init, _ = model.dynamics.observe(embed[:6, :5], action[:6, :5])
init = [itm[:, -1] for itm in init]
prior = model.dynamics.imagine(action[:6, 5:], init)
openl = model.decoder(model.dynamics.get_feature(prior)).mean
mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (mod - truth + 1.0) / 2.0
return torch.cat([truth, mod, error], 3)
def dreamer_loss(policy, model, dist_class, train_batch):
log_gif = False
if "log_gif" in train_batch:
log_gif = True
policy.stats_dict = compute_dreamer_loss(
train_batch["obs"],
train_batch["actions"],
train_batch["rewards"],
policy.model,
policy.config["imagine_horizon"],
policy.config["gamma"],
policy.config["lambda"],
policy.config["kl_coeff"],
policy.config["free_nats"],
log_gif,
)
loss_dict = policy.stats_dict
return (loss_dict["model_loss"], loss_dict["actor_loss"], loss_dict["critic_loss"])
def build_dreamer_model(policy, obs_space, action_space, config):
model = ModelCatalog.get_model_v2(
obs_space,
action_space,
1,
config["dreamer_model"],
name="DreamerModel",
framework="torch",
)
policy.model_variables = model.variables()
return model
def action_sampler_fn(policy, model, input_dict, state, explore, timestep):
"""Action sampler function has two phases. During the prefill phase,
actions are sampled uniformly [-1, 1]. During training phase, actions
are evaluated through DreamerPolicy and an additive gaussian is added
to incentivize exploration.
"""
obs = input_dict["obs"]
# Custom Exploration
if timestep <= policy.config["prefill_timesteps"]:
logp = None
# Random action in space [-1.0, 1.0]
action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
state = model.get_initial_state()
else:
# Weird RLlib Handling, this happens when env rests
if len(state[0].size()) == 3:
# Very hacky, but works on all envs
state = model.get_initial_state()
action, logp, state = model.policy(obs, state, explore)
action = td.Normal(action, policy.config["explore_noise"]).sample()
action = torch.clamp(action, min=-1.0, max=1.0)
policy.global_timestep += policy.config["action_repeat"]
return action, logp, state
def dreamer_stats(policy, train_batch):
return policy.stats_dict
def dreamer_optimizer_fn(policy, config):
model = policy.model
encoder_weights = list(model.encoder.parameters())
decoder_weights = list(model.decoder.parameters())
reward_weights = list(model.reward.parameters())
dynamics_weights = list(model.dynamics.parameters())
actor_weights = list(model.actor.parameters())
critic_weights = list(model.value.parameters())
model_opt = torch.optim.Adam(
encoder_weights + decoder_weights + reward_weights + dynamics_weights,
lr=config["td_model_lr"],
)
actor_opt = torch.optim.Adam(actor_weights, lr=config["actor_lr"])
critic_opt = torch.optim.Adam(critic_weights, lr=config["critic_lr"])
return (model_opt, actor_opt, critic_opt)
def preprocess_episode(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[Episode] = None,
) -> SampleBatch:
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
When t=0, the resetted obs is paired with action and reward of 0.
"""
obs = sample_batch[SampleBatch.OBS]
new_obs = sample_batch[SampleBatch.NEXT_OBS]
action = sample_batch[SampleBatch.ACTIONS]
reward = sample_batch[SampleBatch.REWARDS]
eps_ids = sample_batch[SampleBatch.EPS_ID]
act_shape = action.shape
act_reset = np.array([0.0] * act_shape[-1])[None]
rew_reset = np.array(0.0)[None]
obs_end = np.array(new_obs[act_shape[0] - 1])[None]
batch_obs = np.concatenate([obs, obs_end], axis=0)
batch_action = np.concatenate([act_reset, action], axis=0)
batch_rew = np.concatenate([rew_reset, reward], axis=0)
batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0)
new_batch = {
SampleBatch.OBS: batch_obs,
SampleBatch.REWARDS: batch_rew,
SampleBatch.ACTIONS: batch_action,
SampleBatch.EPS_ID: batch_eps_ids,
}
return SampleBatch(new_batch)
DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn,
postprocess_fn=preprocess_episode,
loss_fn=dreamer_loss,
stats_fn=dreamer_stats,
make_model=build_dreamer_model,
optimizer_fn=dreamer_optimizer_fn,
extra_grad_process_fn=apply_grad_clipping,
)