mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

The DDPG/TD3 algorithms currently do not have a PyTorch implementation. This PR adds PyTorch support for DDPG/TD3 to RLlib. This PR: - Depends on the re-factor PR for DDPG (Functional Algorithm API). - Adds learning regression tests for the PyTorch version of DDPG and a DDPG (torch) - Updates the documentation to reflect that DDPG and TD3 now support PyTorch. * Learning Pendulum-v0 on torch version (same config as tf). Wall time a little slower (~20% than tf). * Fix GPU target model problem.
268 lines
9.6 KiB
Python
268 lines
9.6 KiB
Python
from gym.spaces import Discrete
|
|
|
|
import ray
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
|
|
PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
|
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
|
|
from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|
from ray.rllib.policy.torch_policy import LearningRateSchedule
|
|
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
|
from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf
|
|
from ray.rllib.utils import try_import_torch
|
|
|
|
torch, nn = try_import_torch()
|
|
F = None
|
|
if nn:
|
|
F = torch.nn.functional
|
|
|
|
|
|
class QLoss:
|
|
def __init__(self,
|
|
q_t_selected,
|
|
q_tp1_best,
|
|
importance_weights,
|
|
rewards,
|
|
done_mask,
|
|
gamma=0.99,
|
|
n_step=1,
|
|
num_atoms=1,
|
|
v_min=-10.0,
|
|
v_max=10.0):
|
|
|
|
if num_atoms > 1:
|
|
raise ValueError("Torch version of DQN does not support "
|
|
"distributional Q yet!")
|
|
|
|
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
|
|
|
# compute RHS of bellman equation
|
|
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
|
|
|
# compute the error (potentially clipped)
|
|
self.td_error = q_t_selected - q_t_selected_target.detach()
|
|
self.loss = torch.mean(
|
|
importance_weights.float() * huber_loss(self.td_error))
|
|
self.stats = {
|
|
"mean_q": torch.mean(q_t_selected),
|
|
"min_q": torch.min(q_t_selected),
|
|
"max_q": torch.max(q_t_selected),
|
|
"td_error": self.td_error,
|
|
"mean_td_error": torch.mean(self.td_error),
|
|
}
|
|
|
|
|
|
class ComputeTDErrorMixin:
|
|
def __init__(self):
|
|
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
|
|
importance_weights):
|
|
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
|
|
input_dict[SampleBatch.ACTIONS] = act_t
|
|
input_dict[SampleBatch.REWARDS] = rew_t
|
|
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
|
|
input_dict[SampleBatch.DONES] = done_mask
|
|
input_dict[PRIO_WEIGHTS] = importance_weights
|
|
|
|
# Do forward pass on loss to update td error attribute
|
|
build_q_losses(self, self.model, None, input_dict)
|
|
|
|
return self.q_loss.td_error
|
|
|
|
self.compute_td_error = compute_td_error
|
|
|
|
|
|
def build_q_model_and_distribution(policy, obs_space, action_space, config):
|
|
|
|
if not isinstance(action_space, Discrete):
|
|
raise UnsupportedSpaceException(
|
|
"Action space {} is not supported for DQN.".format(action_space))
|
|
|
|
if config["hiddens"]:
|
|
# try to infer the last layer size, otherwise fall back to 256
|
|
num_outputs = ([256] + config["model"]["fcnet_hiddens"])[-1]
|
|
config["model"]["no_final_linear"] = True
|
|
else:
|
|
num_outputs = action_space.n
|
|
|
|
# TODO(sven): Move option to add LayerNorm after each Dense
|
|
# generically into ModelCatalog.
|
|
add_layer_norm = (
|
|
isinstance(getattr(policy, "exploration", None), ParameterNoise)
|
|
or config["exploration_config"]["type"] == "ParameterNoise")
|
|
|
|
policy.q_model = ModelCatalog.get_model_v2(
|
|
obs_space=obs_space,
|
|
action_space=action_space,
|
|
num_outputs=num_outputs,
|
|
model_config=config["model"],
|
|
framework="torch",
|
|
model_interface=DQNTorchModel,
|
|
name=Q_SCOPE,
|
|
dueling=config["dueling"],
|
|
q_hiddens=config["hiddens"],
|
|
use_noisy=config["noisy"],
|
|
sigma0=config["sigma0"],
|
|
# TODO(sven): Move option to add LayerNorm after each Dense
|
|
# generically into ModelCatalog.
|
|
add_layer_norm=add_layer_norm)
|
|
|
|
policy.q_func_vars = policy.q_model.variables()
|
|
|
|
policy.target_q_model = ModelCatalog.get_model_v2(
|
|
obs_space=obs_space,
|
|
action_space=action_space,
|
|
num_outputs=num_outputs,
|
|
model_config=config["model"],
|
|
framework="torch",
|
|
model_interface=DQNTorchModel,
|
|
name=Q_TARGET_SCOPE,
|
|
dueling=config["dueling"],
|
|
q_hiddens=config["hiddens"],
|
|
use_noisy=config["noisy"],
|
|
sigma0=config["sigma0"],
|
|
# TODO(sven): Move option to add LayerNorm after each Dense
|
|
# generically into ModelCatalog.
|
|
add_layer_norm=add_layer_norm)
|
|
|
|
policy.target_q_func_vars = policy.target_q_model.variables()
|
|
|
|
return policy.q_model, TorchCategorical
|
|
|
|
|
|
def get_distribution_inputs_and_class(policy,
|
|
model,
|
|
obs_batch,
|
|
*,
|
|
explore=True,
|
|
is_training=False,
|
|
**kwargs):
|
|
q_vals = compute_q_values(policy, model, obs_batch, explore, is_training)
|
|
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
|
|
|
policy.q_values = q_vals
|
|
return policy.q_values, TorchCategorical, [] # state-out
|
|
|
|
|
|
def build_q_losses(policy, model, _, train_batch):
|
|
config = policy.config
|
|
# q network evaluation
|
|
q_t = compute_q_values(
|
|
policy,
|
|
policy.q_model,
|
|
train_batch[SampleBatch.CUR_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# target q network evalution
|
|
q_tp1 = compute_q_values(
|
|
policy,
|
|
policy.target_q_model,
|
|
train_batch[SampleBatch.NEXT_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
|
|
# q scores for actions which we know were selected in the given state.
|
|
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS],
|
|
policy.action_space.n)
|
|
q_t_selected = torch.sum(q_t * one_hot_selection, 1)
|
|
|
|
# compute estimate of best possible value starting from state at t + 1
|
|
if config["double_q"]:
|
|
q_tp1_using_online_net = compute_q_values(
|
|
policy,
|
|
policy.q_model,
|
|
train_batch[SampleBatch.NEXT_OBS],
|
|
explore=False,
|
|
is_training=True)
|
|
q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1)
|
|
q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net,
|
|
policy.action_space.n)
|
|
q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
|
else:
|
|
q_tp1_best_one_hot_selection = F.one_hot(
|
|
torch.argmax(q_tp1, 1), policy.action_space.n)
|
|
q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
|
|
|
|
policy.q_loss = QLoss(q_t_selected, q_tp1_best, train_batch[PRIO_WEIGHTS],
|
|
train_batch[SampleBatch.REWARDS],
|
|
train_batch[SampleBatch.DONES].float(),
|
|
config["gamma"], config["n_step"],
|
|
config["num_atoms"], config["v_min"],
|
|
config["v_max"])
|
|
|
|
return policy.q_loss.loss
|
|
|
|
|
|
def adam_optimizer(policy, config):
|
|
return torch.optim.Adam(
|
|
policy.q_func_vars, lr=policy.cur_lr, eps=config["adam_epsilon"])
|
|
|
|
|
|
def build_q_stats(policy, batch):
|
|
return dict({
|
|
"cur_lr": policy.cur_lr,
|
|
}, **policy.q_loss.stats)
|
|
|
|
|
|
def setup_early_mixins(policy, obs_space, action_space, config):
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
|
|
|
|
|
def after_init(policy, obs_space, action_space, config):
|
|
ComputeTDErrorMixin.__init__(policy)
|
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
# Move target net to device (this is done autoatically for the
|
|
# policy.model, but not for any other models the policy has).
|
|
policy.target_q_model = policy.target_q_model.to(policy.device)
|
|
|
|
|
|
def compute_q_values(policy, model, obs, explore, is_training=False):
|
|
if policy.config["num_atoms"] > 1:
|
|
raise ValueError("torch DQN does not support distributional DQN yet!")
|
|
|
|
model_out, state = model({
|
|
SampleBatch.CUR_OBS: obs,
|
|
"is_training": is_training,
|
|
}, [], None)
|
|
|
|
advantages_or_q_values = model.get_advantages_or_q_values(model_out)
|
|
|
|
if policy.config["dueling"]:
|
|
state_value = model.get_state_value(model_out)
|
|
advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1)
|
|
advantages_centered = advantages_or_q_values - torch.unsqueeze(
|
|
advantages_mean, 1)
|
|
q_values = state_value + advantages_centered
|
|
else:
|
|
q_values = advantages_or_q_values
|
|
|
|
return q_values
|
|
|
|
|
|
def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist):
|
|
return {"q_values": policy.q_values}
|
|
|
|
|
|
DQNTorchPolicy = build_torch_policy(
|
|
name="DQNTorchPolicy",
|
|
loss_fn=build_q_losses,
|
|
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
|
|
make_model_and_action_dist=build_q_model_and_distribution,
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
|
stats_fn=build_q_stats,
|
|
postprocess_fn=postprocess_nstep_and_prio,
|
|
optimizer_fn=adam_optimizer,
|
|
extra_grad_process_fn=apply_grad_clipping,
|
|
extra_action_out_fn=extra_action_out_fn,
|
|
before_init=setup_early_mixins,
|
|
after_init=after_init,
|
|
mixins=[
|
|
TargetNetworkMixin,
|
|
ComputeTDErrorMixin,
|
|
LearningRateSchedule,
|
|
])
|