mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] New Offline RL Algorithm: CQL (based on SAC) (#13118)
This commit is contained in:
parent
33089c44e2
commit
42cd414e5b
6 changed files with 417 additions and 6 deletions
8
rllib/agents/cql/__init__.py
Normal file
8
rllib/agents/cql/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"CQL_DEFAULT_CONFIG",
|
||||
"CQLTorchPolicy",
|
||||
"CQLTrainer",
|
||||
]
|
51
rllib/agents/cql/cql.py
Normal file
51
rllib/agents/cql/cql.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
"""CQL (derived from SAC).
|
||||
"""
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.sac.sac import SACTrainer, \
|
||||
DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
CQL_DEFAULT_CONFIG = merge_dicts(
|
||||
SAC_CONFIG, {
|
||||
# You should override this to point to an offline dataset.
|
||||
"input": "sampler",
|
||||
# Number of iterations with Behavior Cloning Pretraining
|
||||
"bc_iters": 20000,
|
||||
# CQL Loss Temperature
|
||||
"temperature": 1.0,
|
||||
# Num Actions to sample for CQL Loss
|
||||
"num_actions": 10,
|
||||
# Whether to use the Langrangian for Alpha Prime (in CQL Loss)
|
||||
"lagrangian": False,
|
||||
# Lagrangian Threshold
|
||||
"lagrangian_thresh": 5.0,
|
||||
# Min Q Weight multiplier
|
||||
"min_q_weight": 5.0,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config: TrainerConfigDict):
|
||||
if config["framework"] == "tf":
|
||||
raise ValueError("Tensorflow CQL not implemented yet!")
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
if config["framework"] == "torch":
|
||||
return CQLTorchPolicy
|
||||
|
||||
|
||||
CQLTrainer = SACTrainer.with_updates(
|
||||
name="CQL",
|
||||
default_config=CQL_DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
default_policy=CQLTorchPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
)
|
301
rllib/agents/cql/cql_torch_policy.py
Normal file
301
rllib/agents/cql/cql_torch_policy.py
Normal file
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
PyTorch policy class used for CQL.
|
||||
"""
|
||||
import numpy as np
|
||||
import gym
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Type, Union
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
|
||||
validate_spaces
|
||||
from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
|
||||
build_sac_model_and_action_dist, optimizer_fn, ComputeTDErrorMixin, \
|
||||
TargetNetworkMixin, setup_late_mixins, action_distribution_fn
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
|
||||
TrainerConfigDict
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = nn.functional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Returns policy tiled actions and log probabilities for CQL Loss
|
||||
def policy_actions_repeat(model, action_dist, obs, num_repeat=1):
|
||||
obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
|
||||
obs.shape[0] * num_repeat, obs.shape[1])
|
||||
policy_dist = action_dist(model.get_policy_output(obs_temp), model)
|
||||
actions = policy_dist.sample()
|
||||
log_p = torch.unsqueeze(policy_dist.logp(actions), -1)
|
||||
return actions, log_p.squeeze()
|
||||
|
||||
|
||||
def q_values_repeat(model, obs, actions, twin=False):
|
||||
action_shape = actions.shape[0]
|
||||
obs_shape = obs.shape[0]
|
||||
num_repeat = int(action_shape / obs_shape)
|
||||
obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(
|
||||
obs.shape[0] * num_repeat, obs.shape[1])
|
||||
if twin:
|
||||
preds = model.get_q_values(obs_temp, actions)
|
||||
else:
|
||||
preds = model.get_twin_q_values(obs_temp, actions)
|
||||
preds = preds.view(obs.shape[0], num_repeat, 1)
|
||||
return preds
|
||||
|
||||
|
||||
def cql_loss(policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
print(policy.cur_iter)
|
||||
policy.cur_iter += 1
|
||||
# For best performance, turn deterministic off
|
||||
deterministic = policy.config["_deterministic_loss"]
|
||||
twin_q = policy.config["twin_q"]
|
||||
discount = policy.config["gamma"]
|
||||
action_low = model.action_space.low[0]
|
||||
action_high = model.action_space.high[0]
|
||||
|
||||
# CQL Parameters
|
||||
bc_iters = policy.config["bc_iters"]
|
||||
cql_temp = policy.config["temperature"]
|
||||
num_actions = policy.config["num_actions"]
|
||||
min_q_weight = policy.config["min_q_weight"]
|
||||
use_lagrange = policy.config["lagrangian"]
|
||||
target_action_gap = policy.config["lagrangian_thresh"]
|
||||
|
||||
obs = train_batch[SampleBatch.CUR_OBS]
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
rewards = train_batch[SampleBatch.REWARDS]
|
||||
next_obs = train_batch[SampleBatch.NEXT_OBS]
|
||||
terminals = train_batch[SampleBatch.DONES]
|
||||
|
||||
model_out_t, _ = model({
|
||||
"obs": obs,
|
||||
"is_training": True,
|
||||
}, [], None)
|
||||
|
||||
model_out_tp1, _ = model({
|
||||
"obs": next_obs,
|
||||
"is_training": True,
|
||||
}, [], None)
|
||||
|
||||
target_model_out_tp1, _ = policy.target_model({
|
||||
"obs": next_obs,
|
||||
"is_training": True,
|
||||
}, [], None)
|
||||
|
||||
action_dist_class = _get_dist_class(policy.config, policy.action_space)
|
||||
action_dist_t = action_dist_class(
|
||||
model.get_policy_output(model_out_t), policy.model)
|
||||
policy_t = action_dist_t.sample() if not deterministic else \
|
||||
action_dist_t.deterministic_sample()
|
||||
log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
|
||||
|
||||
# Unlike original SAC, Alpha and Actor Loss are computed first.
|
||||
# Alpha Loss
|
||||
alpha_loss = -(model.log_alpha *
|
||||
(log_pis_t + model.target_entropy).detach()).mean()
|
||||
|
||||
# Policy Loss (Either Behavior Clone Loss or SAC Loss)
|
||||
alpha = torch.exp(model.log_alpha)
|
||||
if policy.cur_iter >= bc_iters:
|
||||
min_q = model.get_q_values(model_out_t, policy_t)
|
||||
if twin_q:
|
||||
twin_q = model.get_twin_q_values(model_out_t, policy_t)
|
||||
min_q = torch.min(min_q, twin_q)
|
||||
actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
|
||||
else:
|
||||
bc_logp = action_dist_t.logp(actions)
|
||||
actor_loss = (alpha * log_pis_t - bc_logp).mean()
|
||||
|
||||
# Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
|
||||
# SAC Loss
|
||||
action_dist_tp1 = action_dist_class(
|
||||
model.get_policy_output(model_out_tp1), policy.model)
|
||||
policy_tp1 = action_dist_tp1.sample() if not deterministic else \
|
||||
action_dist_tp1.deterministic_sample()
|
||||
|
||||
# Q-values for the batched actions.
|
||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
q_t = torch.squeeze(q_t, dim=-1)
|
||||
if twin_q:
|
||||
twin_q_t = model.get_twin_q_values(model_out_t,
|
||||
train_batch[SampleBatch.ACTIONS])
|
||||
twin_q_t = torch.squeeze(twin_q_t, dim=-1)
|
||||
|
||||
# Target q network evaluation.
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
||||
if twin_q:
|
||||
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
||||
target_model_out_tp1, policy_tp1)
|
||||
# Take min over both twin-NNs.
|
||||
q_tp1 = torch.min(q_tp1, twin_q_tp1)
|
||||
q_tp1 = torch.squeeze(input=q_tp1, dim=-1)
|
||||
q_tp1 = (1.0 - terminals.float()) * q_tp1
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_target = (
|
||||
rewards + (discount**policy.config["n_step"]) * q_tp1).detach()
|
||||
|
||||
# Compute the TD-error (potentially clipped), for priority replay buffer
|
||||
base_td_error = torch.abs(q_t - q_t_target)
|
||||
if twin_q:
|
||||
twin_td_error = torch.abs(twin_q_t - q_t_target)
|
||||
td_error = 0.5 * (base_td_error + twin_td_error)
|
||||
else:
|
||||
td_error = base_td_error
|
||||
critic_loss = [nn.MSELoss()(q_t, q_t_target)]
|
||||
if twin_q:
|
||||
critic_loss.append(nn.MSELoss()(twin_q_t, q_t_target))
|
||||
|
||||
# CQL Loss (We are using Entropy version of CQL (the best version))
|
||||
rand_actions = convert_to_torch_tensor(
|
||||
torch.FloatTensor(actions.shape[0] * num_actions,
|
||||
actions.shape[-1]).uniform_(action_low, action_high))
|
||||
curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
|
||||
obs, num_actions)
|
||||
next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
|
||||
next_obs, num_actions)
|
||||
curr_logp = curr_logp.view(actions.shape[0], num_actions, 1)
|
||||
next_logp = next_logp.view(actions.shape[0], num_actions, 1)
|
||||
|
||||
q1_rand = q_values_repeat(model, model_out_t, rand_actions)
|
||||
q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
|
||||
q1_next_actions = q_values_repeat(model, model_out_t, next_actions)
|
||||
|
||||
if twin_q:
|
||||
q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
|
||||
q2_curr_actions = q_values_repeat(
|
||||
model, model_out_t, curr_actions, twin=True)
|
||||
q2_next_actions = q_values_repeat(
|
||||
model, model_out_t, next_actions, twin=True)
|
||||
|
||||
random_density = np.log(0.5**curr_actions.shape[-1])
|
||||
cat_q1 = torch.cat([
|
||||
q1_rand - random_density, q1_next_actions - next_logp.detach(),
|
||||
q1_curr_actions - curr_logp.detach()
|
||||
], 1)
|
||||
if twin_q:
|
||||
cat_q2 = torch.cat([
|
||||
q2_rand - random_density, q2_next_actions - next_logp.detach(),
|
||||
q2_curr_actions - curr_logp.detach()
|
||||
], 1)
|
||||
|
||||
min_qf1_loss = torch.logsumexp(
|
||||
cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
|
||||
min_qf1_loss = min_qf1_loss - q_t.mean() * min_q_weight
|
||||
if twin_q:
|
||||
min_qf2_loss = torch.logsumexp(
|
||||
cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp
|
||||
min_qf2_loss = min_qf2_loss - twin_q_t.mean() * min_q_weight
|
||||
|
||||
if use_lagrange:
|
||||
alpha_prime = torch.clamp(
|
||||
model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0]
|
||||
min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
|
||||
if twin_q:
|
||||
min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
|
||||
alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
|
||||
else:
|
||||
alpha_prime_loss = -min_qf1_loss
|
||||
|
||||
cql_loss = [min_qf2_loss]
|
||||
if twin_q:
|
||||
cql_loss.append(min_qf2_loss)
|
||||
|
||||
critic_loss[0] += min_qf1_loss
|
||||
if twin_q:
|
||||
critic_loss[1] += min_qf2_loss
|
||||
|
||||
# Save for stats function.
|
||||
policy.q_t = q_t
|
||||
policy.policy_t = policy_t
|
||||
policy.log_pis_t = log_pis_t
|
||||
policy.td_error = td_error
|
||||
policy.actor_loss = actor_loss
|
||||
policy.critic_loss = critic_loss
|
||||
policy.alpha_loss = alpha_loss
|
||||
policy.log_alpha_value = model.log_alpha
|
||||
policy.alpha_value = alpha
|
||||
policy.target_entropy = model.target_entropy
|
||||
# CQL Stats
|
||||
policy.cql_loss = cql_loss
|
||||
if use_lagrange:
|
||||
policy.log_alpha_prime_value = model.log_alpha_prime[0]
|
||||
policy.alpha_prime_value = alpha_prime
|
||||
policy.alpha_prime_loss = alpha_prime_loss
|
||||
|
||||
# Return all loss terms corresponding to our optimizers.
|
||||
if use_lagrange:
|
||||
return tuple([policy.actor_loss] + policy.critic_loss +
|
||||
[policy.alpha_loss] + [policy.alpha_prime_loss])
|
||||
return tuple([policy.actor_loss] + policy.critic_loss +
|
||||
[policy.alpha_loss])
|
||||
|
||||
|
||||
def cql_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
sac_dict = stats(policy, train_batch)
|
||||
sac_dict["cql_loss"] = torch.mean(torch.stack(policy.cql_loss))
|
||||
if policy.config["lagrangian"]:
|
||||
sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value
|
||||
sac_dict["alpha_prime_value"] = policy.alpha_prime_value
|
||||
sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss
|
||||
return sac_dict
|
||||
|
||||
|
||||
def cql_optimizer_fn(policy: Policy, config: TrainerConfigDict) -> \
|
||||
Tuple[LocalOptimizer]:
|
||||
policy.cur_iter = 0
|
||||
opt_list = optimizer_fn(policy, config)
|
||||
if config["lagrangian"]:
|
||||
log_alpha_prime = nn.Parameter(
|
||||
torch.zeros(1, requires_grad=True).float())
|
||||
policy.model.register_parameter("log_alpha_prime", log_alpha_prime)
|
||||
policy.alpha_prime_optim = torch.optim.Adam(
|
||||
params=[policy.model.log_alpha_prime],
|
||||
lr=config["optimization"]["critic_learning_rate"],
|
||||
eps=1e-7, # to match tf.keras.optimizers.Adam's epsilon default
|
||||
)
|
||||
return tuple([policy.actor_optim] + policy.critic_optims +
|
||||
[policy.alpha_optim] + [policy.alpha_prime_optim])
|
||||
return opt_list
|
||||
|
||||
|
||||
def cql_setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
setup_late_mixins(policy, obs_space, action_space, config)
|
||||
if config["lagrangian"]:
|
||||
policy.model.log_alpha_prime = policy.model.log_alpha_prime.to(
|
||||
policy.device)
|
||||
|
||||
|
||||
# Build a child class of `TorchPolicy`, given the custom functions defined
|
||||
# above.
|
||||
CQLTorchPolicy = build_policy_class(
|
||||
name="CQLTorchPolicy",
|
||||
framework="torch",
|
||||
loss_fn=cql_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG,
|
||||
stats_fn=cql_stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=cql_optimizer_fn,
|
||||
validate_spaces=validate_spaces,
|
||||
before_loss_init=cql_setup_late_mixins,
|
||||
make_model_and_action_dist=build_sac_model_and_action_dist,
|
||||
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
|
||||
action_distribution_fn=action_distribution_fn,
|
||||
)
|
|
@ -40,6 +40,11 @@ def _import_bc():
|
|||
return marwil.BCTrainer
|
||||
|
||||
|
||||
def _import_cql():
|
||||
from ray.rllib.agents import cql
|
||||
return cql.CQLTrainer
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGTrainer
|
||||
|
@ -128,6 +133,7 @@ ALGORITHMS = {
|
|||
"APPO": _import_appo,
|
||||
"ARS": _import_ars,
|
||||
"BC": _import_bc,
|
||||
"CQL": _import_cql,
|
||||
"ES": _import_es,
|
||||
"DDPG": _import_ddpg,
|
||||
"DDPPO": _import_ddppo,
|
||||
|
|
|
@ -441,12 +441,12 @@ class TargetNetworkMixin:
|
|||
model_state_dict = self.model.state_dict()
|
||||
# Support partial (soft) synching.
|
||||
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
||||
if tau != 1.0:
|
||||
target_state_dict = self.target_model.state_dict()
|
||||
model_state_dict = {
|
||||
k: tau * model_state_dict[k] + (1 - tau) * v
|
||||
for k, v in target_state_dict.items()
|
||||
}
|
||||
target_state_dict = self.target_model.state_dict()
|
||||
model_state_dict = {
|
||||
k: tau * model_state_dict[k] + (1 - tau) * v
|
||||
for k, v in target_state_dict.items()
|
||||
}
|
||||
|
||||
self.target_model.load_state_dict(model_state_dict)
|
||||
|
||||
|
||||
|
|
45
rllib/tuned_examples/cql/halfcheetah-cql.yaml
Normal file
45
rllib/tuned_examples/cql/halfcheetah-cql.yaml
Normal file
|
@ -0,0 +1,45 @@
|
|||
halfcheetah_cql:
|
||||
env: HalfCheetah-v3
|
||||
run: CQL
|
||||
stop:
|
||||
episode_reward_mean: 9000
|
||||
config:
|
||||
# SAC Configs
|
||||
framework: torch
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
target_entropy: auto
|
||||
no_done_at_end: false
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
timesteps_per_iteration: 1000
|
||||
learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0003
|
||||
num_workers: 0
|
||||
num_gpus: 0
|
||||
clip_actions: false
|
||||
normalize_actions: true
|
||||
evaluation_interval: 1
|
||||
metrics_smoothing_episodes: 5
|
||||
# CQL Configs
|
||||
min_q_weight: 5.0
|
||||
bc_iters: 20000
|
||||
temperature: 1.0
|
||||
num_actions: 10
|
||||
lagrangian: False
|
||||
evaluation_config:
|
||||
input: sampler
|
||||
|
Loading…
Add table
Reference in a new issue