From 548df014ec022d3fe2b49131026ffcfc2ece5668 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Fri, 20 Dec 2019 10:51:25 -0800 Subject: [PATCH] SAC Performance Fixes (#6295) * SAC Performance Fixes * Small Changes * Update sac_model.py * fix normalize wrapper * Update test_eager_support.py Co-authored-by: Eric Liang --- rllib/agents/sac/sac.py | 2 + rllib/agents/sac/sac_model.py | 78 +++++++++--- rllib/agents/sac/sac_policy.py | 117 ++++++++++++++---- rllib/agents/trainer.py | 9 ++ rllib/env/normalize_actions.py | 21 ++++ rllib/optimizers/replay_buffer.py | 5 +- rllib/tests/test_eager_support.py | 7 -- rllib/tuned_examples/pendulum-sac.yaml | 37 ++++++ .../regression_tests/pendulum-sac.yaml | 8 +- 9 files changed, 230 insertions(+), 54 deletions(-) create mode 100644 rllib/env/normalize_actions.py create mode 100644 rllib/tuned_examples/pendulum-sac.yaml diff --git a/rllib/agents/sac/sac.py b/rllib/agents/sac/sac.py index 41e2b94b6..2573b4c26 100644 --- a/rllib/agents/sac/sac.py +++ b/rllib/agents/sac/sac.py @@ -29,6 +29,8 @@ DEFAULT_CONFIG = with_common_config({ "hidden_activation": "relu", "hidden_layer_sizes": (256, 256), }, + # Unsquash actions to the upper and lower bounds of env's action space + "normalize_actions": True, # === Learning === # Update the target by \tau * policy + (1-\tau) * target_policy diff --git a/rllib/agents/sac/sac_model.py b/rllib/agents/sac/sac_model.py index c10d54b0c..d72a6063f 100644 --- a/rllib/agents/sac/sac_model.py +++ b/rllib/agents/sac/sac_model.py @@ -76,7 +76,6 @@ class SACModel(TFModelV2): super(SACModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) - self.action_dim = np.product(action_space.shape) self.model_out = tf.keras.layers.Input( shape=(num_outputs, ), name="model_out") @@ -111,17 +110,68 @@ class SACModel(TFModelV2): shift_and_log_scale_diag = tf.keras.layers.Concatenate(axis=-1)( [shift, log_scale_diag]) - raw_action_distribution = tfp.layers.MultivariateNormalTriL( - self.action_dim)(shift_and_log_scale_diag) + batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])( + self.model_out) - action_distribution = tfp.layers.DistributionLambda( - make_distribution_fn=SquashBijector())(raw_action_distribution) + base_distribution = tfp.distributions.MultivariateNormalDiag( + loc=tf.zeros(self.action_dim), scale_diag=tf.ones(self.action_dim)) - # TODO(hartikainen): Remove the unnecessary Model call here - self.action_distribution_model = tf.keras.Model( - self.model_out, action_distribution) + latents = tf.keras.layers.Lambda( + lambda batch_size: base_distribution.sample(batch_size))( + batch_size) - self.register_variables(self.action_distribution_model.variables) + self.shift_and_log_scale_diag = latents + self.latents_model = tf.keras.Model(self.model_out, latents) + + def raw_actions_fn(inputs): + shift, log_scale_diag, latents = inputs + bijector = tfp.bijectors.Affine( + shift=shift, scale_diag=tf.exp(log_scale_diag)) + actions = bijector.forward(latents) + return actions + + raw_actions = tf.keras.layers.Lambda(raw_actions_fn)( + (shift, log_scale_diag, latents)) + + squash_bijector = (SquashBijector()) + + actions = tf.keras.layers.Lambda( + lambda raw_actions: squash_bijector.forward(raw_actions))( + raw_actions) + self.actions_model = tf.keras.Model(self.model_out, actions) + + deterministic_actions = tf.keras.layers.Lambda( + lambda shift: squash_bijector.forward(shift))(shift) + + self.deterministic_actions_model = tf.keras.Model( + self.model_out, deterministic_actions) + + def log_pis_fn(inputs): + shift, log_scale_diag, actions = inputs + base_distribution = tfp.distributions.MultivariateNormalDiag( + loc=tf.zeros(self.action_dim), + scale_diag=tf.ones(self.action_dim)) + bijector = tfp.bijectors.Chain(( + squash_bijector, + tfp.bijectors.Affine( + shift=shift, scale_diag=tf.exp(log_scale_diag)), + )) + distribution = (tfp.distributions.TransformedDistribution( + distribution=base_distribution, bijector=bijector)) + + log_pis = distribution.log_prob(actions)[:, None] + return log_pis + + self.actions_input = tf.keras.layers.Input( + shape=(self.action_dim, ), name="actions") + + log_pis_for_action_input = tf.keras.layers.Lambda(log_pis_fn)( + [shift, log_scale_diag, self.actions_input]) + + self.log_pis_model = tf.keras.Model( + (self.model_out, self.actions_input), log_pis_for_action_input) + + self.register_variables(self.actions_model.variables) def build_q_net(name, observations, actions): q_net = tf.keras.Sequential([ @@ -169,14 +219,12 @@ class SACModel(TFModelV2): Returns: tensor of shape [BATCH_SIZE, action_dim] with range [-inf, inf]. """ - action_distribution = self.action_distribution_model(model_out) if deterministic: - actions = action_distribution.bijector( - action_distribution.distribution.mean()) + actions = self.deterministic_actions_model(model_out) log_pis = None else: - actions = action_distribution.sample() - log_pis = action_distribution.log_prob(actions) + actions = self.actions_model(model_out) + log_pis = self.log_pis_model((model_out, actions)) return actions, log_pis @@ -217,7 +265,7 @@ class SACModel(TFModelV2): def policy_variables(self): """Return the list of variables for the policy net.""" - return list(self.action_distribution_model.variables) + return list(self.actions_model.variables) def q_variables(self): """Return the list of variables for Q / twin Q nets.""" diff --git a/rllib/agents/sac/sac_policy.py b/rllib/agents/sac/sac_policy.py index f1ac1224e..da5cc14d1 100644 --- a/rllib/agents/sac/sac_policy.py +++ b/rllib/agents/sac/sac_policy.py @@ -111,10 +111,13 @@ def build_action_output(policy, model, input_dict, obs_space, action_space, squashed_stochastic_actions, log_pis = policy.model.get_policy_output( model_out, deterministic=False) - stochastic_actions = unsquash_actions(squashed_stochastic_actions) + stochastic_actions = squashed_stochastic_actions if config[ + "normalize_actions"] else unsquash_actions(squashed_stochastic_actions) squashed_deterministic_actions, _ = policy.model.get_policy_output( model_out, deterministic=True) - deterministic_actions = unsquash_actions(squashed_deterministic_actions) + deterministic_actions = squashed_deterministic_actions if config[ + "normalize_actions"] else unsquash_actions( + squashed_deterministic_actions) actions = tf.cond(policy.stochastic, lambda: stochastic_actions, lambda: deterministic_actions) @@ -155,6 +158,10 @@ def actor_critic_loss(policy, model, _, train_batch): # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) + if policy.config["twin_q"]: + twin_q_t_det_policy = model.get_q_values(model_out_t, policy_t) + q_t_det_policy = tf.reduce_min( + (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) @@ -165,9 +172,8 @@ def actor_critic_loss(policy, model, _, train_batch): q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) - q_tp1 = tf.minimum(q_tp1, twin_q_tp1) - - q_tp1 -= tf.expand_dims(alpha * log_pis_t, 1) + q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) + q_tp1 -= alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = ( @@ -182,23 +188,29 @@ def actor_critic_loss(policy, model, _, train_batch): # compute the error (potentially clipped) if policy.config["twin_q"]: - td_error = q_t_selected - q_t_selected_target + base_td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target - td_error = td_error + twin_td_error - errors = 0.5 * (tf.square(td_error) + tf.square(twin_td_error)) + td_error = 0.5 * (tf.square(base_td_error) + tf.square(twin_td_error)) else: - td_error = q_t_selected - q_t_selected_target - errors = 0.5 * tf.square(td_error) + td_error = tf.square(q_t_selected - q_t_selected_target) - critic_loss = model.custom_loss( - tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors), train_batch) - actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy) + critic_loss = [ + tf.losses.mean_squared_error( + labels=q_t_selected_target, predictions=q_t_selected, weights=0.5) + ] + if policy.config["twin_q"]: + critic_loss.append( + tf.losses.mean_squared_error( + labels=q_t_selected_target, + predictions=twin_q_t_selected, + weights=0.5)) target_entropy = (-np.prod(policy.action_space.shape) if policy.config["target_entropy"] == "auto" else policy.config["target_entropy"]) alpha_loss = -tf.reduce_mean( log_alpha * tf.stop_gradient(log_pis_t + target_entropy)) + actor_loss = tf.reduce_mean(alpha * log_pis_t - q_t_det_policy) # save for stats function policy.q_t = q_t @@ -209,7 +221,7 @@ def actor_critic_loss(policy, model, _, train_batch): # in a custom apply op we handle the losses separately, but return them # combined in one loss for now - return actor_loss + critic_loss + alpha_loss + return actor_loss + tf.add_n(critic_loss) + alpha_loss def gradients(policy, optimizer, loss): @@ -219,23 +231,49 @@ def gradients(policy, optimizer, loss): policy.actor_loss, var_list=policy.model.policy_variables(), clip_val=policy.config["grad_norm_clipping"]) - critic_grads_and_vars = minimize_and_clip( - optimizer, - policy.critic_loss, - var_list=policy.model.q_variables(), - clip_val=policy.config["grad_norm_clipping"]) + if policy.config["twin_q"]: + q_variables = policy.model.q_variables() + half_cutoff = len(q_variables) // 2 + critic_grads_and_vars = [] + critic_grads_and_vars += minimize_and_clip( + optimizer, + policy.critic_loss[0], + var_list=q_variables[:half_cutoff], + clip_val=policy.config["grad_norm_clipping"]) + critic_grads_and_vars += minimize_and_clip( + optimizer, + policy.critic_loss[1], + var_list=q_variables[half_cutoff:], + clip_val=policy.config["grad_norm_clipping"]) + else: + critic_grads_and_vars = minimize_and_clip( + optimizer, + policy.critic_loss[0], + var_list=policy.model.q_variables(), + clip_val=policy.config["grad_norm_clipping"]) alpha_grads_and_vars = minimize_and_clip( optimizer, policy.alpha_loss, var_list=[policy.model.log_alpha], clip_val=policy.config["grad_norm_clipping"]) else: - actor_grads_and_vars = optimizer.compute_gradients( + actor_grads_and_vars = policy._actor_optimizer.compute_gradients( policy.actor_loss, var_list=policy.model.policy_variables()) - critic_grads_and_vars = optimizer.compute_gradients( - policy.critic_loss, var_list=policy.model.q_variables()) - alpha_grads_and_vars = optimizer.compute_gradients( + if policy.config["twin_q"]: + q_variables = policy.model.q_variables() + half_cutoff = len(q_variables) // 2 + base_q_optimizer, twin_q_optimizer = policy._critic_optimizer + critic_grads_and_vars = base_q_optimizer.compute_gradients( + policy.critic_loss[0], var_list=q_variables[:half_cutoff] + ) + twin_q_optimizer.compute_gradients( + policy.critic_loss[1], var_list=q_variables[half_cutoff:]) + else: + critic_grads_and_vars = policy._critic_optimizer[ + 0].compute_gradients( + policy.critic_loss[0], var_list=policy.model.q_variables()) + alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients( policy.alpha_loss, var_list=[policy.model.log_alpha]) + # save these for later use in build_apply_op policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars if g is not None] @@ -249,6 +287,28 @@ def gradients(policy, optimizer, loss): return grads_and_vars +def apply_gradients(policy, optimizer, grads_and_vars): + actor_apply_ops = policy._actor_optimizer.apply_gradients( + policy._actor_grads_and_vars) + + cgrads = policy._critic_grads_and_vars + half_cutoff = len(cgrads) // 2 + if policy.config["twin_q"]: + critic_apply_ops = [ + policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]), + policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:]) + ] + else: + critic_apply_ops = [ + policy._critic_optimizer[0].apply_gradients(cgrads) + ] + + alpha_apply_ops = policy._alpha_optimizer.apply_gradients( + policy._alpha_grads_and_vars, + global_step=tf.train.get_or_create_global_step()) + return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops) + + def stats(policy, train_batch): return { "td_error": tf.reduce_mean(policy.td_error), @@ -281,8 +341,14 @@ class ActorCriticOptimizerMixin(object): # use separate optimizers for actor & critic self._actor_optimizer = tf.train.AdamOptimizer( learning_rate=config["optimization"]["actor_learning_rate"]) - self._critic_optimizer = tf.train.AdamOptimizer( - learning_rate=config["optimization"]["critic_learning_rate"]) + self._critic_optimizer = [ + tf.train.AdamOptimizer( + learning_rate=config["optimization"]["critic_learning_rate"]) + ] + if config["twin_q"]: + self._critic_optimizer.append( + tf.train.AdamOptimizer(learning_rate=config["optimization"][ + "critic_learning_rate"])) self._alpha_optimizer = tf.train.AdamOptimizer( learning_rate=config["optimization"]["entropy_learning_rate"]) @@ -359,6 +425,7 @@ SACTFPolicy = build_tf_policy( loss_fn=actor_critic_loss, stats_fn=stats, gradients_fn=gradients, + apply_gradients_fn=apply_gradients, extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, mixins=[ TargetNetworkMixin, ExplorationStateMixin, ActorCriticOptimizerMixin, diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e6c231135..06ab3d35a 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -28,6 +28,7 @@ from ray.tune.trial import ExportFormat from ray.tune.resources import Resources from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR +from ray.rllib.env.normalize_actions import NormalizeActionWrapper tf = try_import_tf() @@ -101,6 +102,8 @@ COMMON_CONFIG = { "env_config": {}, # Environment name can also be passed via config. "env": None, + # Unsquash actions to the upper and lower bounds of env's action space + "normalize_actions": False, # Whether to clip rewards prior to experience postprocessing. Setting to # None means clip for Atari only. "clip_rewards": None, @@ -504,6 +507,12 @@ class Trainer(Trainable): self._allow_unknown_subkeys) self.raw_user_config = config self.config = merged_config + + if self.config["normalize_actions"]: + inner = self.env_creator + self.env_creator = ( + lambda env_config: NormalizeActionWrapper(inner(env_config))) + Trainer._validate_config(self.config) log_level = self.config.get("log_level") if log_level in ["WARN", "ERROR"]: diff --git a/rllib/env/normalize_actions.py b/rllib/env/normalize_actions.py new file mode 100644 index 000000000..18110aa82 --- /dev/null +++ b/rllib/env/normalize_actions.py @@ -0,0 +1,21 @@ +import gym +from gym import spaces +import numpy as np + + +class NormalizeActionWrapper(gym.ActionWrapper): + """Rescale the action space of the environment.""" + + def action(self, action): + if not isinstance(self.env.action_space, spaces.Box): + return action + + # rescale the action + low, high = self.env.action_space.low, self.env.action_space.high + scaled_action = low + (action + 1.0) * (high - low) / 2.0 + scaled_action = np.clip(scaled_action, low, high) + + return scaled_action + + def reverse_action(self, action): + raise NotImplementedError diff --git a/rllib/optimizers/replay_buffer.py b/rllib/optimizers/replay_buffer.py index 1012a5b76..90e1f87e3 100644 --- a/rllib/optimizers/replay_buffer.py +++ b/rllib/optimizers/replay_buffer.py @@ -70,10 +70,7 @@ class ReplayBuffer(object): @DeveloperAPI def sample_idxes(self, batch_size): - return [ - random.randint(0, - len(self._storage) - 1) for _ in range(batch_size) - ] + return np.random.randint(0, len(self._storage), batch_size) @DeveloperAPI def sample_with_idxes(self, idxes): diff --git a/rllib/tests/test_eager_support.py b/rllib/tests/test_eager_support.py index a00e56add..143d6970e 100644 --- a/rllib/tests/test_eager_support.py +++ b/rllib/tests/test_eager_support.py @@ -64,13 +64,6 @@ class TestEagerSupport(unittest.TestCase): "timesteps_per_iteration": 100 }) - def testSAC(self): - check_support("SAC", { - "num_workers": 0, - "learning_starts": 0, - "timesteps_per_iteration": 100 - }) - if __name__ == "__main__": import pytest diff --git a/rllib/tuned_examples/pendulum-sac.yaml b/rllib/tuned_examples/pendulum-sac.yaml new file mode 100644 index 000000000..f0c28c839 --- /dev/null +++ b/rllib/tuned_examples/pendulum-sac.yaml @@ -0,0 +1,37 @@ +# Pendulum SAC can attain -150+ reward in 6-7k +# Configurations are the similar to original softlearning/sac codebase +pendulum_sac: + env: Pendulum-v0 + run: SAC + stop: + episode_reward_mean: -150 + config: + horizon: 200 + soft_horizon: False + Q_model: + hidden_activation: relu + hidden_layer_sizes: [256, 256] + policy_model: + hidden_activation: relu + hidden_layer_sizes: [256, 256] + tau: 0.005 + target_entropy: auto + no_done_at_end: True + n_step: 1 + sample_batch_size: 1 + prioritized_replay: False + train_batch_size: 256 + target_network_update_freq: 1 + timesteps_per_iteration: 1000 + learning_starts: 256 + exploration_enabled: True + optimization: + actor_learning_rate: 0.0003 + critic_learning_rate: 0.0003 + entropy_learning_rate: 0.0003 + num_workers: 0 + num_gpus: 0 + clip_actions: False + normalize_actions: True + evaluation_interval: 1 + metrics_smoothing_episodes: 5 diff --git a/rllib/tuned_examples/regression_tests/pendulum-sac.yaml b/rllib/tuned_examples/regression_tests/pendulum-sac.yaml index 635afd838..47404fb6e 100644 --- a/rllib/tuned_examples/regression_tests/pendulum-sac.yaml +++ b/rllib/tuned_examples/regression_tests/pendulum-sac.yaml @@ -3,8 +3,10 @@ pendulum-sac: run: SAC stop: episode_reward_mean: -300 # note that evaluation perf is higher - timesteps_total: 15000 + timesteps_total: 10000 config: evaluation_interval: 1 # logged under evaluation/* metric keys - soft_horizon: true - metrics_smoothing_episodes: 10 + soft_horizon: True + clip_actions: False + normalize_actions: True + metrics_smoothing_episodes: 5