from gym.spaces import Discrete import ray from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy_template import build_torch_policy from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf from ray.rllib.utils import try_import_torch torch, nn = try_import_torch() F = None if nn: F = torch.nn.functional class QLoss: def __init__(self, q_t_selected, q_tp1_best, importance_weights, rewards, done_mask, gamma=0.99, n_step=1, num_atoms=1, v_min=-10.0, v_max=10.0): if num_atoms > 1: raise ValueError("Torch version of DQN does not support " "distributional Q yet!") q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked # compute the error (potentially clipped) self.td_error = q_t_selected - q_t_selected_target.detach() self.loss = torch.mean( importance_weights.float() * huber_loss(self.td_error)) self.stats = { "mean_q": torch.mean(q_t_selected), "min_q": torch.min(q_t_selected), "max_q": torch.max(q_t_selected), "td_error": self.td_error, "mean_td_error": torch.mean(self.td_error), } class ComputeTDErrorMixin: def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t}) input_dict[SampleBatch.ACTIONS] = act_t input_dict[SampleBatch.REWARDS] = rew_t input_dict[SampleBatch.NEXT_OBS] = obs_tp1 input_dict[SampleBatch.DONES] = done_mask input_dict[PRIO_WEIGHTS] = importance_weights # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) return self.q_loss.td_error self.compute_td_error = compute_td_error def build_q_model_and_distribution(policy, obs_space, action_space, config): if not isinstance(action_space, Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) if config["hiddens"]: # try to infer the last layer size, otherwise fall back to 256 num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1] config["model"]["no_final_linear"] = True else: num_outputs = action_space.n # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm = ( isinstance(getattr(policy, "exploration", None), ParameterNoise) or config["exploration_config"]["type"] == "ParameterNoise") policy.q_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework="torch", model_interface=DQNTorchModel, name=Q_SCOPE, dueling=config["dueling"], q_hiddens=config["hiddens"], use_noisy=config["noisy"], sigma0=config["sigma0"], # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm=add_layer_norm) policy.q_func_vars = policy.q_model.variables() policy.target_q_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework="torch", model_interface=DQNTorchModel, name=Q_TARGET_SCOPE, dueling=config["dueling"], q_hiddens=config["hiddens"], use_noisy=config["noisy"], sigma0=config["sigma0"], # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm=add_layer_norm) policy.target_q_func_vars = policy.target_q_model.variables() return policy.q_model, TorchCategorical def get_distribution_inputs_and_class(policy, q_model, obs_batch, *, explore=True, is_training=False, **kwargs): q_vals = compute_q_values(policy, q_model, obs_batch, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals return policy.q_values, TorchCategorical, [] # state-out def build_q_losses(policy, model, _, train_batch): config = policy.config # q network evaluation q_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], explore=False, is_training=False) # target q network evalution q_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=False) # q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS], policy.action_space.n) q_t_selected = torch.sum(q_t * one_hot_selection, 1) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: q_tp1_using_online_net = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=False) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net, policy.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) else: q_tp1_best_one_hot_selection = F.one_hot( torch.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) policy.q_loss = QLoss(q_t_selected, q_tp1_best, train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"]) return policy.q_loss.loss def adam_optimizer(policy, config): return torch.optim.Adam( policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"]) def build_q_stats(policy, batch): return dict({ "cur_lr": policy.cur_lr, }, **policy.q_loss.stats) def setup_early_mixins(policy, obs_space, action_space, config): LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) def after_init(policy, obs_space, action_space, config): ComputeTDErrorMixin.__init__(policy) TargetNetworkMixin.__init__(policy, obs_space, action_space, config) # Move target net to device (this is done autoatically for the # policy.model, but not for any other models the policy has). policy.target_q_model = policy.target_q_model.to(policy.device) def compute_q_values(policy, model, obs, explore, is_training=False): if policy.config["num_atoms"] > 1: raise ValueError("torch DQN does not support distributional DQN yet!") model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": is_training, }, [], None) advantages_or_q_values = model.get_advantages_or_q_values(model_out) if policy.config["dueling"]: state_value = model.get_state_value(model_out) advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1) advantages_centered = advantages_or_q_values - torch.unsqueeze( advantages_mean, 1) q_values = state_value + advantages_centered else: q_values = advantages_or_q_values return q_values def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist): return {"q_values": policy.q_values} DQNTorchPolicy = build_torch_policy( name="DQNTorchPolicy", loss_fn=build_q_losses, get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, make_model_and_action_dist=build_q_model_and_distribution, action_distribution_fn=get_distribution_inputs_and_class, stats_fn=build_q_stats, postprocess_fn=postprocess_nstep_and_prio, optimizer_fn=adam_optimizer, extra_grad_process_fn=apply_grad_clipping, extra_action_out_fn=extra_action_out_fn, before_init=setup_early_mixins, after_init=after_init, mixins=[ TargetNetworkMixin, ComputeTDErrorMixin, LearningRateSchedule, ])