ray/rllib/algorithms/maml/maml_torch_policy.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

450 lines
16 KiB
Python
Raw Normal View History

import logging
from typing import Dict, List, Type, Union
import ray
from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config
from ray.rllib.evaluation.postprocessing import (
Postprocessing,
compute_gae_for_sample_batch,
)
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import ValueNetworkMixin
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
try:
import higher
except (ImportError, ModuleNotFoundError):
raise ImportError(
(
"The MAML and MB-MPO algorithms require the `higher` module to be "
"installed! However, there was no installation found. You can install it "
"via `pip install higher`."
)
)
def PPOLoss(
dist_class,
actions,
curr_logits,
behaviour_logits,
advantages,
value_fn,
value_targets,
vf_preds,
cur_kl_coeff,
entropy_coeff,
clip_param,
vf_clip_param,
vf_loss_coeff,
clip_loss=False,
):
def surrogate_loss(
actions, curr_dist, prev_dist, advantages, clip_param, clip_loss
):
pi_new_logp = curr_dist.logp(actions)
pi_old_logp = prev_dist.logp(actions)
logp_ratio = torch.exp(pi_new_logp - pi_old_logp)
if clip_loss:
return torch.min(
advantages * logp_ratio,
advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param),
)
return advantages * logp_ratio
def kl_loss(curr_dist, prev_dist):
return prev_dist.kl(curr_dist)
def entropy_loss(dist):
return dist.entropy()
def vf_loss(value_fn, value_targets, vf_preds, vf_clip_param=0.1):
# GAE Value Function Loss
vf_loss1 = torch.pow(value_fn - value_targets, 2.0)
vf_clipped = vf_preds + torch.clamp(
value_fn - vf_preds, -vf_clip_param, vf_clip_param
)
vf_loss2 = torch.pow(vf_clipped - value_targets, 2.0)
vf_loss = torch.max(vf_loss1, vf_loss2)
return vf_loss
pi_new_dist = dist_class(curr_logits, None)
pi_old_dist = dist_class(behaviour_logits, None)
surr_loss = torch.mean(
surrogate_loss(
actions, pi_new_dist, pi_old_dist, advantages, clip_param, clip_loss
)
)
kl_loss = torch.mean(kl_loss(pi_new_dist, pi_old_dist))
vf_loss = torch.mean(vf_loss(value_fn, value_targets, vf_preds, vf_clip_param))
entropy_loss = torch.mean(entropy_loss(pi_new_dist))
total_loss = -surr_loss + cur_kl_coeff * kl_loss
total_loss += vf_loss_coeff * vf_loss
total_loss -= entropy_coeff * entropy_loss
return total_loss, surr_loss, kl_loss, vf_loss, entropy_loss
# This is the computation graph for workers (inner adaptation steps)
class WorkerLoss(object):
def __init__(
self,
model,
dist_class,
actions,
curr_logits,
behaviour_logits,
advantages,
value_fn,
value_targets,
vf_preds,
cur_kl_coeff,
entropy_coeff,
clip_param,
vf_clip_param,
vf_loss_coeff,
clip_loss=False,
):
self.loss, surr_loss, kl_loss, vf_loss, ent_loss = PPOLoss(
dist_class=dist_class,
actions=actions,
curr_logits=curr_logits,
behaviour_logits=behaviour_logits,
advantages=advantages,
value_fn=value_fn,
value_targets=value_targets,
vf_preds=vf_preds,
cur_kl_coeff=cur_kl_coeff,
entropy_coeff=entropy_coeff,
clip_param=clip_param,
vf_clip_param=vf_clip_param,
vf_loss_coeff=vf_loss_coeff,
clip_loss=clip_loss,
)
# This is the Meta-Update computation graph for main (meta-update step)
class MAMLLoss(object):
def __init__(
self,
model,
config,
dist_class,
value_targets,
advantages,
actions,
behaviour_logits,
vf_preds,
cur_kl_coeff,
policy_vars,
obs,
num_tasks,
split,
meta_opt,
inner_adaptation_steps=1,
entropy_coeff=0,
clip_param=0.3,
vf_clip_param=0.1,
vf_loss_coeff=1.0,
use_gae=True,
):
self.config = config
self.num_tasks = num_tasks
self.inner_adaptation_steps = inner_adaptation_steps
self.clip_param = clip_param
self.dist_class = dist_class
self.cur_kl_coeff = cur_kl_coeff
self.model = model
self.vf_clip_param = vf_clip_param
self.vf_loss_coeff = vf_loss_coeff
self.entropy_coeff = entropy_coeff
# Split episode tensors into [inner_adaptation_steps+1, num_tasks, -1]
self.obs = self.split_placeholders(obs, split)
self.actions = self.split_placeholders(actions, split)
self.behaviour_logits = self.split_placeholders(behaviour_logits, split)
self.advantages = self.split_placeholders(advantages, split)
self.value_targets = self.split_placeholders(value_targets, split)
self.vf_preds = self.split_placeholders(vf_preds, split)
inner_opt = torch.optim.SGD(model.parameters(), lr=config["inner_lr"])
surr_losses = []
val_losses = []
kl_losses = []
entropy_losses = []
meta_losses = []
kls = []
meta_opt.zero_grad()
for i in range(self.num_tasks):
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (
fnet,
diffopt,
):
inner_kls = []
for step in range(self.inner_adaptation_steps):
ppo_loss, _, inner_kl_loss, _, _ = self.compute_losses(
fnet, step, i
)
diffopt.step(ppo_loss)
inner_kls.append(inner_kl_loss)
kls.append(inner_kl_loss.detach())
# Meta Update
ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses(
fnet, self.inner_adaptation_steps - 1, i, clip_loss=True
)
inner_loss = torch.mean(
torch.stack(
[
a * b
for a, b in zip(
self.cur_kl_coeff[
i
* self.inner_adaptation_steps : (i + 1)
* self.inner_adaptation_steps
],
inner_kls,
)
]
)
)
meta_loss = (ppo_loss + inner_loss) / self.num_tasks
meta_loss.backward()
surr_losses.append(s_loss.detach())
kl_losses.append(kl_loss.detach())
val_losses.append(v_loss.detach())
entropy_losses.append(ent.detach())
meta_losses.append(meta_loss.detach())
meta_opt.step()
# Stats Logging
self.mean_policy_loss = torch.mean(torch.stack(surr_losses))
self.mean_kl_loss = torch.mean(torch.stack(kl_losses))
self.mean_vf_loss = torch.mean(torch.stack(val_losses))
self.mean_entropy = torch.mean(torch.stack(entropy_losses))
self.mean_inner_kl = kls
self.loss = torch.sum(torch.stack(meta_losses))
# Hacky, needed to bypass RLlib backend
self.loss.requires_grad = True
def compute_losses(self, model, inner_adapt_iter, task_iter, clip_loss=False):
obs = self.obs[inner_adapt_iter][task_iter]
obs_dict = {"obs": obs, "obs_flat": obs}
curr_logits, _ = model.forward(obs_dict, None, None)
value_fns = model.value_function()
ppo_loss, surr_loss, kl_loss, val_loss, ent_loss = PPOLoss(
dist_class=self.dist_class,
actions=self.actions[inner_adapt_iter][task_iter],
curr_logits=curr_logits,
behaviour_logits=self.behaviour_logits[inner_adapt_iter][task_iter],
advantages=self.advantages[inner_adapt_iter][task_iter],
value_fn=value_fns,
value_targets=self.value_targets[inner_adapt_iter][task_iter],
vf_preds=self.vf_preds[inner_adapt_iter][task_iter],
cur_kl_coeff=0.0,
entropy_coeff=self.entropy_coeff,
clip_param=self.clip_param,
vf_clip_param=self.vf_clip_param,
vf_loss_coeff=self.vf_loss_coeff,
clip_loss=clip_loss,
)
return ppo_loss, surr_loss, kl_loss, val_loss, ent_loss
def split_placeholders(self, placeholder, split):
inner_placeholder_list = torch.split(
placeholder, torch.sum(split, dim=1).tolist(), dim=0
)
placeholder_list = []
for index, split_placeholder in enumerate(inner_placeholder_list):
placeholder_list.append(
torch.split(split_placeholder, split[index].tolist(), dim=0)
)
return placeholder_list
class KLCoeffMixin:
def __init__(self, config):
2020-09-09 00:34:34 -07:00
self.kl_coeff_val = (
[config["kl_coeff"]]
* config["inner_adaptation_steps"]
* config["num_workers"]
)
self.kl_target = self.config["kl_target"]
def update_kls(self, sampled_kls):
for i, kl in enumerate(sampled_kls):
if kl < self.kl_target / 1.5:
self.kl_coeff_val[i] *= 0.5
elif kl > 1.5 * self.kl_target:
self.kl_coeff_val[i] *= 2.0
return self.kl_coeff_val
class MAMLTorchPolicy(ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2):
"""PyTorch policy class used with MAMLTrainer."""
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG, **config)
validate_config(config)
TorchPolicyV2.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"],
)
KLCoeffMixin.__init__(self, config)
ValueNetworkMixin.__init__(self, config)
# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()
@override(TorchPolicyV2)
def loss(
self,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss function.
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
The PPO loss tensor given the input batch.
"""
logits, state = model(train_batch)
self.cur_lr = self.config["lr"]
if self.config["worker_index"]:
self.loss_obj = WorkerLoss(
model=model,
dist_class=dist_class,
actions=train_batch[SampleBatch.ACTIONS],
curr_logits=logits,
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
advantages=train_batch[Postprocessing.ADVANTAGES],
value_fn=model.value_function(),
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=0.0,
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
clip_loss=False,
)
else:
self.var_list = model.named_parameters()
# `split` may not exist yet (during test-loss call), use a dummy value.
# Cannot use get here due to train_batch being a TrackingDict.
if "split" in train_batch:
split = train_batch["split"]
else:
split_shape = (
self.config["inner_adaptation_steps"],
self.config["num_workers"],
)
split_const = int(
train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])
)
split = torch.ones(split_shape, dtype=int) * split_const
self.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,
value_targets=train_batch[Postprocessing.VALUE_TARGETS],
advantages=train_batch[Postprocessing.ADVANTAGES],
actions=train_batch[SampleBatch.ACTIONS],
behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS],
vf_preds=train_batch[SampleBatch.VF_PREDS],
cur_kl_coeff=self.kl_coeff_val,
policy_vars=self.var_list,
obs=train_batch[SampleBatch.CUR_OBS],
num_tasks=self.config["num_workers"],
split=split,
config=self.config,
inner_adaptation_steps=self.config["inner_adaptation_steps"],
entropy_coeff=self.config["entropy_coeff"],
clip_param=self.config["clip_param"],
vf_clip_param=self.config["vf_clip_param"],
vf_loss_coeff=self.config["vf_loss_coeff"],
use_gae=self.config["use_gae"],
meta_opt=self.meta_opt,
)
return self.loss_obj.loss
@override(TorchPolicyV2)
def optimizer(
self,
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""
Workers use simple SGD for inner adaptation
Meta-Policy uses Adam optimizer for meta-update
"""
if not self.config["worker_index"]:
self.meta_opt = torch.optim.Adam(
self.model.parameters(), lr=self.config["lr"]
)
return self.meta_opt
return torch.optim.SGD(self.model.parameters(), lr=self.config["inner_lr"])
@override(TorchPolicyV2)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
if self.config["worker_index"]:
return convert_to_numpy({"worker_loss": self.loss_obj.loss})
else:
return convert_to_numpy(
{
"cur_kl_coeff": self.kl_coeff_val,
"cur_lr": self.cur_lr,
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
"vf_loss": self.loss_obj.mean_vf_loss,
"kl_loss": self.loss_obj.mean_kl_loss,
"inner_kl": self.loss_obj.mean_inner_kl,
"entropy": self.loss_obj.mean_entropy,
}
)
@override(TorchPolicyV2)
def extra_grad_process(
self, optimizer: "torch.optim.Optimizer", loss: TensorType
) -> Dict[str, TensorType]:
return apply_grad_clipping(self, optimizer, loss)
@override(TorchPolicyV2)
def postprocess_trajectory(
self, sample_batch, other_agent_batches=None, episode=None
):
# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak
# in torch (issue #6962).
# TODO: no_grad still necessary?
with torch.no_grad():
return compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)