mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
from typing import List, Sequence
|
|
|
|
import gym
|
|
from ray.rllib.models.torch.misc import SlimFC
|
|
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|
|
|
torch, nn = try_import_torch()
|
|
F = None
|
|
if nn:
|
|
F = nn.functional
|
|
|
|
|
|
class QValueModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
fcnet_hiddens_per_candidate=(256, 32),
|
|
):
|
|
"""Initializes a QValueModel instance.
|
|
|
|
Each document candidate receives one full Q-value stack, defined by
|
|
`fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
|
|
is always {[user] concat [document[i]] for i in document_candidates}.
|
|
|
|
Extra model kwargs:
|
|
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
|
|
candidate documents.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.orig_obs_space = obs_space
|
|
self.embedding_size = self.orig_obs_space["doc"]["0"].shape[0]
|
|
self.num_candidates = len(self.orig_obs_space["doc"])
|
|
assert self.orig_obs_space["user"].shape[0] == self.embedding_size
|
|
|
|
self.q_nets = nn.ModuleList()
|
|
for i in range(self.num_candidates):
|
|
layers = nn.Sequential()
|
|
ins = 2 * self.embedding_size
|
|
for j, h in enumerate(fcnet_hiddens_per_candidate):
|
|
layers.add_module(
|
|
f"q_layer_{i}_{j}",
|
|
SlimFC(in_size=ins, out_size=h, activation_fn="relu"),
|
|
)
|
|
ins = h
|
|
layers.add_module(f"q_out_{i}", SlimFC(ins, 1, activation_fn=None))
|
|
|
|
self.q_nets.append(layers)
|
|
|
|
def forward(self, user: TensorType, docs: List[TensorType]) -> TensorType:
|
|
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
|
|
|
|
Args:
|
|
user: [B x u] where u=embedding of user features.
|
|
docs: List[[B x d]] where d=embedding of doc features. Each item in the
|
|
list represents one document candidate.
|
|
|
|
Returns:
|
|
Tensor ([batch, num candidates) of Q-values.
|
|
1 Q-value per document candidate.
|
|
"""
|
|
q_outs = []
|
|
for i in range(self.num_candidates):
|
|
user_cat_doc = torch.cat([user, docs[i]], dim=1)
|
|
q_outs.append(self.q_nets[i](user_cat_doc))
|
|
|
|
return torch.cat(q_outs, dim=1)
|
|
|
|
|
|
class UserChoiceModel(nn.Module):
|
|
"""The user choice model for SlateQ.
|
|
|
|
This class implements a multinomial logit model for predicting user clicks.
|
|
|
|
Under this model, the click probability of a document is proportional to:
|
|
|
|
.. math::
|
|
\exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes a UserChoiceModel instance."""
|
|
super().__init__()
|
|
self.beta = nn.Parameter(torch.tensor(0.0, dtype=torch.float))
|
|
self.score_no_click = nn.Parameter(torch.tensor(0.0, dtype=torch.float))
|
|
|
|
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
|
|
"""Evaluate the user choice model.
|
|
|
|
This function outputs user click scores for candidate documents. The
|
|
exponentials of these scores are proportional user click probabilities.
|
|
Here we return the scores unnormalized because because only some of the
|
|
documents will be selected and shown to the user.
|
|
|
|
Args:
|
|
user: User embeddings of shape (batch_size, user embedding size).
|
|
doc: Doc embeddings of shape (batch_size, num_docs, doc embedding size).
|
|
|
|
Returns:
|
|
score: logits of shape (batch_size, num_docs + 1),
|
|
where the last dimension represents no_click.
|
|
"""
|
|
batch_size = user.shape[0]
|
|
# Reduce across the embedding axis.
|
|
s = torch.einsum("be,bde->bd", user, doc)
|
|
# s=[batch, num-docs]
|
|
|
|
# Multiply with learnable single "click" weight.
|
|
s = s * self.beta
|
|
# Add the learnable no-click score.
|
|
s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
|
|
|
|
return s
|
|
|
|
|
|
class SlateQTorchModel(TorchModelV2, nn.Module):
|
|
"""Initializes a SlateQTFModel instance.
|
|
|
|
Model includes both the user choice model and the Q-value model.
|
|
|
|
For the Q-value model, each document candidate receives one full Q-value
|
|
stack, defined by `fcnet_hiddens_per_candidate`. The input to each of these
|
|
Q-value stacks is always {[user] concat [document[i]] for i in document_candidates}.
|
|
|
|
Extra model kwargs:
|
|
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
|
|
candidate documents.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: int,
|
|
model_config: ModelConfigDict,
|
|
name: str,
|
|
*,
|
|
fcnet_hiddens_per_candidate: Sequence[int] = (256, 32),
|
|
double_q: bool = True,
|
|
):
|
|
"""Initializes a SlateQModel instance.
|
|
|
|
Args:
|
|
user_embedding_size: The size of the user embedding (number of
|
|
user specific features).
|
|
doc_embedding_size: The size of the doc embedding (number of doc
|
|
specific features).
|
|
num_docs: The number of docs to select a slate from. Note that the slate
|
|
size is inferred from the action space.
|
|
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
|
|
candidate documents.
|
|
double_q: Whether "double Q-learning" is applied in the loss function.
|
|
"""
|
|
nn.Module.__init__(self)
|
|
TorchModelV2.__init__(
|
|
self,
|
|
obs_space,
|
|
action_space,
|
|
# This required parameter (num_outputs) seems redundant: it has no
|
|
# real impact, and can be set arbitrarily. TODO: fix this.
|
|
num_outputs=0,
|
|
model_config=model_config,
|
|
name=name,
|
|
)
|
|
self.num_outputs = num_outputs
|
|
|
|
self.choice_model = UserChoiceModel()
|
|
|
|
self.q_model = QValueModel(self.obs_space, fcnet_hiddens_per_candidate)
|
|
|
|
def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
|
|
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
|
|
|
|
Args:
|
|
user: [B x u] where u=embedding of user features.
|
|
docs: List[[B x d]] where d=embedding of doc features. Each item in the
|
|
list represents one document candidate.
|
|
|
|
Returns:
|
|
Tensor ([batch, num candidates) of Q-values.
|
|
1 Q-value per document candidate.
|
|
"""
|
|
return self.q_model(user, docs)
|