mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
77 lines
2.8 KiB
Python
77 lines
2.8 KiB
Python
"""Tensorflow model for SlateQ"""
|
|
|
|
from typing import List
|
|
|
|
import gym
|
|
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
class SlateQTFModel(TFModelV2):
|
|
def __init__(
|
|
self,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
num_outputs: int,
|
|
model_config: ModelConfigDict,
|
|
name: str,
|
|
fcnet_hiddens_per_candidate=(256, 32),
|
|
):
|
|
"""Initializes a SlateQTFModel instance.
|
|
|
|
Each document candidate receives one full Q-value stack, defined by
|
|
`fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
|
|
is always {[user] concat [document[i]] for i in document_candidates}.
|
|
|
|
Extra model kwargs:
|
|
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
|
|
candidate documents.
|
|
"""
|
|
super(SlateQTFModel, self).__init__(
|
|
obs_space, action_space, None, model_config, name
|
|
)
|
|
|
|
self.embedding_size = self.obs_space["doc"]["0"].shape[0]
|
|
self.num_candidates = len(self.obs_space["doc"])
|
|
assert self.obs_space["user"].shape[0] == self.embedding_size
|
|
|
|
# Setup the Q head output (i.e., model for get_q_values)
|
|
self.user_in = tf.keras.layers.Input(
|
|
shape=(self.embedding_size,), name="user_in"
|
|
)
|
|
self.docs_in = tf.keras.layers.Input(
|
|
shape=(self.embedding_size * self.num_candidates,), name="docs_in"
|
|
)
|
|
|
|
self.num_outputs = num_outputs
|
|
|
|
q_outs = []
|
|
for i in range(self.num_candidates):
|
|
doc = self.docs_in[
|
|
:, self.embedding_size * i : self.embedding_size * (i + 1)
|
|
]
|
|
out = tf.keras.layers.concatenate([self.user_in, doc], axis=1)
|
|
for h in fcnet_hiddens_per_candidate:
|
|
out = tf.keras.layers.Dense(h, activation=tf.nn.relu)(out)
|
|
q_value = tf.keras.layers.Dense(1, name=f"q_value_{i}")(out)
|
|
q_outs.append(q_value)
|
|
q_outs = tf.concat(q_outs, axis=1)
|
|
|
|
self.q_value_head = tf.keras.Model([self.user_in, self.docs_in], q_outs)
|
|
|
|
def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
|
|
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
|
|
|
|
Args:
|
|
user: [B x u] where u=embedding of user features.
|
|
docs: List[[B x d]] where d=embedding of doc features. Each item in the
|
|
list represents one document candidate.
|
|
|
|
Returns:
|
|
Tensor ([batch, num candidates) of Q-values.
|
|
1 Q-value per document candidate.
|
|
"""
|
|
return self.q_value_head([user, tf.concat(docs, 1)])
|