2020-03-06 19:37:12 +01:00
|
|
|
from gym.spaces import Discrete
|
2019-08-01 23:37:36 -07:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
2020-03-06 19:37:12 +01:00
|
|
|
from ray.rllib.utils import try_import_tf
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
tf = try_import_tf()
|
|
|
|
|
|
|
|
SCALE_DIAG_MIN_MAX = (-20, 2)
|
|
|
|
|
|
|
|
|
|
|
|
class SACModel(TFModelV2):
|
|
|
|
"""Extension of standard TFModel for SAC.
|
|
|
|
|
|
|
|
Data flow:
|
|
|
|
obs -> forward() -> model_out
|
|
|
|
model_out -> get_policy_output() -> pi(s)
|
|
|
|
model_out, actions -> get_q_values() -> Q(s, a)
|
|
|
|
model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
|
|
|
|
|
|
|
|
Note that this class by itself is not a valid model unless you
|
|
|
|
implement forward() in a subclass."""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
num_outputs,
|
|
|
|
model_config,
|
|
|
|
name,
|
|
|
|
actor_hidden_activation="relu",
|
|
|
|
actor_hiddens=(256, 256),
|
|
|
|
critic_hidden_activation="relu",
|
|
|
|
critic_hiddens=(256, 256),
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q=False,
|
|
|
|
initial_alpha=1.0):
|
2019-08-01 23:37:36 -07:00
|
|
|
"""Initialize variables of this model.
|
|
|
|
|
|
|
|
Extra model kwargs:
|
|
|
|
actor_hidden_activation (str): activation for actor network
|
|
|
|
actor_hiddens (list): hidden layers sizes for actor network
|
|
|
|
critic_hidden_activation (str): activation for critic network
|
|
|
|
critic_hiddens (list): hidden layers sizes for critic network
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q (bool): build twin Q networks.
|
|
|
|
initial_alpha (float): The initial value for the to-be-optimized
|
|
|
|
alpha parameter (default: 1.0).
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
Note that the core layers for forward() are not defined here, this
|
|
|
|
only defines the layers for the output heads. Those layers for
|
|
|
|
forward() should be defined in subclasses of SACModel.
|
|
|
|
"""
|
|
|
|
super(SACModel, self).__init__(obs_space, action_space, num_outputs,
|
|
|
|
model_config, name)
|
2020-03-06 19:37:12 +01:00
|
|
|
self.discrete = False
|
|
|
|
if isinstance(action_space, Discrete):
|
|
|
|
self.action_dim = action_space.n
|
|
|
|
self.discrete = True
|
|
|
|
action_outs = q_outs = self.action_dim
|
|
|
|
else:
|
|
|
|
self.action_dim = np.product(action_space.shape)
|
|
|
|
action_outs = 2 * self.action_dim
|
|
|
|
q_outs = 1
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
self.model_out = tf.keras.layers.Input(
|
|
|
|
shape=(num_outputs, ), name="model_out")
|
2020-02-24 01:10:20 +01:00
|
|
|
self.action_model = tf.keras.Sequential([
|
2019-08-01 23:37:36 -07:00
|
|
|
tf.keras.layers.Dense(
|
|
|
|
units=hidden,
|
2020-02-22 23:19:49 +01:00
|
|
|
activation=getattr(tf.nn, actor_hidden_activation, None),
|
2020-02-24 01:10:20 +01:00
|
|
|
name="action_{}".format(i + 1))
|
2019-08-01 23:37:36 -07:00
|
|
|
for i, hidden in enumerate(actor_hiddens)
|
|
|
|
] + [
|
|
|
|
tf.keras.layers.Dense(
|
2020-03-06 19:37:12 +01:00
|
|
|
units=action_outs, activation=None, name="action_out")
|
2020-02-24 01:10:20 +01:00
|
|
|
])
|
|
|
|
self.shift_and_log_scale_diag = self.action_model(self.model_out)
|
|
|
|
|
|
|
|
self.register_variables(self.action_model.variables)
|
2019-12-20 10:51:25 -08:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
self.actions_input = None
|
|
|
|
if not self.discrete:
|
|
|
|
self.actions_input = tf.keras.layers.Input(
|
|
|
|
shape=(self.action_dim, ), name="actions")
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
def build_q_net(name, observations, actions):
|
2020-03-06 19:37:12 +01:00
|
|
|
# For continuous actions: Feed obs and actions (concatenated)
|
|
|
|
# through the NN. For discrete actions, only obs.
|
|
|
|
q_net = tf.keras.Sequential(([
|
2019-08-01 23:37:36 -07:00
|
|
|
tf.keras.layers.Concatenate(axis=1),
|
2020-03-06 19:37:12 +01:00
|
|
|
] if not self.discrete else []) + [
|
2019-08-01 23:37:36 -07:00
|
|
|
tf.keras.layers.Dense(
|
|
|
|
units=units,
|
2020-03-06 19:37:12 +01:00
|
|
|
activation=getattr(tf.nn, critic_hidden_activation, None),
|
2019-08-01 23:37:36 -07:00
|
|
|
name="{}_hidden_{}".format(name, i))
|
|
|
|
for i, units in enumerate(critic_hiddens)
|
|
|
|
] + [
|
|
|
|
tf.keras.layers.Dense(
|
2020-03-06 19:37:12 +01:00
|
|
|
units=q_outs, activation=None, name="{}_out".format(name))
|
2019-08-01 23:37:36 -07:00
|
|
|
])
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# TODO(hartikainen): Remove the unnecessary Model calls here
|
|
|
|
if self.discrete:
|
|
|
|
q_net = tf.keras.Model(observations, q_net(observations))
|
|
|
|
else:
|
|
|
|
q_net = tf.keras.Model([observations, actions],
|
|
|
|
q_net([observations, actions]))
|
2019-08-01 23:37:36 -07:00
|
|
|
return q_net
|
|
|
|
|
2020-02-24 01:10:20 +01:00
|
|
|
self.q_net = build_q_net("q", self.model_out, self.actions_input)
|
2019-08-01 23:37:36 -07:00
|
|
|
self.register_variables(self.q_net.variables)
|
|
|
|
|
|
|
|
if twin_q:
|
|
|
|
self.twin_q_net = build_q_net("twin_q", self.model_out,
|
2020-02-24 01:10:20 +01:00
|
|
|
self.actions_input)
|
2019-08-01 23:37:36 -07:00
|
|
|
self.register_variables(self.twin_q_net.variables)
|
|
|
|
else:
|
|
|
|
self.twin_q_net = None
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
self.log_alpha = tf.Variable(
|
|
|
|
np.log(initial_alpha), dtype=tf.float32, name="log_alpha")
|
2019-08-01 23:37:36 -07:00
|
|
|
self.alpha = tf.exp(self.log_alpha)
|
|
|
|
|
|
|
|
self.register_variables([self.log_alpha])
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
def get_q_values(self, model_out, actions=None):
|
2019-08-01 23:37:36 -07:00
|
|
|
"""Return the Q estimates for the most recent forward pass.
|
|
|
|
|
|
|
|
This implements Q(s, a).
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
model_out (Tensor): obs embeddings from the model layers, of shape
|
|
|
|
[BATCH_SIZE, num_outputs].
|
2020-03-06 19:37:12 +01:00
|
|
|
actions (Optional[Tensor]): Actions to return the Q-values for.
|
|
|
|
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
|
|
|
|
case), return Q-values for all actions.
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
tensor of shape [BATCH_SIZE].
|
|
|
|
"""
|
2020-03-06 19:37:12 +01:00
|
|
|
if actions is not None:
|
|
|
|
return self.q_net([model_out, actions])
|
|
|
|
else:
|
|
|
|
return self.q_net(model_out)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
def get_twin_q_values(self, model_out, actions=None):
|
2019-08-01 23:37:36 -07:00
|
|
|
"""Same as get_q_values but using the twin Q net.
|
|
|
|
|
|
|
|
This implements the twin Q(s, a).
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
model_out (Tensor): obs embeddings from the model layers, of shape
|
|
|
|
[BATCH_SIZE, num_outputs].
|
2020-03-06 19:37:12 +01:00
|
|
|
actions (Optional[Tensor]): Actions to return the Q-values for.
|
|
|
|
Shape: [BATCH_SIZE, action_dim]. If None (discrete action
|
|
|
|
case), return Q-values for all actions.
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
tensor of shape [BATCH_SIZE].
|
|
|
|
"""
|
2020-03-06 19:37:12 +01:00
|
|
|
if actions is not None:
|
|
|
|
return self.twin_q_net([model_out, actions])
|
|
|
|
else:
|
|
|
|
return self.twin_q_net(model_out)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-20 12:44:04 -07:00
|
|
|
def get_policy_output(self, model_out):
|
|
|
|
"""Return the action output for the most recent forward pass.
|
|
|
|
|
|
|
|
This outputs the support for pi(s). For continuous action spaces, this
|
|
|
|
is the action directly. For discrete, is is the mean / std dev.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
model_out (Tensor): obs embeddings from the model layers, of shape
|
|
|
|
[BATCH_SIZE, num_outputs].
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tensor of shape [BATCH_SIZE, action_out_size]
|
|
|
|
"""
|
|
|
|
return self.action_model(model_out)
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
def policy_variables(self):
|
|
|
|
"""Return the list of variables for the policy net."""
|
|
|
|
|
2020-02-24 01:10:20 +01:00
|
|
|
return list(self.action_model.variables)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
def q_variables(self):
|
|
|
|
"""Return the list of variables for Q / twin Q nets."""
|
|
|
|
|
|
|
|
return self.q_net.variables + (self.twin_q_net.variables
|
|
|
|
if self.twin_q_net else [])
|