ray/rllib/algorithms/slateq/slateq_torch_policy.py

439 lines
16 KiB
Python

"""PyTorch policy class used for SlateQ."""
import gym
import logging
import numpy as np
from typing import Dict, Tuple, Type
import ray
from ray.rllib.algorithms.sac.sac_torch_policy import TargetNetworkMixin
from ray.rllib.algorithms.slateq.slateq_torch_model import SlateQTorchModel
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical,
TorchDistributionWrapper,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import (
apply_grad_clipping,
concat_multi_gpu_td_errors,
convert_to_torch_tensor,
huber_loss,
)
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def build_slateq_model_and_distribution(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: AlgorithmConfigDict,
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
"""Build models for SlateQ
Args:
policy: The policy, which will use the model for optimization.
obs_space: The policy's observation space.
action_space: The policy's action space.
config: The Algorithm's config dict.
Returns:
Tuple consisting of 1) Q-model and 2) an action distribution class.
"""
model = SlateQTorchModel(
obs_space,
action_space,
num_outputs=action_space.nvec[0],
model_config=config["model"],
name="slateq_model",
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
)
policy.target_model = SlateQTorchModel(
obs_space,
action_space,
num_outputs=action_space.nvec[0],
model_config=config["model"],
name="target_slateq_model",
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
)
return model, TorchCategorical
def build_slateq_losses(
policy: Policy,
model: ModelV2,
_,
train_batch: SampleBatch,
) -> TensorType:
"""Constructs the choice- and Q-value losses for the SlateQTorchPolicy.
Args:
policy: The Policy to calculate the loss for.
model: The Model to calculate the loss for.
train_batch: The training data.
Returns:
The user-choice- and Q-value loss tensors.
"""
# B=batch size
# S=slate size
# C=num candidates
# E=embedding size
# A=number of all possible slates
# Q-value computations.
# ---------------------
# action.shape: [B, S]
actions = train_batch[SampleBatch.ACTIONS]
observation = convert_to_torch_tensor(
train_batch[SampleBatch.OBS], device=actions.device
)
# user.shape: [B, E]
user_obs = observation["user"]
batch_size, embedding_size = user_obs.shape
# doc.shape: [B, C, E]
doc_obs = list(observation["doc"].values())
A, S = policy.slates.shape
# click_indicator.shape: [B, S]
click_indicator = torch.stack(
[k["click"] for k in observation["response"]], 1
).float()
# item_reward.shape: [B, S]
item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1)
# q_values.shape: [B, C]
q_values = model.get_q_values(user_obs, doc_obs)
# slate_q_values.shape: [B, S]
slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1)
# Only get the Q from the clicked document.
# replay_click_q.shape: [B]
replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1)
# Target computations.
# --------------------
next_obs = convert_to_torch_tensor(
train_batch[SampleBatch.NEXT_OBS], device=actions.device
)
# user.shape: [B, E]
user_next_obs = next_obs["user"]
# doc.shape: [B, C, E]
doc_next_obs = list(next_obs["doc"].values())
# Only compute the watch time reward of the clicked item.
reward = torch.sum(item_reward * click_indicator, dim=1)
# TODO: Find out, whether it's correct here to use obs, not next_obs!
# Dopamine uses obs, then next_obs only for the score.
# next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
scores, score_no_click = score_documents(user_next_obs, doc_next_obs)
# next_q_values_slate.shape: [B, A, S]
indices = policy.slates_indices.to(next_q_values.device)
next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
[-1, A, S]
)
# scores_slate.shape [B, A, S]
scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
# score_no_click_slate.shape: [B, A]
score_no_click_slate = torch.reshape(
torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
)
# next_q_target_slate.shape: [B, A]
next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / (
torch.sum(scores_slate, dim=2) + score_no_click_slate
)
next_q_target_max, _ = torch.max(next_q_target_slate, dim=1)
target = reward + policy.config["gamma"] * next_q_target_max * (
1.0 - train_batch["dones"].float()
)
target = target.detach()
clicked = torch.sum(click_indicator, dim=1)
mask_clicked_slates = clicked > 0
clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
# Clicked_indices is a vector and torch.gather selects the batch dimension.
q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
target_clicked = torch.gather(target, 0, clicked_indices)
td_error = torch.where(
clicked.bool(),
replay_click_q - target,
torch.zeros_like(train_batch[SampleBatch.REWARDS]),
)
if policy.config["use_huber"]:
loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
else:
loss = torch.pow(td_error, 2.0)
loss = torch.mean(loss)
td_error = torch.abs(td_error)
mean_td_error = torch.mean(td_error)
# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["q_values"] = torch.mean(q_values)
model.tower_stats["q_clicked"] = torch.mean(q_clicked)
model.tower_stats["scores"] = torch.mean(scores)
model.tower_stats["score_no_click"] = torch.mean(score_no_click)
model.tower_stats["slate_q_values"] = torch.mean(slate_q_values)
model.tower_stats["replay_click_q"] = torch.mean(replay_click_q)
model.tower_stats["bellman_reward"] = torch.mean(reward)
model.tower_stats["next_q_values"] = torch.mean(next_q_values)
model.tower_stats["target"] = torch.mean(target)
model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate)
model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max)
model.tower_stats["target_clicked"] = torch.mean(target_clicked)
model.tower_stats["q_loss"] = loss
model.tower_stats["td_error"] = td_error
model.tower_stats["mean_td_error"] = mean_td_error
model.tower_stats["mean_actions"] = torch.mean(actions.float())
# selected_doc.shape: [batch_size, slate_size, embedding_size]
selected_doc = torch.gather(
# input.shape: [batch_size, num_docs, embedding_size]
torch.stack(doc_obs, 1),
1,
# index.shape: [batch_size, slate_size, embedding_size]
actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
)
scores = model.choice_model(user_obs, selected_doc)
# click_indicator.shape: [batch_size, slate_size]
# no_clicks.shape: [batch_size, 1]
no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True)
# targets.shape: [batch_size, slate_size+1]
targets = torch.cat([click_indicator, no_clicks], dim=1)
choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1))
# print(model.choice_model.a.item(), model.choice_model.b.item())
model.tower_stats["choice_loss"] = choice_loss
return choice_loss, loss
def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
stats = {
"q_values": torch.mean(torch.stack(policy.get_tower_stats("q_values"))),
"q_clicked": torch.mean(torch.stack(policy.get_tower_stats("q_clicked"))),
"scores": torch.mean(torch.stack(policy.get_tower_stats("scores"))),
"score_no_click": torch.mean(
torch.stack(policy.get_tower_stats("score_no_click"))
),
"slate_q_values": torch.mean(
torch.stack(policy.get_tower_stats("slate_q_values"))
),
"replay_click_q": torch.mean(
torch.stack(policy.get_tower_stats("replay_click_q"))
),
"bellman_reward": torch.mean(
torch.stack(policy.get_tower_stats("bellman_reward"))
),
"next_q_values": torch.mean(
torch.stack(policy.get_tower_stats("next_q_values"))
),
"target": torch.mean(torch.stack(policy.get_tower_stats("target"))),
"next_q_target_slate": torch.mean(
torch.stack(policy.get_tower_stats("next_q_target_slate"))
),
"next_q_target_max": torch.mean(
torch.stack(policy.get_tower_stats("next_q_target_max"))
),
"target_clicked": torch.mean(
torch.stack(policy.get_tower_stats("target_clicked"))
),
"q_loss": torch.mean(torch.stack(policy.get_tower_stats("q_loss"))),
"mean_actions": torch.mean(torch.stack(policy.get_tower_stats("mean_actions"))),
"choice_loss": torch.mean(torch.stack(policy.get_tower_stats("choice_loss"))),
# "choice_beta": torch.mean(torch.stack(policy.get_tower_stats("choice_beta"))),
# "choice_score_no_click": torch.mean(
# torch.stack(policy.get_tower_stats("choice_score_no_click"))
# ),
}
# model_stats = {
# k: torch.mean(var)
# for k, var in policy.model.trainable_variables(as_dict=True).items()
# }
# stats.update(model_stats)
return stats
def action_distribution_fn(
policy: Policy,
model: SlateQTorchModel,
input_dict,
*,
explore,
is_training,
**kwargs,
):
"""Determine which action to take."""
observation = input_dict[SampleBatch.OBS]
# user.shape: [B, E]
user_obs = observation["user"]
doc_obs = list(observation["doc"].values())
# Compute scores per candidate.
scores, score_no_click = score_documents(user_obs, doc_obs)
# Compute Q-values per candidate.
q_values = model.get_q_values(user_obs, doc_obs)
per_slate_q_values = get_per_slate_q_values(
policy, score_no_click, scores, q_values
)
if not hasattr(model, "slates"):
model.slates = policy.slates
return per_slate_q_values, TorchCategorical, []
def get_per_slate_q_values(policy, score_no_click, scores, q_values):
indices = policy.slates_indices.to(scores.device)
A, S = policy.slates.shape
slate_q_values = torch.take_along_dim(scores * q_values, indices, dim=1).reshape(
[-1, A, S]
)
slate_scores = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
slate_normalizer = torch.sum(slate_scores, dim=2) + score_no_click.unsqueeze(1)
slate_q_values = slate_q_values / slate_normalizer.unsqueeze(2)
slate_sum_q_values = torch.sum(slate_q_values, dim=2)
return slate_sum_q_values
def score_documents(
user_obs, doc_obs, no_click_score=1.0, multinomial_logits=False, min_normalizer=-1.0
):
"""Computes dot-product scores for user vs doc (plus no-click) feature vectors."""
# Dot product between used and each document feature vector.
scores_per_candidate = torch.sum(
torch.multiply(user_obs.unsqueeze(1), torch.stack(doc_obs, dim=1)), dim=2
)
# Compile a constant no-click score tensor.
score_no_click = torch.full(
size=[user_obs.shape[0], 1], fill_value=no_click_score
).to(scores_per_candidate.device)
# Concatenate click and no-click scores.
all_scores = torch.cat([scores_per_candidate, score_no_click], dim=1)
# Logits: Softmax to yield probabilities.
if multinomial_logits:
all_scores = nn.functional.softmax(all_scores)
# Multinomial proportional model: Shift to `[0.0,..[`.
else:
all_scores = all_scores - min_normalizer
# Return click (per candidate document) and no-click scores.
return all_scores[:, :-1], all_scores[:, -1]
def setup_early(policy, obs_space, action_space, config):
"""Obtain all possible slates given current docs in the candidate set."""
num_candidates = action_space.nvec[0]
slate_size = len(action_space.nvec)
mesh_args = [torch.Tensor(list(range(num_candidates)))] * slate_size
slates = torch.stack(torch.meshgrid(*mesh_args), dim=-1)
slates = torch.reshape(slates, shape=(-1, slate_size))
# Filter slates that include duplicates to ensure each document is picked
# at most once.
unique_mask = []
for i in range(slates.shape[0]):
x = slates[i]
unique_mask.append(len(x) == len(torch.unique(x)))
unique_mask = torch.Tensor(unique_mask).bool().unsqueeze(1)
# slates.shape: [A, S]
slates = torch.masked_select(slates, mask=unique_mask).reshape([-1, slate_size])
# Store all possible slates only once in policy object.
policy.slates = slates.long()
# [1, AxS] Useful for torch.take_along_dim()
policy.slates_indices = policy.slates.reshape(-1).unsqueeze(0)
def optimizer_fn(
policy: Policy, config: AlgorithmConfigDict
) -> Tuple["torch.optim.Optimizer"]:
optimizer_choice = torch.optim.Adam(
policy.model.choice_model.parameters(), lr=config["lr_choice_model"]
)
optimizer_q_value = torch.optim.RMSprop(
policy.model.q_model.parameters(),
lr=config["lr"],
eps=config["rmsprop_epsilon"],
momentum=0.0,
weight_decay=0.95,
centered=True,
)
return optimizer_choice, optimizer_q_value
def postprocess_fn_add_next_actions_for_sarsa(
policy: Policy, batch: SampleBatch, other_agent=None, episode=None
) -> SampleBatch:
"""Add next_actions to SampleBatch for SARSA training"""
if policy.config["slateq_strategy"] == "SARSA":
if not batch["dones"][-1] and policy._no_tracing is False:
raise RuntimeError(
"Expected a complete episode in each sample batch. "
f"But this batch is not: {batch}."
)
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
return batch
def setup_late_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: AlgorithmConfigDict,
) -> None:
"""Call all mixin classes' constructors before SlateQTorchPolicy initialization.
Args:
policy: The Policy object.
obs_space: The Policy's observation space.
action_space: The Policy's action space.
config: The Policy's config.
"""
TargetNetworkMixin.__init__(policy)
SlateQTorchPolicy = build_policy_class(
name="SlateQTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.algorithms.slateq.slateq.DEFAULT_CONFIG,
before_init=setup_early,
after_init=setup_late_mixins,
loss_fn=build_slateq_losses,
stats_fn=build_slateq_stats,
# Build model, loss functions, and optimizers
make_model_and_action_dist=build_slateq_model_and_distribution,
optimizer_fn=optimizer_fn,
# Define how to act.
action_distribution_fn=action_distribution_fn,
# Post processing sampled trajectory data.
# postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
extra_grad_process_fn=apply_grad_clipping,
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
mixins=[TargetNetworkMixin],
)