SAC Performance Fixes (#6295)

* SAC Performance Fixes

* Small Changes

* Update sac_model.py

* fix normalize wrapper

* Update test_eager_support.py

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Michael Luo 2019-12-20 10:51:25 -08:00 committed by Eric Liang
parent eca4cc7c00
commit 548df014ec
9 changed files with 230 additions and 54 deletions

View file

@ -29,6 +29,8 @@ DEFAULT_CONFIG = with_common_config({
"hidden_activation": "relu",
"hidden_layer_sizes": (256, 256),
},
# Unsquash actions to the upper and lower bounds of env's action space
"normalize_actions": True,
# === Learning ===
# Update the target by \tau * policy + (1-\tau) * target_policy

View file

@ -76,7 +76,6 @@ class SACModel(TFModelV2):
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.action_dim = np.product(action_space.shape)
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
@ -111,17 +110,68 @@ class SACModel(TFModelV2):
shift_and_log_scale_diag = tf.keras.layers.Concatenate(axis=-1)(
[shift, log_scale_diag])
raw_action_distribution = tfp.layers.MultivariateNormalTriL(
self.action_dim)(shift_and_log_scale_diag)
batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])(
self.model_out)
action_distribution = tfp.layers.DistributionLambda(
make_distribution_fn=SquashBijector())(raw_action_distribution)
base_distribution = tfp.distributions.MultivariateNormalDiag(
loc=tf.zeros(self.action_dim), scale_diag=tf.ones(self.action_dim))
# TODO(hartikainen): Remove the unnecessary Model call here
self.action_distribution_model = tf.keras.Model(
self.model_out, action_distribution)
latents = tf.keras.layers.Lambda(
lambda batch_size: base_distribution.sample(batch_size))(
batch_size)
self.register_variables(self.action_distribution_model.variables)
self.shift_and_log_scale_diag = latents
self.latents_model = tf.keras.Model(self.model_out, latents)
def raw_actions_fn(inputs):
shift, log_scale_diag, latents = inputs
bijector = tfp.bijectors.Affine(
shift=shift, scale_diag=tf.exp(log_scale_diag))
actions = bijector.forward(latents)
return actions
raw_actions = tf.keras.layers.Lambda(raw_actions_fn)(
(shift, log_scale_diag, latents))
squash_bijector = (SquashBijector())
actions = tf.keras.layers.Lambda(
lambda raw_actions: squash_bijector.forward(raw_actions))(
raw_actions)
self.actions_model = tf.keras.Model(self.model_out, actions)
deterministic_actions = tf.keras.layers.Lambda(
lambda shift: squash_bijector.forward(shift))(shift)
self.deterministic_actions_model = tf.keras.Model(
self.model_out, deterministic_actions)
def log_pis_fn(inputs):
shift, log_scale_diag, actions = inputs
base_distribution = tfp.distributions.MultivariateNormalDiag(
loc=tf.zeros(self.action_dim),
scale_diag=tf.ones(self.action_dim))
bijector = tfp.bijectors.Chain((
squash_bijector,
tfp.bijectors.Affine(
shift=shift, scale_diag=tf.exp(log_scale_diag)),
))
distribution = (tfp.distributions.TransformedDistribution(
distribution=base_distribution, bijector=bijector))
log_pis = distribution.log_prob(actions)[:, None]
return log_pis
self.actions_input = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")
log_pis_for_action_input = tf.keras.layers.Lambda(log_pis_fn)(
[shift, log_scale_diag, self.actions_input])
self.log_pis_model = tf.keras.Model(
(self.model_out, self.actions_input), log_pis_for_action_input)
self.register_variables(self.actions_model.variables)
def build_q_net(name, observations, actions):
q_net = tf.keras.Sequential([
@ -169,14 +219,12 @@ class SACModel(TFModelV2):
Returns:
tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf].
"""
action_distribution = self.action_distribution_model(model_out)
if deterministic:
actions = action_distribution.bijector(
action_distribution.distribution.mean())
actions = self.deterministic_actions_model(model_out)
log_pis = None
else:
actions = action_distribution.sample()
log_pis = action_distribution.log_prob(actions)
actions = self.actions_model(model_out)
log_pis = self.log_pis_model((model_out, actions))
return actions, log_pis
@ -217,7 +265,7 @@ class SACModel(TFModelV2):
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_distribution_model.variables)
return list(self.actions_model.variables)
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""

View file

@ -111,10 +111,13 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
squashed_stochastic_actions, log_pis = policy.model.get_policy_output(
model_out, deterministic=False)
stochastic_actions = unsquash_actions(squashed_stochastic_actions)
stochastic_actions = squashed_stochastic_actions if config[
"normalize_actions"] else unsquash_actions(squashed_stochastic_actions)
squashed_deterministic_actions, _ = policy.model.get_policy_output(
model_out, deterministic=True)
deterministic_actions = unsquash_actions(squashed_deterministic_actions)
deterministic_actions = squashed_deterministic_actions if config[
"normalize_actions"] else unsquash_actions(
squashed_deterministic_actions)
actions = tf.cond(policy.stochastic, lambda: stochastic_actions,
lambda: deterministic_actions)
@ -155,6 +158,10 @@ def actor_critic_loss(policy, model, _, train_batch):
# Q-values for current policy (no noise) in given current state
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
if policy.config["twin_q"]:
twin_q_t_det_policy = model.get_q_values(model_out_t, policy_t)
q_t_det_policy = tf.reduce_min(
(q_t_det_policy, twin_q_t_det_policy), axis=0)
# target q network evaluation
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
@ -165,9 +172,8 @@ def actor_critic_loss(policy, model, _, train_batch):
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
if policy.config["twin_q"]:
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
q_tp1 -= tf.expand_dims(alpha * log_pis_t, 1)
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
q_tp1 -= alpha * log_pis_tp1
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
q_tp1_best_masked = (
@ -182,23 +188,29 @@ def actor_critic_loss(policy, model, _, train_batch):
# compute the error (potentially clipped)
if policy.config["twin_q"]:
td_error = q_t_selected - q_t_selected_target
base_td_error = q_t_selected - q_t_selected_target
twin_td_error = twin_q_t_selected - q_t_selected_target
td_error = td_error + twin_td_error
errors = 0.5 * (tf.square(td_error) + tf.square(twin_td_error))
td_error = 0.5 * (tf.square(base_td_error) + tf.square(twin_td_error))
else:
td_error = q_t_selected - q_t_selected_target
errors = 0.5 * tf.square(td_error)
td_error = tf.square(q_t_selected - q_t_selected_target)
critic_loss = model.custom_loss(
tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors), train_batch)
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
critic_loss = [
tf.losses.mean_squared_error(
labels=q_t_selected_target, predictions=q_t_selected, weights=0.5)
]
if policy.config["twin_q"]:
critic_loss.append(
tf.losses.mean_squared_error(
labels=q_t_selected_target,
predictions=twin_q_t_selected,
weights=0.5))
target_entropy = (-np.prod(policy.action_space.shape)
if policy.config["target_entropy"] == "auto" else
policy.config["target_entropy"])
alpha_loss = -tf.reduce_mean(
log_alpha * tf.stop_gradient(log_pis_t + target_entropy))
actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy)
# save for stats function
policy.q_t = q_t
@ -209,7 +221,7 @@ def actor_critic_loss(policy, model, _, train_batch):
# in a custom apply op we handle the losses separately, but return them
# combined in one loss for now
return actor_loss + critic_loss + alpha_loss
return actor_loss + tf.add_n(critic_loss) + alpha_loss
def gradients(policy, optimizer, loss):
@ -219,23 +231,49 @@ def gradients(policy, optimizer, loss):
policy.actor_loss,
var_list=policy.model.policy_variables(),
clip_val=policy.config["grad_norm_clipping"])
critic_grads_and_vars = minimize_and_clip(
optimizer,
policy.critic_loss,
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_norm_clipping"])
if policy.config["twin_q"]:
q_variables = policy.model.q_variables()
half_cutoff = len(q_variables) // 2
critic_grads_and_vars = []
critic_grads_and_vars += minimize_and_clip(
optimizer,
policy.critic_loss[0],
var_list=q_variables[:half_cutoff],
clip_val=policy.config["grad_norm_clipping"])
critic_grads_and_vars += minimize_and_clip(
optimizer,
policy.critic_loss[1],
var_list=q_variables[half_cutoff:],
clip_val=policy.config["grad_norm_clipping"])
else:
critic_grads_and_vars = minimize_and_clip(
optimizer,
policy.critic_loss[0],
var_list=policy.model.q_variables(),
clip_val=policy.config["grad_norm_clipping"])
alpha_grads_and_vars = minimize_and_clip(
optimizer,
policy.alpha_loss,
var_list=[policy.model.log_alpha],
clip_val=policy.config["grad_norm_clipping"])
else:
actor_grads_and_vars = optimizer.compute_gradients(
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
policy.actor_loss, var_list=policy.model.policy_variables())
critic_grads_and_vars = optimizer.compute_gradients(
policy.critic_loss, var_list=policy.model.q_variables())
alpha_grads_and_vars = optimizer.compute_gradients(
if policy.config["twin_q"]:
q_variables = policy.model.q_variables()
half_cutoff = len(q_variables) // 2
base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
critic_grads_and_vars = base_q_optimizer.compute_gradients(
policy.critic_loss[0], var_list=q_variables[:half_cutoff]
) + twin_q_optimizer.compute_gradients(
policy.critic_loss[1], var_list=q_variables[half_cutoff:])
else:
critic_grads_and_vars = policy._critic_optimizer[
0].compute_gradients(
policy.critic_loss[0], var_list=policy.model.q_variables())
alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
policy.alpha_loss, var_list=[policy.model.log_alpha])
# save these for later use in build_apply_op
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
if g is not None]
@ -249,6 +287,28 @@ def gradients(policy, optimizer, loss):
return grads_and_vars
def apply_gradients(policy, optimizer, grads_and_vars):
actor_apply_ops = policy._actor_optimizer.apply_gradients(
policy._actor_grads_and_vars)
cgrads = policy._critic_grads_and_vars
half_cutoff = len(cgrads) // 2
if policy.config["twin_q"]:
critic_apply_ops = [
policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:])
]
else:
critic_apply_ops = [
policy._critic_optimizer[0].apply_gradients(cgrads)
]
alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
policy._alpha_grads_and_vars,
global_step=tf.train.get_or_create_global_step())
return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
def stats(policy, train_batch):
return {
"td_error": tf.reduce_mean(policy.td_error),
@ -281,8 +341,14 @@ class ActorCriticOptimizerMixin(object):
# use separate optimizers for actor & critic
self._actor_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["actor_learning_rate"])
self._critic_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["critic_learning_rate"])
self._critic_optimizer = [
tf.train.AdamOptimizer(
learning_rate=config["optimization"]["critic_learning_rate"])
]
if config["twin_q"]:
self._critic_optimizer.append(
tf.train.AdamOptimizer(learning_rate=config["optimization"][
"critic_learning_rate"]))
self._alpha_optimizer = tf.train.AdamOptimizer(
learning_rate=config["optimization"]["entropy_learning_rate"])
@ -359,6 +425,7 @@ SACTFPolicy = build_tf_policy(
loss_fn=actor_critic_loss,
stats_fn=stats,
gradients_fn=gradients,
apply_gradients_fn=apply_gradients,
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
mixins=[
TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin,

View file

@ -28,6 +28,7 @@ from ray.tune.trial import ExportFormat
from ray.tune.resources import Resources
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
tf = try_import_tf()
@ -101,6 +102,8 @@ COMMON_CONFIG = {
"env_config": {},
# Environment name can also be passed via config.
"env": None,
# Unsquash actions to the upper and lower bounds of env's action space
"normalize_actions": False,
# Whether to clip rewards prior to experience postprocessing. Setting to
# None means clip for Atari only.
"clip_rewards": None,
@ -504,6 +507,12 @@ class Trainer(Trainable):
self._allow_unknown_subkeys)
self.raw_user_config = config
self.config = merged_config
if self.config["normalize_actions"]:
inner = self.env_creator
self.env_creator = (
lambda env_config: NormalizeActionWrapper(inner(env_config)))
Trainer._validate_config(self.config)
log_level = self.config.get("log_level")
if log_level in ["WARN", "ERROR"]:

21
rllib/env/normalize_actions.py vendored Normal file
View file

@ -0,0 +1,21 @@
import gym
from gym import spaces
import numpy as np
class NormalizeActionWrapper(gym.ActionWrapper):
"""Rescale the action space of the environment."""
def action(self, action):
if not isinstance(self.env.action_space, spaces.Box):
return action
# rescale the action
low, high = self.env.action_space.low, self.env.action_space.high
scaled_action = low + (action + 1.0) * (high - low) / 2.0
scaled_action = np.clip(scaled_action, low, high)
return scaled_action
def reverse_action(self, action):
raise NotImplementedError

View file

@ -70,10 +70,7 @@ class ReplayBuffer(object):
@DeveloperAPI
def sample_idxes(self, batch_size):
return [
random.randint(0,
len(self._storage) - 1) for _ in range(batch_size)
]
return np.random.randint(0, len(self._storage), batch_size)
@DeveloperAPI
def sample_with_idxes(self, idxes):

View file

@ -64,13 +64,6 @@ class TestEagerSupport(unittest.TestCase):
"timesteps_per_iteration": 100
})
def testSAC(self):
check_support("SAC", {
"num_workers": 0,
"learning_starts": 0,
"timesteps_per_iteration": 100
})
if __name__ == "__main__":
import pytest

View file

@ -0,0 +1,37 @@
# Pendulum SAC can attain -150+ reward in 6-7k
# Configurations are the similar to original softlearning/sac codebase
pendulum_sac:
env: Pendulum-v0
run: SAC
stop:
episode_reward_mean: -150
config:
horizon: 200
soft_horizon: False
Q_model:
hidden_activation: relu
hidden_layer_sizes: [256, 256]
policy_model:
hidden_activation: relu
hidden_layer_sizes: [256, 256]
tau: 0.005
target_entropy: auto
no_done_at_end: True
n_step: 1
sample_batch_size: 1
prioritized_replay: False
train_batch_size: 256
target_network_update_freq: 1
timesteps_per_iteration: 1000
learning_starts: 256
exploration_enabled: True
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0003
num_workers: 0
num_gpus: 0
clip_actions: False
normalize_actions: True
evaluation_interval: 1
metrics_smoothing_episodes: 5

View file

@ -3,8 +3,10 @@ pendulum-sac:
run: SAC
stop:
episode_reward_mean: -300 # note that evaluation perf is higher
timesteps_total: 15000
timesteps_total: 10000
config:
evaluation_interval: 1 # logged under evaluation/* metric keys
soft_horizon: true
metrics_smoothing_episodes: 10
soft_horizon: True
clip_actions: False
normalize_actions: True
metrics_smoothing_episodes: 5