mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Dreamer Policy sub-classing schema. (#25585)
This commit is contained in:
parent
65d7a610ab
commit
7495e9c89c
2 changed files with 231 additions and 248 deletions
|
@ -1,19 +1,27 @@
|
|||
import logging
|
||||
from typing import (
|
||||
List,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import logging
|
||||
import ray
|
||||
import numpy as np
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from ray.rllib.algorithms.dreamer.utils import FreezeParameters
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import AgentID, TensorType
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
|
@ -22,74 +30,91 @@ if torch:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# This is the computation graph for workers (inner adaptation steps)
|
||||
def compute_dreamer_loss(
|
||||
obs: TensorType,
|
||||
action: TensorType,
|
||||
reward: TensorType,
|
||||
model: TorchModelV2,
|
||||
imagine_horizon: int,
|
||||
gamma: float = 0.99,
|
||||
lambda_: float = 0.95,
|
||||
kl_coeff: float = 1.0,
|
||||
free_nats: float = 3.0,
|
||||
log: bool = False,
|
||||
):
|
||||
"""Constructs loss for the Dreamer objective.
|
||||
class DreamerTorchPolicy(TorchPolicyV2):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
|
||||
Args:
|
||||
obs: Observations (o_t).
|
||||
action: Actions (a_(t-1)).
|
||||
reward: Rewards (r_(t-1)).
|
||||
model: DreamerModel, encompassing all other models.
|
||||
imagine_horizon: Imagine horizon for actor and critic loss.
|
||||
gamma: Discount factor gamma.
|
||||
lambda_: Lambda, like in GAE.
|
||||
kl_coeff: KL Coefficient for Divergence loss in model loss.
|
||||
free_nats: Threshold for minimum divergence in model loss.
|
||||
log: If log, generate gifs.
|
||||
"""
|
||||
encoder_weights = list(model.encoder.parameters())
|
||||
decoder_weights = list(model.decoder.parameters())
|
||||
reward_weights = list(model.reward.parameters())
|
||||
dynamics_weights = list(model.dynamics.parameters())
|
||||
critic_weights = list(model.value.parameters())
|
||||
config = dict(ray.rllib.algorithms.dreamer.DreamerConfig().to_dict(), **config)
|
||||
|
||||
TorchPolicyV2.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
config,
|
||||
max_seq_len=config["model"]["max_seq_len"],
|
||||
)
|
||||
|
||||
# TODO: Don't require users to call this manually.
|
||||
self._initialize_loss_from_dummy_batch()
|
||||
|
||||
@override(TorchPolicyV2)
|
||||
def loss(
|
||||
self, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch
|
||||
) -> Union[TensorType, List[TensorType]]:
|
||||
log_gif = False
|
||||
if "log_gif" in train_batch:
|
||||
log_gif = True
|
||||
|
||||
# This is the computation graph for workers (inner adaptation steps)
|
||||
encoder_weights = list(self.model.encoder.parameters())
|
||||
decoder_weights = list(self.model.decoder.parameters())
|
||||
reward_weights = list(self.model.reward.parameters())
|
||||
dynamics_weights = list(self.model.dynamics.parameters())
|
||||
critic_weights = list(self.model.value.parameters())
|
||||
model_weights = list(
|
||||
encoder_weights + decoder_weights + reward_weights + dynamics_weights
|
||||
)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
|
||||
# PlaNET Model Loss
|
||||
latent = model.encoder(obs)
|
||||
post, prior = model.dynamics.observe(latent, action)
|
||||
features = model.dynamics.get_feature(post)
|
||||
image_pred = model.decoder(features)
|
||||
reward_pred = model.reward(features)
|
||||
image_loss = -torch.mean(image_pred.log_prob(obs))
|
||||
reward_loss = -torch.mean(reward_pred.log_prob(reward))
|
||||
prior_dist = model.dynamics.get_dist(prior[0], prior[1])
|
||||
post_dist = model.dynamics.get_dist(post[0], post[1])
|
||||
latent = self.model.encoder(train_batch["obs"])
|
||||
post, prior = self.model.dynamics.observe(latent, train_batch["actions"])
|
||||
features = self.model.dynamics.get_feature(post)
|
||||
image_pred = self.model.decoder(features)
|
||||
reward_pred = self.model.reward(features)
|
||||
image_loss = -torch.mean(image_pred.log_prob(train_batch["obs"]))
|
||||
reward_loss = -torch.mean(reward_pred.log_prob(train_batch["rewards"]))
|
||||
prior_dist = self.model.dynamics.get_dist(prior[0], prior[1])
|
||||
post_dist = self.model.dynamics.get_dist(post[0], post[1])
|
||||
div = torch.mean(
|
||||
torch.distributions.kl_divergence(post_dist, prior_dist).sum(dim=2)
|
||||
)
|
||||
div = torch.clamp(div, min=free_nats)
|
||||
model_loss = kl_coeff * div + reward_loss + image_loss
|
||||
div = torch.clamp(div, min=(self.config["free_nats"]))
|
||||
model_loss = self.config["kl_coeff"] * div + reward_loss + image_loss
|
||||
|
||||
# Actor Loss
|
||||
# [imagine_horizon, batch_length*batch_size, feature_size]
|
||||
with torch.no_grad():
|
||||
actor_states = [v.detach() for v in post]
|
||||
with FreezeParameters(model_weights):
|
||||
imag_feat = model.imagine_ahead(actor_states, imagine_horizon)
|
||||
imag_feat = self.model.imagine_ahead(
|
||||
actor_states, self.config["imagine_horizon"]
|
||||
)
|
||||
with FreezeParameters(model_weights + critic_weights):
|
||||
reward = model.reward(imag_feat).mean
|
||||
value = model.value(imag_feat).mean
|
||||
pcont = gamma * torch.ones_like(reward)
|
||||
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
|
||||
reward = self.model.reward(imag_feat).mean
|
||||
value = self.model.value(imag_feat).mean
|
||||
pcont = self.config["gamma"] * torch.ones_like(reward)
|
||||
|
||||
# Similar to GAE-Lambda, calculate value targets
|
||||
next_values = torch.cat([value[:-1][1:], value[-1][None]], dim=0)
|
||||
inputs = reward[:-1] + pcont[:-1] * next_values * (1 - self.config["lambda"])
|
||||
|
||||
def agg_fn(x, y):
|
||||
return y[0] + y[1] * self.config["lambda"] * x
|
||||
|
||||
last = value[-1]
|
||||
returns = []
|
||||
for i in reversed(range(len(inputs))):
|
||||
last = agg_fn(last, [inputs[i], pcont[:-1][i]])
|
||||
returns.append(last)
|
||||
|
||||
returns = list(reversed(returns))
|
||||
returns = torch.stack(returns, dim=0)
|
||||
discount_shape = pcont[:1].size()
|
||||
discount = torch.cumprod(
|
||||
torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0), dim=0
|
||||
torch.cat([torch.ones(*discount_shape).to(device), pcont[:-2]], dim=0),
|
||||
dim=0,
|
||||
)
|
||||
actor_loss = -torch.mean(discount * returns)
|
||||
|
||||
|
@ -98,17 +123,21 @@ def compute_dreamer_loss(
|
|||
val_feat = imag_feat.detach()[:-1]
|
||||
target = returns.detach()
|
||||
val_discount = discount.detach()
|
||||
val_pred = model.value(val_feat)
|
||||
val_pred = self.model.value(val_feat)
|
||||
critic_loss = -torch.mean(val_discount * val_pred.log_prob(target))
|
||||
|
||||
# Logging purposes
|
||||
prior_ent = torch.mean(prior_dist.entropy())
|
||||
post_ent = torch.mean(post_dist.entropy())
|
||||
|
||||
log_gif = None
|
||||
if log:
|
||||
log_gif = log_summary(obs, action, latent, image_pred, model)
|
||||
|
||||
gif = None
|
||||
if log_gif:
|
||||
gif = log_summary(
|
||||
train_batch["obs"],
|
||||
train_batch["actions"],
|
||||
latent,
|
||||
image_pred,
|
||||
self.model,
|
||||
)
|
||||
return_dict = {
|
||||
"model_loss": model_loss,
|
||||
"reward_loss": reward_loss,
|
||||
|
@ -119,140 +148,27 @@ def compute_dreamer_loss(
|
|||
"prior_ent": prior_ent,
|
||||
"post_ent": post_ent,
|
||||
}
|
||||
if gif is not None:
|
||||
return_dict["log_gif"] = gif
|
||||
self.stats_dict = return_dict
|
||||
|
||||
if log_gif is not None:
|
||||
return_dict["log_gif"] = log_gif
|
||||
return return_dict
|
||||
loss_dict = self.stats_dict
|
||||
|
||||
|
||||
# Similar to GAE-Lambda, calculate value targets
|
||||
def lambda_return(reward, value, pcont, bootstrap, lambda_):
|
||||
def agg_fn(x, y):
|
||||
return y[0] + y[1] * lambda_ * x
|
||||
|
||||
next_values = torch.cat([value[1:], bootstrap[None]], dim=0)
|
||||
inputs = reward + pcont * next_values * (1 - lambda_)
|
||||
|
||||
last = bootstrap
|
||||
returns = []
|
||||
for i in reversed(range(len(inputs))):
|
||||
last = agg_fn(last, [inputs[i], pcont[i]])
|
||||
returns.append(last)
|
||||
|
||||
returns = list(reversed(returns))
|
||||
returns = torch.stack(returns, dim=0)
|
||||
return returns
|
||||
|
||||
|
||||
# Creates gif
|
||||
def log_summary(obs, action, embed, image_pred, model):
|
||||
truth = obs[:6] + 0.5
|
||||
recon = image_pred.mean[:6]
|
||||
init, _ = model.dynamics.observe(embed[:6, :5], action[:6, :5])
|
||||
init = [itm[:, -1] for itm in init]
|
||||
prior = model.dynamics.imagine(action[:6, 5:], init)
|
||||
openl = model.decoder(model.dynamics.get_feature(prior)).mean
|
||||
|
||||
mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1)
|
||||
error = (mod - truth + 1.0) / 2.0
|
||||
return torch.cat([truth, mod, error], 3)
|
||||
|
||||
|
||||
def dreamer_loss(policy, model, dist_class, train_batch):
|
||||
log_gif = False
|
||||
if "log_gif" in train_batch:
|
||||
log_gif = True
|
||||
|
||||
policy.stats_dict = compute_dreamer_loss(
|
||||
train_batch["obs"],
|
||||
train_batch["actions"],
|
||||
train_batch["rewards"],
|
||||
policy.model,
|
||||
policy.config["imagine_horizon"],
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
policy.config["kl_coeff"],
|
||||
policy.config["free_nats"],
|
||||
log_gif,
|
||||
return (
|
||||
loss_dict["model_loss"],
|
||||
loss_dict["actor_loss"],
|
||||
loss_dict["critic_loss"],
|
||||
)
|
||||
|
||||
loss_dict = policy.stats_dict
|
||||
|
||||
return (loss_dict["model_loss"], loss_dict["actor_loss"], loss_dict["critic_loss"])
|
||||
|
||||
|
||||
def build_dreamer_model(policy, obs_space, action_space, config):
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space,
|
||||
action_space,
|
||||
1,
|
||||
config["dreamer_model"],
|
||||
name="DreamerModel",
|
||||
framework="torch",
|
||||
)
|
||||
|
||||
policy.model_variables = model.variables()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def action_sampler_fn(policy, model, input_dict, state, explore, timestep):
|
||||
"""Action sampler function has two phases. During the prefill phase,
|
||||
actions are sampled uniformly [-1, 1]. During training phase, actions
|
||||
are evaluated through DreamerPolicy and an additive gaussian is added
|
||||
to incentivize exploration.
|
||||
"""
|
||||
obs = input_dict["obs"]
|
||||
|
||||
# Custom Exploration
|
||||
if timestep <= policy.config["prefill_timesteps"]:
|
||||
logp = None
|
||||
# Random action in space [-1.0, 1.0]
|
||||
action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
|
||||
state = model.get_initial_state()
|
||||
else:
|
||||
# Weird RLlib Handling, this happens when env rests
|
||||
if len(state[0].size()) == 3:
|
||||
# Very hacky, but works on all envs
|
||||
state = model.get_initial_state()
|
||||
action, logp, state = model.policy(obs, state, explore)
|
||||
action = td.Normal(action, policy.config["explore_noise"]).sample()
|
||||
action = torch.clamp(action, min=-1.0, max=1.0)
|
||||
|
||||
policy.global_timestep += policy.config["action_repeat"]
|
||||
|
||||
return action, logp, state
|
||||
|
||||
|
||||
def dreamer_stats(policy, train_batch):
|
||||
return policy.stats_dict
|
||||
|
||||
|
||||
def dreamer_optimizer_fn(policy, config):
|
||||
model = policy.model
|
||||
encoder_weights = list(model.encoder.parameters())
|
||||
decoder_weights = list(model.decoder.parameters())
|
||||
reward_weights = list(model.reward.parameters())
|
||||
dynamics_weights = list(model.dynamics.parameters())
|
||||
actor_weights = list(model.actor.parameters())
|
||||
critic_weights = list(model.value.parameters())
|
||||
model_opt = torch.optim.Adam(
|
||||
encoder_weights + decoder_weights + reward_weights + dynamics_weights,
|
||||
lr=config["td_model_lr"],
|
||||
)
|
||||
actor_opt = torch.optim.Adam(actor_weights, lr=config["actor_lr"])
|
||||
critic_opt = torch.optim.Adam(critic_weights, lr=config["critic_lr"])
|
||||
|
||||
return (model_opt, actor_opt, critic_opt)
|
||||
|
||||
|
||||
def preprocess_episode(
|
||||
policy: Policy,
|
||||
@override(TorchPolicyV2)
|
||||
def postprocess_trajectory(
|
||||
self,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
|
||||
episode: Optional[Episode] = None,
|
||||
) -> SampleBatch:
|
||||
other_agent_batches: Optional[
|
||||
Dict[AgentID, Tuple["Policy", SampleBatch]]
|
||||
] = None,
|
||||
episode: Optional["Episode"] = None,
|
||||
) -> SampleBatch:
|
||||
"""Batch format should be in the form of (s_t, a_(t-1), r_(t-1))
|
||||
When t=0, the resetted obs is paired with action and reward of 0.
|
||||
"""
|
||||
|
@ -280,16 +196,84 @@ def preprocess_episode(
|
|||
}
|
||||
return SampleBatch(new_batch)
|
||||
|
||||
def stats_fn(self, train_batch):
|
||||
return self.stats_dict
|
||||
|
||||
DreamerTorchPolicy = build_policy_class(
|
||||
name="DreamerTorchPolicy",
|
||||
@override(TorchPolicyV2)
|
||||
def optimizer(self):
|
||||
model = self.model
|
||||
encoder_weights = list(model.encoder.parameters())
|
||||
decoder_weights = list(model.decoder.parameters())
|
||||
reward_weights = list(model.reward.parameters())
|
||||
dynamics_weights = list(model.dynamics.parameters())
|
||||
actor_weights = list(model.actor.parameters())
|
||||
critic_weights = list(model.value.parameters())
|
||||
model_opt = torch.optim.Adam(
|
||||
encoder_weights + decoder_weights + reward_weights + dynamics_weights,
|
||||
lr=self.config["td_model_lr"],
|
||||
)
|
||||
actor_opt = torch.optim.Adam(actor_weights, lr=self.config["actor_lr"])
|
||||
critic_opt = torch.optim.Adam(critic_weights, lr=self.config["critic_lr"])
|
||||
|
||||
return (model_opt, actor_opt, critic_opt)
|
||||
|
||||
def action_sampler_fn(policy, model, obs_batch, state_batches, explore, timestep):
|
||||
"""Action sampler function has two phases. During the prefill phase,
|
||||
actions are sampled uniformly [-1, 1]. During training phase, actions
|
||||
are evaluated through DreamerPolicy and an additive gaussian is added
|
||||
to incentivize exploration.
|
||||
"""
|
||||
obs = obs_batch["obs"]
|
||||
|
||||
# Custom Exploration
|
||||
if timestep <= policy.config["prefill_timesteps"]:
|
||||
logp = None
|
||||
# Random action in space [-1.0, 1.0]
|
||||
action = 2.0 * torch.rand(1, model.action_space.shape[0]) - 1.0
|
||||
state_batches = model.get_initial_state()
|
||||
else:
|
||||
# Weird RLlib Handling, this happens when env rests
|
||||
if len(state_batches[0].size()) == 3:
|
||||
# Very hacky, but works on all envs
|
||||
state_batches = model.get_initial_state()
|
||||
action, logp, state_batches = model.policy(obs, state_batches, explore)
|
||||
action = td.Normal(action, policy.config["explore_noise"]).sample()
|
||||
action = torch.clamp(action, min=-1.0, max=1.0)
|
||||
|
||||
policy.global_timestep += policy.config["action_repeat"]
|
||||
|
||||
return action, logp, state_batches
|
||||
|
||||
def make_model(self):
|
||||
|
||||
model = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
1,
|
||||
self.config["dreamer_model"],
|
||||
name="DreamerModel",
|
||||
framework="torch",
|
||||
get_default_config=lambda: ray.rllib.algorithms.dreamer.dreamer.DEFAULT_CONFIG,
|
||||
action_sampler_fn=action_sampler_fn,
|
||||
postprocess_fn=preprocess_episode,
|
||||
loss_fn=dreamer_loss,
|
||||
stats_fn=dreamer_stats,
|
||||
make_model=build_dreamer_model,
|
||||
optimizer_fn=dreamer_optimizer_fn,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
)
|
||||
)
|
||||
|
||||
self.model_variables = model.variables()
|
||||
|
||||
return model
|
||||
|
||||
def extra_grad_process(
|
||||
self, optimizer: "torch.optim.Optimizer", loss: TensorType
|
||||
) -> Dict[str, TensorType]:
|
||||
return apply_grad_clipping(self, optimizer, loss)
|
||||
|
||||
|
||||
# Creates gif
|
||||
def log_summary(obs, action, embed, image_pred, model):
|
||||
truth = obs[:6] + 0.5
|
||||
recon = image_pred.mean[:6]
|
||||
init, _ = model.dynamics.observe(embed[:6, :5], action[:6, :5])
|
||||
init = [itm[:, -1] for itm in init]
|
||||
prior = model.dynamics.imagine(action[:6, 5:], init)
|
||||
openl = model.decoder(model.dynamics.get_feature(prior)).mean
|
||||
|
||||
mod = torch.cat([recon[:, :5] + 0.5, openl + 0.5], 1)
|
||||
error = (mod - truth + 1.0) / 2.0
|
||||
return torch.cat([truth, mod, error], 3)
|
||||
|
|
|
@ -1012,7 +1012,6 @@ class TorchPolicyV2(Policy):
|
|||
if is_overridden(self.action_sampler_fn):
|
||||
action_dist = dist_inputs = None
|
||||
actions, logp, state_out = self.action_sampler_fn(
|
||||
self,
|
||||
self.model,
|
||||
obs_batch=input_dict,
|
||||
state_batches=state_batches,
|
||||
|
|
Loading…
Add table
Reference in a new issue