ray/rllib/algorithms/slateq/slateq_torch_model.py

185 lines
6.7 KiB
Python

from typing import List, Sequence
import gym
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
torch, nn = try_import_torch()
F = None
if nn:
F = nn.functional
class QValueModel(nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
fcnet_hiddens_per_candidate=(256, 32),
):
"""Initializes a QValueModel instance.
Each document candidate receives one full Q-value stack, defined by
`fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
is always {[user] concat [document[i]] for i in document_candidates}.
Extra model kwargs:
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
candidate documents.
"""
super().__init__()
self.orig_obs_space = obs_space
self.embedding_size = self.orig_obs_space["doc"]["0"].shape[0]
self.num_candidates = len(self.orig_obs_space["doc"])
assert self.orig_obs_space["user"].shape[0] == self.embedding_size
self.q_nets = nn.ModuleList()
for i in range(self.num_candidates):
layers = nn.Sequential()
ins = 2 * self.embedding_size
for j, h in enumerate(fcnet_hiddens_per_candidate):
layers.add_module(
f"q_layer_{i}_{j}",
SlimFC(in_size=ins, out_size=h, activation_fn="relu"),
)
ins = h
layers.add_module(f"q_out_{i}", SlimFC(ins, 1, activation_fn=None))
self.q_nets.append(layers)
def forward(self, user: TensorType, docs: List[TensorType]) -> TensorType:
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
Args:
user: [B x u] where u=embedding of user features.
docs: List[[B x d]] where d=embedding of doc features. Each item in the
list represents one document candidate.
Returns:
Tensor ([batch, num candidates) of Q-values.
1 Q-value per document candidate.
"""
q_outs = []
for i in range(self.num_candidates):
user_cat_doc = torch.cat([user, docs[i]], dim=1)
q_outs.append(self.q_nets[i](user_cat_doc))
return torch.cat(q_outs, dim=1)
class UserChoiceModel(nn.Module):
"""The user choice model for SlateQ.
This class implements a multinomial logit model for predicting user clicks.
Under this model, the click probability of a document is proportional to:
.. math::
\exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
"""
def __init__(self):
"""Initializes a UserChoiceModel instance."""
super().__init__()
self.beta = nn.Parameter(torch.tensor(0.0, dtype=torch.float))
self.score_no_click = nn.Parameter(torch.tensor(0.0, dtype=torch.float))
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
"""Evaluate the user choice model.
This function outputs user click scores for candidate documents. The
exponentials of these scores are proportional user click probabilities.
Here we return the scores unnormalized because because only some of the
documents will be selected and shown to the user.
Args:
user: User embeddings of shape (batch_size, user embedding size).
doc: Doc embeddings of shape (batch_size, num_docs, doc embedding size).
Returns:
score: logits of shape (batch_size, num_docs + 1),
where the last dimension represents no_click.
"""
batch_size = user.shape[0]
# Reduce across the embedding axis.
s = torch.einsum("be,bde->bd", user, doc)
# s=[batch, num-docs]
# Multiply with learnable single "click" weight.
s = s * self.beta
# Add the learnable no-click score.
s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
return s
class SlateQTorchModel(TorchModelV2, nn.Module):
"""Initializes a SlateQTFModel instance.
Model includes both the user choice model and the Q-value model.
For the Q-value model, each document candidate receives one full Q-value
stack, defined by `fcnet_hiddens_per_candidate`. The input to each of these
Q-value stacks is always {[user] concat [document[i]] for i in document_candidates}.
Extra model kwargs:
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
candidate documents.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
*,
fcnet_hiddens_per_candidate: Sequence[int] = (256, 32),
double_q: bool = True,
):
"""Initializes a SlateQModel instance.
Args:
user_embedding_size: The size of the user embedding (number of
user specific features).
doc_embedding_size: The size of the doc embedding (number of doc
specific features).
num_docs: The number of docs to select a slate from. Note that the slate
size is inferred from the action space.
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
candidate documents.
double_q: Whether "double Q-learning" is applied in the loss function.
"""
nn.Module.__init__(self)
TorchModelV2.__init__(
self,
obs_space,
action_space,
# This required parameter (num_outputs) seems redundant: it has no
# real impact, and can be set arbitrarily. TODO: fix this.
num_outputs=0,
model_config=model_config,
name=name,
)
self.num_outputs = num_outputs
self.choice_model = UserChoiceModel()
self.q_model = QValueModel(self.obs_space, fcnet_hiddens_per_candidate)
def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
Args:
user: [B x u] where u=embedding of user features.
docs: List[[B x d]] where d=embedding of doc features. Each item in the
list represents one document candidate.
Returns:
Tensor ([batch, num candidates) of Q-values.
1 Q-value per document candidate.
"""
return self.q_model(user, docs)