2020-04-15 13:25:16 +02:00
|
|
|
from gym.spaces import Box, Discrete
|
2019-08-01 23:37:36 -07:00
|
|
|
import logging
|
|
|
|
|
|
|
|
import ray
|
|
|
|
import ray.experimental.tf_utils
|
2020-04-16 10:20:01 +02:00
|
|
|
from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
|
2020-04-09 23:04:21 +02:00
|
|
|
TargetNetworkMixin
|
|
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.agents.sac.sac_tf_model import SACTFModel
|
|
|
|
from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
|
2019-08-01 23:37:36 -07:00
|
|
|
from ray.rllib.models import ModelCatalog
|
2020-04-30 20:09:33 +02:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
|
|
|
|
DiagGaussian, SquashedGaussian
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2020-03-06 19:37:12 +01:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
|
2020-04-09 23:04:21 +02:00
|
|
|
from ray.rllib.utils.tf_ops import minimize_and_clip
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
tf = try_import_tf()
|
|
|
|
tfp = try_import_tfp()
|
2020-02-24 01:10:20 +01:00
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def build_sac_model(policy, obs_space, action_space, config):
|
2020-04-15 13:25:16 +02:00
|
|
|
# 2 cases:
|
|
|
|
# 1) with separate state-preprocessor (before obs+action concat).
|
|
|
|
# 2) no separate state-preprocessor: concat obs+actions right away.
|
2019-08-01 23:37:36 -07:00
|
|
|
if config["use_state_preprocessor"]:
|
2020-04-15 13:25:16 +02:00
|
|
|
num_outputs = 256 # Flatten last Conv2D to this many nodes.
|
2019-08-01 23:37:36 -07:00
|
|
|
else:
|
2020-04-15 13:25:16 +02:00
|
|
|
num_outputs = 0
|
2020-04-17 08:49:15 +02:00
|
|
|
# No state preprocessor: fcnet_hiddens should be empty.
|
|
|
|
if config["model"]["fcnet_hiddens"]:
|
|
|
|
logger.warning(
|
|
|
|
"When not using a state-preprocessor with SAC, `fcnet_hiddens`"
|
|
|
|
" will be set to an empty list! Any hidden layer sizes are "
|
|
|
|
"defined via `policy_model.hidden_layer_sizes` and "
|
|
|
|
"`Q_model.hidden_layer_sizes`.")
|
|
|
|
config["model"]["fcnet_hiddens"] = []
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
# Force-ignore any additionally provided hidden layer sizes.
|
|
|
|
# Everything should be configured using SAC's "Q_model" and "policy_model"
|
|
|
|
# settings.
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.model = ModelCatalog.get_model_v2(
|
2020-04-15 13:25:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
|
|
|
model_interface=SACTorchModel
|
|
|
|
if config["framework"] == "torch" else SACTFModel,
|
2019-08-01 23:37:36 -07:00
|
|
|
name="sac_model",
|
2020-04-15 13:25:16 +02:00
|
|
|
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
|
|
|
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
|
|
|
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
|
|
|
|
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q=config["twin_q"],
|
2020-04-15 13:25:16 +02:00
|
|
|
initial_alpha=config["initial_alpha"],
|
|
|
|
target_entropy=config["target_entropy"])
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
2020-04-15 13:25:16 +02:00
|
|
|
obs_space=obs_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
|
|
|
model_interface=SACTorchModel
|
|
|
|
if config["framework"] == "torch" else SACTFModel,
|
2019-08-01 23:37:36 -07:00
|
|
|
name="target_sac_model",
|
2020-04-15 13:25:16 +02:00
|
|
|
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
|
|
|
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
|
|
|
critic_hidden_activation=config["Q_model"]["fcnet_activation"],
|
|
|
|
critic_hiddens=config["Q_model"]["fcnet_hiddens"],
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_q=config["twin_q"],
|
2020-04-15 13:25:16 +02:00
|
|
|
initial_alpha=config["initial_alpha"],
|
|
|
|
target_entropy=config["target_entropy"])
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
return policy.model
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess_trajectory(policy,
|
|
|
|
sample_batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
2020-02-11 00:22:07 +01:00
|
|
|
return postprocess_nstep_and_prio(policy, sample_batch)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2020-02-24 01:10:20 +01:00
|
|
|
def get_dist_class(config, action_space):
|
2020-03-06 19:37:12 +01:00
|
|
|
if isinstance(action_space, Discrete):
|
2020-04-30 20:09:33 +02:00
|
|
|
return Categorical
|
2020-03-06 19:37:12 +01:00
|
|
|
else:
|
2020-04-30 20:09:33 +02:00
|
|
|
if config["normalize_actions"]:
|
|
|
|
return SquashedGaussian if \
|
|
|
|
not config["_use_beta_distribution"] else Beta
|
|
|
|
else:
|
|
|
|
return DiagGaussian
|
2020-02-22 23:19:49 +01:00
|
|
|
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
def get_distribution_inputs_and_class(policy,
|
|
|
|
model,
|
|
|
|
obs_batch,
|
|
|
|
*,
|
|
|
|
explore=True,
|
|
|
|
**kwargs):
|
|
|
|
# Get base-model output.
|
|
|
|
model_out, state_out = model({
|
|
|
|
"obs": obs_batch,
|
2019-08-01 23:37:36 -07:00
|
|
|
"is_training": policy._get_is_training_placeholder(),
|
|
|
|
}, [], None)
|
2020-04-01 09:43:21 +02:00
|
|
|
# Get action model output from base-model output.
|
2020-03-20 12:44:04 -07:00
|
|
|
distribution_inputs = model.get_policy_output(model_out)
|
2020-04-01 09:43:21 +02:00
|
|
|
action_dist_class = get_dist_class(policy.config, policy.action_space)
|
|
|
|
return distribution_inputs, action_dist_class, state_out
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
def sac_actor_critic_loss(policy, model, _, train_batch):
|
2020-04-15 13:25:16 +02:00
|
|
|
# Should be True only for debugging purposes (e.g. test cases)!
|
|
|
|
deterministic = policy.config["_deterministic_loss"]
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
model_out_t, _ = model({
|
|
|
|
"obs": train_batch[SampleBatch.CUR_OBS],
|
2019-08-01 23:37:36 -07:00
|
|
|
"is_training": policy._get_is_training_placeholder(),
|
|
|
|
}, [], None)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
model_out_tp1, _ = model({
|
|
|
|
"obs": train_batch[SampleBatch.NEXT_OBS],
|
2019-08-01 23:37:36 -07:00
|
|
|
"is_training": policy._get_is_training_placeholder(),
|
|
|
|
}, [], None)
|
|
|
|
|
|
|
|
target_model_out_tp1, _ = policy.target_model({
|
2019-08-23 02:21:11 -04:00
|
|
|
"obs": train_batch[SampleBatch.NEXT_OBS],
|
2019-08-01 23:37:36 -07:00
|
|
|
"is_training": policy._get_is_training_placeholder(),
|
|
|
|
}, [], None)
|
2020-02-24 01:10:20 +01:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Discrete case.
|
|
|
|
if model.discrete:
|
|
|
|
# Get all action probs directly from pi and form their logp.
|
2020-03-20 12:44:04 -07:00
|
|
|
log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1)
|
2020-03-06 19:37:12 +01:00
|
|
|
policy_t = tf.exp(log_pis_t)
|
2020-03-20 12:44:04 -07:00
|
|
|
log_pis_tp1 = tf.nn.log_softmax(
|
|
|
|
model.get_policy_output(model_out_tp1), -1)
|
2020-03-06 19:37:12 +01:00
|
|
|
policy_tp1 = tf.exp(log_pis_tp1)
|
|
|
|
# Q-values.
|
|
|
|
q_t = model.get_q_values(model_out_t)
|
|
|
|
# Target Q-values.
|
|
|
|
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t = model.get_twin_q_values(model_out_t)
|
|
|
|
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
|
|
|
target_model_out_tp1)
|
|
|
|
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
|
|
|
|
q_tp1 -= model.alpha * log_pis_tp1
|
|
|
|
|
|
|
|
# Actually selected Q-values (from the actions batch).
|
|
|
|
one_hot = tf.one_hot(
|
|
|
|
train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1])
|
|
|
|
q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1)
|
|
|
|
# Discrete case: "Best" means weighted by the policy (prob) outputs.
|
|
|
|
q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1)
|
|
|
|
q_tp1_best_masked = \
|
|
|
|
(1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
|
|
|
|
q_tp1_best
|
|
|
|
# Continuous actions case.
|
|
|
|
else:
|
|
|
|
# Sample simgle actions from distribution.
|
|
|
|
action_dist_class = get_dist_class(policy.config, policy.action_space)
|
|
|
|
action_dist_t = action_dist_class(
|
2020-03-20 12:44:04 -07:00
|
|
|
model.get_policy_output(model_out_t), policy.model)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_t = action_dist_t.sample() if not deterministic else \
|
|
|
|
action_dist_t.deterministic_sample()
|
|
|
|
log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1)
|
2020-03-06 19:37:12 +01:00
|
|
|
action_dist_tp1 = action_dist_class(
|
2020-03-20 12:44:04 -07:00
|
|
|
model.get_policy_output(model_out_tp1), policy.model)
|
2020-04-15 13:25:16 +02:00
|
|
|
policy_tp1 = action_dist_tp1.sample() if not deterministic else \
|
|
|
|
action_dist_tp1.deterministic_sample()
|
|
|
|
log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1)
|
2020-03-06 19:37:12 +01:00
|
|
|
|
|
|
|
# Q-values for the actually selected actions.
|
|
|
|
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t = model.get_twin_q_values(
|
|
|
|
model_out_t, train_batch[SampleBatch.ACTIONS])
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Q-values for current policy in given current state.
|
|
|
|
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t_det_policy = model.get_twin_q_values(
|
|
|
|
model_out_t, policy_t)
|
|
|
|
q_t_det_policy = tf.reduce_min(
|
|
|
|
(q_t_det_policy, twin_q_t_det_policy), axis=0)
|
|
|
|
|
|
|
|
# target q network evaluation
|
|
|
|
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
|
|
|
|
policy_tp1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
|
|
|
target_model_out_tp1, policy_tp1)
|
2020-04-15 13:25:16 +02:00
|
|
|
# Take min over both twin-NNs.
|
|
|
|
q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
|
|
|
q_tp1 -= model.alpha * log_pis_tp1
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
|
|
|
q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES],
|
|
|
|
tf.float32)) * q_tp1_best
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
assert policy.config["n_step"] == 1, "TODO(hartikainen) n_step > 1"
|
|
|
|
|
|
|
|
# compute RHS of bellman equation
|
|
|
|
q_t_selected_target = tf.stop_gradient(
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch[SampleBatch.REWARDS] +
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
|
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Compute the TD-error (potentially clipped).
|
|
|
|
base_td_error = tf.abs(q_t_selected - q_t_selected_target)
|
2019-08-01 23:37:36 -07:00
|
|
|
if policy.config["twin_q"]:
|
2020-03-06 19:37:12 +01:00
|
|
|
twin_td_error = tf.abs(twin_q_t_selected - q_t_selected_target)
|
|
|
|
td_error = 0.5 * (base_td_error + twin_td_error)
|
2019-08-01 23:37:36 -07:00
|
|
|
else:
|
2020-03-06 19:37:12 +01:00
|
|
|
td_error = base_td_error
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2019-12-20 10:51:25 -08:00
|
|
|
critic_loss = [
|
|
|
|
tf.losses.mean_squared_error(
|
|
|
|
labels=q_t_selected_target, predictions=q_t_selected, weights=0.5)
|
|
|
|
]
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
critic_loss.append(
|
|
|
|
tf.losses.mean_squared_error(
|
|
|
|
labels=q_t_selected_target,
|
|
|
|
predictions=twin_q_t_selected,
|
|
|
|
weights=0.5))
|
2019-08-01 23:37:36 -07:00
|
|
|
|
2020-03-06 19:37:12 +01:00
|
|
|
# Alpha- and actor losses.
|
|
|
|
# Note: In the papers, alpha is used directly, here we take the log.
|
|
|
|
# Discrete case: Multiply the action probs as weights with the original
|
|
|
|
# loss terms (no expectations needed).
|
|
|
|
if model.discrete:
|
|
|
|
alpha_loss = tf.reduce_mean(
|
|
|
|
tf.reduce_sum(
|
|
|
|
tf.multiply(
|
|
|
|
tf.stop_gradient(policy_t), -model.log_alpha *
|
2020-04-15 13:25:16 +02:00
|
|
|
tf.stop_gradient(log_pis_t + model.target_entropy)),
|
2020-03-06 19:37:12 +01:00
|
|
|
axis=-1))
|
|
|
|
actor_loss = tf.reduce_mean(
|
|
|
|
tf.reduce_sum(
|
|
|
|
tf.multiply(
|
|
|
|
# NOTE: No stop_grad around policy output here
|
|
|
|
# (compare with q_t_det_policy for continuous case).
|
|
|
|
policy_t,
|
|
|
|
model.alpha * log_pis_t - tf.stop_gradient(q_t)),
|
|
|
|
axis=-1))
|
|
|
|
else:
|
|
|
|
alpha_loss = -tf.reduce_mean(
|
2020-04-15 13:25:16 +02:00
|
|
|
model.log_alpha *
|
|
|
|
tf.stop_gradient(log_pis_t + model.target_entropy))
|
2020-03-06 19:37:12 +01:00
|
|
|
actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy)
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
# save for stats function
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.policy_t = policy_t
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.q_t = q_t
|
|
|
|
policy.td_error = td_error
|
|
|
|
policy.actor_loss = actor_loss
|
|
|
|
policy.critic_loss = critic_loss
|
|
|
|
policy.alpha_loss = alpha_loss
|
2020-03-06 19:37:12 +01:00
|
|
|
policy.alpha_value = model.alpha
|
2020-04-15 13:25:16 +02:00
|
|
|
policy.target_entropy = model.target_entropy
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
# in a custom apply op we handle the losses separately, but return them
|
|
|
|
# combined in one loss for now
|
2019-12-20 10:51:25 -08:00
|
|
|
return actor_loss + tf.add_n(critic_loss) + alpha_loss
|
2019-08-01 23:37:36 -07:00
|
|
|
|
|
|
|
|
|
|
|
def gradients(policy, optimizer, loss):
|
2020-04-15 13:25:16 +02:00
|
|
|
if policy.config["grad_clip"]:
|
2019-08-01 23:37:36 -07:00
|
|
|
actor_grads_and_vars = minimize_and_clip(
|
2020-04-09 23:04:21 +02:00
|
|
|
optimizer, # isn't optimizer not well defined here (which one)?
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.actor_loss,
|
|
|
|
var_list=policy.model.policy_variables(),
|
2020-04-15 13:25:16 +02:00
|
|
|
clip_val=policy.config["grad_clip"])
|
2019-12-20 10:51:25 -08:00
|
|
|
if policy.config["twin_q"]:
|
|
|
|
q_variables = policy.model.q_variables()
|
|
|
|
half_cutoff = len(q_variables) // 2
|
|
|
|
critic_grads_and_vars = []
|
|
|
|
critic_grads_and_vars += minimize_and_clip(
|
|
|
|
optimizer,
|
|
|
|
policy.critic_loss[0],
|
|
|
|
var_list=q_variables[:half_cutoff],
|
2020-04-15 13:25:16 +02:00
|
|
|
clip_val=policy.config["grad_clip"])
|
2019-12-20 10:51:25 -08:00
|
|
|
critic_grads_and_vars += minimize_and_clip(
|
|
|
|
optimizer,
|
|
|
|
policy.critic_loss[1],
|
|
|
|
var_list=q_variables[half_cutoff:],
|
2020-04-15 13:25:16 +02:00
|
|
|
clip_val=policy.config["grad_clip"])
|
2019-12-20 10:51:25 -08:00
|
|
|
else:
|
|
|
|
critic_grads_and_vars = minimize_and_clip(
|
|
|
|
optimizer,
|
|
|
|
policy.critic_loss[0],
|
|
|
|
var_list=policy.model.q_variables(),
|
2020-04-15 13:25:16 +02:00
|
|
|
clip_val=policy.config["grad_clip"])
|
2019-08-01 23:37:36 -07:00
|
|
|
alpha_grads_and_vars = minimize_and_clip(
|
2019-08-23 02:21:11 -04:00
|
|
|
optimizer,
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.alpha_loss,
|
2019-08-23 02:21:11 -04:00
|
|
|
var_list=[policy.model.log_alpha],
|
2020-04-15 13:25:16 +02:00
|
|
|
clip_val=policy.config["grad_clip"])
|
2019-08-01 23:37:36 -07:00
|
|
|
else:
|
2019-12-20 10:51:25 -08:00
|
|
|
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
2019-08-01 23:37:36 -07:00
|
|
|
policy.actor_loss, var_list=policy.model.policy_variables())
|
2019-12-20 10:51:25 -08:00
|
|
|
if policy.config["twin_q"]:
|
|
|
|
q_variables = policy.model.q_variables()
|
|
|
|
half_cutoff = len(q_variables) // 2
|
|
|
|
base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
|
|
|
|
critic_grads_and_vars = base_q_optimizer.compute_gradients(
|
|
|
|
policy.critic_loss[0], var_list=q_variables[:half_cutoff]
|
|
|
|
) + twin_q_optimizer.compute_gradients(
|
|
|
|
policy.critic_loss[1], var_list=q_variables[half_cutoff:])
|
|
|
|
else:
|
|
|
|
critic_grads_and_vars = policy._critic_optimizer[
|
|
|
|
0].compute_gradients(
|
|
|
|
policy.critic_loss[0], var_list=policy.model.q_variables())
|
|
|
|
alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
|
2019-08-23 02:21:11 -04:00
|
|
|
policy.alpha_loss, var_list=[policy.model.log_alpha])
|
2019-12-20 10:51:25 -08:00
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
# save these for later use in build_apply_op
|
|
|
|
policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars
|
|
|
|
if g is not None]
|
|
|
|
policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars
|
|
|
|
if g is not None]
|
|
|
|
policy._alpha_grads_and_vars = [(g, v) for (g, v) in alpha_grads_and_vars
|
|
|
|
if g is not None]
|
|
|
|
grads_and_vars = (
|
|
|
|
policy._actor_grads_and_vars + policy._critic_grads_and_vars +
|
|
|
|
policy._alpha_grads_and_vars)
|
|
|
|
return grads_and_vars
|
|
|
|
|
|
|
|
|
2019-12-20 10:51:25 -08:00
|
|
|
def apply_gradients(policy, optimizer, grads_and_vars):
|
|
|
|
actor_apply_ops = policy._actor_optimizer.apply_gradients(
|
|
|
|
policy._actor_grads_and_vars)
|
|
|
|
|
|
|
|
cgrads = policy._critic_grads_and_vars
|
|
|
|
half_cutoff = len(cgrads) // 2
|
|
|
|
if policy.config["twin_q"]:
|
|
|
|
critic_apply_ops = [
|
|
|
|
policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
|
|
|
|
policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:])
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
critic_apply_ops = [
|
|
|
|
policy._critic_optimizer[0].apply_gradients(cgrads)
|
|
|
|
]
|
|
|
|
|
|
|
|
alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
|
|
|
|
policy._alpha_grads_and_vars,
|
|
|
|
global_step=tf.train.get_or_create_global_step())
|
|
|
|
return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
|
|
|
|
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def stats(policy, train_batch):
|
2019-08-01 23:37:36 -07:00
|
|
|
return {
|
2020-04-15 13:25:16 +02:00
|
|
|
# "policy_t": policy.policy_t,
|
|
|
|
# "td_error": policy.td_error,
|
|
|
|
"mean_td_error": tf.reduce_mean(policy.td_error),
|
2019-08-01 23:37:36 -07:00
|
|
|
"actor_loss": tf.reduce_mean(policy.actor_loss),
|
|
|
|
"critic_loss": tf.reduce_mean(policy.critic_loss),
|
2020-03-06 19:37:12 +01:00
|
|
|
"alpha_loss": tf.reduce_mean(policy.alpha_loss),
|
|
|
|
"alpha_value": tf.reduce_mean(policy.alpha_value),
|
|
|
|
"target_entropy": tf.constant(policy.target_entropy),
|
2019-08-01 23:37:36 -07:00
|
|
|
"mean_q": tf.reduce_mean(policy.q_t),
|
|
|
|
"max_q": tf.reduce_max(policy.q_t),
|
|
|
|
"min_q": tf.reduce_min(policy.q_t),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class ActorCriticOptimizerMixin:
|
2019-08-01 23:37:36 -07:00
|
|
|
def __init__(self, config):
|
|
|
|
# create global step for counting the number of update operations
|
|
|
|
self.global_step = tf.train.get_or_create_global_step()
|
|
|
|
|
|
|
|
# use separate optimizers for actor & critic
|
|
|
|
self._actor_optimizer = tf.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["actor_learning_rate"])
|
2019-12-20 10:51:25 -08:00
|
|
|
self._critic_optimizer = [
|
|
|
|
tf.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["critic_learning_rate"])
|
|
|
|
]
|
|
|
|
if config["twin_q"]:
|
|
|
|
self._critic_optimizer.append(
|
|
|
|
tf.train.AdamOptimizer(learning_rate=config["optimization"][
|
|
|
|
"critic_learning_rate"]))
|
2019-08-01 23:37:36 -07:00
|
|
|
self._alpha_optimizer = tf.train.AdamOptimizer(
|
|
|
|
learning_rate=config["optimization"]["entropy_learning_rate"])
|
|
|
|
|
|
|
|
|
|
|
|
def setup_early_mixins(policy, obs_space, action_space, config):
|
|
|
|
ActorCriticOptimizerMixin.__init__(policy, config)
|
|
|
|
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def setup_mid_mixins(policy, obs_space, action_space, config):
|
2020-04-09 23:04:21 +02:00
|
|
|
ComputeTDErrorMixin.__init__(policy, sac_actor_critic_loss)
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
def setup_late_mixins(policy, obs_space, action_space, config):
|
|
|
|
TargetNetworkMixin.__init__(policy, config)
|
|
|
|
|
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
def validate_spaces(pid, observation_space, action_space, config):
|
|
|
|
if not isinstance(action_space, (Box, Discrete)):
|
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} is not supported for "
|
|
|
|
"SAC.".format(action_space, pid))
|
|
|
|
if isinstance(action_space, Box) and len(action_space.shape) > 1:
|
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} has multiple dimensions "
|
|
|
|
"{}. ".format(action_space, pid, action_space.shape) +
|
|
|
|
"Consider reshaping this into a single dimension, "
|
|
|
|
"using a Tuple action space, or the multi-agent API.")
|
|
|
|
|
|
|
|
|
2019-08-01 23:37:36 -07:00
|
|
|
SACTFPolicy = build_tf_policy(
|
|
|
|
name="SACTFPolicy",
|
|
|
|
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
|
|
|
|
make_model=build_sac_model,
|
|
|
|
postprocess_fn=postprocess_trajectory,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
2020-04-09 23:04:21 +02:00
|
|
|
loss_fn=sac_actor_critic_loss,
|
2019-08-01 23:37:36 -07:00
|
|
|
stats_fn=stats,
|
|
|
|
gradients_fn=gradients,
|
2019-12-20 10:51:25 -08:00
|
|
|
apply_gradients_fn=apply_gradients,
|
2019-08-01 23:37:36 -07:00
|
|
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
|
|
|
mixins=[
|
2020-02-19 21:18:45 +01:00
|
|
|
TargetNetworkMixin, ActorCriticOptimizerMixin, ComputeTDErrorMixin
|
2019-08-01 23:37:36 -07:00
|
|
|
],
|
2020-06-25 19:01:32 +02:00
|
|
|
validate_spaces=validate_spaces,
|
2019-08-01 23:37:36 -07:00
|
|
|
before_init=setup_early_mixins,
|
2019-08-23 02:21:11 -04:00
|
|
|
before_loss_init=setup_mid_mixins,
|
2019-08-01 23:37:36 -07:00
|
|
|
after_init=setup_late_mixins,
|
|
|
|
obs_include_prev_action_reward=False)
|