ray/rllib/agents/slateq/slateq_tf_model.py

77 lines
2.8 KiB
Python

"""Tensorflow model for SlateQ"""
from typing import List
import gym
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
class SlateQTFModel(TFModelV2):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
fcnet_hiddens_per_candidate=(256, 32),
):
"""Initializes a SlateQTFModel instance.
Each document candidate receives one full Q-value stack, defined by
`fcnet_hiddens_per_candidate`. The input to each of these Q-value stacks
is always {[user] concat [document[i]] for i in document_candidates}.
Extra model kwargs:
fcnet_hiddens_per_candidate: List of layer-sizes for each(!) of the
candidate documents.
"""
super(SlateQTFModel, self).__init__(
obs_space, action_space, None, model_config, name
)
self.embedding_size = self.obs_space["doc"]["0"].shape[0]
self.num_candidates = len(self.obs_space["doc"])
assert self.obs_space["user"].shape[0] == self.embedding_size
# Setup the Q head output (i.e., model for get_q_values)
self.user_in = tf.keras.layers.Input(
shape=(self.embedding_size,), name="user_in"
)
self.docs_in = tf.keras.layers.Input(
shape=(self.embedding_size * self.num_candidates,), name="docs_in"
)
self.num_outputs = num_outputs
q_outs = []
for i in range(self.num_candidates):
doc = self.docs_in[
:, self.embedding_size * i : self.embedding_size * (i + 1)
]
out = tf.keras.layers.concatenate([self.user_in, doc], axis=1)
for h in fcnet_hiddens_per_candidate:
out = tf.keras.layers.Dense(h, activation=tf.nn.relu)(out)
q_value = tf.keras.layers.Dense(1, name=f"q_value_{i}")(out)
q_outs.append(q_value)
q_outs = tf.concat(q_outs, axis=1)
self.q_value_head = tf.keras.Model([self.user_in, self.docs_in], q_outs)
def get_q_values(self, user: TensorType, docs: List[TensorType]) -> TensorType:
"""Returns Q-values, 1 for each candidate document, given user and doc tensors.
Args:
user: [B x u] where u=embedding of user features.
docs: List[[B x d]] where d=embedding of doc features. Each item in the
list represents one document candidate.
Returns:
Tensor ([batch, num candidates) of Q-values.
1 Q-value per document candidate.
"""
return self.q_value_head([user, tf.concat(docs, 1)])