diff --git a/rllib/algorithms/dreamer/dreamer_torch_policy.py b/rllib/algorithms/dreamer/dreamer_torch_policy.py index 5915c1150..0f63de0d4 100644 --- a/rllib/algorithms/dreamer/dreamer_torch_policy.py +++ b/rllib/algorithms/dreamer/dreamer_torch_policy.py @@ -1,19 +1,27 @@ -import logging +from typing import ( + List, + Tuple, + Union, +) +import logging +import ray import numpy as np from typing import Dict, Optional -import ray + from ray.rllib.algorithms.dreamer.utils import FreezeParameters from ray.rllib.evaluation.episode import Episode from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.policy import Policy -from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import apply_grad_clipping from ray.rllib.utils.typing import AgentID, TensorType +from ray.rllib.utils.annotations import override +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 torch, nn = try_import_torch() if torch: @@ -22,126 +30,239 @@ if torch: logger = logging.getLogger(__name__) -# This is the computation graph for workers (inner adaptation steps) -def compute_dreamer_loss( - obs: TensorType, - action: TensorType, - reward: TensorType, - model: TorchModelV2, - imagine_horizon: int, - gamma: float = 0.99, - lambda_: float = 0.95, - kl_coeff: float = 1.0, - free_nats: float = 3.0, - log: bool = False, -): - """Constructs loss for the Dreamer objective. +class DreamerTorchPolicy(TorchPolicyV2): + def __init__(self, observation_space, action_space, config): - Args: - obs: Observations (o_t). - action: Actions (a_(t-1)). - reward: Rewards (r_(t-1)). - model: DreamerModel, encompassing all other models. - imagine_horizon: Imagine horizon for actor and critic loss. - gamma: Discount factor gamma. - lambda_: Lambda, like in GAE. - kl_coeff: KL Coefficient for Divergence loss in model loss. - free_nats: Threshold for minimum divergence in model loss. - log: If log, generate gifs. - """ - encoder_weights = list(model.encoder.parameters()) - decoder_weights = list(model.decoder.parameters()) - reward_weights = list(model.reward.parameters()) - dynamics_weights = list(model.dynamics.parameters()) - critic_weights = list(model.value.parameters()) - model_weights = list( - encoder_weights + decoder_weights + reward_weights + dynamics_weights - ) + config = dict(ray.rllib.algorithms.dreamer.DreamerConfig().to_dict(), **config) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) - # PlaNET Model Loss - latent = model.encoder(obs) - post, prior = model.dynamics.observe(latent, action) - features = model.dynamics.get_feature(post) - image_pred = model.decoder(features) - reward_pred = model.reward(features) - image_loss = -torch.mean(image_pred.log_prob(obs)) - reward_loss = -torch.mean(reward_pred.log_prob(reward)) - prior_dist = model.dynamics.get_dist(prior[0], prior[1]) - post_dist = model.dynamics.get_dist(post[0], post[1]) - div = torch.mean( - torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2) - ) - div = torch.clamp(div, min=free_nats) - model_loss = kl_coeff * div + reward_loss + image_loss + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() - # Actor Loss - # [imagine_horizon, batch_length*batch_size, feature_size] - with torch.no_grad(): - actor_states = [v.detach() for v in post] - with FreezeParameters(model_weights): - imag_feat = model.imagine_ahead(actor_states, imagine_horizon) - with FreezeParameters(model_weights + critic_weights): - reward = model.reward(imag_feat).mean - value = model.value(imag_feat).mean - pcont = gamma * torch.ones_like(reward) - returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_) - discount_shape = pcont[:1].size() - discount = torch.cumprod( - torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0), dim=0 - ) - actor_loss = -torch.mean(discount * returns) + @override(TorchPolicyV2) + def loss( + self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch + ) -> Union[TensorType, List[TensorType]]: + log_gif = False + if "log_gif" in train_batch: + log_gif = True - # Critic Loss - with torch.no_grad(): - val_feat = imag_feat.detach()[:-1] - target = returns.detach() - val_discount = discount.detach() - val_pred = model.value(val_feat) - critic_loss = -torch.mean(val_discount * val_pred.log_prob(target)) + # This is the computation graph for workers (inner adaptation steps) + encoder_weights = list(self.model.encoder.parameters()) + decoder_weights = list(self.model.decoder.parameters()) + reward_weights = list(self.model.reward.parameters()) + dynamics_weights = list(self.model.dynamics.parameters()) + critic_weights = list(self.model.value.parameters()) + model_weights = list( + encoder_weights + decoder_weights + reward_weights + dynamics_weights + ) + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) - # Logging purposes - prior_ent = torch.mean(prior_dist.entropy()) - post_ent = torch.mean(post_dist.entropy()) + # PlaNET Model Loss + latent = self.model.encoder(train_batch["obs"]) + post, prior = self.model.dynamics.observe(latent, train_batch["actions"]) + features = self.model.dynamics.get_feature(post) + image_pred = self.model.decoder(features) + reward_pred = self.model.reward(features) + image_loss = -torch.mean(image_pred.log_prob(train_batch["obs"])) + reward_loss = -torch.mean(reward_pred.log_prob(train_batch["rewards"])) + prior_dist = self.model.dynamics.get_dist(prior[0], prior[1]) + post_dist = self.model.dynamics.get_dist(post[0], post[1]) + div = torch.mean( + torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2) + ) + div = torch.clamp(div, min=(self.config["free_nats"])) + model_loss = self.config["kl_coeff"] * div + reward_loss + image_loss - log_gif = None - if log: - log_gif = log_summary(obs, action, latent, image_pred, model) + # Actor Loss + # [imagine_horizon, batch_length*batch_size, feature_size] + with torch.no_grad(): + actor_states = [v.detach() for v in post] + with FreezeParameters(model_weights): + imag_feat = self.model.imagine_ahead( + actor_states, self.config["imagine_horizon"] + ) + with FreezeParameters(model_weights + critic_weights): + reward = self.model.reward(imag_feat).mean + value = self.model.value(imag_feat).mean + pcont = self.config["gamma"] * torch.ones_like(reward) - return_dict = { - "model_loss": model_loss, - "reward_loss": reward_loss, - "image_loss": image_loss, - "divergence": div, - "actor_loss": actor_loss, - "critic_loss": critic_loss, - "prior_ent": prior_ent, - "post_ent": post_ent, - } + # Similar to GAE-Lambda, calculate value targets + next_values = torch.cat([value[:-1][1:], value[-1][None]], dim=0) + inputs = reward[:-1] + pcont[:-1] * next_values * (1 - self.config["lambda"]) - if log_gif is not None: - return_dict["log_gif"] = log_gif - return return_dict + def agg_fn(x, y): + return y[0] + y[1] * self.config["lambda"] * x + last = value[-1] + returns = [] + for i in reversed(range(len(inputs))): + last = agg_fn(last, [inputs[i], pcont[:-1][i]]) + returns.append(last) -# Similar to GAE-Lambda, calculate value targets -def lambda_return(reward, value, pcont, bootstrap, lambda_): - def agg_fn(x, y): - return y[0] + y[1] * lambda_ * x + returns = list(reversed(returns)) + returns = torch.stack(returns, dim=0) + discount_shape = pcont[:1].size() + discount = torch.cumprod( + torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0), + dim=0, + ) + actor_loss = -torch.mean(discount * returns) - next_values = torch.cat([value[1:], bootstrap[None]], dim=0) - inputs = reward + pcont * next_values * (1 - lambda_) + # Critic Loss + with torch.no_grad(): + val_feat = imag_feat.detach()[:-1] + target = returns.detach() + val_discount = discount.detach() + val_pred = self.model.value(val_feat) + critic_loss = -torch.mean(val_discount * val_pred.log_prob(target)) - last = bootstrap - returns = [] - for i in reversed(range(len(inputs))): - last = agg_fn(last, [inputs[i], pcont[i]]) - returns.append(last) + # Logging purposes + prior_ent = torch.mean(prior_dist.entropy()) + post_ent = torch.mean(post_dist.entropy()) + gif = None + if log_gif: + gif = log_summary( + train_batch["obs"], + train_batch["actions"], + latent, + image_pred, + self.model, + ) + return_dict = { + "model_loss": model_loss, + "reward_loss": reward_loss, + "image_loss": image_loss, + "divergence": div, + "actor_loss": actor_loss, + "critic_loss": critic_loss, + "prior_ent": prior_ent, + "post_ent": post_ent, + } + if gif is not None: + return_dict["log_gif"] = gif + self.stats_dict = return_dict - returns = list(reversed(returns)) - returns = torch.stack(returns, dim=0) - return returns + loss_dict = self.stats_dict + + return ( + loss_dict["model_loss"], + loss_dict["actor_loss"], + loss_dict["critic_loss"], + ) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[ + Dict[AgentID, Tuple["Policy", SampleBatch]] + ] = None, + episode: Optional["Episode"] = None, + ) -> SampleBatch: + """Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) + When t=0, the resetted obs is paired with action and reward of 0. + """ + obs = sample_batch[SampleBatch.OBS] + new_obs = sample_batch[SampleBatch.NEXT_OBS] + action = sample_batch[SampleBatch.ACTIONS] + reward = sample_batch[SampleBatch.REWARDS] + eps_ids = sample_batch[SampleBatch.EPS_ID] + + act_shape = action.shape + act_reset = np.array([0.0] * act_shape[-1])[None] + rew_reset = np.array(0.0)[None] + obs_end = np.array(new_obs[act_shape[0] - 1])[None] + + batch_obs = np.concatenate([obs, obs_end], axis=0) + batch_action = np.concatenate([act_reset, action], axis=0) + batch_rew = np.concatenate([rew_reset, reward], axis=0) + batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0) + + new_batch = { + SampleBatch.OBS: batch_obs, + SampleBatch.REWARDS: batch_rew, + SampleBatch.ACTIONS: batch_action, + SampleBatch.EPS_ID: batch_eps_ids, + } + return SampleBatch(new_batch) + + def stats_fn(self, train_batch): + return self.stats_dict + + @override(TorchPolicyV2) + def optimizer(self): + model = self.model + encoder_weights = list(model.encoder.parameters()) + decoder_weights = list(model.decoder.parameters()) + reward_weights = list(model.reward.parameters()) + dynamics_weights = list(model.dynamics.parameters()) + actor_weights = list(model.actor.parameters()) + critic_weights = list(model.value.parameters()) + model_opt = torch.optim.Adam( + encoder_weights + decoder_weights + reward_weights + dynamics_weights, + lr=self.config["td_model_lr"], + ) + actor_opt = torch.optim.Adam(actor_weights, lr=self.config["actor_lr"]) + critic_opt = torch.optim.Adam(critic_weights, lr=self.config["critic_lr"]) + + return (model_opt, actor_opt, critic_opt) + + def action_sampler_fn(policy, model, obs_batch, state_batches, explore, timestep): + """Action sampler function has two phases. During the prefill phase, + actions are sampled uniformly [-1, 1]. During training phase, actions + are evaluated through DreamerPolicy and an additive gaussian is added + to incentivize exploration. + """ + obs = obs_batch["obs"] + + # Custom Exploration + if timestep <= policy.config["prefill_timesteps"]: + logp = None + # Random action in space [-1.0, 1.0] + action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0 + state_batches = model.get_initial_state() + else: + # Weird RLlib Handling, this happens when env rests + if len(state_batches[0].size()) == 3: + # Very hacky, but works on all envs + state_batches = model.get_initial_state() + action, logp, state_batches = model.policy(obs, state_batches, explore) + action = td.Normal(action, policy.config["explore_noise"]).sample() + action = torch.clamp(action, min=-1.0, max=1.0) + + policy.global_timestep += policy.config["action_repeat"] + + return action, logp, state_batches + + def make_model(self): + + model = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + 1, + self.config["dreamer_model"], + name="DreamerModel", + framework="torch", + ) + + self.model_variables = model.variables() + + return model + + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) # Creates gif @@ -156,140 +277,3 @@ def log_summary(obs, action, embed, image_pred, model): mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1) error = (mod - truth + 1.0) / 2.0 return torch.cat([truth, mod, error], 3) - - -def dreamer_loss(policy, model, dist_class, train_batch): - log_gif = False - if "log_gif" in train_batch: - log_gif = True - - policy.stats_dict = compute_dreamer_loss( - train_batch["obs"], - train_batch["actions"], - train_batch["rewards"], - policy.model, - policy.config["imagine_horizon"], - policy.config["gamma"], - policy.config["lambda"], - policy.config["kl_coeff"], - policy.config["free_nats"], - log_gif, - ) - - loss_dict = policy.stats_dict - - return (loss_dict["model_loss"], loss_dict["actor_loss"], loss_dict["critic_loss"]) - - -def build_dreamer_model(policy, obs_space, action_space, config): - - model = ModelCatalog.get_model_v2( - obs_space, - action_space, - 1, - config["dreamer_model"], - name="DreamerModel", - framework="torch", - ) - - policy.model_variables = model.variables() - - return model - - -def action_sampler_fn(policy, model, input_dict, state, explore, timestep): - """Action sampler function has two phases. During the prefill phase, - actions are sampled uniformly [-1, 1]. During training phase, actions - are evaluated through DreamerPolicy and an additive gaussian is added - to incentivize exploration. - """ - obs = input_dict["obs"] - - # Custom Exploration - if timestep <= policy.config["prefill_timesteps"]: - logp = None - # Random action in space [-1.0, 1.0] - action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0 - state = model.get_initial_state() - else: - # Weird RLlib Handling, this happens when env rests - if len(state[0].size()) == 3: - # Very hacky, but works on all envs - state = model.get_initial_state() - action, logp, state = model.policy(obs, state, explore) - action = td.Normal(action, policy.config["explore_noise"]).sample() - action = torch.clamp(action, min=-1.0, max=1.0) - - policy.global_timestep += policy.config["action_repeat"] - - return action, logp, state - - -def dreamer_stats(policy, train_batch): - return policy.stats_dict - - -def dreamer_optimizer_fn(policy, config): - model = policy.model - encoder_weights = list(model.encoder.parameters()) - decoder_weights = list(model.decoder.parameters()) - reward_weights = list(model.reward.parameters()) - dynamics_weights = list(model.dynamics.parameters()) - actor_weights = list(model.actor.parameters()) - critic_weights = list(model.value.parameters()) - model_opt = torch.optim.Adam( - encoder_weights + decoder_weights + reward_weights + dynamics_weights, - lr=config["td_model_lr"], - ) - actor_opt = torch.optim.Adam(actor_weights, lr=config["actor_lr"]) - critic_opt = torch.optim.Adam(critic_weights, lr=config["critic_lr"]) - - return (model_opt, actor_opt, critic_opt) - - -def preprocess_episode( - policy: Policy, - sample_batch: SampleBatch, - other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[Episode] = None, -) -> SampleBatch: - """Batch format should be in the form of (s_t, a_(t-1), r_(t-1)) - When t=0, the resetted obs is paired with action and reward of 0. - """ - obs = sample_batch[SampleBatch.OBS] - new_obs = sample_batch[SampleBatch.NEXT_OBS] - action = sample_batch[SampleBatch.ACTIONS] - reward = sample_batch[SampleBatch.REWARDS] - eps_ids = sample_batch[SampleBatch.EPS_ID] - - act_shape = action.shape - act_reset = np.array([0.0] * act_shape[-1])[None] - rew_reset = np.array(0.0)[None] - obs_end = np.array(new_obs[act_shape[0] - 1])[None] - - batch_obs = np.concatenate([obs, obs_end], axis=0) - batch_action = np.concatenate([act_reset, action], axis=0) - batch_rew = np.concatenate([rew_reset, reward], axis=0) - batch_eps_ids = np.concatenate([eps_ids, eps_ids[-1:]], axis=0) - - new_batch = { - SampleBatch.OBS: batch_obs, - SampleBatch.REWARDS: batch_rew, - SampleBatch.ACTIONS: batch_action, - SampleBatch.EPS_ID: batch_eps_ids, - } - return SampleBatch(new_batch) - - -DreamerTorchPolicy = build_policy_class( - name="DreamerTorchPolicy", - framework="torch", - get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG, - action_sampler_fn=action_sampler_fn, - postprocess_fn=preprocess_episode, - loss_fn=dreamer_loss, - stats_fn=dreamer_stats, - make_model=build_dreamer_model, - optimizer_fn=dreamer_optimizer_fn, - extra_grad_process_fn=apply_grad_clipping, -) diff --git a/rllib/policy/torch_policy_v2.py b/rllib/policy/torch_policy_v2.py index 6489bd4b8..bbb1d17f8 100644 --- a/rllib/policy/torch_policy_v2.py +++ b/rllib/policy/torch_policy_v2.py @@ -1012,7 +1012,6 @@ class TorchPolicyV2(Policy): if is_overridden(self.action_sampler_fn): action_dist = dist_inputs = None actions, logp, state_out = self.action_sampler_fn( - self, self.model, obs_batch=input_dict, state_batches=state_batches,