mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
SAC Performance Fixes (#6295)
* SAC Performance Fixes * Small Changes * Update sac_model.py * fix normalize wrapper * Update test_eager_support.py Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
eca4cc7c00
commit
548df014ec
9 changed files with 230 additions and 54 deletions
|
@ -29,6 +29,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"hidden_activation": "relu",
|
||||
"hidden_layer_sizes": (256, 256),
|
||||
},
|
||||
# Unsquash actions to the upper and lower bounds of env's action space
|
||||
"normalize_actions": True,
|
||||
|
||||
# === Learning ===
|
||||
# Update the target by \tau * policy + (1-\tau) * target_policy
|
||||
|
|
|
@ -76,7 +76,6 @@ class SACModel(TFModelV2):
|
|||
|
||||
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
self.action_dim = np.product(action_space.shape)
|
||||
self.model_out = tf.keras.layers.Input(
|
||||
shape=(num_outputs, ), name="model_out")
|
||||
|
@ -111,17 +110,68 @@ class SACModel(TFModelV2):
|
|||
shift_and_log_scale_diag = tf.keras.layers.Concatenate(axis=-1)(
|
||||
[shift, log_scale_diag])
|
||||
|
||||
raw_action_distribution = tfp.layers.MultivariateNormalTriL(
|
||||
self.action_dim)(shift_and_log_scale_diag)
|
||||
batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])(
|
||||
self.model_out)
|
||||
|
||||
action_distribution = tfp.layers.DistributionLambda(
|
||||
make_distribution_fn=SquashBijector())(raw_action_distribution)
|
||||
base_distribution = tfp.distributions.MultivariateNormalDiag(
|
||||
loc=tf.zeros(self.action_dim), scale_diag=tf.ones(self.action_dim))
|
||||
|
||||
# TODO(hartikainen): Remove the unnecessary Model call here
|
||||
self.action_distribution_model = tf.keras.Model(
|
||||
self.model_out, action_distribution)
|
||||
latents = tf.keras.layers.Lambda(
|
||||
lambda batch_size: base_distribution.sample(batch_size))(
|
||||
batch_size)
|
||||
|
||||
self.register_variables(self.action_distribution_model.variables)
|
||||
self.shift_and_log_scale_diag = latents
|
||||
self.latents_model = tf.keras.Model(self.model_out, latents)
|
||||
|
||||
def raw_actions_fn(inputs):
|
||||
shift, log_scale_diag, latents = inputs
|
||||
bijector = tfp.bijectors.Affine(
|
||||
shift=shift, scale_diag=tf.exp(log_scale_diag))
|
||||
actions = bijector.forward(latents)
|
||||
return actions
|
||||
|
||||
raw_actions = tf.keras.layers.Lambda(raw_actions_fn)(
|
||||
(shift, log_scale_diag, latents))
|
||||
|
||||
squash_bijector = (SquashBijector())
|
||||
|
||||
actions = tf.keras.layers.Lambda(
|
||||
lambda raw_actions: squash_bijector.forward(raw_actions))(
|
||||
raw_actions)
|
||||
self.actions_model = tf.keras.Model(self.model_out, actions)
|
||||
|
||||
deterministic_actions = tf.keras.layers.Lambda(
|
||||
lambda shift: squash_bijector.forward(shift))(shift)
|
||||
|
||||
self.deterministic_actions_model = tf.keras.Model(
|
||||
self.model_out, deterministic_actions)
|
||||
|
||||
def log_pis_fn(inputs):
|
||||
shift, log_scale_diag, actions = inputs
|
||||
base_distribution = tfp.distributions.MultivariateNormalDiag(
|
||||
loc=tf.zeros(self.action_dim),
|
||||
scale_diag=tf.ones(self.action_dim))
|
||||
bijector = tfp.bijectors.Chain((
|
||||
squash_bijector,
|
||||
tfp.bijectors.Affine(
|
||||
shift=shift, scale_diag=tf.exp(log_scale_diag)),
|
||||
))
|
||||
distribution = (tfp.distributions.TransformedDistribution(
|
||||
distribution=base_distribution, bijector=bijector))
|
||||
|
||||
log_pis = distribution.log_prob(actions)[:, None]
|
||||
return log_pis
|
||||
|
||||
self.actions_input = tf.keras.layers.Input(
|
||||
shape=(self.action_dim, ), name="actions")
|
||||
|
||||
log_pis_for_action_input = tf.keras.layers.Lambda(log_pis_fn)(
|
||||
[shift, log_scale_diag, self.actions_input])
|
||||
|
||||
self.log_pis_model = tf.keras.Model(
|
||||
(self.model_out, self.actions_input), log_pis_for_action_input)
|
||||
|
||||
self.register_variables(self.actions_model.variables)
|
||||
|
||||
def build_q_net(name, observations, actions):
|
||||
q_net = tf.keras.Sequential([
|
||||
|
@ -169,14 +219,12 @@ class SACModel(TFModelV2):
|
|||
Returns:
|
||||
tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf].
|
||||
"""
|
||||
action_distribution = self.action_distribution_model(model_out)
|
||||
if deterministic:
|
||||
actions = action_distribution.bijector(
|
||||
action_distribution.distribution.mean())
|
||||
actions = self.deterministic_actions_model(model_out)
|
||||
log_pis = None
|
||||
else:
|
||||
actions = action_distribution.sample()
|
||||
log_pis = action_distribution.log_prob(actions)
|
||||
actions = self.actions_model(model_out)
|
||||
log_pis = self.log_pis_model((model_out, actions))
|
||||
|
||||
return actions, log_pis
|
||||
|
||||
|
@ -217,7 +265,7 @@ class SACModel(TFModelV2):
|
|||
def policy_variables(self):
|
||||
"""Return the list of variables for the policy net."""
|
||||
|
||||
return list(self.action_distribution_model.variables)
|
||||
return list(self.actions_model.variables)
|
||||
|
||||
def q_variables(self):
|
||||
"""Return the list of variables for Q / twin Q nets."""
|
||||
|
|
|
@ -111,10 +111,13 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
|
|||
|
||||
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
|
||||
model_out, deterministic=False)
|
||||
stochastic_actions = unsquash_actions(squashed_stochastic_actions)
|
||||
stochastic_actions = squashed_stochastic_actions if config[
|
||||
"normalize_actions"] else unsquash_actions(squashed_stochastic_actions)
|
||||
squashed_deterministic_actions, _ = policy.model.get_policy_output(
|
||||
model_out, deterministic=True)
|
||||
deterministic_actions = unsquash_actions(squashed_deterministic_actions)
|
||||
deterministic_actions = squashed_deterministic_actions if config[
|
||||
"normalize_actions"] else unsquash_actions(
|
||||
squashed_deterministic_actions)
|
||||
|
||||
actions = tf.cond(policy.stochastic, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
|
@ -155,6 +158,10 @@ def actor_critic_loss(policy, model, _, train_batch):
|
|||
|
||||
# Q-values for current policy (no noise) in given current state
|
||||
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
||||
q_t_det_policy = tf.reduce_min(
|
||||
(q_t_det_policy, twin_q_t_det_policy), axis=0)
|
||||
|
||||
# target q network evaluation
|
||||
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
|
||||
|
@ -165,9 +172,8 @@ def actor_critic_loss(policy, model, _, train_batch):
|
|||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
if policy.config["twin_q"]:
|
||||
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
||||
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
||||
|
||||
q_tp1 -= tf.expand_dims(alpha * log_pis_t, 1)
|
||||
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
|
||||
q_tp1 -= alpha * log_pis_tp1
|
||||
|
||||
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
||||
q_tp1_best_masked = (
|
||||
|
@ -182,23 +188,29 @@ def actor_critic_loss(policy, model, _, train_batch):
|
|||
|
||||
# compute the error (potentially clipped)
|
||||
if policy.config["twin_q"]:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
base_td_error = q_t_selected - q_t_selected_target
|
||||
twin_td_error = twin_q_t_selected - q_t_selected_target
|
||||
td_error = td_error + twin_td_error
|
||||
errors = 0.5 * (tf.square(td_error) + tf.square(twin_td_error))
|
||||
td_error = 0.5 * (tf.square(base_td_error) + tf.square(twin_td_error))
|
||||
else:
|
||||
td_error = q_t_selected - q_t_selected_target
|
||||
errors = 0.5 * tf.square(td_error)
|
||||
td_error = tf.square(q_t_selected - q_t_selected_target)
|
||||
|
||||
critic_loss = model.custom_loss(
|
||||
tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors), train_batch)
|
||||
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
|
||||
critic_loss = [
|
||||
tf.losses.mean_squared_error(
|
||||
labels=q_t_selected_target, predictions=q_t_selected, weights=0.5)
|
||||
]
|
||||
if policy.config["twin_q"]:
|
||||
critic_loss.append(
|
||||
tf.losses.mean_squared_error(
|
||||
labels=q_t_selected_target,
|
||||
predictions=twin_q_t_selected,
|
||||
weights=0.5))
|
||||
|
||||
target_entropy = (-np.prod(policy.action_space.shape)
|
||||
if policy.config["target_entropy"] == "auto" else
|
||||
policy.config["target_entropy"])
|
||||
alpha_loss = -tf.reduce_mean(
|
||||
log_alpha * tf.stop_gradient(log_pis_t + target_entropy))
|
||||
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
|
||||
|
||||
# save for stats function
|
||||
policy.q_t = q_t
|
||||
|
@ -209,7 +221,7 @@ def actor_critic_loss(policy, model, _, train_batch):
|
|||
|
||||
# in a custom apply op we handle the losses separately, but return them
|
||||
# combined in one loss for now
|
||||
return actor_loss + critic_loss + alpha_loss
|
||||
return actor_loss + tf.add_n(critic_loss) + alpha_loss
|
||||
|
||||
|
||||
def gradients(policy, optimizer, loss):
|
||||
|
@ -219,23 +231,49 @@ def gradients(policy, optimizer, loss):
|
|||
policy.actor_loss,
|
||||
var_list=policy.model.policy_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars = minimize_and_clip(
|
||||
optimizer,
|
||||
policy.critic_loss,
|
||||
var_list=policy.model.q_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
if policy.config["twin_q"]:
|
||||
q_variables = policy.model.q_variables()
|
||||
half_cutoff = len(q_variables) // 2
|
||||
critic_grads_and_vars = []
|
||||
critic_grads_and_vars += minimize_and_clip(
|
||||
optimizer,
|
||||
policy.critic_loss[0],
|
||||
var_list=q_variables[:half_cutoff],
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
critic_grads_and_vars += minimize_and_clip(
|
||||
optimizer,
|
||||
policy.critic_loss[1],
|
||||
var_list=q_variables[half_cutoff:],
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
critic_grads_and_vars = minimize_and_clip(
|
||||
optimizer,
|
||||
policy.critic_loss[0],
|
||||
var_list=policy.model.q_variables(),
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
alpha_grads_and_vars = minimize_and_clip(
|
||||
optimizer,
|
||||
policy.alpha_loss,
|
||||
var_list=[policy.model.log_alpha],
|
||||
clip_val=policy.config["grad_norm_clipping"])
|
||||
else:
|
||||
actor_grads_and_vars = optimizer.compute_gradients(
|
||||
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
||||
policy.actor_loss, var_list=policy.model.policy_variables())
|
||||
critic_grads_and_vars = optimizer.compute_gradients(
|
||||
policy.critic_loss, var_list=policy.model.q_variables())
|
||||
alpha_grads_and_vars = optimizer.compute_gradients(
|
||||
if policy.config["twin_q"]:
|
||||
q_variables = policy.model.q_variables()
|
||||
half_cutoff = len(q_variables) // 2
|
||||
base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
|
||||
critic_grads_and_vars = base_q_optimizer.compute_gradients(
|
||||
policy.critic_loss[0], var_list=q_variables[:half_cutoff]
|
||||
) + twin_q_optimizer.compute_gradients(
|
||||
policy.critic_loss[1], var_list=q_variables[half_cutoff:])
|
||||
else:
|
||||
critic_grads_and_vars = policy._critic_optimizer[
|
||||
0].compute_gradients(
|
||||
policy.critic_loss[0], var_list=policy.model.q_variables())
|
||||
alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
|
||||
policy.alpha_loss, var_list=[policy.model.log_alpha])
|
||||
|
||||
# save these for later use in build_apply_op
|
||||
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
||||
if g is not None]
|
||||
|
@ -249,6 +287,28 @@ def gradients(policy, optimizer, loss):
|
|||
return grads_and_vars
|
||||
|
||||
|
||||
def apply_gradients(policy, optimizer, grads_and_vars):
|
||||
actor_apply_ops = policy._actor_optimizer.apply_gradients(
|
||||
policy._actor_grads_and_vars)
|
||||
|
||||
cgrads = policy._critic_grads_and_vars
|
||||
half_cutoff = len(cgrads) // 2
|
||||
if policy.config["twin_q"]:
|
||||
critic_apply_ops = [
|
||||
policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
|
||||
policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:])
|
||||
]
|
||||
else:
|
||||
critic_apply_ops = [
|
||||
policy._critic_optimizer[0].apply_gradients(cgrads)
|
||||
]
|
||||
|
||||
alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
|
||||
policy._alpha_grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
return {
|
||||
"td_error": tf.reduce_mean(policy.td_error),
|
||||
|
@ -281,8 +341,14 @@ class ActorCriticOptimizerMixin(object):
|
|||
# use separate optimizers for actor & critic
|
||||
self._actor_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["actor_learning_rate"])
|
||||
self._critic_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["critic_learning_rate"])
|
||||
self._critic_optimizer = [
|
||||
tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["critic_learning_rate"])
|
||||
]
|
||||
if config["twin_q"]:
|
||||
self._critic_optimizer.append(
|
||||
tf.train.AdamOptimizer(learning_rate=config["optimization"][
|
||||
"critic_learning_rate"]))
|
||||
self._alpha_optimizer = tf.train.AdamOptimizer(
|
||||
learning_rate=config["optimization"]["entropy_learning_rate"])
|
||||
|
||||
|
@ -359,6 +425,7 @@ SACTFPolicy = build_tf_policy(
|
|||
loss_fn=actor_critic_loss,
|
||||
stats_fn=stats,
|
||||
gradients_fn=gradients,
|
||||
apply_gradients_fn=apply_gradients,
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
||||
mixins=[
|
||||
TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin,
|
||||
|
|
|
@ -28,6 +28,7 @@ from ray.tune.trial import ExportFormat
|
|||
from ray.tune.resources import Resources
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -101,6 +102,8 @@ COMMON_CONFIG = {
|
|||
"env_config": {},
|
||||
# Environment name can also be passed via config.
|
||||
"env": None,
|
||||
# Unsquash actions to the upper and lower bounds of env's action space
|
||||
"normalize_actions": False,
|
||||
# Whether to clip rewards prior to experience postprocessing. Setting to
|
||||
# None means clip for Atari only.
|
||||
"clip_rewards": None,
|
||||
|
@ -504,6 +507,12 @@ class Trainer(Trainable):
|
|||
self._allow_unknown_subkeys)
|
||||
self.raw_user_config = config
|
||||
self.config = merged_config
|
||||
|
||||
if self.config["normalize_actions"]:
|
||||
inner = self.env_creator
|
||||
self.env_creator = (
|
||||
lambda env_config: NormalizeActionWrapper(inner(env_config)))
|
||||
|
||||
Trainer._validate_config(self.config)
|
||||
log_level = self.config.get("log_level")
|
||||
if log_level in ["WARN", "ERROR"]:
|
||||
|
|
21
rllib/env/normalize_actions.py
vendored
Normal file
21
rllib/env/normalize_actions.py
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NormalizeActionWrapper(gym.ActionWrapper):
|
||||
"""Rescale the action space of the environment."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(self.env.action_space, spaces.Box):
|
||||
return action
|
||||
|
||||
# rescale the action
|
||||
low, high = self.env.action_space.low, self.env.action_space.high
|
||||
scaled_action = low + (action + 1.0) * (high - low) / 2.0
|
||||
scaled_action = np.clip(scaled_action, low, high)
|
||||
|
||||
return scaled_action
|
||||
|
||||
def reverse_action(self, action):
|
||||
raise NotImplementedError
|
|
@ -70,10 +70,7 @@ class ReplayBuffer(object):
|
|||
|
||||
@DeveloperAPI
|
||||
def sample_idxes(self, batch_size):
|
||||
return [
|
||||
random.randint(0,
|
||||
len(self._storage) - 1) for _ in range(batch_size)
|
||||
]
|
||||
return np.random.randint(0, len(self._storage), batch_size)
|
||||
|
||||
@DeveloperAPI
|
||||
def sample_with_idxes(self, idxes):
|
||||
|
|
|
@ -64,13 +64,6 @@ class TestEagerSupport(unittest.TestCase):
|
|||
"timesteps_per_iteration": 100
|
||||
})
|
||||
|
||||
def testSAC(self):
|
||||
check_support("SAC", {
|
||||
"num_workers": 0,
|
||||
"learning_starts": 0,
|
||||
"timesteps_per_iteration": 100
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
37
rllib/tuned_examples/pendulum-sac.yaml
Normal file
37
rllib/tuned_examples/pendulum-sac.yaml
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Pendulum SAC can attain -150+ reward in 6-7k
|
||||
# Configurations are the similar to original softlearning/sac codebase
|
||||
pendulum_sac:
|
||||
env: Pendulum-v0
|
||||
run: SAC
|
||||
stop:
|
||||
episode_reward_mean: -150
|
||||
config:
|
||||
horizon: 200
|
||||
soft_horizon: False
|
||||
Q_model:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [256, 256]
|
||||
policy_model:
|
||||
hidden_activation: relu
|
||||
hidden_layer_sizes: [256, 256]
|
||||
tau: 0.005
|
||||
target_entropy: auto
|
||||
no_done_at_end: True
|
||||
n_step: 1
|
||||
sample_batch_size: 1
|
||||
prioritized_replay: False
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
timesteps_per_iteration: 1000
|
||||
learning_starts: 256
|
||||
exploration_enabled: True
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0003
|
||||
num_workers: 0
|
||||
num_gpus: 0
|
||||
clip_actions: False
|
||||
normalize_actions: True
|
||||
evaluation_interval: 1
|
||||
metrics_smoothing_episodes: 5
|
|
@ -3,8 +3,10 @@ pendulum-sac:
|
|||
run: SAC
|
||||
stop:
|
||||
episode_reward_mean: -300 # note that evaluation perf is higher
|
||||
timesteps_total: 15000
|
||||
timesteps_total: 10000
|
||||
config:
|
||||
evaluation_interval: 1 # logged under evaluation/* metric keys
|
||||
soft_horizon: true
|
||||
metrics_smoothing_episodes: 10
|
||||
soft_horizon: True
|
||||
clip_actions: False
|
||||
normalize_actions: True
|
||||
metrics_smoothing_episodes: 5
|
||||
|
|
Loading…
Add table
Reference in a new issue