ray/rllib/agents/dqn/r2d2_tf_policy.py
Sven Mika 2d24ef0d32
[RLlib] Add all simple learning tests as framework=tf2. (#19273)
* Unpin gym and deprecate pendulum v0

Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1

Lastly, all of the RLlib tests and Tune tests have
been moved to python 3.7

* fix tune test_sampler::testSampleBoundsAx

* fix re-install ray for py3.7 tests

Co-authored-by: avnishn <avnishn@uw.edu>
2021-11-02 12:10:17 +01:00

314 lines
11 KiB
Python

"""TensorFlow policy class used for R2D2."""
from typing import Dict, List, Optional, Tuple
import gym
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import clip_gradients, \
compute_q_values, PRIO_WEIGHTS, postprocess_nstep_and_prio
from ray.rllib.agents.dqn.dqn_tf_policy import build_q_model
from ray.rllib.agents.dqn.simple_q_tf_policy import TargetNetworkMixin
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import huber_loss
from ray.rllib.utils.typing import ModelInputDict, TensorType, \
TrainerConfigDict
tf1, tf, tfv = try_import_tf()
def build_r2d2_model(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space, config: TrainerConfigDict
) -> Tuple[ModelV2, ActionDistribution]:
"""Build q_model and target_model for DQN
Args:
policy (Policy): The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
q_model
Note: The target q model will not be returned, just assigned to
`policy.target_model`.
"""
# Create the policy's models.
model = build_q_model(policy, obs_space, action_space, config)
# Assert correct model type by checking the init state to be present.
# For attention nets: These don't necessarily publish their init state via
# Model.get_initial_state, but may only use the trajectory view API
# (view_requirements).
assert (model.get_initial_state() != [] or
model.view_requirements.get("state_in_0") is not None), \
"R2D2 requires its model to be a recurrent one! Try using " \
"`model.use_lstm` or `model.use_attention` in your config " \
"to auto-wrap your model with an LSTM- or attention net."
return model
def r2d2_loss(policy: Policy, model, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the loss for R2D2TFPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
config = policy.config
# Construct internal state inputs.
i = 0
state_batches = []
while "state_in_{}".format(i) in train_batch:
state_batches.append(train_batch["state_in_{}".format(i)])
i += 1
assert state_batches
# Q-network evaluation (at t).
q, _, _, _ = compute_q_values(
policy,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)
# Target Q-network evaluation (at t+1).
q_target, _, _, _ = compute_q_values(
policy,
policy.target_model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)
if not hasattr(policy, "target_q_func_vars"):
policy.target_q_func_vars = policy.target_model.variables()
actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64)
dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
rewards = train_batch[SampleBatch.REWARDS]
weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
B = tf.shape(state_batches[0])[0]
T = tf.shape(q)[0] // B
# Q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(actions, policy.action_space.n)
q_selected = tf.reduce_sum(
tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection,
axis=1)
if config["double_q"]:
best_actions = tf.argmax(q, axis=1)
else:
best_actions = tf.argmax(q_target, axis=1)
best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n)
q_target_best = tf.reduce_sum(
tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target))
* best_actions_one_hot,
axis=1)
if config["num_atoms"] > 1:
raise ValueError("Distributional R2D2 not supported yet!")
else:
q_target_best_masked_tp1 = (1.0 - dones) * tf.concat(
[q_target_best[1:], tf.constant([0.0])], axis=0)
if config["use_h_function"]:
h_inv = h_inverse(q_target_best_masked_tp1,
config["h_function_epsilon"])
target = h_function(
rewards + config["gamma"]**config["n_step"] * h_inv,
config["h_function_epsilon"])
else:
target = rewards + \
config["gamma"] ** config["n_step"] * q_target_best_masked_tp1
# Seq-mask all loss-related terms.
seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS],
T)[:, :-1]
# Mask away also the burn-in sequence at the beginning.
burn_in = policy.config["burn_in"]
# Making sure, this works for both static graph and eager.
if burn_in > 0:
seq_mask = tf.cond(
pred=tf.convert_to_tensor(burn_in, tf.int32) < T,
true_fn=lambda: tf.concat([tf.fill([B, burn_in], False),
seq_mask[:, burn_in:]], 1),
false_fn=lambda: seq_mask,
)
def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, seq_mask))
# Make sure to use the correct time indices:
# Q(t) - [gamma * r + Q^(t+1)]
q_selected = tf.reshape(q_selected, [B, T])[:, :-1]
td_error = q_selected - tf.stop_gradient(
tf.reshape(target, [B, T])[:, :-1])
td_error = td_error * tf.cast(seq_mask, tf.float32)
weights = tf.reshape(weights, [B, T])[:, :-1]
policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
# Store the TD-error per time chunk (b/c we need only one mean
# prioritized replay weight per stored sequence).
policy._td_error = tf.reduce_mean(td_error, axis=-1)
policy._loss_stats = {
"mean_q": reduce_mean_valid(q_selected),
"min_q": tf.reduce_min(q_selected),
"max_q": tf.reduce_max(q_selected),
"mean_td_error": reduce_mean_valid(td_error),
}
return policy._total_loss
def h_function(x, epsilon=1.0):
"""h-function to normalize target Qs, described in the paper [1].
h(x) = sign(x) * [sqrt(abs(x) + 1) - 1] + epsilon * x
Used in [1] in combination with h_inverse:
targets = h(r + gamma * h_inverse(Q^))
"""
return tf.sign(x) * (tf.sqrt(tf.abs(x) + 1.0) - 1.0) + epsilon * x
def h_inverse(x, epsilon=1.0):
"""Inverse if the above h-function, described in the paper [1].
If x > 0.0:
h-1(x) = [2eps * x + (2eps + 1) - sqrt(4eps x + (2eps + 1)^2)] /
(2 * eps^2)
If x < 0.0:
h-1(x) = [2eps * x + (2eps + 1) + sqrt(-4eps x + (2eps + 1)^2)] /
(2 * eps^2)
"""
two_epsilon = epsilon * 2
if_x_pos = (two_epsilon * x + (two_epsilon + 1.0) -
tf.sqrt(4.0 * epsilon * x +
(two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
if_x_neg = (two_epsilon * x - (two_epsilon + 1.0) +
tf.sqrt(-4.0 * epsilon * x +
(two_epsilon + 1.0)**2)) / (2.0 * epsilon**2)
return tf.where(x < 0.0, if_x_neg, if_x_pos)
class ComputeTDErrorMixin:
"""Assign the `compute_td_error` method to the R2D2TFPolicy
This allows us to prioritize on the worker side.
"""
def __init__(self):
def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask,
importance_weights):
input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t})
input_dict[SampleBatch.ACTIONS] = act_t
input_dict[SampleBatch.REWARDS] = rew_t
input_dict[SampleBatch.NEXT_OBS] = obs_tp1
input_dict[SampleBatch.DONES] = done_mask
input_dict[PRIO_WEIGHTS] = importance_weights
# Do forward pass on loss to update td error attribute
r2d2_loss(self, self.model, None, input_dict)
return self._td_error
self.compute_td_error = compute_td_error
def get_distribution_inputs_and_class(
policy: Policy,
model: ModelV2,
*,
input_dict: ModelInputDict,
state_batches: Optional[List[TensorType]] = None,
seq_lens: Optional[TensorType] = None,
explore: bool = True,
is_training: bool = False,
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
if policy.config["framework"] == "torch":
from ray.rllib.agents.dqn.r2d2_torch_policy import \
compute_q_values as torch_compute_q_values
func = torch_compute_q_values
else:
func = compute_q_values
q_vals, logits, probs_or_logits, state_out = func(
policy, model, input_dict, state_batches, seq_lens, explore,
is_training)
policy.q_values = q_vals
if not hasattr(policy, "q_func_vars"):
policy.q_func_vars = model.variables()
action_dist_class = TorchCategorical if \
policy.config["framework"] == "torch" else Categorical
return policy.q_values, action_dist_class, state_out
def adam_optimizer(policy: Policy, config: TrainerConfigDict
) -> "tf.keras.optimizers.Optimizer":
return tf1.train.AdamOptimizer(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"])
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:
return dict({
"cur_lr": policy.cur_lr,
}, **policy._loss_stats)
def setup_early_mixins(policy: Policy, obs_space, action_space,
config: TrainerConfigDict) -> None:
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
R2D2TFPolicy = build_tf_policy(
name="R2D2TFPolicy",
loss_fn=r2d2_loss,
get_default_config=lambda: ray.rllib.agents.dqn.r2d2.DEFAULT_CONFIG,
postprocess_fn=postprocess_nstep_and_prio,
stats_fn=build_q_stats,
make_model=build_r2d2_model,
action_distribution_fn=get_distribution_inputs_and_class,
optimizer_fn=adam_optimizer,
extra_action_out_fn=lambda policy: {"q_values": policy.q_values},
compute_gradients_fn=clip_gradients,
extra_learn_fetches_fn=lambda policy: {"td_error": policy._td_error},
before_init=setup_early_mixins,
before_loss_init=before_loss_init,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
LearningRateSchedule,
])