ray/rllib/algorithms/slateq/slateq_tf_model.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

78 lines
2.8 KiB
Python
Raw Normal View History

"""Tensorflow model for SlateQ"""
from typing import List
import gym
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
class SlateQTFModel(TFModelV2):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
fcnet_hiddens_per_candidate=(256, 32),
):
"""Initializes a SlateQTFModel instance.
Each document candidate receives one full Q-value stack, defined by
`fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
is always {[user] concat [document[i]] for i in document_candidates}.
Extra model kwargs:
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
candidate documents.
"""
super(SlateQTFModel, self).__init__(
obs_space, action_space, None, model_config, name
)
self.embedding_size = self.obs_space["doc"]["0"].shape[0]
self.num_candidates = len(self.obs_space["doc"])
assert self.obs_space["user"].shape[0] == self.embedding_size
# Setup the Q head output (i.e., model for get_q_values)
self.user_in = tf.keras.layers.Input(
shape=(self.embedding_size,), name="user_in"
)
self.docs_in = tf.keras.layers.Input(
shape=(self.embedding_size * self.num_candidates,), name="docs_in"
)
self.num_outputs = num_outputs
q_outs = []
for i in range(self.num_candidates):
doc = self.docs_in[
:, self.embedding_size * i : self.embedding_size * (i + 1)
]
out = tf.keras.layers.concatenate([self.user_in, doc], axis=1)
for h in fcnet_hiddens_per_candidate:
out = tf.keras.layers.Dense(h, activation=tf.nn.relu)(out)
q_value = tf.keras.layers.Dense(1, name=f"q_value_{i}")(out)
q_outs.append(q_value)
q_outs = tf.concat(q_outs, axis=1)
self.q_value_head = tf.keras.Model([self.user_in, self.docs_in], q_outs)
def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
Args:
user: [B x u] where u=embedding of user features.
docs: List[[B x d]] where d=embedding of doc features. Each item in the
list represents one document candidate.
Returns:
Tensor ([batch, num candidates) of Q-values.
1 Q-value per document candidate.
"""
return self.q_value_head([user, tf.concat(docs, 1)])