mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
439 lines
16 KiB
Python
439 lines
16 KiB
Python
"""PyTorch policy class used for SlateQ."""
|
|
|
|
import gym
|
|
import logging
|
|
import numpy as np
|
|
from typing import Dict, Tuple, Type
|
|
|
|
import ray
|
|
from ray.rllib.algorithms.sac.sac_torch_policy import TargetNetworkMixin
|
|
from ray.rllib.algorithms.slateq.slateq_torch_model import SlateQTorchModel
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.torch.torch_action_dist import (
|
|
TorchCategorical,
|
|
TorchDistributionWrapper,
|
|
)
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_utils import (
|
|
apply_grad_clipping,
|
|
concat_multi_gpu_td_errors,
|
|
convert_to_torch_tensor,
|
|
huber_loss,
|
|
)
|
|
from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict
|
|
|
|
torch, nn = try_import_torch()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def build_slateq_model_and_distribution(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: AlgorithmConfigDict,
|
|
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
|
"""Build models for SlateQ
|
|
|
|
Args:
|
|
policy: The policy, which will use the model for optimization.
|
|
obs_space: The policy's observation space.
|
|
action_space: The policy's action space.
|
|
config: The Algorithm's config dict.
|
|
|
|
Returns:
|
|
Tuple consisting of 1) Q-model and 2) an action distribution class.
|
|
"""
|
|
model = SlateQTorchModel(
|
|
obs_space,
|
|
action_space,
|
|
num_outputs=action_space.nvec[0],
|
|
model_config=config["model"],
|
|
name="slateq_model",
|
|
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
|
|
)
|
|
|
|
policy.target_model = SlateQTorchModel(
|
|
obs_space,
|
|
action_space,
|
|
num_outputs=action_space.nvec[0],
|
|
model_config=config["model"],
|
|
name="target_slateq_model",
|
|
fcnet_hiddens_per_candidate=config["fcnet_hiddens_per_candidate"],
|
|
)
|
|
|
|
return model, TorchCategorical
|
|
|
|
|
|
def build_slateq_losses(
|
|
policy: Policy,
|
|
model: ModelV2,
|
|
_,
|
|
train_batch: SampleBatch,
|
|
) -> TensorType:
|
|
"""Constructs the choice- and Q-value losses for the SlateQTorchPolicy.
|
|
|
|
Args:
|
|
policy: The Policy to calculate the loss for.
|
|
model: The Model to calculate the loss for.
|
|
train_batch: The training data.
|
|
|
|
Returns:
|
|
The user-choice- and Q-value loss tensors.
|
|
"""
|
|
|
|
# B=batch size
|
|
# S=slate size
|
|
# C=num candidates
|
|
# E=embedding size
|
|
# A=number of all possible slates
|
|
|
|
# Q-value computations.
|
|
# ---------------------
|
|
# action.shape: [B, S]
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
|
|
observation = convert_to_torch_tensor(
|
|
train_batch[SampleBatch.OBS], device=actions.device
|
|
)
|
|
# user.shape: [B, E]
|
|
user_obs = observation["user"]
|
|
batch_size, embedding_size = user_obs.shape
|
|
# doc.shape: [B, C, E]
|
|
doc_obs = list(observation["doc"].values())
|
|
|
|
A, S = policy.slates.shape
|
|
|
|
# click_indicator.shape: [B, S]
|
|
click_indicator = torch.stack(
|
|
[k["click"] for k in observation["response"]], 1
|
|
).float()
|
|
# item_reward.shape: [B, S]
|
|
item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1)
|
|
# q_values.shape: [B, C]
|
|
q_values = model.get_q_values(user_obs, doc_obs)
|
|
# slate_q_values.shape: [B, S]
|
|
slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1)
|
|
# Only get the Q from the clicked document.
|
|
# replay_click_q.shape: [B]
|
|
replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1)
|
|
|
|
# Target computations.
|
|
# --------------------
|
|
next_obs = convert_to_torch_tensor(
|
|
train_batch[SampleBatch.NEXT_OBS], device=actions.device
|
|
)
|
|
|
|
# user.shape: [B, E]
|
|
user_next_obs = next_obs["user"]
|
|
# doc.shape: [B, C, E]
|
|
doc_next_obs = list(next_obs["doc"].values())
|
|
# Only compute the watch time reward of the clicked item.
|
|
reward = torch.sum(item_reward * click_indicator, dim=1)
|
|
|
|
# TODO: Find out, whether it's correct here to use obs, not next_obs!
|
|
# Dopamine uses obs, then next_obs only for the score.
|
|
# next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
|
|
next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
|
|
scores, score_no_click = score_documents(user_next_obs, doc_next_obs)
|
|
|
|
# next_q_values_slate.shape: [B, A, S]
|
|
indices = policy.slates_indices.to(next_q_values.device)
|
|
next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
|
|
[-1, A, S]
|
|
)
|
|
# scores_slate.shape [B, A, S]
|
|
scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
|
|
# score_no_click_slate.shape: [B, A]
|
|
score_no_click_slate = torch.reshape(
|
|
torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
|
|
)
|
|
|
|
# next_q_target_slate.shape: [B, A]
|
|
next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / (
|
|
torch.sum(scores_slate, dim=2) + score_no_click_slate
|
|
)
|
|
next_q_target_max, _ = torch.max(next_q_target_slate, dim=1)
|
|
|
|
target = reward + policy.config["gamma"] * next_q_target_max * (
|
|
1.0 - train_batch["dones"].float()
|
|
)
|
|
target = target.detach()
|
|
|
|
clicked = torch.sum(click_indicator, dim=1)
|
|
mask_clicked_slates = clicked > 0
|
|
clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
|
|
clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
|
|
# Clicked_indices is a vector and torch.gather selects the batch dimension.
|
|
q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
|
|
target_clicked = torch.gather(target, 0, clicked_indices)
|
|
|
|
td_error = torch.where(
|
|
clicked.bool(),
|
|
replay_click_q - target,
|
|
torch.zeros_like(train_batch[SampleBatch.REWARDS]),
|
|
)
|
|
if policy.config["use_huber"]:
|
|
loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
|
|
else:
|
|
loss = torch.pow(td_error, 2.0)
|
|
loss = torch.mean(loss)
|
|
td_error = torch.abs(td_error)
|
|
mean_td_error = torch.mean(td_error)
|
|
|
|
# Store values for stats function in model (tower), such that for
|
|
# multi-GPU, we do not override them during the parallel loss phase.
|
|
model.tower_stats["q_values"] = torch.mean(q_values)
|
|
model.tower_stats["q_clicked"] = torch.mean(q_clicked)
|
|
model.tower_stats["scores"] = torch.mean(scores)
|
|
model.tower_stats["score_no_click"] = torch.mean(score_no_click)
|
|
model.tower_stats["slate_q_values"] = torch.mean(slate_q_values)
|
|
model.tower_stats["replay_click_q"] = torch.mean(replay_click_q)
|
|
model.tower_stats["bellman_reward"] = torch.mean(reward)
|
|
model.tower_stats["next_q_values"] = torch.mean(next_q_values)
|
|
model.tower_stats["target"] = torch.mean(target)
|
|
model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate)
|
|
model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max)
|
|
model.tower_stats["target_clicked"] = torch.mean(target_clicked)
|
|
model.tower_stats["q_loss"] = loss
|
|
model.tower_stats["td_error"] = td_error
|
|
model.tower_stats["mean_td_error"] = mean_td_error
|
|
model.tower_stats["mean_actions"] = torch.mean(actions.float())
|
|
|
|
# selected_doc.shape: [batch_size, slate_size, embedding_size]
|
|
selected_doc = torch.gather(
|
|
# input.shape: [batch_size, num_docs, embedding_size]
|
|
torch.stack(doc_obs, 1),
|
|
1,
|
|
# index.shape: [batch_size, slate_size, embedding_size]
|
|
actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
|
|
)
|
|
|
|
scores = model.choice_model(user_obs, selected_doc)
|
|
|
|
# click_indicator.shape: [batch_size, slate_size]
|
|
# no_clicks.shape: [batch_size, 1]
|
|
no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True)
|
|
# targets.shape: [batch_size, slate_size+1]
|
|
targets = torch.cat([click_indicator, no_clicks], dim=1)
|
|
choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1))
|
|
# print(model.choice_model.a.item(), model.choice_model.b.item())
|
|
|
|
model.tower_stats["choice_loss"] = choice_loss
|
|
|
|
return choice_loss, loss
|
|
|
|
|
|
def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
|
stats = {
|
|
"q_values": torch.mean(torch.stack(policy.get_tower_stats("q_values"))),
|
|
"q_clicked": torch.mean(torch.stack(policy.get_tower_stats("q_clicked"))),
|
|
"scores": torch.mean(torch.stack(policy.get_tower_stats("scores"))),
|
|
"score_no_click": torch.mean(
|
|
torch.stack(policy.get_tower_stats("score_no_click"))
|
|
),
|
|
"slate_q_values": torch.mean(
|
|
torch.stack(policy.get_tower_stats("slate_q_values"))
|
|
),
|
|
"replay_click_q": torch.mean(
|
|
torch.stack(policy.get_tower_stats("replay_click_q"))
|
|
),
|
|
"bellman_reward": torch.mean(
|
|
torch.stack(policy.get_tower_stats("bellman_reward"))
|
|
),
|
|
"next_q_values": torch.mean(
|
|
torch.stack(policy.get_tower_stats("next_q_values"))
|
|
),
|
|
"target": torch.mean(torch.stack(policy.get_tower_stats("target"))),
|
|
"next_q_target_slate": torch.mean(
|
|
torch.stack(policy.get_tower_stats("next_q_target_slate"))
|
|
),
|
|
"next_q_target_max": torch.mean(
|
|
torch.stack(policy.get_tower_stats("next_q_target_max"))
|
|
),
|
|
"target_clicked": torch.mean(
|
|
torch.stack(policy.get_tower_stats("target_clicked"))
|
|
),
|
|
"q_loss": torch.mean(torch.stack(policy.get_tower_stats("q_loss"))),
|
|
"mean_actions": torch.mean(torch.stack(policy.get_tower_stats("mean_actions"))),
|
|
"choice_loss": torch.mean(torch.stack(policy.get_tower_stats("choice_loss"))),
|
|
# "choice_beta": torch.mean(torch.stack(policy.get_tower_stats("choice_beta"))),
|
|
# "choice_score_no_click": torch.mean(
|
|
# torch.stack(policy.get_tower_stats("choice_score_no_click"))
|
|
# ),
|
|
}
|
|
# model_stats = {
|
|
# k: torch.mean(var)
|
|
# for k, var in policy.model.trainable_variables(as_dict=True).items()
|
|
# }
|
|
# stats.update(model_stats)
|
|
|
|
return stats
|
|
|
|
|
|
def action_distribution_fn(
|
|
policy: Policy,
|
|
model: SlateQTorchModel,
|
|
input_dict,
|
|
*,
|
|
explore,
|
|
is_training,
|
|
**kwargs,
|
|
):
|
|
"""Determine which action to take."""
|
|
|
|
observation = input_dict[SampleBatch.OBS]
|
|
|
|
# user.shape: [B, E]
|
|
user_obs = observation["user"]
|
|
doc_obs = list(observation["doc"].values())
|
|
|
|
# Compute scores per candidate.
|
|
scores, score_no_click = score_documents(user_obs, doc_obs)
|
|
# Compute Q-values per candidate.
|
|
q_values = model.get_q_values(user_obs, doc_obs)
|
|
|
|
per_slate_q_values = get_per_slate_q_values(
|
|
policy, score_no_click, scores, q_values
|
|
)
|
|
if not hasattr(model, "slates"):
|
|
model.slates = policy.slates
|
|
return per_slate_q_values, TorchCategorical, []
|
|
|
|
|
|
def get_per_slate_q_values(policy, score_no_click, scores, q_values):
|
|
indices = policy.slates_indices.to(scores.device)
|
|
A, S = policy.slates.shape
|
|
slate_q_values = torch.take_along_dim(scores * q_values, indices, dim=1).reshape(
|
|
[-1, A, S]
|
|
)
|
|
slate_scores = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
|
|
slate_normalizer = torch.sum(slate_scores, dim=2) + score_no_click.unsqueeze(1)
|
|
|
|
slate_q_values = slate_q_values / slate_normalizer.unsqueeze(2)
|
|
slate_sum_q_values = torch.sum(slate_q_values, dim=2)
|
|
return slate_sum_q_values
|
|
|
|
|
|
def score_documents(
|
|
user_obs, doc_obs, no_click_score=1.0, multinomial_logits=False, min_normalizer=-1.0
|
|
):
|
|
"""Computes dot-product scores for user vs doc (plus no-click) feature vectors."""
|
|
|
|
# Dot product between used and each document feature vector.
|
|
scores_per_candidate = torch.sum(
|
|
torch.multiply(user_obs.unsqueeze(1), torch.stack(doc_obs, dim=1)), dim=2
|
|
)
|
|
# Compile a constant no-click score tensor.
|
|
score_no_click = torch.full(
|
|
size=[user_obs.shape[0], 1], fill_value=no_click_score
|
|
).to(scores_per_candidate.device)
|
|
# Concatenate click and no-click scores.
|
|
all_scores = torch.cat([scores_per_candidate, score_no_click], dim=1)
|
|
|
|
# Logits: Softmax to yield probabilities.
|
|
if multinomial_logits:
|
|
all_scores = nn.functional.softmax(all_scores)
|
|
# Multinomial proportional model: Shift to `[0.0,..[`.
|
|
else:
|
|
all_scores = all_scores - min_normalizer
|
|
|
|
# Return click (per candidate document) and no-click scores.
|
|
return all_scores[:, :-1], all_scores[:, -1]
|
|
|
|
|
|
def setup_early(policy, obs_space, action_space, config):
|
|
"""Obtain all possible slates given current docs in the candidate set."""
|
|
|
|
num_candidates = action_space.nvec[0]
|
|
slate_size = len(action_space.nvec)
|
|
|
|
mesh_args = [torch.Tensor(list(range(num_candidates)))] * slate_size
|
|
slates = torch.stack(torch.meshgrid(*mesh_args), dim=-1)
|
|
slates = torch.reshape(slates, shape=(-1, slate_size))
|
|
# Filter slates that include duplicates to ensure each document is picked
|
|
# at most once.
|
|
unique_mask = []
|
|
for i in range(slates.shape[0]):
|
|
x = slates[i]
|
|
unique_mask.append(len(x) == len(torch.unique(x)))
|
|
unique_mask = torch.Tensor(unique_mask).bool().unsqueeze(1)
|
|
# slates.shape: [A, S]
|
|
slates = torch.masked_select(slates, mask=unique_mask).reshape([-1, slate_size])
|
|
|
|
# Store all possible slates only once in policy object.
|
|
policy.slates = slates.long()
|
|
# [1, AxS] Useful for torch.take_along_dim()
|
|
policy.slates_indices = policy.slates.reshape(-1).unsqueeze(0)
|
|
|
|
|
|
def optimizer_fn(
|
|
policy: Policy, config: AlgorithmConfigDict
|
|
) -> Tuple["torch.optim.Optimizer"]:
|
|
optimizer_choice = torch.optim.Adam(
|
|
policy.model.choice_model.parameters(), lr=config["lr_choice_model"]
|
|
)
|
|
optimizer_q_value = torch.optim.RMSprop(
|
|
policy.model.q_model.parameters(),
|
|
lr=config["lr"],
|
|
eps=config["rmsprop_epsilon"],
|
|
momentum=0.0,
|
|
weight_decay=0.95,
|
|
centered=True,
|
|
)
|
|
return optimizer_choice, optimizer_q_value
|
|
|
|
|
|
def postprocess_fn_add_next_actions_for_sarsa(
|
|
policy: Policy, batch: SampleBatch, other_agent=None, episode=None
|
|
) -> SampleBatch:
|
|
"""Add next_actions to SampleBatch for SARSA training"""
|
|
if policy.config["slateq_strategy"] == "SARSA":
|
|
if not batch["dones"][-1] and policy._no_tracing is False:
|
|
raise RuntimeError(
|
|
"Expected a complete episode in each sample batch. "
|
|
f"But this batch is not: {batch}."
|
|
)
|
|
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
|
|
|
|
return batch
|
|
|
|
|
|
def setup_late_mixins(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: AlgorithmConfigDict,
|
|
) -> None:
|
|
"""Call all mixin classes' constructors before SlateQTorchPolicy initialization.
|
|
|
|
Args:
|
|
policy: The Policy object.
|
|
obs_space: The Policy's observation space.
|
|
action_space: The Policy's action space.
|
|
config: The Policy's config.
|
|
"""
|
|
TargetNetworkMixin.__init__(policy)
|
|
|
|
|
|
SlateQTorchPolicy = build_policy_class(
|
|
name="SlateQTorchPolicy",
|
|
framework="torch",
|
|
get_default_config=lambda: ray.rllib.algorithms.slateq.slateq.DEFAULT_CONFIG,
|
|
before_init=setup_early,
|
|
after_init=setup_late_mixins,
|
|
loss_fn=build_slateq_losses,
|
|
stats_fn=build_slateq_stats,
|
|
# Build model, loss functions, and optimizers
|
|
make_model_and_action_dist=build_slateq_model_and_distribution,
|
|
optimizer_fn=optimizer_fn,
|
|
# Define how to act.
|
|
action_distribution_fn=action_distribution_fn,
|
|
# Post processing sampled trajectory data.
|
|
# postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
|
|
extra_grad_process_fn=apply_grad_clipping,
|
|
extra_learn_fetches_fn=concat_multi_gpu_td_errors,
|
|
mixins=[TargetNetworkMixin],
|
|
)
|