import logging import gym from typing import Dict, Tuple, List, Optional, Any, Type import ray from ray.rllib.algorithms.dqn.dqn_tf_policy import ( postprocess_nstep_and_prio, PRIO_WEIGHTS, ) from ray.rllib.evaluation import Episode from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( TorchDeterministic, TorchDirichlet, TorchDistributionWrapper, ) from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.torch_utils import ( apply_grad_clipping, concat_multi_gpu_td_errors, huber_loss, l2_loss, ) from ray.rllib.utils.typing import ( ModelGradients, TensorType, AlgorithmConfigDict, ) from ray.rllib.algorithms.ddpg.utils import make_ddpg_models, validate_spaces torch, nn = try_import_torch() logger = logging.getLogger(__name__) class ComputeTDErrorMixin: def __init__(self: TorchPolicyV2): def compute_td_error( obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights ): input_dict = self._lazy_tensor_dict( SampleBatch( { SampleBatch.CUR_OBS: obs_t, SampleBatch.ACTIONS: act_t, SampleBatch.REWARDS: rew_t, SampleBatch.NEXT_OBS: obs_tp1, SampleBatch.DONES: done_mask, PRIO_WEIGHTS: importance_weights, } ) ) # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). self.loss(self.model, None, input_dict) # `self.model.td_error` is set within actor_critic_loss call. return self.model.tower_stats["td_error"] self.compute_td_error = compute_td_error class TargetNetworkMixin: """Mixin class adding a method for (soft) target net(s) synchronizations. - Adds the `update_target` method to the policy. Calling `update_target` updates all target Q-networks' weights from their respective "main" Q-metworks, based on tau (smooth, partial updating). """ def __init__(self): # Hard initial update from Q-net(s) to target Q-net(s). self.update_target(tau=1.0) def update_target(self: TorchPolicyV2, tau=None): # Update_target_fn will be called periodically to copy Q network to # target Q network, using (soft) tau-synching. tau = tau or self.config.get("tau") model_state_dict = self.model.state_dict() # Support partial (soft) synching. # If tau == 1.0: Full sync from Q-model to target Q-model. target_state_dict = next(iter(self.target_models.values())).state_dict() model_state_dict = { k: tau * model_state_dict[k] + (1 - tau) * v for k, v in target_state_dict.items() } for target in self.target_models.values(): target.load_state_dict(model_state_dict) @override(TorchPolicyV2) def set_weights(self: TorchPolicyV2, weights): # Makes sure that whenever we restore weights for this policy's # model, we sync the target network (from the main model) # at the same time. TorchPolicyV2.set_weights(self, weights) self.update_target() class DDPGTorchPolicy(TargetNetworkMixin, ComputeTDErrorMixin, TorchPolicyV2): def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: AlgorithmConfigDict, ): config = dict(ray.rllib.algorithms.ddpg.ddpg.DDPGConfig().to_dict(), **config) # Create global step for counting the number of update operations. self.global_step = 0 # Validate action space for DDPG validate_spaces(self, observation_space, action_space) TorchPolicyV2.__init__( self, observation_space, action_space, config, max_seq_len=config["model"]["max_seq_len"], ) ComputeTDErrorMixin.__init__(self) # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() TargetNetworkMixin.__init__(self) @override(TorchPolicyV2) def make_model_and_action_dist( self, ) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]: model = make_ddpg_models(self) if isinstance(self.action_space, Simplex): distr_class = TorchDirichlet else: distr_class = TorchDeterministic return model, distr_class @override(TorchPolicyV2) def optimizer( self, ) -> List["torch.optim.Optimizer"]: """Create separate optimizers for actor & critic losses.""" # Set epsilons to match tf.keras.optimizers.Adam's epsilon default. self._actor_optimizer = torch.optim.Adam( params=self.model.policy_variables(), lr=self.config["actor_lr"], eps=1e-7 ) self._critic_optimizer = torch.optim.Adam( params=self.model.q_variables(), lr=self.config["critic_lr"], eps=1e-7 ) # Return them in the same order as the respective loss terms are returned. return [self._actor_optimizer, self._critic_optimizer] @override(TorchPolicyV2) def apply_gradients(self, gradients: ModelGradients) -> None: # For policy gradient, update policy net one time v.s. # update critic net `policy_delay` time(s). if self.global_step % self.config["policy_delay"] == 0: self._actor_optimizer.step() self._critic_optimizer.step() # Increment global step & apply ops. self.global_step += 1 @override(TorchPolicyV2) def action_distribution_fn( self, model: ModelV2, *, obs_batch: TensorType, state_batches: TensorType, is_training: bool = False, **kwargs ) -> Tuple[TensorType, type, List[TensorType]]: model_out, _ = model( SampleBatch(obs=obs_batch[SampleBatch.CUR_OBS], _is_training=is_training) ) dist_inputs = model.get_policy_output(model_out) if isinstance(self.action_space, Simplex): distr_class = TorchDirichlet else: distr_class = TorchDeterministic return dist_inputs, distr_class, [] # []=state out @override(TorchPolicyV2) def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, episode: Optional[Episode] = None, ) -> SampleBatch: return postprocess_nstep_and_prio( self, sample_batch, other_agent_batches, episode ) @override(TorchPolicyV2) def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> List[TensorType]: target_model = self.target_models[model] twin_q = self.config["twin_q"] gamma = self.config["gamma"] n_step = self.config["n_step"] use_huber = self.config["use_huber"] huber_threshold = self.config["huber_threshold"] l2_reg = self.config["l2_reg"] input_dict = SampleBatch( obs=train_batch[SampleBatch.CUR_OBS], _is_training=True ) input_dict_next = SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True ) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = target_model(input_dict_next, [], None) # Policy network evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) policy_t = model.get_policy_output(model_out_t) # policy_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) policy_tp1 = target_model.get_policy_output(target_model_out_tp1) # Action outputs. if self.config["smooth_target_policy"]: target_noise_clip = self.config["target_noise_clip"] clipped_normal_sample = torch.clamp( torch.normal( mean=torch.zeros(policy_tp1.size()), std=self.config["target_noise"] ).to(policy_tp1.device), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = torch.min( torch.max( policy_tp1 + clipped_normal_sample, torch.tensor( self.action_space.low, dtype=torch.float32, device=policy_tp1.device, ), ), torch.tensor( self.action_space.high, dtype=torch.float32, device=policy_tp1.device, ), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) actor_loss = -torch.mean(q_t_det_policy) if twin_q: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS] ) # q_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # Target q-net(s) evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1_smoothed ) q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = ( train_batch[SampleBatch.REWARDS] + gamma ** n_step * q_tp1_best_masked ).detach() # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold ) else: errors = 0.5 * ( torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0) ) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * torch.pow(td_error, 2.0) critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors) # Add l2-regularization if required. if l2_reg is not None: for name, var in model.policy_variables(as_dict=True).items(): if "bias" not in name: actor_loss += l2_reg * l2_loss(var) for name, var in model.q_variables(as_dict=True).items(): if "bias" not in name: critic_loss += l2_reg * l2_loss(var) # Model self-supervised losses. if self.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). return [actor_loss, critic_loss] @override(TorchPolicyV2) def extra_grad_process( self, optimizer: torch.optim.Optimizer, loss: TensorType ) -> Dict[str, TensorType]: # Clip grads if configured. return apply_grad_clipping(self, optimizer, loss) @override(TorchPolicyV2) def extra_compute_grad_fetches(self) -> Dict[str, Any]: fetches = convert_to_numpy(concat_multi_gpu_td_errors(self)) return dict({LEARNER_STATS_KEY: {}}, **fetches) @override(TorchPolicyV2) def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: q_t = torch.stack(self.get_tower_stats("q_t")) stats = { "actor_loss": torch.mean(torch.stack(self.get_tower_stats("actor_loss"))), "critic_loss": torch.mean(torch.stack(self.get_tower_stats("critic_loss"))), "mean_q": torch.mean(q_t), "max_q": torch.max(q_t), "min_q": torch.min(q_t), } return convert_to_numpy(stats)