2018-06-09 00:21:35 -07:00
|
|
|
from gym.spaces import Box
|
2020-09-12 03:04:44 +08:00
|
|
|
from functools import partial
|
2020-04-09 23:04:21 +02:00
|
|
|
import logging
|
2018-06-09 00:21:35 -07:00
|
|
|
import numpy as np
|
2021-05-19 07:32:29 -07:00
|
|
|
import gym
|
|
|
|
from typing import Dict, Tuple, List
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
import ray
|
2019-02-24 14:26:46 -08:00
|
|
|
import ray.experimental.tf_utils
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.ddpg.ddpg_tf_model import DDPGTFModel
|
|
|
|
from ray.rllib.algorithms.ddpg.ddpg_torch_model import DDPGTorchModel
|
|
|
|
from ray.rllib.algorithms.ddpg.noop_model import NoopModel, TorchNoopModel
|
|
|
|
from ray.rllib.algorithms.dqn.dqn_tf_policy import (
|
|
|
|
postprocess_nstep_and_prio,
|
|
|
|
PRIO_WEIGHTS,
|
|
|
|
)
|
2021-08-03 11:35:49 -04:00
|
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
|
|
from ray.rllib.models.modelv2 import ModelV2
|
2021-02-10 15:10:01 +01:00
|
|
|
from ray.rllib.models.tf.tf_action_dist import Deterministic, Dirichlet
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic, TorchDirichlet
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.utils.annotations import override
|
2021-08-03 11:35:49 -04:00
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
2020-04-09 23:04:21 +02:00
|
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
2020-06-25 19:01:32 +02:00
|
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
2020-07-08 16:12:20 +02:00
|
|
|
from ray.rllib.utils.framework import get_variable, try_import_tf
|
2021-02-10 15:10:01 +01:00
|
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable
|
2021-05-19 07:32:29 -07:00
|
|
|
from ray.rllib.utils.typing import (
|
|
|
|
TrainerConfigDict,
|
|
|
|
TensorType,
|
2021-10-04 13:29:00 +02:00
|
|
|
LocalOptimizer,
|
|
|
|
ModelGradients,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-08-03 11:35:49 -04:00
|
|
|
from ray.util.debug import log_once
|
2019-05-10 20:36:18 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-07-24 13:55:55 -07:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def build_ddpg_models(
|
|
|
|
policy: Policy,
|
|
|
|
observation_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> ModelV2:
|
2020-04-09 23:04:21 +02:00
|
|
|
if policy.config["use_state_preprocessor"]:
|
|
|
|
default_model = None # catalog decides
|
|
|
|
num_outputs = 256 # arbitrary
|
|
|
|
config["model"]["no_final_linear"] = True
|
|
|
|
else:
|
2020-05-27 16:19:13 +02:00
|
|
|
default_model = TorchNoopModel if config["framework"] == "torch" else NoopModel
|
2020-04-09 23:04:21 +02:00
|
|
|
num_outputs = int(np.product(observation_space.shape))
|
|
|
|
|
|
|
|
policy.model = ModelCatalog.get_model_v2(
|
|
|
|
obs_space=observation_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
|
|
|
model_interface=(
|
|
|
|
DDPGTorchModel if config["framework"] == "torch" else DDPGTFModel
|
|
|
|
),
|
2020-04-09 23:04:21 +02:00
|
|
|
default_model=default_model,
|
|
|
|
name="ddpg_model",
|
|
|
|
actor_hidden_activation=config["actor_hidden_activation"],
|
|
|
|
actor_hiddens=config["actor_hiddens"],
|
|
|
|
critic_hidden_activation=config["critic_hidden_activation"],
|
|
|
|
critic_hiddens=config["critic_hiddens"],
|
|
|
|
twin_q=config["twin_q"],
|
|
|
|
add_layer_norm=(
|
|
|
|
policy.config["exploration_config"].get("type") == "ParameterNoise"
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
policy.target_model = ModelCatalog.get_model_v2(
|
|
|
|
obs_space=observation_space,
|
|
|
|
action_space=action_space,
|
|
|
|
num_outputs=num_outputs,
|
|
|
|
model_config=config["model"],
|
2020-05-27 16:19:13 +02:00
|
|
|
framework=config["framework"],
|
|
|
|
model_interface=(
|
|
|
|
DDPGTorchModel if config["framework"] == "torch" else DDPGTFModel
|
|
|
|
),
|
2020-04-09 23:04:21 +02:00
|
|
|
default_model=default_model,
|
|
|
|
name="target_ddpg_model",
|
|
|
|
actor_hidden_activation=config["actor_hidden_activation"],
|
|
|
|
actor_hiddens=config["actor_hiddens"],
|
|
|
|
critic_hidden_activation=config["critic_hidden_activation"],
|
|
|
|
critic_hiddens=config["critic_hiddens"],
|
|
|
|
twin_q=config["twin_q"],
|
|
|
|
add_layer_norm=(
|
|
|
|
policy.config["exploration_config"].get("type") == "ParameterNoise"
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
return policy.model
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def get_distribution_inputs_and_class(
|
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
obs_batch: SampleBatch,
|
|
|
|
*,
|
|
|
|
explore: bool = True,
|
|
|
|
is_training: bool = False,
|
|
|
|
**kwargs
|
|
|
|
) -> Tuple[TensorType, ActionDistribution, List[TensorType]]:
|
2021-12-02 04:11:26 -08:00
|
|
|
model_out, _ = model(SampleBatch(obs=obs_batch, _is_training=is_training), [], None)
|
2020-04-09 23:04:21 +02:00
|
|
|
dist_inputs = model.get_policy_output(model_out)
|
|
|
|
|
2021-02-10 15:10:01 +01:00
|
|
|
if isinstance(policy.action_space, Simplex):
|
|
|
|
distr_class = (
|
|
|
|
TorchDirichlet if policy.config["framework"] == "torch" else Dirichlet
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-02-10 15:10:01 +01:00
|
|
|
else:
|
|
|
|
distr_class = (
|
|
|
|
TorchDeterministic
|
|
|
|
if policy.config["framework"] == "torch"
|
|
|
|
else Deterministic
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-02-10 15:10:01 +01:00
|
|
|
return dist_inputs, distr_class, [] # []=state out
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def ddpg_actor_critic_loss(
|
|
|
|
policy: Policy, model: ModelV2, _, train_batch: SampleBatch
|
|
|
|
) -> TensorType:
|
2020-04-09 23:04:21 +02:00
|
|
|
twin_q = policy.config["twin_q"]
|
|
|
|
gamma = policy.config["gamma"]
|
|
|
|
n_step = policy.config["n_step"]
|
|
|
|
use_huber = policy.config["use_huber"]
|
|
|
|
huber_threshold = policy.config["huber_threshold"]
|
|
|
|
l2_reg = policy.config["l2_reg"]
|
|
|
|
|
2021-12-02 04:11:26 -08:00
|
|
|
input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True)
|
|
|
|
input_dict_next = SampleBatch(
|
|
|
|
obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
model_out_t, _ = model(input_dict, [], None)
|
|
|
|
model_out_tp1, _ = model(input_dict_next, [], None)
|
|
|
|
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
|
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
policy.target_q_func_vars = policy.target_model.variables()
|
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
# Policy network evaluation.
|
2020-07-08 16:12:20 +02:00
|
|
|
policy_t = model.get_policy_output(model_out_t)
|
|
|
|
policy_tp1 = policy.target_model.get_policy_output(target_model_out_tp1)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Action outputs.
|
2020-07-08 16:12:20 +02:00
|
|
|
if policy.config["smooth_target_policy"]:
|
|
|
|
target_noise_clip = policy.config["target_noise_clip"]
|
|
|
|
clipped_normal_sample = tf.clip_by_value(
|
|
|
|
tf.random.normal(
|
2020-08-07 16:49:49 -07:00
|
|
|
tf.shape(policy_tp1), stddev=policy.config["target_noise"]
|
|
|
|
),
|
|
|
|
-target_noise_clip,
|
|
|
|
target_noise_clip,
|
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
policy_tp1_smoothed = tf.clip_by_value(
|
|
|
|
policy_tp1 + clipped_normal_sample,
|
|
|
|
policy.action_space.low * tf.ones_like(policy_tp1),
|
|
|
|
policy.action_space.high * tf.ones_like(policy_tp1),
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# No smoothing, just use deterministic actions.
|
|
|
|
policy_tp1_smoothed = policy_tp1
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Q-net(s) evaluation.
|
2020-07-08 16:12:20 +02:00
|
|
|
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
|
|
|
# Q-values for given actions & observations in given current
|
|
|
|
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
2020-04-09 23:04:21 +02:00
|
|
|
|
2020-07-08 16:12:20 +02:00
|
|
|
# Q-values for current policy (no noise) in given current state
|
|
|
|
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
if twin_q:
|
2020-08-07 16:49:49 -07:00
|
|
|
twin_q_t = model.get_twin_q_values(
|
|
|
|
model_out_t, train_batch[SampleBatch.ACTIONS]
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Target q-net(s) evaluation.
|
2020-07-08 16:12:20 +02:00
|
|
|
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
if twin_q:
|
2020-07-08 16:12:20 +02:00
|
|
|
twin_q_tp1 = policy.target_model.get_twin_q_values(
|
|
|
|
target_model_out_tp1, policy_tp1_smoothed
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
|
|
|
if twin_q:
|
|
|
|
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
|
|
|
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
|
|
|
|
|
|
|
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
|
|
|
q_tp1_best_masked = (
|
|
|
|
1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
|
|
|
) * q_tp1_best
|
|
|
|
|
|
|
|
# Compute RHS of bellman equation.
|
2021-04-27 17:18:17 +02:00
|
|
|
q_t_selected_target = tf.stop_gradient(
|
|
|
|
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
|
|
|
|
+ gamma ** n_step * q_tp1_best_masked
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Compute the error (potentially clipped).
|
|
|
|
if twin_q:
|
|
|
|
td_error = q_t_selected - q_t_selected_target
|
|
|
|
twin_td_error = twin_q_t_selected - q_t_selected_target
|
|
|
|
if use_huber:
|
2020-04-26 23:08:13 +02:00
|
|
|
errors = huber_loss(td_error, huber_threshold) + huber_loss(
|
|
|
|
twin_td_error, huber_threshold
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
else:
|
2020-06-25 19:01:32 +02:00
|
|
|
errors = 0.5 * tf.math.square(td_error) + 0.5 * tf.math.square(
|
|
|
|
twin_td_error
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
else:
|
|
|
|
td_error = q_t_selected - q_t_selected_target
|
|
|
|
if use_huber:
|
|
|
|
errors = huber_loss(td_error, huber_threshold)
|
|
|
|
else:
|
2020-06-25 19:01:32 +02:00
|
|
|
errors = 0.5 * tf.math.square(td_error)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
2020-06-25 19:01:32 +02:00
|
|
|
critic_loss = tf.reduce_mean(
|
|
|
|
tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
actor_loss = -tf.reduce_mean(q_t_det_policy)
|
|
|
|
|
|
|
|
# Add l2-regularization if required.
|
|
|
|
if l2_reg is not None:
|
|
|
|
for var in policy.model.policy_variables():
|
|
|
|
if "bias" not in var.name:
|
2020-07-08 16:12:20 +02:00
|
|
|
actor_loss += l2_reg * tf.nn.l2_loss(var)
|
2020-04-09 23:04:21 +02:00
|
|
|
for var in policy.model.q_variables():
|
|
|
|
if "bias" not in var.name:
|
2020-07-08 16:12:20 +02:00
|
|
|
critic_loss += l2_reg * tf.nn.l2_loss(var)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Model self-supervised losses.
|
|
|
|
if policy.config["use_state_preprocessor"]:
|
|
|
|
# Expand input_dict in case custom_loss' need them.
|
|
|
|
input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
|
|
|
|
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
|
|
|
|
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
|
|
|
|
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
|
2020-04-16 10:20:01 +02:00
|
|
|
if log_once("ddpg_custom_loss"):
|
|
|
|
logger.warning(
|
|
|
|
"You are using a state-preprocessor with DDPG and "
|
|
|
|
"therefore, `custom_loss` will be called on your Model! "
|
|
|
|
"Please be aware that DDPG now uses the ModelV2 API, which "
|
|
|
|
"merges all previously separate sub-models (policy_model, "
|
|
|
|
"q_model, and twin_q_model) into one ModelV2, on which "
|
|
|
|
"`custom_loss` is called, passing it "
|
|
|
|
"[actor_loss, critic_loss] as 1st argument. "
|
|
|
|
"You may have to change your custom loss function to handle "
|
|
|
|
"this."
|
|
|
|
)
|
|
|
|
[actor_loss, critic_loss] = model.custom_loss(
|
|
|
|
[actor_loss, critic_loss], input_dict
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
# Store values for stats function.
|
|
|
|
policy.actor_loss = actor_loss
|
|
|
|
policy.critic_loss = critic_loss
|
|
|
|
policy.td_error = td_error
|
|
|
|
policy.q_t = q_t
|
|
|
|
|
|
|
|
# Return one loss value (even though we treat them separately in our
|
|
|
|
# 2 optimizers: actor and critic).
|
|
|
|
return policy.critic_loss + policy.actor_loss
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def build_apply_op(
|
|
|
|
policy: Policy, optimizer: LocalOptimizer, grads_and_vars: ModelGradients
|
|
|
|
) -> TensorType:
|
2020-04-09 23:04:21 +02:00
|
|
|
# For policy gradient, update policy net one time v.s.
|
|
|
|
# update critic net `policy_delay` time(s).
|
|
|
|
should_apply_actor_opt = tf.equal(
|
2020-06-30 10:13:20 +02:00
|
|
|
tf.math.floormod(policy.global_step, policy.config["policy_delay"]), 0
|
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
def make_apply_op():
|
|
|
|
return policy._actor_optimizer.apply_gradients(policy._actor_grads_and_vars)
|
|
|
|
|
|
|
|
actor_op = tf.cond(
|
|
|
|
should_apply_actor_opt, true_fn=make_apply_op, false_fn=lambda: tf.no_op()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
critic_op = policy._critic_optimizer.apply_gradients(policy._critic_grads_and_vars)
|
|
|
|
# Increment global step & apply ops.
|
2020-07-11 22:06:35 +02:00
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
policy.global_step.assign_add(1)
|
|
|
|
return tf.no_op()
|
|
|
|
else:
|
|
|
|
with tf1.control_dependencies([tf1.assign_add(policy.global_step, 1)]):
|
|
|
|
return tf.group(actor_op, critic_op)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def gradients_fn(
|
|
|
|
policy: Policy, optimizer: LocalOptimizer, loss: TensorType
|
|
|
|
) -> ModelGradients:
|
2020-07-11 22:06:35 +02:00
|
|
|
if policy.config["framework"] in ["tf2", "tfe"]:
|
2020-07-08 16:12:20 +02:00
|
|
|
tape = optimizer.tape
|
|
|
|
pol_weights = policy.model.policy_variables()
|
2020-08-07 16:49:49 -07:00
|
|
|
actor_grads_and_vars = list(
|
|
|
|
zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
q_weights = policy.model.q_variables()
|
2020-08-07 16:49:49 -07:00
|
|
|
critic_grads_and_vars = list(
|
|
|
|
zip(tape.gradient(policy.critic_loss, q_weights), q_weights)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
else:
|
|
|
|
actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
|
|
|
|
policy.actor_loss, var_list=policy.model.policy_variables()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
critic_grads_and_vars = policy._critic_optimizer.compute_gradients(
|
|
|
|
policy.critic_loss, var_list=policy.model.q_variables()
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-07-08 16:12:20 +02:00
|
|
|
|
|
|
|
# Clip if necessary.
|
|
|
|
if policy.config["grad_clip"]:
|
2020-09-12 03:04:44 +08:00
|
|
|
clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
|
2020-07-08 16:12:20 +02:00
|
|
|
else:
|
|
|
|
clip_func = tf.identity
|
|
|
|
|
|
|
|
# Save grads and vars for later use in `build_apply_op`.
|
2020-08-07 16:49:49 -07:00
|
|
|
policy._actor_grads_and_vars = [
|
|
|
|
(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None
|
|
|
|
]
|
|
|
|
policy._critic_grads_and_vars = [
|
|
|
|
(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None
|
|
|
|
]
|
2020-07-08 16:12:20 +02:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
grads_and_vars = policy._actor_grads_and_vars + policy._critic_grads_and_vars
|
2020-07-08 16:12:20 +02:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
return grads_and_vars
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def build_ddpg_stats(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]:
|
2020-04-09 23:04:21 +02:00
|
|
|
stats = {
|
|
|
|
"mean_q": tf.reduce_mean(policy.q_t),
|
|
|
|
"max_q": tf.reduce_max(policy.q_t),
|
|
|
|
"min_q": tf.reduce_min(policy.q_t),
|
|
|
|
}
|
|
|
|
return stats
|
|
|
|
|
|
|
|
|
2021-08-14 03:01:24 +02:00
|
|
|
class ActorCriticOptimizerMixin:
|
|
|
|
"""Mixin class to generate the necessary optimizers for actor-critic algos.
|
|
|
|
|
|
|
|
- Creates global step for counting the number of update operations.
|
|
|
|
- Creates separate optimizers for actor, critic, and alpha.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
# Eager mode.
|
|
|
|
if config["framework"] in ["tf2", "tfe"]:
|
|
|
|
self.global_step = get_variable(0, tf_name="global_step")
|
|
|
|
self._actor_optimizer = tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["actor_lr"]
|
|
|
|
)
|
|
|
|
self._critic_optimizer = tf.keras.optimizers.Adam(
|
|
|
|
learning_rate=config["critic_lr"]
|
|
|
|
)
|
|
|
|
# Static graph mode.
|
|
|
|
else:
|
|
|
|
self.global_step = tf1.train.get_or_create_global_step()
|
|
|
|
self._actor_optimizer = tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["actor_lr"]
|
|
|
|
)
|
|
|
|
self._critic_optimizer = tf1.train.AdamOptimizer(
|
|
|
|
learning_rate=config["critic_lr"]
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
|
|
|
|
|
2021-08-14 03:01:24 +02:00
|
|
|
def setup_early_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> None:
|
|
|
|
"""Call mixin classes' constructors before Policy's initialization.
|
|
|
|
|
|
|
|
Adds the necessary optimizers to the given Policy.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy (Policy): The Policy object.
|
|
|
|
obs_space (gym.spaces.Space): The Policy's observation space.
|
|
|
|
action_space (gym.spaces.Space): The Policy's action space.
|
|
|
|
config (TrainerConfigDict): The Policy's config.
|
|
|
|
"""
|
|
|
|
ActorCriticOptimizerMixin.__init__(policy, config)
|
2020-04-09 23:04:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ComputeTDErrorMixin:
|
|
|
|
def __init__(self, loss_fn):
|
|
|
|
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
|
|
|
def compute_td_error(
|
|
|
|
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights
|
|
|
|
):
|
|
|
|
# Do forward pass on loss to update td errors attribute
|
|
|
|
# (one TD-error value per item in batch to update PR weights).
|
|
|
|
loss_fn(
|
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
None,
|
|
|
|
{
|
|
|
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
|
|
|
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
|
|
|
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
|
|
|
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
|
|
|
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
|
|
|
|
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
# `self.td_error` is set in loss_fn.
|
|
|
|
return self.td_error
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
self.compute_td_error = compute_td_error
|
2019-09-04 21:39:22 -07:00
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def setup_mid_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> None:
|
2020-04-09 23:04:21 +02:00
|
|
|
ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss)
|
2019-09-04 21:39:22 -07:00
|
|
|
|
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
class TargetNetworkMixin:
|
2021-05-19 07:32:29 -07:00
|
|
|
def __init__(self, config: TrainerConfigDict):
|
2020-04-09 23:04:21 +02:00
|
|
|
@make_tf_callable(self.get_session())
|
|
|
|
def update_target_fn(tau):
|
|
|
|
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
|
|
|
update_target_expr = []
|
|
|
|
model_vars = self.model.trainable_variables()
|
|
|
|
target_model_vars = self.target_model.trainable_variables()
|
|
|
|
assert len(model_vars) == len(target_model_vars), (
|
|
|
|
model_vars,
|
|
|
|
target_model_vars,
|
|
|
|
)
|
|
|
|
for var, var_target in zip(model_vars, target_model_vars):
|
|
|
|
update_target_expr.append(
|
|
|
|
var_target.assign(tau * var + (1.0 - tau) * var_target)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-04-09 23:04:21 +02:00
|
|
|
logger.debug("Update target op {}".format(var_target))
|
|
|
|
return tf.group(*update_target_expr)
|
2019-09-04 21:39:22 -07:00
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
# Hard initial update.
|
|
|
|
self._do_update = update_target_fn
|
|
|
|
self.update_target(tau=1.0)
|
|
|
|
|
|
|
|
# Support both hard and soft sync.
|
2021-05-19 07:32:29 -07:00
|
|
|
def update_target(self, tau: int = None) -> None:
|
2020-04-09 23:04:21 +02:00
|
|
|
self._do_update(np.float32(tau or self.config.get("tau")))
|
|
|
|
|
|
|
|
@override(TFPolicy)
|
2021-05-19 07:32:29 -07:00
|
|
|
def variables(self) -> List[TensorType]:
|
2020-04-09 23:04:21 +02:00
|
|
|
return self.model.variables() + self.target_model.variables()
|
|
|
|
|
|
|
|
|
2021-05-19 07:32:29 -07:00
|
|
|
def setup_late_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> None:
|
2020-04-09 23:04:21 +02:00
|
|
|
TargetNetworkMixin.__init__(policy, config)
|
|
|
|
|
|
|
|
|
2021-10-04 13:29:00 +02:00
|
|
|
def validate_spaces(
|
|
|
|
policy: Policy,
|
|
|
|
observation_space: gym.spaces.Space,
|
2021-05-19 07:32:29 -07:00
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> None:
|
2020-06-25 19:01:32 +02:00
|
|
|
if not isinstance(action_space, Box):
|
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} is not supported for "
|
2021-10-04 13:29:00 +02:00
|
|
|
"DDPG.".format(action_space, policy)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-06-25 19:01:32 +02:00
|
|
|
elif len(action_space.shape) > 1:
|
|
|
|
raise UnsupportedSpaceException(
|
|
|
|
"Action space ({}) of {} has multiple dimensions "
|
2021-10-04 13:29:00 +02:00
|
|
|
"{}. ".format(action_space, policy, action_space.shape)
|
2020-06-25 19:01:32 +02:00
|
|
|
+ "Consider reshaping this into a single dimension, "
|
|
|
|
"using a Tuple action space, or the multi-agent API."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-04-09 23:04:21 +02:00
|
|
|
DDPGTFPolicy = build_tf_policy(
|
2020-05-05 12:36:42 -07:00
|
|
|
name="DDPGTFPolicy",
|
2022-05-19 09:30:42 -07:00
|
|
|
get_default_config=lambda: ray.rllib.algorithms.ddpg.ddpg.DEFAULT_CONFIG,
|
2020-04-09 23:04:21 +02:00
|
|
|
make_model=build_ddpg_models,
|
|
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
|
|
|
loss_fn=ddpg_actor_critic_loss,
|
|
|
|
stats_fn=build_ddpg_stats,
|
|
|
|
postprocess_fn=postprocess_nstep_and_prio,
|
2021-05-18 11:10:46 +02:00
|
|
|
compute_gradients_fn=gradients_fn,
|
2020-04-09 23:04:21 +02:00
|
|
|
apply_gradients_fn=build_apply_op,
|
|
|
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
|
2020-06-25 19:01:32 +02:00
|
|
|
validate_spaces=validate_spaces,
|
2021-08-14 03:01:24 +02:00
|
|
|
before_init=setup_early_mixins,
|
2020-04-09 23:04:21 +02:00
|
|
|
before_loss_init=setup_mid_mixins,
|
|
|
|
after_init=setup_late_mixins,
|
|
|
|
mixins=[
|
|
|
|
TargetNetworkMixin,
|
2021-08-14 03:01:24 +02:00
|
|
|
ActorCriticOptimizerMixin,
|
2020-04-09 23:04:21 +02:00
|
|
|
ComputeTDErrorMixin,
|
|
|
|
],
|
|
|
|
)
|