ray/rllib/agents/slateq/slateq_tf_policy.py

382 lines
13 KiB
Python

"""TensorFlow policy class used for SlateQ."""
import functools
import gym
import logging
import numpy as np
from typing import Dict
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import clip_gradients
from ray.rllib.agents.sac.sac_tf_policy import TargetNetworkMixin
from ray.rllib.agents.slateq.slateq_tf_model import SlateQTFModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import SlateMultiCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import huber_loss
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
def build_slateq_model(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> SlateQTFModel:
"""Build models for the SlateQTFPolicy.
Args:
policy: The policy, which will use the model for optimization.
obs_space: The policy's observation space.
action_space: The policy's action space.
config: The Trainer's config dict.
Returns:
The slate-Q specific Q-model instance.
"""
model = SlateQTFModel(
obs_space,
action_space,
num_outputs=action_space.nvec[0],
model_config=config["model"],
name="slateq_model",
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
)
policy.target_model = SlateQTFModel(
obs_space,
action_space,
num_outputs=action_space.nvec[0],
model_config=config["model"],
name="target_slateq_model",
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
)
return model
def build_slateq_losses(
policy: Policy,
model: ModelV2,
_,
train_batch: SampleBatch,
) -> TensorType:
"""Constructs the choice- and Q-value losses for the SlateQTorchPolicy.
Args:
policy: The Policy to calculate the loss for.
model: The Model to calculate the loss for.
train_batch: The training data.
Returns:
The Q-value loss tensor.
"""
# B=batch size
# S=slate size
# C=num candidates
# E=embedding size
# A=number of all possible slates
# Q-value computations.
# ---------------------
observation = train_batch[SampleBatch.OBS]
# user.shape: [B, E]
user_obs = observation["user"]
batch_size = tf.shape(user_obs)[0]
# doc.shape: [B, C, E]
doc_obs = list(observation["doc"].values())
# action.shape: [B, S]
actions = train_batch[SampleBatch.ACTIONS]
# click_indicator.shape: [B, S]
click_indicator = tf.cast(
tf.stack([k["click"] for k in observation["response"]], 1), tf.float32
)
# item_reward.shape: [B, S]
item_reward = tf.stack([k["watch_time"] for k in observation["response"]], 1)
# q_values.shape: [B, C]
q_values = model.get_q_values(user_obs, doc_obs)
# slate_q_values.shape: [B, S]
slate_q_values = tf.gather(
q_values, tf.cast(actions, dtype=tf.int32), batch_dims=-1
)
# Only get the Q from the clicked document.
# replay_click_q.shape: [B]
replay_click_q = tf.reduce_sum(
input_tensor=slate_q_values * click_indicator, axis=1, name="replay_click_q"
)
# Target computations.
# --------------------
next_obs = train_batch[SampleBatch.NEXT_OBS]
# user.shape: [B, E]
user_next_obs = next_obs["user"]
# doc.shape: [B, C, E]
doc_next_obs = list(next_obs["doc"].values())
# Only compute the watch time reward of the clicked item.
reward = tf.reduce_sum(input_tensor=item_reward * click_indicator, axis=1)
# TODO: Find out, whether it's correct here to use obs, not next_obs!
# Dopamine uses obs, then next_obs only for the score.
# next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
next_q_values = policy.target_model.get_q_values(user_obs, doc_obs)
scores, score_no_click = score_documents(user_next_obs, doc_next_obs)
# next_q_values_slate.shape: [B, A, S]
next_q_values_slate = tf.gather(next_q_values, policy.slates, axis=1)
# scores_slate.shape [B, A, S]
scores_slate = tf.gather(scores, policy.slates, axis=1)
# score_no_click_slate.shape: [B, A]
score_no_click_slate = tf.reshape(
tf.tile(score_no_click, tf.shape(input=policy.slates)[:1]), [batch_size, -1]
)
# next_q_target_slate.shape: [B, A]
next_q_target_slate = tf.reduce_sum(
input_tensor=next_q_values_slate * scores_slate, axis=2
) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate)
next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)
target = reward + policy.config["gamma"] * next_q_target_max * (
1.0 - tf.cast(train_batch["dones"], tf.float32)
)
target = tf.stop_gradient(target)
clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1)
# Clicked_indices is a vector and tf.gather selects the batch dimension.
q_clicked = tf.gather(replay_click_q, clicked_indices)
target_clicked = tf.gather(target, clicked_indices)
td_error = tf.where(
tf.cast(clicked, tf.bool),
replay_click_q - target,
tf.zeros_like(train_batch[SampleBatch.REWARDS]),
)
if policy.config["use_huber"]:
loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
else:
loss = tf.math.square(td_error)
loss = tf.reduce_mean(loss)
td_error = tf.abs(td_error)
mean_td_error = tf.reduce_mean(td_error)
policy._q_values = tf.reduce_mean(q_values)
policy._q_clicked = tf.reduce_mean(q_clicked)
policy._scores = tf.reduce_mean(scores)
policy._score_no_click = tf.reduce_mean(score_no_click)
policy._slate_q_values = tf.reduce_mean(slate_q_values)
policy._replay_click_q = tf.reduce_mean(replay_click_q)
policy._bellman_reward = tf.reduce_mean(reward)
policy._next_q_values = tf.reduce_mean(next_q_values)
policy._target = tf.reduce_mean(target)
policy._next_q_target_slate = tf.reduce_mean(next_q_target_slate)
policy._next_q_target_max = tf.reduce_mean(next_q_target_max)
policy._target_clicked = tf.reduce_mean(target_clicked)
policy._q_loss = loss
policy._td_error = td_error
policy._mean_td_error = mean_td_error
policy._mean_actions = tf.reduce_mean(actions)
return loss
def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
stats = {
"q_values": policy._q_values,
"q_clicked": policy._q_clicked,
"scores": policy._scores,
"score_no_click": policy._score_no_click,
"slate_q_values": policy._slate_q_values,
"replay_click_q": policy._replay_click_q,
"bellman_reward": policy._bellman_reward,
"next_q_values": policy._next_q_values,
"target": policy._target,
"next_q_target_slate": policy._next_q_target_slate,
"next_q_target_max": policy._next_q_target_max,
"target_clicked": policy._target_clicked,
"td_error": policy._td_error,
"mean_td_error": policy._mean_td_error,
"q_loss": policy._q_loss,
"mean_actions": policy._mean_actions,
}
# if hasattr(policy, "_mean_grads_0"):
# stats.update({"mean_grads_0": policy._mean_grads_0})
# stats.update({"mean_grads_1": policy._mean_grads_1})
# stats.update({"mean_grads_2": policy._mean_grads_2})
# stats.update({"mean_grads_3": policy._mean_grads_3})
# stats.update({"mean_grads_4": policy._mean_grads_4})
# stats.update({"mean_grads_5": policy._mean_grads_5})
# stats.update({"mean_grads_6": policy._mean_grads_6})
# stats.update({"mean_grads_7": policy._mean_grads_7})
return stats
def action_distribution_fn(
policy: Policy, model: SlateQTFModel, input_dict, *, explore, is_training, **kwargs
):
"""Determine which action to take."""
# First, we transform the observation into its unflattened form.
observation = input_dict[SampleBatch.OBS]
# user.shape: [B, E]
user_obs = observation["user"]
doc_obs = list(observation["doc"].values())
# Compute scores per candidate.
scores, score_no_click = score_documents(user_obs, doc_obs)
# Compute Q-values per candidate.
q_values = model.get_q_values(user_obs, doc_obs)
with tf.name_scope("select_slate"):
per_slate_q_values = get_per_slate_q_values(
policy.slates, score_no_click, scores, q_values
)
return (
per_slate_q_values,
functools.partial(
SlateMultiCategorical,
action_space=policy.action_space,
all_slates=policy.slates,
),
[],
)
def get_per_slate_q_values(slates, s_no_click, s, q):
slate_q_values = tf.gather(s * q, slates, axis=1)
slate_scores = tf.gather(s, slates, axis=1)
slate_normalizer = tf.reduce_sum(
input_tensor=slate_scores, axis=2
) + tf.expand_dims(s_no_click, 1)
slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 2)
slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=2)
return slate_sum_q_values
def score_documents(
user_obs, doc_obs, no_click_score=1.0, multinomial_logits=False, min_normalizer=-1.0
):
"""Computes dot-product scores for user vs doc (plus no-click) feature vectors."""
# Dot product between used and each document feature vector.
scores_per_candidate = tf.reduce_sum(
tf.multiply(tf.expand_dims(user_obs, 1), tf.stack(doc_obs, axis=1)), 2
)
# Compile a constant no-click score tensor.
score_no_click = tf.fill([tf.shape(user_obs)[0], 1], no_click_score)
# Concatenate click and no-click scores.
all_scores = tf.concat([scores_per_candidate, score_no_click], axis=1)
# Logits: Softmax to yield probabilities.
if multinomial_logits:
all_scores = tf.nn.softmax(all_scores)
# Multinomial proportional model: Shift to `[0.0,..[`.
else:
all_scores = all_scores - min_normalizer
# Return click (per candidate document) and no-click scores.
return all_scores[:, :-1], all_scores[:, -1]
def setup_early(policy, obs_space, action_space, config):
"""Obtain all possible slates given current docs in the candidate set."""
num_candidates = action_space.nvec[0]
slate_size = len(action_space.nvec)
num_all_slates = np.prod([(num_candidates - i) for i in range(slate_size)])
mesh_args = [list(range(num_candidates))] * slate_size
slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
slates = tf.reshape(slates, shape=(-1, slate_size))
# Filter slates that include duplicates to ensure each document is picked
# at most once.
unique_mask = tf.map_fn(
lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
slates,
dtype=tf.bool,
)
# slates.shape: [A, S]
slates = tf.boolean_mask(tensor=slates, mask=unique_mask)
slates.set_shape([num_all_slates, slate_size])
# Store all possible slates only once in policy object.
policy.slates = slates
def setup_mid_mixins(policy: Policy, obs_space, action_space, config) -> None:
"""Call mixin classes' constructors before SlateQTorchPolicy loss initialization.
Args:
policy: The Policy object.
obs_space: The Policy's observation space.
action_space: The Policy's action space.
config: The Policy's config.
"""
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
"""Call mixin classes' constructors after SlateQTorchPolicy loss initialization.
Args:
policy: The Policy object.
obs_space: The Policy's observation space.
action_space: The Policy's action space.
config: The Policy's config.
"""
TargetNetworkMixin.__init__(policy, config)
def rmsprop_optimizer(
policy: Policy, config: TrainerConfigDict
) -> "tf.keras.optimizers.Optimizer":
if policy.config["framework"] in ["tf2", "tfe"]:
return tf.keras.optimizers.RMSprop(
learning_rate=policy.cur_lr,
epsilon=config["rmsprop_epsilon"],
decay=0.95,
momentum=0.0,
centered=True,
)
else:
return tf1.train.RMSPropOptimizer(
learning_rate=policy.cur_lr,
epsilon=config["rmsprop_epsilon"],
decay=0.95,
momentum=0.0,
centered=True,
)
SlateQTFPolicy = build_tf_policy(
name="SlateQTFPolicy",
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
# Build model, loss functions, and optimizers
make_model=build_slateq_model,
loss_fn=build_slateq_losses,
stats_fn=build_slateq_stats,
optimizer_fn=rmsprop_optimizer,
# Define how to act.
action_distribution_fn=action_distribution_fn,
compute_gradients_fn=clip_gradients,
before_init=setup_early,
before_loss_init=setup_mid_mixins,
after_init=setup_late_mixins,
mixins=[LearningRateSchedule, TargetNetworkMixin],
)