mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Enable Twin Delayed DDPG for RLlib DDPG agent (#3353)
This commit is contained in:
parent
6b3236349c
commit
24bfe8ab76
6 changed files with 242 additions and 56 deletions
|
@ -16,6 +16,22 @@ OPTIMIZER_SHARED_CONFIGS = [
|
|||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
|
||||
# TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
|
||||
# twin Q-net
|
||||
"twin_q": False,
|
||||
# delayed policy update
|
||||
"policy_delay": 1,
|
||||
# target policy smoothing
|
||||
# this also forces the use of gaussian instead of OU noise for exploration
|
||||
"smooth_target_policy": False,
|
||||
# gaussian stddev of act noise
|
||||
"act_noise": 0.1,
|
||||
# gaussian stddev of target noise
|
||||
"target_noise": 0.2,
|
||||
# target noise limit (bound)
|
||||
"noise_clip": 0.5,
|
||||
|
||||
# === Model ===
|
||||
# Hidden layer sizes of the policy network
|
||||
"actor_hiddens": [64, 64],
|
||||
|
@ -67,9 +83,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"compress_observations": False,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
"actor_lr": 1e-4,
|
||||
"critic_lr": 1e-3,
|
||||
# Learning rate for adam optimizer.
|
||||
# Instead of using two optimizers, we use two different loss coefficients
|
||||
"lr": 1e-3,
|
||||
"actor_loss_coeff": 0.1,
|
||||
"critic_loss_coeff": 1.0,
|
||||
# If True, use huber loss instead of squared loss for critic network
|
||||
# Conventionally, no need to clip gradients if using a huber loss
|
||||
"use_huber": False,
|
||||
|
|
|
@ -19,6 +19,8 @@ P_SCOPE = "p_func"
|
|||
P_TARGET_SCOPE = "target_p_func"
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
TWIN_Q_SCOPE = "twin_q_func"
|
||||
TWIN_Q_TARGET_SCOPE = "twin_target_q_func"
|
||||
|
||||
|
||||
class PNetwork(object):
|
||||
|
@ -50,24 +52,45 @@ class ActionNetwork(object):
|
|||
stochastic,
|
||||
eps,
|
||||
theta=0.15,
|
||||
sigma=0.2):
|
||||
sigma=0.2,
|
||||
use_gaussian_noise=False,
|
||||
act_noise=0.1,
|
||||
is_target=False,
|
||||
target_noise=0.2,
|
||||
noise_clip=0.5):
|
||||
|
||||
# shape is [None, dim_action]
|
||||
deterministic_actions = (
|
||||
(high_action - low_action) * p_values + low_action)
|
||||
|
||||
exploration_sample = tf.get_variable(
|
||||
name="ornstein_uhlenbeck",
|
||||
dtype=tf.float32,
|
||||
initializer=low_action.size * [.0],
|
||||
trainable=False)
|
||||
normal_sample = tf.random_normal(
|
||||
shape=[low_action.size], mean=0.0, stddev=1.0)
|
||||
exploration_value = tf.assign_add(
|
||||
exploration_sample,
|
||||
theta * (.0 - exploration_sample) + sigma * normal_sample)
|
||||
stochastic_actions = deterministic_actions + eps * (
|
||||
high_action - low_action) * exploration_value
|
||||
if use_gaussian_noise:
|
||||
if is_target:
|
||||
normal_sample = tf.random_normal(
|
||||
tf.shape(deterministic_actions), stddev=target_noise)
|
||||
normal_sample = tf.clip_by_value(normal_sample, -noise_clip,
|
||||
noise_clip)
|
||||
stochastic_actions = tf.clip_by_value(
|
||||
deterministic_actions + normal_sample, low_action,
|
||||
high_action)
|
||||
else:
|
||||
normal_sample = tf.random_normal(
|
||||
tf.shape(deterministic_actions), stddev=act_noise)
|
||||
stochastic_actions = tf.clip_by_value(
|
||||
deterministic_actions + normal_sample, low_action,
|
||||
high_action)
|
||||
else:
|
||||
exploration_sample = tf.get_variable(
|
||||
name="ornstein_uhlenbeck",
|
||||
dtype=tf.float32,
|
||||
initializer=low_action.size * [.0],
|
||||
trainable=False)
|
||||
normal_sample = tf.random_normal(
|
||||
shape=[low_action.size], mean=0.0, stddev=1.0)
|
||||
exploration_value = tf.assign_add(
|
||||
exploration_sample,
|
||||
theta * (.0 - exploration_sample) + sigma * normal_sample)
|
||||
stochastic_actions = deterministic_actions + eps * (
|
||||
high_action - low_action) * exploration_value
|
||||
|
||||
self.actions = tf.cond(stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
|
@ -97,12 +120,21 @@ class ActorCriticLoss(object):
|
|||
importance_weights,
|
||||
rewards,
|
||||
done_mask,
|
||||
twin_q_t,
|
||||
twin_q_tp1,
|
||||
actor_loss_coeff=0.1,
|
||||
critic_loss_coeff=1.0,
|
||||
gamma=0.99,
|
||||
n_step=1,
|
||||
use_huber=False,
|
||||
huber_threshold=1.0):
|
||||
huber_threshold=1.0,
|
||||
twin_q=False,
|
||||
policy_delay=1):
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if twin_q:
|
||||
twin_q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best
|
||||
|
@ -111,16 +143,36 @@ class ActorCriticLoss(object):
|
|||
q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
self.td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
if use_huber:
|
||||
errors = _huber_loss(self.td_error, huber_threshold)
|
||||
if twin_q:
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
twin_td_error = twin_q_t_selected - tf.stop_gradient(
|
||||
q_t_selected_target)
|
||||
self.td_error = td_error + twin_td_error
|
||||
if use_huber:
|
||||
errors = _huber_loss(td_error, huber_threshold) + _huber_loss(
|
||||
twin_td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(
|
||||
twin_td_error)
|
||||
else:
|
||||
errors = 0.5 * tf.square(self.td_error)
|
||||
self.td_error = (
|
||||
q_t_selected - tf.stop_gradient(q_t_selected_target))
|
||||
if use_huber:
|
||||
errors = _huber_loss(self.td_error, huber_threshold)
|
||||
else:
|
||||
errors = 0.5 * tf.square(self.td_error)
|
||||
|
||||
self.critic_loss = tf.reduce_mean(importance_weights * errors)
|
||||
self.critic_loss = critic_loss_coeff * tf.reduce_mean(
|
||||
importance_weights * errors)
|
||||
|
||||
# for policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
policy_delay_mask = tf.to_float(
|
||||
tf.equal(tf.mod(global_step, policy_delay), 0))
|
||||
self.actor_loss = (-1.0 * actor_loss_coeff * policy_delay_mask *
|
||||
tf.reduce_mean(q_tp0))
|
||||
|
||||
# for policy gradient
|
||||
self.actor_loss = -1.0 * tf.reduce_mean(q_tp0)
|
||||
self.total_loss = self.actor_loss + self.critic_loss
|
||||
|
||||
|
||||
|
@ -137,10 +189,9 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
self.dim_actions = action_space.shape[0]
|
||||
self.low_action = action_space.low
|
||||
self.high_action = action_space.high
|
||||
self.actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["actor_lr"])
|
||||
self.critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["critic_lr"])
|
||||
|
||||
# create global step for counting the number of update operations
|
||||
self.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# Action inputs
|
||||
self.stochastic = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
|
@ -159,10 +210,13 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
self.output_actions = self._build_action_network(
|
||||
p_values, self.stochastic, self.eps)
|
||||
|
||||
with tf.variable_scope(A_SCOPE, reuse=True):
|
||||
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
|
||||
self.reset_noise_op = tf.assign(exploration_sample,
|
||||
self.dim_actions * [.0])
|
||||
if self.config["smooth_target_policy"]:
|
||||
self.reset_noise_op = tf.no_op()
|
||||
else:
|
||||
with tf.variable_scope(A_SCOPE, reuse=True):
|
||||
exploration_sample = tf.get_variable(name="ornstein_uhlenbeck")
|
||||
self.reset_noise_op = tf.assign(exploration_sample,
|
||||
self.dim_actions * [.0])
|
||||
|
||||
# Replay inputs
|
||||
self.obs_t = tf.placeholder(
|
||||
|
@ -189,13 +243,16 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
|
||||
# Action outputs
|
||||
with tf.variable_scope(A_SCOPE, reuse=True):
|
||||
deterministic_flag = tf.constant(value=False, dtype=tf.bool)
|
||||
zero_eps = tf.constant(value=.0, dtype=tf.float32)
|
||||
output_actions = self._build_action_network(
|
||||
self.p_t, deterministic_flag, zero_eps)
|
||||
|
||||
self.p_t,
|
||||
stochastic=tf.constant(value=False, dtype=tf.bool),
|
||||
eps=.0)
|
||||
output_actions_estimated = self._build_action_network(
|
||||
p_tp1, deterministic_flag, zero_eps)
|
||||
p_tp1,
|
||||
stochastic=tf.constant(
|
||||
value=self.config["smooth_target_policy"], dtype=tf.bool),
|
||||
eps=.0,
|
||||
is_target=True)
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
|
@ -205,14 +262,28 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
if self.config["twin_q"]:
|
||||
with tf.variable_scope(TWIN_Q_SCOPE) as scope:
|
||||
twin_q_t, twin_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.twin_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
|
||||
output_actions_estimated)
|
||||
target_q_func_vars = _scope_vars(scope.name)
|
||||
if self.config["twin_q"]:
|
||||
with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope:
|
||||
twin_q_tp1, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space, output_actions_estimated)
|
||||
twin_target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0)
|
||||
if self.config["twin_q"]:
|
||||
self.loss = self._build_actor_critic_loss(
|
||||
q_t, q_tp1, q_tp0, twin_q_t=twin_q_t, twin_q_tp1=twin_q_tp1)
|
||||
else:
|
||||
self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0)
|
||||
|
||||
if config["l2_reg"] is not None:
|
||||
for var in self.p_func_vars:
|
||||
|
@ -223,6 +294,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
if "bias" not in var.name:
|
||||
self.loss.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
if self.config["twin_q"]:
|
||||
for var in self.twin_q_func_vars:
|
||||
if "bias" not in var.name:
|
||||
self.loss.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
|
@ -235,6 +311,13 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
update_target_expr.append(
|
||||
var_target.assign(self.tau * var +
|
||||
(1.0 - self.tau) * var_target))
|
||||
if self.config["twin_q"]:
|
||||
for var, var_target in zip(
|
||||
sorted(self.twin_q_func_vars, key=lambda v: v.name),
|
||||
sorted(twin_target_q_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(
|
||||
var_target.assign(self.tau * var +
|
||||
(1.0 - self.tau) * var_target))
|
||||
for var, var_target in zip(
|
||||
sorted(self.p_func_vars, key=lambda v: v.name),
|
||||
sorted(target_p_func_vars, key=lambda v: v.name)):
|
||||
|
@ -288,34 +371,52 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
|||
self.config["actor_hiddens"],
|
||||
self.config["actor_hidden_activation"]).action_scores
|
||||
|
||||
def _build_action_network(self, p_values, stochastic, eps):
|
||||
return ActionNetwork(p_values, self.low_action, self.high_action,
|
||||
stochastic, eps, self.config["exploration_theta"],
|
||||
self.config["exploration_sigma"]).actions
|
||||
def _build_action_network(self, p_values, stochastic, eps,
|
||||
is_target=False):
|
||||
return ActionNetwork(
|
||||
p_values, self.low_action, self.high_action, stochastic, eps,
|
||||
self.config["exploration_theta"], self.config["exploration_sigma"],
|
||||
self.config["smooth_target_policy"], self.config["act_noise"],
|
||||
is_target, self.config["target_noise"],
|
||||
self.config["noise_clip"]).actions
|
||||
|
||||
def _build_actor_critic_loss(self, q_t, q_tp1, q_tp0):
|
||||
def _build_actor_critic_loss(self,
|
||||
q_t,
|
||||
q_tp1,
|
||||
q_tp0,
|
||||
twin_q_t=None,
|
||||
twin_q_tp1=None):
|
||||
return ActorCriticLoss(
|
||||
q_t, q_tp1, q_tp0, self.importance_weights, self.rew_t,
|
||||
self.done_mask, self.config["gamma"], self.config["n_step"],
|
||||
self.config["use_huber"], self.config["huber_threshold"])
|
||||
self.done_mask, twin_q_t, twin_q_tp1,
|
||||
self.config["actor_loss_coeff"], self.config["critic_loss_coeff"],
|
||||
self.config["gamma"], self.config["n_step"],
|
||||
self.config["use_huber"], self.config["huber_threshold"],
|
||||
self.config["twin_q"])
|
||||
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
|
||||
def gradients(self, optimizer):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
actor_grads_and_vars = _minimize_and_clip(
|
||||
self.actor_optimizer,
|
||||
optimizer,
|
||||
self.loss.actor_loss,
|
||||
var_list=self.p_func_vars,
|
||||
clip_val=self.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars = _minimize_and_clip(
|
||||
self.critic_optimizer,
|
||||
optimizer,
|
||||
self.loss.critic_loss,
|
||||
var_list=self.q_func_vars,
|
||||
var_list=self.q_func_vars + self.twin_q_func_vars
|
||||
if self.config["twin_q"] else self.q_func_vars,
|
||||
clip_val=self.config["grad_norm_clipping"])
|
||||
else:
|
||||
actor_grads_and_vars = self.actor_optimizer.compute_gradients(
|
||||
actor_grads_and_vars = optimizer.compute_gradients(
|
||||
self.loss.actor_loss, var_list=self.p_func_vars)
|
||||
critic_grads_and_vars = self.critic_optimizer.compute_gradients(
|
||||
self.loss.critic_loss, var_list=self.q_func_vars)
|
||||
critic_grads_and_vars = optimizer.compute_gradients(
|
||||
self.loss.critic_loss,
|
||||
var_list=self.q_func_vars + self.twin_q_func_vars
|
||||
if self.config["twin_q"] else self.q_func_vars)
|
||||
actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
||||
if g is not None]
|
||||
critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
|
||||
|
|
|
@ -102,12 +102,17 @@ class TFPolicyGraph(PolicyGraph):
|
|||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [(g, v)
|
||||
for (g, v) in self.gradients(self._optimizer)
|
||||
if g is not None]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars)
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
self._apply_op = self._optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
|
||||
self._variables = ray.experimental.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
|
|
|
@ -34,8 +34,9 @@ mountaincarcontinuous-ddpg:
|
|||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
actor_lr: 0.0001
|
||||
critic_lr: 0.001
|
||||
lr: 0.001
|
||||
actor_loss_coeff: 0.1
|
||||
critic_loss_coeff: 1.0
|
||||
use_huber: False
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.00001
|
||||
|
|
|
@ -34,8 +34,9 @@ pendulum-ddpg:
|
|||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
actor_lr: 0.0001
|
||||
critic_lr: 0.001
|
||||
lr: 0.001
|
||||
actor_loss_coeff: 0.1
|
||||
critic_loss_coeff: 1.0
|
||||
use_huber: True
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
|
|
60
python/ray/rllib/tuned_examples/pendulum-td3.yaml
Normal file
60
python/ray/rllib/tuned_examples/pendulum-td3.yaml
Normal file
|
@ -0,0 +1,60 @@
|
|||
# This configuration can expect to reach -160 reward in 10k-20k timesteps
|
||||
pendulum-ddpg:
|
||||
env: Pendulum-v0
|
||||
run: DDPG
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
time_total_s: 600 # 10 minutes
|
||||
config:
|
||||
# === Tricks ===
|
||||
twin_q: True
|
||||
policy_delay: 2
|
||||
smooth_target_policy: True
|
||||
act_noise: 0.1
|
||||
target_noise: 0.2
|
||||
noise_clip: 0.5
|
||||
|
||||
# === Model ===
|
||||
actor_hiddens: [64, 64]
|
||||
critic_hiddens: [64, 64]
|
||||
n_step: 1
|
||||
model: {}
|
||||
gamma: 0.99
|
||||
env_config: {}
|
||||
|
||||
# === Exploration ===
|
||||
schedule_max_timesteps: 100000
|
||||
timesteps_per_iteration: 600
|
||||
exploration_fraction: 0.1
|
||||
exploration_final_eps: 0.02
|
||||
noise_scale: 0.1
|
||||
exploration_theta: 0.15
|
||||
exploration_sigma: 0.2
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
|
||||
# === Replay buffer ===
|
||||
buffer_size: 10000
|
||||
prioritized_replay: True
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
lr: 0.001
|
||||
actor_loss_coeff: 0.1
|
||||
critic_loss_coeff: 1.0
|
||||
use_huber: True
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
sample_batch_size: 1
|
||||
train_batch_size: 64
|
||||
|
||||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
optimizer_class: "SyncReplayOptimizer"
|
||||
per_worker_exploration: False
|
||||
worker_side_prioritization: False
|
Loading…
Add table
Reference in a new issue