ray/rllib/agents/sac/sac_model.py
Sven Mika 510c850651
[RLlib] SAC add discrete action support. (#7320)
* Exploration API (+EpsilonGreedy sub-class).

* Exploration API (+EpsilonGreedy sub-class).

* Cleanup/LINT.

* Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents).

* Add `error` option to deprecation_warning().

* WIP.

* Bug fix: Get exploration-info for tf framework.
Bug fix: Properly deprecate some DQN config keys.

* WIP.

* LINT.

* WIP.

* Split PerWorkerEpsilonGreedy out of EpsilonGreedy.
Docstrings.

* Fix bug in sampler.py in case Policy has self.exploration = None

* Update rllib/agents/dqn/dqn.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Update rllib/agents/trainer.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Change requests.

* LINT

* In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set

* Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps).

* Update rllib/evaluation/worker_set.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Review fixes.

* Fix default value for DQN's exploration spec.

* LINT

* Fix recursion bug (wrong parent c'tor).

* Do not pass timestep to get_exploration_info.

* Update tf_policy.py

* Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs.

* Bug fix tf-action-dist

* DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG).

* Switch off exploration when getting action probs from off-policy-estimator's policy.

* LINT

* Fix test_checkpoint_restore.py.

* Deprecate all SAC exploration (unused) configs.

* Properly use `model.last_output()` everywhere. Instead of `model._last_output`.

* WIP.

* Take out set_epsilon from multi-agent-env test (not needed, decays anyway).

* WIP.

* Trigger re-test (flaky checkpoint-restore test).

* WIP.

* WIP.

* Add test case for deterministic action sampling in PPO.

* bug fix.

* Added deterministic test cases for different Agents.

* Fix problem with TupleActions in dynamic-tf-policy.

* Separate supported_spaces tests so they can be run separately for easier debugging.

* LINT.

* Fix autoregressive_action_dist.py test case.

* Re-test.

* Fix.

* Remove duplicate py_test rule from bazel.

* LINT.

* WIP.

* WIP.

* SAC fix.

* SAC fix.

* WIP.

* WIP.

* WIP.

* FIX 2 examples tests.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Renamed test file.

* WIP.

* Add unittest.main.

* Make action_dist_class mandatory.

* fix

* FIX.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix explorations test case (contextlib cannot find its own nullcontext??).

* Force torch to be installed for QMIX.

* LINT.

* Fix determine_tests_to_run.py.

* Fix determine_tests_to_run.py.

* WIP

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Rename some stuff.

* Rename some stuff.

* WIP.

* update.

* WIP.

* Gumbel Softmax Dist.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP

* WIP.

* WIP.

* Hypertune.

* Hypertune.

* Hypertune.

* Lock-in.

* Cleanup.

* LINT.

* Fix.

* Update rllib/policy/eager_tf_policy.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rllib/agents/sac/sac_policy.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rllib/agents/sac/sac_policy.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rllib/models/tf/tf_action_dist.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rllib/models/tf/tf_action_dist.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Fix items from review comments.

* Add dm_tree to RLlib dependencies.

* Add dm_tree to RLlib dependencies.

* Fix DQN test cases ((Torch)Categorical).

* Fix wrong pip install.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
2020-03-06 10:37:12 -08:00

173 lines
6.5 KiB
Python

from gym.spaces import Discrete
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
SCALE_DIAG_MIN_MAX = (-20, 2)
class SACModel(TFModelV2):
"""Extension of standard TFModel for SAC.
Data flow:
obs -> forward() -> model_out
model_out -> get_policy_output() -> pi(s)
model_out, actions -> get_q_values() -> Q(s, a)
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
actor_hidden_activation="relu",
actor_hiddens=(256, 256),
critic_hidden_activation="relu",
critic_hiddens=(256, 256),
twin_q=False,
initial_alpha=1.0):
"""Initialize variables of this model.
Extra model kwargs:
actor_hidden_activation (str): activation for actor network
actor_hiddens (list): hidden layers sizes for actor network
critic_hidden_activation (str): activation for critic network
critic_hiddens (list): hidden layers sizes for critic network
twin_q (bool): build twin Q networks.
initial_alpha (float): The initial value for the to-be-optimized
alpha parameter (default: 1.0).
Note that the core layers for forward() are not defined here, this
only defines the layers for the output heads. Those layers for
forward() should be defined in subclasses of SACModel.
"""
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.discrete = False
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
self.discrete = True
action_outs = q_outs = self.action_dim
else:
self.action_dim = np.product(action_space.shape)
action_outs = 2 * self.action_dim
q_outs = 1
self.model_out = tf.keras.layers.Input(
shape=(num_outputs, ), name="model_out")
self.action_model = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
activation=getattr(tf.nn, actor_hidden_activation, None),
name="action_{}".format(i + 1))
for i, hidden in enumerate(actor_hiddens)
] + [
tf.keras.layers.Dense(
units=action_outs, activation=None, name="action_out")
])
self.shift_and_log_scale_diag = self.action_model(self.model_out)
self.register_variables(self.action_model.variables)
self.actions_input = None
if not self.discrete:
self.actions_input = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")
def build_q_net(name, observations, actions):
# For continuous actions: Feed obs and actions (concatenated)
# through the NN. For discrete actions, only obs.
q_net = tf.keras.Sequential(([
tf.keras.layers.Concatenate(axis=1),
] if not self.discrete else []) + [
tf.keras.layers.Dense(
units=units,
activation=getattr(tf.nn, critic_hidden_activation, None),
name="{}_hidden_{}".format(name, i))
for i, units in enumerate(critic_hiddens)
] + [
tf.keras.layers.Dense(
units=q_outs, activation=None, name="{}_out".format(name))
])
# TODO(hartikainen): Remove the unnecessary Model calls here
if self.discrete:
q_net = tf.keras.Model(observations, q_net(observations))
else:
q_net = tf.keras.Model([observations, actions],
q_net([observations, actions]))
return q_net
self.q_net = build_q_net("q", self.model_out, self.actions_input)
self.register_variables(self.q_net.variables)
if twin_q:
self.twin_q_net = build_q_net("twin_q", self.model_out,
self.actions_input)
self.register_variables(self.twin_q_net.variables)
else:
self.twin_q_net = None
self.log_alpha = tf.Variable(
np.log(initial_alpha), dtype=tf.float32, name="log_alpha")
self.alpha = tf.exp(self.log_alpha)
self.register_variables([self.log_alpha])
def get_q_values(self, model_out, actions=None):
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Optional[Tensor]): Actions to return the Q-values for.
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
case), return Q-values for all actions.
Returns:
tensor of shape [BATCH_SIZE].
"""
if actions is not None:
return self.q_net([model_out, actions])
else:
return self.q_net(model_out)
def get_twin_q_values(self, model_out, actions=None):
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Optional[Tensor]): Actions to return the Q-values for.
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
case), return Q-values for all actions.
Returns:
tensor of shape [BATCH_SIZE].
"""
if actions is not None:
return self.twin_q_net([model_out, actions])
else:
return self.twin_q_net(model_out)
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_model.variables)
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""
return self.q_net.variables + (self.twin_q_net.variables
if self.twin_q_net else [])