mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
453 lines
18 KiB
Python
453 lines
18 KiB
Python
from functools import partial
|
|
import logging
|
|
import numpy as np
|
|
import gym
|
|
from typing import Dict, Tuple, List, Type, Union, Optional, Any
|
|
|
|
import ray
|
|
import ray.experimental.tf_utils
|
|
from ray.rllib.algorithms.ddpg.utils import make_ddpg_models, validate_spaces
|
|
from ray.rllib.algorithms.dqn.dqn_tf_policy import (
|
|
postprocess_nstep_and_prio,
|
|
PRIO_WEIGHTS,
|
|
)
|
|
from ray.rllib.evaluation import Episode
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.tf_action_dist import (
|
|
Deterministic,
|
|
Dirichlet,
|
|
TFActionDistribution,
|
|
)
|
|
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
|
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
|
from ray.rllib.utils.framework import get_variable, try_import_tf
|
|
from ray.rllib.utils.spaces.simplex import Simplex
|
|
from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable
|
|
from ray.rllib.utils.typing import (
|
|
AlgorithmConfigDict,
|
|
TensorType,
|
|
LocalOptimizer,
|
|
ModelGradients,
|
|
)
|
|
from ray.util.debug import log_once
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ComputeTDErrorMixin:
|
|
def __init__(self: Union[DynamicTFPolicyV2, EagerTFPolicyV2]):
|
|
@make_tf_callable(self.get_session(), dynamic_shape=True)
|
|
def compute_td_error(
|
|
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights
|
|
):
|
|
input_dict = SampleBatch(
|
|
{
|
|
SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t),
|
|
SampleBatch.ACTIONS: tf.convert_to_tensor(act_t),
|
|
SampleBatch.REWARDS: tf.convert_to_tensor(rew_t),
|
|
SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1),
|
|
SampleBatch.DONES: tf.convert_to_tensor(done_mask),
|
|
PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights),
|
|
}
|
|
)
|
|
# Do forward pass on loss to update td errors attribute
|
|
# (one TD-error value per item in batch to update PR weights).
|
|
self.loss(self.model, None, input_dict)
|
|
# `self.td_error` is set in loss_fn.
|
|
return self.td_error
|
|
|
|
self.compute_td_error = compute_td_error
|
|
|
|
|
|
class TargetNetworkMixin:
|
|
def __init__(self: Union[DynamicTFPolicyV2, EagerTFPolicyV2]):
|
|
@make_tf_callable(self.get_session())
|
|
def update_target_fn(tau):
|
|
tau = tf.convert_to_tensor(tau, dtype=tf.float32)
|
|
update_target_expr = []
|
|
model_vars = self.model.trainable_variables()
|
|
target_model_vars = self.target_model.trainable_variables()
|
|
assert len(model_vars) == len(target_model_vars), (
|
|
model_vars,
|
|
target_model_vars,
|
|
)
|
|
for var, var_target in zip(model_vars, target_model_vars):
|
|
update_target_expr.append(
|
|
var_target.assign(tau * var + (1.0 - tau) * var_target)
|
|
)
|
|
logger.debug("Update target op {}".format(var_target))
|
|
return tf.group(*update_target_expr)
|
|
|
|
# Hard initial update.
|
|
self._do_update = update_target_fn
|
|
self.update_target(tau=1.0)
|
|
|
|
# Support both hard and soft sync.
|
|
def update_target(
|
|
self: Union[DynamicTFPolicyV2, EagerTFPolicyV2], tau: int = None
|
|
) -> None:
|
|
self._do_update(np.float32(tau or self.config.get("tau")))
|
|
|
|
@override(TFPolicy)
|
|
def variables(self: Union[DynamicTFPolicyV2, EagerTFPolicyV2]) -> List[TensorType]:
|
|
return self.model.variables() + self.target_model.variables()
|
|
|
|
|
|
# We need this builder function because we want to share the same
|
|
# custom logics between TF1 dynamic and TF2 eager policies.
|
|
def get_ddpg_tf_policy(
|
|
name: str, base: Type[Union[DynamicTFPolicyV2, EagerTFPolicyV2]]
|
|
) -> Type:
|
|
"""Construct a DDPGTFPolicy inheriting either dynamic or eager base policies.
|
|
|
|
Args:
|
|
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
|
|
Returns:
|
|
A TF Policy to be used with DDPG.
|
|
"""
|
|
|
|
class DDPGTFPolicy(TargetNetworkMixin, ComputeTDErrorMixin, base):
|
|
def __init__(
|
|
self,
|
|
observation_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: AlgorithmConfigDict,
|
|
*,
|
|
existing_inputs: Optional[Dict[str, "tf1.placeholder"]] = None,
|
|
existing_model: Optional[ModelV2] = None,
|
|
):
|
|
# First thing first, enable eager execution if necessary.
|
|
base.enable_eager_execution_if_necessary()
|
|
|
|
config = dict(
|
|
ray.rllib.algorithms.ddpg.ddpg.DDPGConfig().to_dict(), **config
|
|
)
|
|
|
|
# Validate action space for DDPG
|
|
validate_spaces(self, observation_space, action_space)
|
|
|
|
base.__init__(
|
|
self,
|
|
observation_space,
|
|
action_space,
|
|
config,
|
|
existing_inputs=existing_inputs,
|
|
existing_model=existing_model,
|
|
)
|
|
|
|
ComputeTDErrorMixin.__init__(self)
|
|
|
|
self.maybe_initialize_optimizer_and_loss()
|
|
|
|
TargetNetworkMixin.__init__(self)
|
|
|
|
@override(base)
|
|
def make_model(self) -> ModelV2:
|
|
return make_ddpg_models(self)
|
|
|
|
@override(base)
|
|
def optimizer(
|
|
self,
|
|
) -> List["tf.keras.optimizers.Optimizer"]:
|
|
"""Create separate optimizers for actor & critic losses."""
|
|
if self.config["framework"] in ["tf2", "tfe"]:
|
|
self.global_step = get_variable(0, tf_name="global_step")
|
|
self._actor_optimizer = tf.keras.optimizers.Adam(
|
|
learning_rate=self.config["actor_lr"]
|
|
)
|
|
self._critic_optimizer = tf.keras.optimizers.Adam(
|
|
learning_rate=self.config["critic_lr"]
|
|
)
|
|
# Static graph mode.
|
|
else:
|
|
self.global_step = tf1.train.get_or_create_global_step()
|
|
self._actor_optimizer = tf1.train.AdamOptimizer(
|
|
learning_rate=self.config["actor_lr"]
|
|
)
|
|
self._critic_optimizer = tf1.train.AdamOptimizer(
|
|
learning_rate=self.config["critic_lr"]
|
|
)
|
|
return [self._actor_optimizer, self._critic_optimizer]
|
|
|
|
@override(base)
|
|
def compute_gradients_fn(
|
|
self, optimizer: LocalOptimizer, loss: TensorType
|
|
) -> ModelGradients:
|
|
if self.config["framework"] in ["tf2", "tfe"]:
|
|
tape = optimizer.tape
|
|
pol_weights = self.model.policy_variables()
|
|
actor_grads_and_vars = list(
|
|
zip(tape.gradient(self.actor_loss, pol_weights), pol_weights)
|
|
)
|
|
q_weights = self.model.q_variables()
|
|
critic_grads_and_vars = list(
|
|
zip(tape.gradient(self.critic_loss, q_weights), q_weights)
|
|
)
|
|
else:
|
|
actor_grads_and_vars = self._actor_optimizer.compute_gradients(
|
|
self.actor_loss, var_list=self.model.policy_variables()
|
|
)
|
|
critic_grads_and_vars = self._critic_optimizer.compute_gradients(
|
|
self.critic_loss, var_list=self.model.q_variables()
|
|
)
|
|
|
|
# Clip if necessary.
|
|
if self.config["grad_clip"]:
|
|
clip_func = partial(tf.clip_by_norm, clip_norm=self.config["grad_clip"])
|
|
else:
|
|
clip_func = tf.identity
|
|
|
|
# Save grads and vars for later use in `build_apply_op`.
|
|
self._actor_grads_and_vars = [
|
|
(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None
|
|
]
|
|
self._critic_grads_and_vars = [
|
|
(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None
|
|
]
|
|
|
|
grads_and_vars = self._actor_grads_and_vars + self._critic_grads_and_vars
|
|
|
|
return grads_and_vars
|
|
|
|
@override(base)
|
|
def apply_gradients_fn(
|
|
self,
|
|
optimizer: "tf.keras.optimizers.Optimizer",
|
|
grads: ModelGradients,
|
|
) -> "tf.Operation":
|
|
# For policy gradient, update policy net one time v.s.
|
|
# update critic net `policy_delay` time(s).
|
|
should_apply_actor_opt = tf.equal(
|
|
tf.math.floormod(self.global_step, self.config["policy_delay"]), 0
|
|
)
|
|
|
|
def make_apply_op():
|
|
return self._actor_optimizer.apply_gradients(self._actor_grads_and_vars)
|
|
|
|
actor_op = tf.cond(
|
|
should_apply_actor_opt,
|
|
true_fn=make_apply_op,
|
|
false_fn=lambda: tf.no_op(),
|
|
)
|
|
critic_op = self._critic_optimizer.apply_gradients(
|
|
self._critic_grads_and_vars
|
|
)
|
|
# Increment global step & apply ops.
|
|
if self.config["framework"] in ["tf2", "tfe"]:
|
|
self.global_step.assign_add(1)
|
|
return tf.no_op()
|
|
else:
|
|
with tf1.control_dependencies([tf1.assign_add(self.global_step, 1)]):
|
|
return tf.group(actor_op, critic_op)
|
|
|
|
@override(base)
|
|
def action_distribution_fn(
|
|
self,
|
|
model: ModelV2,
|
|
*,
|
|
obs_batch: TensorType,
|
|
state_batches: TensorType,
|
|
is_training: bool = False,
|
|
**kwargs,
|
|
) -> Tuple[TensorType, type, List[TensorType]]:
|
|
model_out, _ = model(SampleBatch(obs=obs_batch, _is_training=is_training))
|
|
dist_inputs = model.get_policy_output(model_out)
|
|
|
|
if isinstance(self.action_space, Simplex):
|
|
distr_class = Dirichlet
|
|
else:
|
|
distr_class = Deterministic
|
|
return dist_inputs, distr_class, [] # []=state out
|
|
|
|
@override(base)
|
|
def postprocess_trajectory(
|
|
self,
|
|
sample_batch: SampleBatch,
|
|
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
|
|
episode: Optional[Episode] = None,
|
|
) -> SampleBatch:
|
|
return postprocess_nstep_and_prio(
|
|
self, sample_batch, other_agent_batches, episode
|
|
)
|
|
|
|
@override(base)
|
|
def loss(
|
|
self,
|
|
model: Union[ModelV2, "tf.keras.Model"],
|
|
dist_class: Type[TFActionDistribution],
|
|
train_batch: SampleBatch,
|
|
) -> TensorType:
|
|
twin_q = self.config["twin_q"]
|
|
gamma = self.config["gamma"]
|
|
n_step = self.config["n_step"]
|
|
use_huber = self.config["use_huber"]
|
|
huber_threshold = self.config["huber_threshold"]
|
|
l2_reg = self.config["l2_reg"]
|
|
|
|
input_dict = SampleBatch(
|
|
obs=train_batch[SampleBatch.CUR_OBS], _is_training=True
|
|
)
|
|
input_dict_next = SampleBatch(
|
|
obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True
|
|
)
|
|
|
|
model_out_t, _ = model(input_dict, [], None)
|
|
model_out_tp1, _ = model(input_dict_next, [], None)
|
|
target_model_out_tp1, _ = self.target_model(input_dict_next, [], None)
|
|
|
|
self.target_q_func_vars = self.target_model.variables()
|
|
|
|
# Policy network evaluation.
|
|
policy_t = model.get_policy_output(model_out_t)
|
|
policy_tp1 = self.target_model.get_policy_output(target_model_out_tp1)
|
|
|
|
# Action outputs.
|
|
if self.config["smooth_target_policy"]:
|
|
target_noise_clip = self.config["target_noise_clip"]
|
|
clipped_normal_sample = tf.clip_by_value(
|
|
tf.random.normal(
|
|
tf.shape(policy_tp1), stddev=self.config["target_noise"]
|
|
),
|
|
-target_noise_clip,
|
|
target_noise_clip,
|
|
)
|
|
policy_tp1_smoothed = tf.clip_by_value(
|
|
policy_tp1 + clipped_normal_sample,
|
|
self.action_space.low * tf.ones_like(policy_tp1),
|
|
self.action_space.high * tf.ones_like(policy_tp1),
|
|
)
|
|
else:
|
|
# No smoothing, just use deterministic actions.
|
|
policy_tp1_smoothed = policy_tp1
|
|
|
|
# Q-net(s) evaluation.
|
|
# prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
|
# Q-values for given actions & observations in given current
|
|
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
|
|
|
# Q-values for current policy (no noise) in given current state
|
|
q_t_det_policy = model.get_q_values(model_out_t, policy_t)
|
|
|
|
if twin_q:
|
|
twin_q_t = model.get_twin_q_values(
|
|
model_out_t, train_batch[SampleBatch.ACTIONS]
|
|
)
|
|
|
|
# Target q-net(s) evaluation.
|
|
q_tp1 = self.target_model.get_q_values(
|
|
target_model_out_tp1, policy_tp1_smoothed
|
|
)
|
|
|
|
if twin_q:
|
|
twin_q_tp1 = self.target_model.get_twin_q_values(
|
|
target_model_out_tp1, policy_tp1_smoothed
|
|
)
|
|
|
|
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
|
if twin_q:
|
|
twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
|
|
q_tp1 = tf.minimum(q_tp1, twin_q_tp1)
|
|
|
|
q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
|
|
q_tp1_best_masked = (
|
|
1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
|
) * q_tp1_best
|
|
|
|
# Compute RHS of bellman equation.
|
|
q_t_selected_target = tf.stop_gradient(
|
|
tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
|
|
+ gamma ** n_step * q_tp1_best_masked
|
|
)
|
|
|
|
# Compute the error (potentially clipped).
|
|
if twin_q:
|
|
td_error = q_t_selected - q_t_selected_target
|
|
twin_td_error = twin_q_t_selected - q_t_selected_target
|
|
if use_huber:
|
|
errors = huber_loss(td_error, huber_threshold) + huber_loss(
|
|
twin_td_error, huber_threshold
|
|
)
|
|
else:
|
|
errors = 0.5 * tf.math.square(td_error) + 0.5 * tf.math.square(
|
|
twin_td_error
|
|
)
|
|
else:
|
|
td_error = q_t_selected - q_t_selected_target
|
|
if use_huber:
|
|
errors = huber_loss(td_error, huber_threshold)
|
|
else:
|
|
errors = 0.5 * tf.math.square(td_error)
|
|
|
|
critic_loss = tf.reduce_mean(
|
|
tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors
|
|
)
|
|
actor_loss = -tf.reduce_mean(q_t_det_policy)
|
|
|
|
# Add l2-regularization if required.
|
|
if l2_reg is not None:
|
|
for var in self.model.policy_variables():
|
|
if "bias" not in var.name:
|
|
actor_loss += l2_reg * tf.nn.l2_loss(var)
|
|
for var in self.model.q_variables():
|
|
if "bias" not in var.name:
|
|
critic_loss += l2_reg * tf.nn.l2_loss(var)
|
|
|
|
# Model self-supervised losses.
|
|
if self.config["use_state_preprocessor"]:
|
|
# Expand input_dict in case custom_loss' need them.
|
|
input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
|
|
input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
|
|
input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
|
|
input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
|
|
if log_once("ddpg_custom_loss"):
|
|
logger.warning(
|
|
"You are using a state-preprocessor with DDPG and "
|
|
"therefore, `custom_loss` will be called on your Model! "
|
|
"Please be aware that DDPG now uses the ModelV2 API, which "
|
|
"merges all previously separate sub-models (policy_model, "
|
|
"q_model, and twin_q_model) into one ModelV2, on which "
|
|
"`custom_loss` is called, passing it "
|
|
"[actor_loss, critic_loss] as 1st argument. "
|
|
"You may have to change your custom loss function to handle "
|
|
"this."
|
|
)
|
|
[actor_loss, critic_loss] = model.custom_loss(
|
|
[actor_loss, critic_loss], input_dict
|
|
)
|
|
|
|
# Store values for stats function.
|
|
self.actor_loss = actor_loss
|
|
self.critic_loss = critic_loss
|
|
self.td_error = td_error
|
|
self.q_t = q_t
|
|
|
|
# Return one loss value (even though we treat them separately in our
|
|
# 2 optimizers: actor and critic).
|
|
return self.critic_loss + self.actor_loss
|
|
|
|
@override(base)
|
|
def extra_learn_fetches_fn(self) -> Dict[str, Any]:
|
|
return {"td_error": self.td_error}
|
|
|
|
@override(base)
|
|
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
|
stats = {
|
|
"mean_q": tf.reduce_mean(self.q_t),
|
|
"max_q": tf.reduce_max(self.q_t),
|
|
"min_q": tf.reduce_min(self.q_t),
|
|
}
|
|
return stats
|
|
|
|
DDPGTFPolicy.__name__ = name
|
|
DDPGTFPolicy.__qualname__ = name
|
|
|
|
return DDPGTFPolicy
|
|
|
|
|
|
DDPGTF1Policy = get_ddpg_tf_policy("DDPGTF1Policy", DynamicTFPolicyV2)
|
|
DDPGTF2Policy = get_ddpg_tf_policy("DDPGTF2Policy", EagerTFPolicyV2)
|