ray/rllib/agents/sac/sac_tf_model.py
Sven Mika f7e4dae852
[RLlib] DQN and SAC Atari benchmark fixes. (#7962)
* Add Atari SAC-discrete (learning MsPacman in 40k ts up to 780 rewards).
* SAC loss function test case fix.
2020-04-17 08:49:15 +02:00

198 lines
7.5 KiB
Python

from gym.spaces import Discrete
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class SACTFModel(TFModelV2):
"""Extension of standard TFModel for SAC.
Data flow:
obs -> forward() -> model_out
model_out -> get_policy_output() -> pi(s)
model_out, actions -> get_q_values() -> Q(s, a)
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
actor_hidden_activation="relu",
actor_hiddens=(256, 256),
critic_hidden_activation="relu",
critic_hiddens=(256, 256),
twin_q=False,
initial_alpha=1.0,
target_entropy=None):
"""Initialize variables of this model.
Extra model kwargs:
actor_hidden_activation (str): activation for actor network
actor_hiddens (list): hidden layers sizes for actor network
critic_hidden_activation (str): activation for critic network
critic_hiddens (list): hidden layers sizes for critic network
twin_q (bool): build twin Q networks.
initial_alpha (float): The initial value for the to-be-optimized
alpha parameter (default: 1.0).
Note that the core layers for forward() are not defined here, this
only defines the layers for the output heads. Those layers for
forward() should be defined in subclasses of SACModel.
"""
super(SACTFModel, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
if isinstance(action_space, Discrete):
self.action_dim = action_space.n
self.discrete = True
action_outs = q_outs = self.action_dim
else:
self.action_dim = np.product(action_space.shape)
self.discrete = False
action_outs = 2 * self.action_dim
q_outs = 1
self.model_out = tf.keras.layers.Input(
shape=(self.num_outputs, ), name="model_out")
self.action_model = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
activation=getattr(tf.nn, actor_hidden_activation, None),
name="action_{}".format(i + 1))
for i, hidden in enumerate(actor_hiddens)
] + [
tf.keras.layers.Dense(
units=action_outs, activation=None, name="action_out")
])
self.shift_and_log_scale_diag = self.action_model(self.model_out)
self.register_variables(self.action_model.variables)
self.actions_input = None
if not self.discrete:
self.actions_input = tf.keras.layers.Input(
shape=(self.action_dim, ), name="actions")
def build_q_net(name, observations, actions):
# For continuous actions: Feed obs and actions (concatenated)
# through the NN. For discrete actions, only obs.
q_net = tf.keras.Sequential(([
tf.keras.layers.Concatenate(axis=1),
] if not self.discrete else []) + [
tf.keras.layers.Dense(
units=units,
activation=getattr(tf.nn, critic_hidden_activation, None),
name="{}_hidden_{}".format(name, i))
for i, units in enumerate(critic_hiddens)
] + [
tf.keras.layers.Dense(
units=q_outs, activation=None, name="{}_out".format(name))
])
# TODO(hartikainen): Remove the unnecessary Model calls here
if self.discrete:
q_net = tf.keras.Model(observations, q_net(observations))
else:
q_net = tf.keras.Model([observations, actions],
q_net([observations, actions]))
return q_net
self.q_net = build_q_net("q", self.model_out, self.actions_input)
self.register_variables(self.q_net.variables)
if twin_q:
self.twin_q_net = build_q_net("twin_q", self.model_out,
self.actions_input)
self.register_variables(self.twin_q_net.variables)
else:
self.twin_q_net = None
self.log_alpha = tf.Variable(
np.log(initial_alpha), dtype=tf.float32, name="log_alpha")
self.alpha = tf.exp(self.log_alpha)
# Auto-calculate the target entropy.
if target_entropy is None or target_entropy == "auto":
# See hyperparams in [2] (README.md).
if self.discrete:
target_entropy = 0.98 * np.array(
-np.log(1.0 / action_space.n), dtype=np.float32)
# See [1] (README.md).
else:
target_entropy = -np.prod(action_space.shape)
self.target_entropy = target_entropy
self.register_variables([self.log_alpha])
def get_q_values(self, model_out, actions=None):
"""Return the Q estimates for the most recent forward pass.
This implements Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Optional[Tensor]): Actions to return the Q-values for.
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
case), return Q-values for all actions.
Returns:
tensor of shape [BATCH_SIZE].
"""
if actions is not None:
return self.q_net([model_out, actions])
else:
return self.q_net(model_out)
def get_twin_q_values(self, model_out, actions=None):
"""Same as get_q_values but using the twin Q net.
This implements the twin Q(s, a).
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
actions (Optional[Tensor]): Actions to return the Q-values for.
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
case), return Q-values for all actions.
Returns:
tensor of shape [BATCH_SIZE].
"""
if actions is not None:
return self.twin_q_net([model_out, actions])
else:
return self.twin_q_net(model_out)
def get_policy_output(self, model_out):
"""Return the action output for the most recent forward pass.
This outputs the support for pi(s). For continuous action spaces, this
is the action directly. For discrete, is is the mean / std dev.
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
Returns:
tensor of shape [BATCH_SIZE, action_out_size]
"""
return self.action_model(model_out)
def policy_variables(self):
"""Return the list of variables for the policy net."""
return list(self.action_model.variables)
def q_variables(self):
"""Return the list of variables for Q / twin Q nets."""
return self.q_net.variables + (self.twin_q_net.variables
if self.twin_q_net else [])