ray/rllib/algorithms/cql/cql_torch_policy.py

406 lines
14 KiB
Python

"""
PyTorch policy class used for CQL.
"""
import numpy as np
import gym
import logging
import tree
from typing import Dict, List, Tuple, Type, Union
import ray
import ray.experimental.tf_utils
from ray.rllib.algorithms.sac.sac_tf_policy import (
postprocess_trajectory,
validate_spaces,
)
from ray.rllib.algorithms.sac.sac_torch_policy import (
_get_dist_class,
stats,
build_sac_model_and_action_dist,
optimizer_fn,
ComputeTDErrorMixin,
TargetNetworkMixin,
setup_late_mixins,
action_distribution_fn,
)
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict
from ray.rllib.utils.torch_utils import (
apply_grad_clipping,
convert_to_torch_tensor,
concat_multi_gpu_td_errors,
)
torch, nn = try_import_torch()
F = nn.functional
logger = logging.getLogger(__name__)
MEAN_MIN = -9.0
MEAN_MAX = 9.0
def _repeat_tensor(t: TensorType, n: int):
# Insert new dimension at posotion 1 into tensor t
t_rep = t.unsqueeze(1)
# Repeat tensor t_rep along new dimension n times
t_rep = torch.repeat_interleave(t_rep, n, dim=1)
# Merge new dimension into batch dimension
t_rep = t_rep.view(-1, *t.shape[1:])
return t_rep
# Returns policy tiled actions and log probabilities for CQL Loss
def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
batch_size = tree.flatten(obs)[0].shape[0]
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
logits, _ = model.get_action_model_outputs(obs_temp)
policy_dist = action_dist(logits, model)
actions, logp_ = policy_dist.sample_logp()
logp = logp_.unsqueeze(-1)
return actions, logp.view(batch_size, num_repeat, 1)
def q_values_repeat(model, obs, actions, twin=False):
action_shape = actions.shape[0]
obs_shape = tree.flatten(obs)[0].shape[0]
num_repeat = int(action_shape / obs_shape)
obs_temp = tree.map_structure(lambda t: _repeat_tensor(t, num_repeat), obs)
if not twin:
preds_, _ = model.get_q_values(obs_temp, actions)
else:
preds_, _ = model.get_twin_q_values(obs_temp, actions)
preds = preds_.view(obs_shape, num_repeat, 1)
return preds
def cql_loss(
policy: Policy,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
logger.info(f"Current iteration = {policy.cur_iter}")
policy.cur_iter += 1
# Look up the target model (tower) using the model tower.
target_model = policy.target_models[model]
# For best performance, turn deterministic off
deterministic = policy.config["_deterministic_loss"]
assert not deterministic
twin_q = policy.config["twin_q"]
discount = policy.config["gamma"]
action_low = model.action_space.low[0]
action_high = model.action_space.high[0]
# CQL Parameters
bc_iters = policy.config["bc_iters"]
cql_temp = policy.config["temperature"]
num_actions = policy.config["num_actions"]
min_q_weight = policy.config["min_q_weight"]
use_lagrange = policy.config["lagrangian"]
target_action_gap = policy.config["lagrangian_thresh"]
obs = train_batch[SampleBatch.CUR_OBS]
actions = train_batch[SampleBatch.ACTIONS]
rewards = train_batch[SampleBatch.REWARDS].float()
next_obs = train_batch[SampleBatch.NEXT_OBS]
terminals = train_batch[SampleBatch.DONES]
model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)
model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None)
target_model_out_tp1, _ = target_model(
SampleBatch(obs=next_obs, _is_training=True), [], None
)
action_dist_class = _get_dist_class(policy, policy.config, policy.action_space)
action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
action_dist_t = action_dist_class(action_dist_inputs_t, model)
policy_t, log_pis_t = action_dist_t.sample_logp()
log_pis_t = torch.unsqueeze(log_pis_t, -1)
# Unlike original SAC, Alpha and Actor Loss are computed first.
# Alpha Loss
alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean()
batch_size = tree.flatten(obs)[0].shape[0]
if batch_size == policy.config["train_batch_size"]:
policy.alpha_optim.zero_grad()
alpha_loss.backward()
policy.alpha_optim.step()
# Policy Loss (Either Behavior Clone Loss or SAC Loss)
alpha = torch.exp(model.log_alpha)
if policy.cur_iter >= bc_iters:
min_q, _ = model.get_q_values(model_out_t, policy_t)
if twin_q:
twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
min_q = torch.min(min_q, twin_q_)
actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
else:
bc_logp = action_dist_t.logp(actions)
actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
# actor_loss = -bc_logp.mean()
if batch_size == policy.config["train_batch_size"]:
policy.actor_optim.zero_grad()
actor_loss.backward(retain_graph=True)
policy.actor_optim.step()
# Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
# SAC Loss:
# Q-values for the batched actions.
action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
policy_tp1, _ = action_dist_tp1.sample_logp()
q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
q_t_selected = torch.squeeze(q_t, dim=-1)
if twin_q:
twin_q_t, _ = model.get_twin_q_values(
model_out_t, train_batch[SampleBatch.ACTIONS]
)
twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
# Target q network evaluation.
q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
if twin_q:
twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1)
# Take min over both twin-NNs.
q_tp1 = torch.min(q_tp1, twin_q_tp1)
q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
# compute RHS of bellman equation
q_t_target = (
rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked
).detach()
# Compute the TD-error (potentially clipped), for priority replay buffer
base_td_error = torch.abs(q_t_selected - q_t_target)
if twin_q:
twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
td_error = 0.5 * (base_td_error + twin_td_error)
else:
td_error = base_td_error
critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
if twin_q:
critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)
# CQL Loss (We are using Entropy version of CQL (the best version))
rand_actions = convert_to_torch_tensor(
torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(
action_low, action_high
),
policy.device,
)
curr_actions, curr_logp = policy_actions_repeat(
model, action_dist_class, model_out_t, num_actions
)
next_actions, next_logp = policy_actions_repeat(
model, action_dist_class, model_out_tp1, num_actions
)
q1_rand = q_values_repeat(model, model_out_t, rand_actions)
q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
if twin_q:
q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True)
q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True)
random_density = np.log(0.5 ** curr_actions.shape[-1])
cat_q1 = torch.cat(
[
q1_rand - random_density,
q1_next_actions - next_logp.detach(),
q1_curr_actions - curr_logp.detach(),
],
1,
)
if twin_q:
cat_q2 = torch.cat(
[
q2_rand - random_density,
q2_next_actions - next_logp.detach(),
q2_curr_actions - curr_logp.detach(),
],
1,
)
min_qf1_loss_ = (
torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
)
min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
if twin_q:
min_qf2_loss_ = (
torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
)
min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)
if use_lagrange:
alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[
0
]
min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
if twin_q:
min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
else:
alpha_prime_loss = -min_qf1_loss
cql_loss = [min_qf1_loss]
if twin_q:
cql_loss.append(min_qf2_loss)
critic_loss = [critic_loss_1 + min_qf1_loss]
if twin_q:
critic_loss.append(critic_loss_2 + min_qf2_loss)
if batch_size == policy.config["train_batch_size"]:
policy.critic_optims[0].zero_grad()
critic_loss[0].backward(retain_graph=True)
policy.critic_optims[0].step()
if twin_q:
policy.critic_optims[1].zero_grad()
critic_loss[1].backward(retain_graph=False)
policy.critic_optims[1].step()
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
# SAC stats.
model.tower_stats["q_t"] = q_t_selected
model.tower_stats["policy_t"] = policy_t
model.tower_stats["log_pis_t"] = log_pis_t
model.tower_stats["actor_loss"] = actor_loss
model.tower_stats["critic_loss"] = critic_loss
model.tower_stats["alpha_loss"] = alpha_loss
model.tower_stats["log_alpha_value"] = model.log_alpha
model.tower_stats["alpha_value"] = alpha
model.tower_stats["target_entropy"] = model.target_entropy
# CQL stats.
model.tower_stats["cql_loss"] = cql_loss
# TD-error tensor in final stats
# will be concatenated and retrieved for each individual batch item.
model.tower_stats["td_error"] = td_error
if use_lagrange:
model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
model.tower_stats["alpha_prime_value"] = alpha_prime
model.tower_stats["alpha_prime_loss"] = alpha_prime_loss
if batch_size == policy.config["train_batch_size"]:
policy.alpha_prime_optim.zero_grad()
alpha_prime_loss.backward()
policy.alpha_prime_optim.step()
# Return all loss terms corresponding to our optimizers.
return tuple(
[actor_loss]
+ critic_loss
+ [alpha_loss]
+ ([alpha_prime_loss] if use_lagrange else [])
)
def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
# Get SAC loss stats.
stats_dict = stats(policy, train_batch)
# Add CQL loss stats to the dict.
stats_dict["cql_loss"] = torch.mean(
torch.stack(*policy.get_tower_stats("cql_loss"))
)
if policy.config["lagrangian"]:
stats_dict["log_alpha_prime_value"] = torch.mean(
torch.stack(policy.get_tower_stats("log_alpha_prime_value"))
)
stats_dict["alpha_prime_value"] = torch.mean(
torch.stack(policy.get_tower_stats("alpha_prime_value"))
)
stats_dict["alpha_prime_loss"] = torch.mean(
torch.stack(policy.get_tower_stats("alpha_prime_loss"))
)
return stats_dict
def cql_optimizer_fn(
policy: Policy, config: AlgorithmConfigDict
) -> Tuple[LocalOptimizer]:
policy.cur_iter = 0
opt_list = optimizer_fn(policy, config)
if config["lagrangian"]:
log_alpha_prime = nn.Parameter(torch.zeros(1, requires_grad=True).float())
policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
policy.alpha_prime_optim = torch.optim.Adam(
params=[policy.model.log_alpha_prime],
lr=config["optimization"]["critic_learning_rate"],
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
)
return tuple(
[policy.actor_optim]
+ policy.critic_optims
+ [policy.alpha_optim]
+ [policy.alpha_prime_optim]
)
return opt_list
def cql_setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: AlgorithmConfigDict,
) -> None:
setup_late_mixins(policy, obs_space, action_space, config)
if config["lagrangian"]:
policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(policy.device)
def compute_gradients_fn(policy, postprocessed_batch):
batches = [policy._lazy_tensor_dict(postprocessed_batch)]
model = policy.model
policy._loss(policy, model, policy.dist_class, batches[0])
stats = {LEARNER_STATS_KEY: policy._convert_to_numpy(cql_stats(policy, batches[0]))}
return [None, stats]
def apply_gradients_fn(policy, gradients):
return
# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
CQLTorchPolicy = build_policy_class(
name="CQLTorchPolicy",
framework="torch",
loss_fn=cql_loss,
get_default_config=lambda: ray.rllib.algorithms.cql.cql.DEFAULT_CONFIG,
stats_fn=cql_stats,
postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=cql_optimizer_fn,
validate_spaces=validate_spaces,
before_loss_init=cql_setup_late_mixins,
make_model_and_action_dist=build_sac_model_and_action_dist,
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
action_distribution_fn=action_distribution_fn,
compute_gradients_fn=compute_gradients_fn,
apply_gradients_fn=apply_gradients_fn,
)