ray/rllib/agents/slateq/slateq_torch_policy.py
2020-12-26 20:14:18 -05:00

421 lines
16 KiB
Python

"""PyTorch policy class used for SlateQ"""
from typing import Dict, List, Sequence, Tuple
import gym
import numpy as np
import ray
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
TorchDistributionWrapper)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
TrainerConfigDict)
torch, nn = try_import_torch()
F = None
if nn:
F = nn.functional
class QValueModel(nn.Module):
"""The Q-value model for SlateQ"""
def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
super().__init__()
# construct hidden layers
layers = []
ins = 2 * embedding_size
for n in q_hiddens:
layers.append(nn.Linear(ins, n))
layers.append(nn.LeakyReLU())
ins = n
layers.append(nn.Linear(ins, 1))
self.layers = nn.Sequential(*layers)
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
"""Evaluate the user-doc Q model
Args:
user (TensorType): User embedding of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
embedding_size).
Returns:
score (TensorType): q_values of shape (batch_size, num_docs + 1).
"""
batch_size, num_docs, embedding_size = doc.shape
doc_flat = doc.view((batch_size * num_docs, embedding_size))
user_repeated = user.repeat(num_docs, 1)
x = torch.cat([user_repeated, doc_flat], dim=1)
x = self.layers(x)
# Similar to Google's SlateQ implementation in RecSim, we force the
# Q-values to zeros if there are no clicks.
x_no_click = torch.zeros((batch_size, 1), device=x.device)
return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
class UserChoiceModel(nn.Module):
r"""The user choice model for SlateQ
This class implements a multinomial logit model for predicting user clicks.
Under this model, the click probability of a document is proportional to:
.. math::
\exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
"""
def __init__(self):
super().__init__()
self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
"""Evaluate the user choice model
This function outputs user click scores for candidate documents. The
exponentials of these scores are proportional user click probabilities.
Here we return the scores unnormalized because because only some of the
documents will be selected and shown to the user.
Args:
user (TensorType): User embeddings of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
embedding_size).
Returns:
score (TensorType): logits of shape (batch_size, num_docs + 1),
where the last dimension represents no_click.
"""
batch_size = user.shape[0]
s = torch.einsum("be,bde->bd", user, doc)
s = s * self.beta
s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
return s
class SlateQModel(TorchModelV2, nn.Module):
"""The SlateQ model class
It includes both the user choice model and the Q-value model.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
model_config: ModelConfigDict,
name: str,
*,
embedding_size: int,
q_hiddens: Sequence[int],
):
nn.Module.__init__(self)
TorchModelV2.__init__(
self,
obs_space,
action_space,
# This required parameter (num_outputs) seems redundant: it has no
# real imact, and can be set arbitrarily. TODO: fix this.
num_outputs=0,
model_config=model_config,
name=name)
self.choice_model = UserChoiceModel()
self.q_model = QValueModel(embedding_size, q_hiddens)
self.slate_size = len(action_space.nvec)
def choose_slate(self, user: TensorType,
doc: TensorType) -> Tuple[TensorType, TensorType]:
"""Build a slate by selecting from candidate documents
Args:
user (TensorType): User embeddings of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size,
num_docs, embedding_size).
Returns:
slate_selected (TensorType): Indices of documents selected for
the slate, with shape (batch_size, slate_size).
best_slate_q_value (TensorType): The Q-value of the selected slate,
with shape (batch_size).
"""
# Step 1: compute item scores (proportional to click probabilities)
# raw_scores.shape=[batch_size, num_docs+1]
raw_scores = self.choice_model(user, doc)
# max_raw_scores.shape=[batch_size, 1]
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
# deduct scores by max_scores to avoid value explosion
scores = torch.exp(raw_scores - max_raw_scores)
scores_doc = scores[:, :-1] # shape=[batch_size, num_docs]
scores_no_click = scores[:, [-1]] # shape=[batch_size, 1]
# Step 2: calculate the item-wise Q values
# q_values.shape=[batch_size, num_docs+1]
q_values = self.q_model(user, doc)
q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs]
q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1]
# Step 3: construct all possible slates
_, num_docs, _ = doc.shape
indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
# slates.shape = [num_slates, slate_size]
slates = torch.combinations(indices, r=self.slate_size)
num_slates, _ = slates.shape
# Step 4: calculate slate Q values
batch_size, _ = q_values_doc.shape
# slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
slate_decomp_q_values = torch.gather(
# input.shape: [batch_size, num_slates, num_docs]
input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
dim=2,
# index.shape: [batch_size, num_slates, slate_size]
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
# slate_scores.shape: [batch_size, num_slates, slate_size]
slate_scores = torch.gather(
# input.shape: [batch_size, num_slates, num_docs]
input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
dim=2,
# index.shape: [batch_size, num_slates, slate_size]
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
# slate_q_values.shape: [batch_size, num_slates]
slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
(q_values_no_click * scores_no_click)) / (
slate_scores.sum(dim=2) + scores_no_click)
# Step 5: find the slate that maximizes q value
best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
# slates_selected.shape: [batch_size, slate_size]
slates_selected = slates[max_idx]
return slates_selected, best_slate_q_value
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
# user.shape: [batch_size, embedding_size]
user = input_dict[SampleBatch.OBS]["user"]
# doc.shape: [batch_size, num_docs, embedding_size]
doc = torch.cat([
val.unsqueeze(1)
for val in input_dict[SampleBatch.OBS]["doc"].values()
], 1)
slates_selected, _ = self.choose_slate(user, doc)
state_out = []
return slates_selected, state_out
def build_slateq_model_and_distribution(
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
"""Build models for SlateQ
Args:
policy (Policy): The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
(q_model, TorchCategorical)
"""
model = SlateQModel(
obs_space,
action_space,
model_config=config["model"],
name="slateq_model",
embedding_size=config["recsim_embedding_size"],
q_hiddens=config["hiddens"],
)
return model, TorchCategorical
def build_slateq_losses(policy: Policy, model: SlateQModel, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the losses for SlateQPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
obs = restore_original_dimensions(
train_batch[SampleBatch.OBS],
policy.observation_space,
tensorlib=torch)
# user.shape: [batch_size, embedding_size]
user = obs["user"]
# doc.shape: [batch_size, num_docs, embedding_size]
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
# action.shape: [batch_size, slate_size]
actions = train_batch[SampleBatch.ACTIONS]
next_obs = restore_original_dimensions(
train_batch[SampleBatch.NEXT_OBS],
policy.observation_space,
tensorlib=torch)
# Step 1: Build user choice model loss
_, _, embedding_size = doc.shape
# selected_doc.shape: [batch_size, slate_size, embedding_size]
selected_doc = torch.gather(
# input.shape: [batch_size, num_docs, embedding_size]
input=doc,
dim=1,
# index.shape: [batch_size, slate_size, embedding_size]
index=actions.unsqueeze(2).expand(-1, -1, embedding_size))
scores = model.choice_model(user, selected_doc)
choice_loss_fn = nn.CrossEntropyLoss()
# clicks.shape: [batch_size, slate_size]
clicks = torch.stack(
[resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
# clicks.shape: [batch_size, slate_size+1]
targets = torch.cat([clicks, no_clicks], dim=1)
choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
# print(model.choice_model.a.item(), model.choice_model.b.item())
# Step 2: Build qvalue loss
# Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
# 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
# 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
# 'batch_indexes']
learning_strategy = policy.config["slateq_strategy"]
if learning_strategy == "SARSA":
# next_doc.shape: [batch_size, num_docs, embedding_size]
next_doc = torch.cat(
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
next_actions = train_batch["next_actions"]
_, _, embedding_size = next_doc.shape
# selected_doc.shape: [batch_size, slate_size, embedding_size]
next_selected_doc = torch.gather(
# input.shape: [batch_size, num_docs, embedding_size]
input=next_doc,
dim=1,
# index.shape: [batch_size, slate_size, embedding_size]
index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
next_user = next_obs["user"]
dones = train_batch["dones"]
with torch.no_grad():
# q_values.shape: [batch_size, slate_size+1]
q_values = model.q_model(next_user, next_selected_doc)
# raw_scores.shape: [batch_size, slate_size+1]
raw_scores = model.choice_model(next_user, next_selected_doc)
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
scores = torch.exp(raw_scores - max_raw_scores)
# next_q_values.shape: [batch_size]
next_q_values = torch.sum(
q_values * scores, dim=1) / torch.sum(
scores, dim=1)
next_q_values[dones] = 0.0
elif learning_strategy == "MYOP":
next_q_values = 0.
elif learning_strategy == "QL":
# next_doc.shape: [batch_size, num_docs, embedding_size]
next_doc = torch.cat(
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
next_user = next_obs["user"]
dones = train_batch["dones"]
with torch.no_grad():
_, next_q_values = model.choose_slate(next_user, next_doc)
next_q_values[dones] = 0.0
else:
raise ValueError(learning_strategy)
# target_q_values.shape: [batch_size]
target_q_values = next_q_values + train_batch["rewards"]
q_values = model.q_model(user,
selected_doc) # shape: [batch_size, slate_size+1]
# raw_scores.shape: [batch_size, slate_size+1]
raw_scores = model.choice_model(user, selected_doc)
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
scores = torch.exp(raw_scores - max_raw_scores)
q_values = torch.sum(
q_values * scores, dim=1) / torch.sum(
scores, dim=1) # shape=[batch_size]
q_value_loss = nn.MSELoss()(q_values, target_q_values)
return [choice_loss, q_value_loss]
def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
) -> List["torch.optim.Optimizer"]:
optimizer_choice = torch.optim.Adam(
policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
optimizer_q_value = torch.optim.Adam(
policy.model.q_model.parameters(),
lr=config["lr_q_model"],
eps=config["adam_epsilon"])
return [optimizer_choice, optimizer_q_value]
def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
explore, timestep):
"""Determine which action to take"""
# First, we transform the observation into its unflattened form
obs = restore_original_dimensions(
input_dict[SampleBatch.CUR_OBS],
policy.observation_space,
tensorlib=torch)
# user.shape: [batch_size(=1), embedding_size]
user = obs["user"]
# doc.shape: [batch_size(=1), num_docs, embedding_size]
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
selected_slates, _ = model.choose_slate(user, doc)
action = selected_slates
logp = None
state_out = []
return action, logp, state_out
def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
batch: SampleBatch,
other_agent=None,
episode=None) -> SampleBatch:
"""Add next_actions to SampleBatch for SARSA training"""
if policy.config["slateq_strategy"] == "SARSA":
if not batch["dones"][-1]:
raise RuntimeError(
"Expected a complete episode in each sample batch. "
f"But this batch is not: {batch}.")
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
return batch
SlateQTorchPolicy = build_policy_class(
name="SlateQTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
# build model, loss functions, and optimizers
make_model_and_action_dist=build_slateq_model_and_distribution,
optimizer_fn=build_slateq_optimizers,
loss_fn=build_slateq_losses,
# define how to act
action_sampler_fn=action_sampler_fn,
# post processing batch sampled data
postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
)