mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
345 lines
11 KiB
Python
345 lines
11 KiB
Python
"""TensorFlow policy class used for R2D2."""
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import gym
|
|
import ray
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import (
|
|
clip_gradients,
|
|
compute_q_values,
|
|
PRIO_WEIGHTS,
|
|
postprocess_nstep_and_prio,
|
|
)
|
|
from ray.rllib.agents.dqn.dqn_tf_policy import build_q_model
|
|
from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
|
|
from ray.rllib.models.action_dist import ActionDistribution
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.tf_utils import huber_loss
|
|
from ray.rllib.utils.typing import ModelInputDict, TensorType, TrainerConfigDict
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
def build_r2d2_model(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict,
|
|
) -> Tuple[ModelV2, ActionDistribution]:
|
|
"""Build q_model and target_model for DQN
|
|
|
|
Args:
|
|
policy (Policy): The policy, which will use the model for optimization.
|
|
obs_space (gym.spaces.Space): The policy's observation space.
|
|
action_space (gym.spaces.Space): The policy's action space.
|
|
config (TrainerConfigDict):
|
|
|
|
Returns:
|
|
q_model
|
|
Note: The target q model will not be returned, just assigned to
|
|
`policy.target_model`.
|
|
"""
|
|
|
|
# Create the policy's models.
|
|
model = build_q_model(policy, obs_space, action_space, config)
|
|
|
|
# Assert correct model type by checking the init state to be present.
|
|
# For attention nets: These don't necessarily publish their init state via
|
|
# Model.get_initial_state, but may only use the trajectory view API
|
|
# (view_requirements).
|
|
assert (
|
|
model.get_initial_state() != []
|
|
or model.view_requirements.get("state_in_0") is not None
|
|
), (
|
|
"R2D2 requires its model to be a recurrent one! Try using "
|
|
"`model.use_lstm` or `model.use_attention` in your config "
|
|
"to auto-wrap your model with an LSTM- or attention net."
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType:
|
|
"""Constructs the loss for R2D2TFPolicy.
|
|
|
|
Args:
|
|
policy (Policy): The Policy to calculate the loss for.
|
|
model (ModelV2): The Model to calculate the loss for.
|
|
train_batch (SampleBatch): The training data.
|
|
|
|
Returns:
|
|
TensorType: A single loss tensor.
|
|
"""
|
|
config = policy.config
|
|
|
|
# Construct internal state inputs.
|
|
i = 0
|
|
state_batches = []
|
|
while "state_in_{}".format(i) in train_batch:
|
|
state_batches.append(train_batch["state_in_{}".format(i)])
|
|
i += 1
|
|
assert state_batches
|
|
|
|
# Q-network evaluation (at t).
|
|
q, _, _, _ = compute_q_values(
|
|
policy,
|
|
model,
|
|
train_batch,
|
|
state_batches=state_batches,
|
|
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
|
|
explore=False,
|
|
is_training=True,
|
|
)
|
|
|
|
# Target Q-network evaluation (at t+1).
|
|
q_target, _, _, _ = compute_q_values(
|
|
policy,
|
|
policy.target_model,
|
|
train_batch,
|
|
state_batches=state_batches,
|
|
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
|
|
explore=False,
|
|
is_training=True,
|
|
)
|
|
|
|
if not hasattr(policy, "target_q_func_vars"):
|
|
policy.target_q_func_vars = policy.target_model.variables()
|
|
|
|
actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64)
|
|
dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
|
|
rewards = train_batch[SampleBatch.REWARDS]
|
|
weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
|
|
|
|
B = tf.shape(state_batches[0])[0]
|
|
T = tf.shape(q)[0] // B
|
|
|
|
# Q scores for actions which we know were selected in the given state.
|
|
one_hot_selection = tf.one_hot(actions, policy.action_space.n)
|
|
q_selected = tf.reduce_sum(
|
|
tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection, axis=1
|
|
)
|
|
|
|
if config["double_q"]:
|
|
best_actions = tf.argmax(q, axis=1)
|
|
else:
|
|
best_actions = tf.argmax(q_target, axis=1)
|
|
|
|
best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n)
|
|
q_target_best = tf.reduce_sum(
|
|
tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target))
|
|
* best_actions_one_hot,
|
|
axis=1,
|
|
)
|
|
|
|
if config["num_atoms"] > 1:
|
|
raise ValueError("Distributional R2D2 not supported yet!")
|
|
else:
|
|
q_target_best_masked_tp1 = (1.0 - dones) * tf.concat(
|
|
[q_target_best[1:], tf.constant([0.0])], axis=0
|
|
)
|
|
|
|
if config["use_h_function"]:
|
|
h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"])
|
|
target = h_function(
|
|
rewards + config["gamma"] ** config["n_step"] * h_inv,
|
|
config["h_function_epsilon"],
|
|
)
|
|
else:
|
|
target = (
|
|
rewards + config["gamma"] ** config["n_step"] * q_target_best_masked_tp1
|
|
)
|
|
|
|
# Seq-mask all loss-related terms.
|
|
seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
|
|
# Mask away also the burn-in sequence at the beginning.
|
|
burn_in = policy.config["burn_in"]
|
|
# Making sure, this works for both static graph and eager.
|
|
if burn_in > 0:
|
|
seq_mask = tf.cond(
|
|
pred=tf.convert_to_tensor(burn_in, tf.int32) < T,
|
|
true_fn=lambda: tf.concat(
|
|
[tf.fill([B, burn_in], False), seq_mask[:, burn_in:]], 1
|
|
),
|
|
false_fn=lambda: seq_mask,
|
|
)
|
|
|
|
def reduce_mean_valid(t):
|
|
return tf.reduce_mean(tf.boolean_mask(t, seq_mask))
|
|
|
|
# Make sure to use the correct time indices:
|
|
# Q(t) - [gamma * r + Q^(t+1)]
|
|
q_selected = tf.reshape(q_selected, [B, T])[:, :-1]
|
|
td_error = q_selected - tf.stop_gradient(tf.reshape(target, [B, T])[:, :-1])
|
|
td_error = td_error * tf.cast(seq_mask, tf.float32)
|
|
weights = tf.reshape(weights, [B, T])[:, :-1]
|
|
policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
|
|
# Store the TD-error per time chunk (b/c we need only one mean
|
|
# prioritized replay weight per stored sequence).
|
|
policy._td_error = tf.reduce_mean(td_error, axis=-1)
|
|
policy._loss_stats = {
|
|
"mean_q": reduce_mean_valid(q_selected),
|
|
"min_q": tf.reduce_min(q_selected),
|
|
"max_q": tf.reduce_max(q_selected),
|
|
"mean_td_error": reduce_mean_valid(td_error),
|
|
}
|
|
|
|
return policy._total_loss
|
|
|
|
|
|
def h_function(x, epsilon=1.0):
|
|
"""h-function to normalize target Qs, described in the paper [1].
|
|
|
|
h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x
|
|
|
|
Used in [1] in combination with h_inverse:
|
|
targets = h(r + gamma * h_inverse(Q^))
|
|
"""
|
|
return tf.sign(x) * (tf.sqrt(tf.abs(x) + 1.0) - 1.0) + epsilon * x
|
|
|
|
|
|
def h_inverse(x, epsilon=1.0):
|
|
"""Inverse if the above h-function, described in the paper [1].
|
|
|
|
If x > 0.0:
|
|
h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
|
|
(2 * eps^2)
|
|
|
|
If x < 0.0:
|
|
h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
|
|
(2 * eps^2)
|
|
"""
|
|
two_epsilon = epsilon * 2
|
|
if_x_pos = (
|
|
two_epsilon * x
|
|
+ (two_epsilon + 1.0)
|
|
- tf.sqrt(4.0 * epsilon * x + (two_epsilon + 1.0) ** 2)
|
|
) / (2.0 * epsilon ** 2)
|
|
if_x_neg = (
|
|
two_epsilon * x
|
|
- (two_epsilon + 1.0)
|
|
+ tf.sqrt(-4.0 * epsilon * x + (two_epsilon + 1.0) ** 2)
|
|
) / (2.0 * epsilon ** 2)
|
|
return tf.where(x < 0.0, if_x_neg, if_x_pos)
|
|
|
|
|
|
class ComputeTDErrorMixin:
|
|
"""Assign the `compute_td_error` method to the R2D2TFPolicy
|
|
|
|
This allows us to prioritize on the worker side.
|
|
"""
|
|
|
|
def __init__(self):
|
|
def compute_td_error(
|
|
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights
|
|
):
|
|
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
|
|
input_dict[SampleBatch.ACTIONS] = act_t
|
|
input_dict[SampleBatch.REWARDS] = rew_t
|
|
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
|
|
input_dict[SampleBatch.DONES] = done_mask
|
|
input_dict[PRIO_WEIGHTS] = importance_weights
|
|
|
|
# Do forward pass on loss to update td error attribute
|
|
r2d2_loss(self, self.model, None, input_dict)
|
|
|
|
return self._td_error
|
|
|
|
self.compute_td_error = compute_td_error
|
|
|
|
|
|
def get_distribution_inputs_and_class(
|
|
policy: Policy,
|
|
model: ModelV2,
|
|
*,
|
|
input_dict: ModelInputDict,
|
|
state_batches: Optional[List[TensorType]] = None,
|
|
seq_lens: Optional[TensorType] = None,
|
|
explore: bool = True,
|
|
is_training: bool = False,
|
|
**kwargs
|
|
) -> Tuple[TensorType, type, List[TensorType]]:
|
|
|
|
if policy.config["framework"] == "torch":
|
|
from ray.rllib.agents.dqn.r2d2_torch_policy import (
|
|
compute_q_values as torch_compute_q_values,
|
|
)
|
|
|
|
func = torch_compute_q_values
|
|
else:
|
|
func = compute_q_values
|
|
|
|
q_vals, logits, probs_or_logits, state_out = func(
|
|
policy, model, input_dict, state_batches, seq_lens, explore, is_training
|
|
)
|
|
|
|
policy.q_values = q_vals
|
|
if not hasattr(policy, "q_func_vars"):
|
|
policy.q_func_vars = model.variables()
|
|
|
|
action_dist_class = (
|
|
TorchCategorical if policy.config["framework"] == "torch" else Categorical
|
|
)
|
|
|
|
return policy.q_values, action_dist_class, state_out
|
|
|
|
|
|
def adam_optimizer(
|
|
policy: Policy, config: TrainerConfigDict
|
|
) -> "tf.keras.optimizers.Optimizer":
|
|
return tf1.train.AdamOptimizer(
|
|
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
|
|
)
|
|
|
|
|
|
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
|
return dict(
|
|
{
|
|
"cur_lr": policy.cur_lr,
|
|
},
|
|
**policy._loss_stats
|
|
)
|
|
|
|
|
|
def setup_early_mixins(
|
|
policy: Policy, obs_space, action_space, config: TrainerConfigDict
|
|
) -> None:
|
|
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
|
|
|
|
|
def before_loss_init(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict,
|
|
) -> None:
|
|
ComputeTDErrorMixin.__init__(policy)
|
|
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
|
|
|
|
|
R2D2TFPolicy = build_tf_policy(
|
|
name="R2D2TFPolicy",
|
|
loss_fn=r2d2_loss,
|
|
get_default_config=lambda: ray.rllib.agents.dqn.r2d2.R2D2_DEFAULT_CONFIG,
|
|
postprocess_fn=postprocess_nstep_and_prio,
|
|
stats_fn=build_q_stats,
|
|
make_model=build_r2d2_model,
|
|
action_distribution_fn=get_distribution_inputs_and_class,
|
|
optimizer_fn=adam_optimizer,
|
|
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
|
|
compute_gradients_fn=clip_gradients,
|
|
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
|
|
before_init=setup_early_mixins,
|
|
before_loss_init=before_loss_init,
|
|
mixins=[
|
|
TargetNetworkMixin,
|
|
ComputeTDErrorMixin,
|
|
LearningRateSchedule,
|
|
],
|
|
)
|