""" TensorFlow policy class used for CQL. """ from functools import partial import numpy as np import gym import logging from typing import Dict, List, Type, Union import ray import ray.experimental.tf_utils from ray.rllib.agents.sac.sac_tf_policy import ( apply_gradients as sac_apply_gradients, compute_and_clip_gradients as sac_compute_and_clip_gradients, get_distribution_inputs_and_class, _get_dist_class, build_sac_model, postprocess_trajectory, setup_late_mixins, stats, validate_spaces, ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, ComputeTDErrorMixin, TargetNetworkMixin, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp from ray.rllib.utils.typing import ( LocalOptimizer, ModelGradients, TensorType, TrainerConfigDict, ) tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() logger = logging.getLogger(__name__) MEAN_MIN = -9.0 MEAN_MAX = 9.0 # Returns policy tiled actions and log probabilities for CQL Loss def policy_actions_repeat(model, action_dist, obs, num_repeat=1): obs_temp = tf.reshape( tf.tile(tf.expand_dims(obs, 1), [1, num_repeat, 1]), [-1, obs.shape[1]] ) logits = model.get_policy_output(obs_temp) policy_dist = action_dist(logits, model) actions, logp_ = policy_dist.sample_logp() logp = tf.expand_dims(logp_, -1) return actions, tf.reshape(logp, [tf.shape(obs)[0], num_repeat, 1]) def q_values_repeat(model, obs, actions, twin=False): action_shape = tf.shape(actions)[0] obs_shape = tf.shape(obs)[0] num_repeat = action_shape // obs_shape obs_temp = tf.reshape( tf.tile(tf.expand_dims(obs, 1), [1, num_repeat, 1]), [-1, tf.shape(obs)[1]] ) if not twin: preds_ = model.get_q_values(obs_temp, actions) else: preds_ = model.get_twin_q_values(obs_temp, actions) preds = tf.reshape(preds_, [tf.shape(obs)[0], num_repeat, 1]) return preds def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=next_obs, _is_training=True), [], None ) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = tf.expand_dims(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy) ) # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = tf.math.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = tf.math.minimum(min_q, twin_q_) actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - min_q) else: bc_logp = action_dist_t.logp(actions) actor_loss = tf.reduce_mean(tf.stop_gradient(alpha) * log_pis_t - bc_logp) # actor_loss = -tf.reduce_mean(bc_logp) # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, actions) twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1 ) # Take min over both twin-NNs. q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best # compute RHS of bellman equation q_t_target = tf.stop_gradient( rewards + (discount ** policy.config["n_step"]) * q_tp1_best_masked ) # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = tf.math.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) if twin_q: critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions, _ = policy._random_action_generator.get_exploration_action( action_distribution=action_dist_class( tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model ), timestep=0, explore=True, ) curr_actions, curr_logp = policy_actions_repeat( model, action_dist_class, model_out_t, num_actions ) next_actions, next_logp = policy_actions_repeat( model, action_dist_class, model_out_tp1, num_actions ) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5 ** int(curr_actions.shape[-1])) cat_q1 = tf.concat( [ q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), q1_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) if twin_q: cat_q2 = tf.concat( [ q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), q2_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) min_qf1_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp ) min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) if twin_q: min_qf2_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp ) min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) if use_lagrange: alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + alpha_prime_loss return actor_loss + tf.math.add_n(critic_loss) + alpha_loss def cql_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: sac_dict = stats(policy, train_batch) sac_dict["cql_loss"] = tf.reduce_mean(tf.stack(policy.cql_loss)) if policy.config["lagrangian"]: sac_dict["log_alpha_prime_value"] = policy.log_alpha_prime_value sac_dict["alpha_prime_value"] = policy.alpha_prime_value sac_dict["alpha_prime_loss"] = policy.alpha_prime_loss return sac_dict class ActorCriticOptimizerMixin(SACActorCriticOptimizerMixin): def __init__(self, config): super().__init__(config) if config["lagrangian"]: # Eager mode. if config["framework"] in ["tf2", "tfe"]: self._alpha_prime_optimizer = tf.keras.optimizers.Adam( learning_rate=config["optimization"]["critic_learning_rate"] ) # Static graph mode. else: self._alpha_prime_optimizer = tf1.train.AdamOptimizer( learning_rate=config["optimization"]["critic_learning_rate"] ) def setup_early_mixins( policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, ) -> None: """Call mixin classes' constructors before Policy's initialization. Adds the necessary optimizers to the given Policy. Args: policy (Policy): The Policy object. obs_space (gym.spaces.Space): The Policy's observation space. action_space (gym.spaces.Space): The Policy's action space. config (TrainerConfigDict): The Policy's config. """ policy.cur_iter = 0 ActorCriticOptimizerMixin.__init__(policy, config) if config["lagrangian"]: policy.model.log_alpha_prime = get_variable( 0.0, framework="tf", trainable=True, tf_name="log_alpha_prime" ) policy.alpha_prime_optim = tf.keras.optimizers.Adam( learning_rate=config["optimization"]["critic_learning_rate"], ) # Generic random action generator for calculating CQL-loss. policy._random_action_generator = Random( action_space, model=None, framework="tf2", policy_config=config, num_workers=0, worker_index=0, ) def compute_gradients_fn( policy: Policy, optimizer: LocalOptimizer, loss: TensorType ) -> ModelGradients: grads_and_vars = sac_compute_and_clip_gradients(policy, optimizer, loss) if policy.config["lagrangian"]: # Eager: Use GradientTape (which is a property of the `optimizer` # object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). if policy.config["framework"] in ["tf2", "tfe"]: tape = optimizer.tape log_alpha_prime = [policy.model.log_alpha_prime] alpha_prime_grads_and_vars = list( zip( tape.gradient(policy.alpha_prime_loss, log_alpha_prime), log_alpha_prime, ) ) # Tf1.x: Use optimizer.compute_gradients() else: alpha_prime_grads_and_vars = ( policy._alpha_prime_optimizer.compute_gradients( policy.alpha_prime_loss, var_list=[policy.model.log_alpha_prime] ) ) # Clip if necessary. if policy.config["grad_clip"]: clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) else: clip_func = tf.identity # Save grads and vars for later use in `build_apply_op`. policy._alpha_prime_grads_and_vars = [ (clip_func(g), v) for (g, v) in alpha_prime_grads_and_vars if g is not None ] grads_and_vars += policy._alpha_prime_grads_and_vars return grads_and_vars def apply_gradients_fn(policy, optimizer, grads_and_vars): sac_results = sac_apply_gradients(policy, optimizer, grads_and_vars) if policy.config["lagrangian"]: # Eager mode -> Just apply and return None. if policy.config["framework"] in ["tf2", "tfe"]: policy._alpha_prime_optimizer.apply_gradients( policy._alpha_prime_grads_and_vars ) return # Tf static graph -> Return grouped op. else: alpha_prime_apply_op = policy._alpha_prime_optimizer.apply_gradients( policy._alpha_prime_grads_and_vars, global_step=tf1.train.get_or_create_global_step(), ) return tf.group([sac_results, alpha_prime_apply_op]) return sac_results # Build a child class of `TFPolicy`, given the custom functions defined # above. CQLTFPolicy = build_tf_policy( name="CQLTFPolicy", loss_fn=cql_loss, get_default_config=lambda: ray.rllib.agents.cql.cql.CQL_DEFAULT_CONFIG, validate_spaces=validate_spaces, stats_fn=cql_stats, postprocess_fn=postprocess_trajectory, before_init=setup_early_mixins, after_init=setup_late_mixins, make_model=build_sac_model, mixins=[ActorCriticOptimizerMixin, TargetNetworkMixin, ComputeTDErrorMixin], action_distribution_fn=get_distribution_inputs_and_class, compute_gradients_fn=compute_gradients_fn, apply_gradients_fn=apply_gradients_fn, )