mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Implement the SlateQ algorithm (#11450)
This commit is contained in:
parent
e735add268
commit
5af745c90d
7 changed files with 803 additions and 4 deletions
|
@ -110,6 +110,11 @@ def _import_simple_q():
|
|||
return dqn.SimpleQTrainer
|
||||
|
||||
|
||||
def _import_slate_q():
|
||||
from ray.rllib.agents import slateq
|
||||
return slateq.SlateQTrainer
|
||||
|
||||
|
||||
def _import_td3():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.TD3Trainer
|
||||
|
@ -127,6 +132,7 @@ ALGORITHMS = {
|
|||
"DDPG": _import_ddpg,
|
||||
"DDPPO": _import_ddppo,
|
||||
"DQN": _import_dqn,
|
||||
"SlateQ": _import_slate_q,
|
||||
"DREAMER": _import_dreamer,
|
||||
"IMPALA": _import_impala,
|
||||
"MAML": _import_maml,
|
||||
|
|
8
rllib/agents/slateq/__init__.py
Normal file
8
rllib/agents/slateq/__init__.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from ray.rllib.agents.slateq.slateq import SlateQTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
|
||||
|
||||
__all__ = [
|
||||
"SlateQTrainer",
|
||||
"SlateQTorchPolicy",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
232
rllib/agents/slateq/slateq.py
Normal file
232
rllib/agents/slateq/slateq.py
Normal file
|
@ -0,0 +1,232 @@
|
|||
"""
|
||||
SlateQ (Reinforcement Learning for Recommendation)
|
||||
==================================================
|
||||
|
||||
This file defines the trainer class for the SlateQ algorithm from the
|
||||
`"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
|
||||
Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
|
||||
paper.
|
||||
|
||||
See `slateq_torch_policy.py` for the definition of the policy. Currently, only
|
||||
PyTorch is supported. The algorithm is written and tested for Google's RecSim
|
||||
environment (https://github.com/google-research/recsim).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Type
|
||||
|
||||
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defines all SlateQ strategies implemented.
|
||||
ALL_SLATEQ_STRATEGIES = [
|
||||
# RANDOM: Randomly select documents for slates.
|
||||
"RANDOM",
|
||||
# MYOP: Select documents that maximize user click probabilities. This is
|
||||
# a myopic strategy and ignores long term rewards. This is equivalent to
|
||||
# setting a zero discount rate for future rewards.
|
||||
"MYOP",
|
||||
# SARSA: Use the SlateQ SARSA learning algorithm.
|
||||
"SARSA",
|
||||
# QL: Use the SlateQ Q-learning algorithm.
|
||||
"QL",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
# Dense-layer setup for each the advantage branch and the value branch
|
||||
# in a dueling architecture.
|
||||
"hiddens": [256, 64, 16],
|
||||
|
||||
# set batchmode
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Deep Learning Framework Settings ===
|
||||
# Currently, only PyTorch is supported
|
||||
"framework": "torch",
|
||||
|
||||
# === Exploration Settings ===
|
||||
"exploration_config": {
|
||||
# The Exploration class to use.
|
||||
"type": "EpsilonGreedy",
|
||||
# Config for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.02,
|
||||
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
|
||||
},
|
||||
# Switch to greedy actions in evaluation workers.
|
||||
"evaluation_config": {
|
||||
"explore": False,
|
||||
},
|
||||
|
||||
# Minimum env steps to optimize for per train call. This value does
|
||||
# not affect learning, only the length of iterations.
|
||||
"timesteps_per_iteration": 1000,
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": 50000,
|
||||
# Whether to LZ4 compress observations
|
||||
"compress_observations": False,
|
||||
# If set, this will fix the ratio of replayed from a buffer and learned on
|
||||
# timesteps to sampled from an environment and stored in the replay buffer
|
||||
# timesteps. Otherwise, the replay will proceed at the native ratio
|
||||
# determined by (train_batch_size / rollout_fragment_length).
|
||||
"training_intensity": None,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer for the user choice model
|
||||
"lr_choice_model": 1e-2,
|
||||
# Learning rate for adam optimizer for the q model
|
||||
"lr_q_model": 1e-2,
|
||||
# Adam epsilon hyper parameter
|
||||
"adam_epsilon": 1e-8,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_clip": 40,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 1000,
|
||||
# Size of a batch sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
"train_batch_size": 32,
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# === SlateQ specific options ===
|
||||
# Learning method used by the slateq policy. Choose from: RANDOM,
|
||||
# MYOP (myopic), SARSA, QL (Q-Learning),
|
||||
"slateq_strategy": "QL",
|
||||
# user/doc embedding size for the recsim environment
|
||||
"recsim_embedding_size": 20,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Checks the config based on settings"""
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("SlateQ only runs on PyTorch")
|
||||
|
||||
if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
|
||||
raise ValueError("Unknown slateq_strategy: "
|
||||
f"{config['slateq_strategy']}.")
|
||||
|
||||
if config["slateq_strategy"] == "SARSA":
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError(
|
||||
"For SARSA strategy, batch_mode must be 'complete_episodes'")
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
"""Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: A local iterator over training metrics.
|
||||
"""
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_mode=config["multiagent"]["replay_mode"],
|
||||
replay_sequence_length=config["replay_sequence_length"],
|
||||
)
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in our local replay buffer. Calling
|
||||
# next() on store_op drives this.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer))
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step.
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
if config["slateq_strategy"] != "RANDOM":
|
||||
# Alternate deterministically between (1) and (2). Only return the
|
||||
# output of (2) since training metrics are not available until (2)
|
||||
# runs.
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op],
|
||||
mode="round_robin",
|
||||
output_indexes=[1],
|
||||
round_robin_weights=calculate_round_robin_weights(config))
|
||||
else:
|
||||
# No training is needed for the RANDOM strategy.
|
||||
train_op = rollouts
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
|
||||
"""Calculate the round robin weights for the rollout and train steps"""
|
||||
if not config["training_intensity"]:
|
||||
return [1, 1]
|
||||
# e.g., 32 / 4 -> native ratio of 8.0
|
||||
native_ratio = (
|
||||
config["train_batch_size"] / config["rollout_fragment_length"])
|
||||
# Training intensity is specified in terms of
|
||||
# (steps_replayed / steps_sampled), so adjust for the native ratio.
|
||||
weights = [1, config["training_intensity"] / native_ratio]
|
||||
return weights
|
||||
|
||||
|
||||
def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
|
||||
"""Policy class picker function.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Type[Policy]: The Policy class to use with SlateQTrainer.
|
||||
"""
|
||||
if config["slateq_strategy"] == "RANDOM":
|
||||
return RandomPolicy
|
||||
else:
|
||||
return SlateQTorchPolicy
|
||||
|
||||
|
||||
SlateQTrainer = build_trainer(
|
||||
name="SlateQ",
|
||||
get_policy_class=get_policy_class,
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=validate_config,
|
||||
execution_plan=execution_plan)
|
420
rllib/agents/slateq/slateq_torch_policy.py
Normal file
420
rllib/agents/slateq/slateq_torch_policy.py
Normal file
|
@ -0,0 +1,420 @@
|
|||
"""PyTorch policy class used for SlateQ"""
|
||||
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
|
||||
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
|
||||
TorchDistributionWrapper)
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
|
||||
TrainerConfigDict)
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = None
|
||||
if nn:
|
||||
F = nn.functional
|
||||
|
||||
|
||||
class QValueModel(nn.Module):
|
||||
"""The Q-value model for SlateQ"""
|
||||
|
||||
def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
|
||||
super().__init__()
|
||||
|
||||
# construct hidden layers
|
||||
layers = []
|
||||
ins = 2 * embedding_size
|
||||
for n in q_hiddens:
|
||||
layers.append(nn.Linear(ins, n))
|
||||
layers.append(nn.LeakyReLU())
|
||||
ins = n
|
||||
layers.append(nn.Linear(ins, 1))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
|
||||
"""Evaluate the user-doc Q model
|
||||
|
||||
Args:
|
||||
user (TensorType): User embedding of shape (batch_size,
|
||||
embedding_size).
|
||||
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
|
||||
embedding_size).
|
||||
|
||||
Returns:
|
||||
score (TensorType): q_values of shape (batch_size, num_docs + 1).
|
||||
"""
|
||||
batch_size, num_docs, embedding_size = doc.shape
|
||||
doc_flat = doc.view((batch_size * num_docs, embedding_size))
|
||||
user_repeated = user.repeat(num_docs, 1)
|
||||
x = torch.cat([user_repeated, doc_flat], dim=1)
|
||||
x = self.layers(x)
|
||||
# Similar to Google's SlateQ implementation in RecSim, we force the
|
||||
# Q-values to zeros if there are no clicks.
|
||||
x_no_click = torch.zeros((batch_size, 1), device=x.device)
|
||||
return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
|
||||
|
||||
|
||||
class UserChoiceModel(nn.Module):
|
||||
r"""The user choice model for SlateQ
|
||||
|
||||
This class implements a multinomial logit model for predicting user clicks.
|
||||
|
||||
Under this model, the click probability of a document is proportional to:
|
||||
|
||||
.. math::
|
||||
\exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
|
||||
self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))
|
||||
|
||||
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
|
||||
"""Evaluate the user choice model
|
||||
|
||||
This function outputs user click scores for candidate documents. The
|
||||
exponentials of these scores are proportional user click probabilities.
|
||||
Here we return the scores unnormalized because because only some of the
|
||||
documents will be selected and shown to the user.
|
||||
|
||||
Args:
|
||||
user (TensorType): User embeddings of shape (batch_size,
|
||||
embedding_size).
|
||||
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
|
||||
embedding_size).
|
||||
|
||||
Returns:
|
||||
score (TensorType): logits of shape (batch_size, num_docs + 1),
|
||||
where the last dimension represents no_click.
|
||||
"""
|
||||
batch_size = user.shape[0]
|
||||
s = torch.einsum("be,bde->bd", user, doc)
|
||||
s = s * self.beta
|
||||
s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
|
||||
return s
|
||||
|
||||
|
||||
class SlateQModel(TorchModelV2, nn.Module):
|
||||
"""The SlateQ model class
|
||||
|
||||
It includes both the user choice model and the Q-value model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
*,
|
||||
embedding_size: int,
|
||||
q_hiddens: Sequence[int],
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
TorchModelV2.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
# This required parameter (num_outputs) seems redundant: it has no
|
||||
# real imact, and can be set arbitrarily. TODO: fix this.
|
||||
num_outputs=0,
|
||||
model_config=model_config,
|
||||
name=name)
|
||||
self.choice_model = UserChoiceModel()
|
||||
self.q_model = QValueModel(embedding_size, q_hiddens)
|
||||
self.slate_size = len(action_space.nvec)
|
||||
|
||||
def choose_slate(self, user: TensorType,
|
||||
doc: TensorType) -> Tuple[TensorType, TensorType]:
|
||||
"""Build a slate by selecting from candidate documents
|
||||
|
||||
Args:
|
||||
user (TensorType): User embeddings of shape (batch_size,
|
||||
embedding_size).
|
||||
doc (TensorType): Doc embeddings of shape (batch_size,
|
||||
num_docs, embedding_size).
|
||||
|
||||
Returns:
|
||||
slate_selected (TensorType): Indices of documents selected for
|
||||
the slate, with shape (batch_size, slate_size).
|
||||
best_slate_q_value (TensorType): The Q-value of the selected slate,
|
||||
with shape (batch_size).
|
||||
"""
|
||||
# Step 1: compute item scores (proportional to click probabilities)
|
||||
# raw_scores.shape=[batch_size, num_docs+1]
|
||||
raw_scores = self.choice_model(user, doc)
|
||||
# max_raw_scores.shape=[batch_size, 1]
|
||||
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
|
||||
# deduct scores by max_scores to avoid value explosion
|
||||
scores = torch.exp(raw_scores - max_raw_scores)
|
||||
scores_doc = scores[:, :-1] # shape=[batch_size, num_docs]
|
||||
scores_no_click = scores[:, [-1]] # shape=[batch_size, 1]
|
||||
|
||||
# Step 2: calculate the item-wise Q values
|
||||
# q_values.shape=[batch_size, num_docs+1]
|
||||
q_values = self.q_model(user, doc)
|
||||
q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs]
|
||||
q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1]
|
||||
|
||||
# Step 3: construct all possible slates
|
||||
_, num_docs, _ = doc.shape
|
||||
indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
|
||||
# slates.shape = [num_slates, slate_size]
|
||||
slates = torch.combinations(indices, r=self.slate_size)
|
||||
num_slates, _ = slates.shape
|
||||
|
||||
# Step 4: calculate slate Q values
|
||||
batch_size, _ = q_values_doc.shape
|
||||
# slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
|
||||
slate_decomp_q_values = torch.gather(
|
||||
# input.shape: [batch_size, num_slates, num_docs]
|
||||
input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
|
||||
dim=2,
|
||||
# index.shape: [batch_size, num_slates, slate_size]
|
||||
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
|
||||
# slate_scores.shape: [batch_size, num_slates, slate_size]
|
||||
slate_scores = torch.gather(
|
||||
# input.shape: [batch_size, num_slates, num_docs]
|
||||
input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
|
||||
dim=2,
|
||||
# index.shape: [batch_size, num_slates, slate_size]
|
||||
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
|
||||
# slate_q_values.shape: [batch_size, num_slates]
|
||||
slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
|
||||
(q_values_no_click * scores_no_click)) / (
|
||||
slate_scores.sum(dim=2) + scores_no_click)
|
||||
|
||||
# Step 5: find the slate that maximizes q value
|
||||
best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
|
||||
# slates_selected.shape: [batch_size, slate_size]
|
||||
slates_selected = slates[max_idx]
|
||||
return slates_selected, best_slate_q_value
|
||||
|
||||
def forward(self, input_dict: Dict[str, TensorType],
|
||||
state: List[TensorType],
|
||||
seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
|
||||
# user.shape: [batch_size, embedding_size]
|
||||
user = input_dict[SampleBatch.OBS]["user"]
|
||||
# doc.shape: [batch_size, num_docs, embedding_size]
|
||||
doc = torch.cat([
|
||||
val.unsqueeze(1)
|
||||
for val in input_dict[SampleBatch.OBS]["doc"].values()
|
||||
], 1)
|
||||
|
||||
slates_selected, _ = self.choose_slate(user, doc)
|
||||
|
||||
state_out = []
|
||||
return slates_selected, state_out
|
||||
|
||||
|
||||
def build_slateq_model_and_distribution(
|
||||
policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
|
||||
"""Build models for SlateQ
|
||||
|
||||
Args:
|
||||
policy (Policy): The policy, which will use the model for optimization.
|
||||
obs_space (gym.spaces.Space): The policy's observation space.
|
||||
action_space (gym.spaces.Space): The policy's action space.
|
||||
config (TrainerConfigDict):
|
||||
|
||||
Returns:
|
||||
(q_model, TorchCategorical)
|
||||
"""
|
||||
model = SlateQModel(
|
||||
obs_space,
|
||||
action_space,
|
||||
model_config=config["model"],
|
||||
name="slateq_model",
|
||||
embedding_size=config["recsim_embedding_size"],
|
||||
q_hiddens=config["hiddens"],
|
||||
)
|
||||
return model, TorchCategorical
|
||||
|
||||
|
||||
def build_slateq_losses(policy: Policy, model: SlateQModel, _,
|
||||
train_batch: SampleBatch) -> TensorType:
|
||||
"""Constructs the losses for SlateQPolicy.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
TensorType: A single loss tensor.
|
||||
"""
|
||||
obs = restore_original_dimensions(
|
||||
train_batch[SampleBatch.OBS],
|
||||
policy.observation_space,
|
||||
tensorlib=torch)
|
||||
# user.shape: [batch_size, embedding_size]
|
||||
user = obs["user"]
|
||||
# doc.shape: [batch_size, num_docs, embedding_size]
|
||||
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
|
||||
# action.shape: [batch_size, slate_size]
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
|
||||
next_obs = restore_original_dimensions(
|
||||
train_batch[SampleBatch.NEXT_OBS],
|
||||
policy.observation_space,
|
||||
tensorlib=torch)
|
||||
|
||||
# Step 1: Build user choice model loss
|
||||
_, _, embedding_size = doc.shape
|
||||
# selected_doc.shape: [batch_size, slate_size, embedding_size]
|
||||
selected_doc = torch.gather(
|
||||
# input.shape: [batch_size, num_docs, embedding_size]
|
||||
input=doc,
|
||||
dim=1,
|
||||
# index.shape: [batch_size, slate_size, embedding_size]
|
||||
index=actions.unsqueeze(2).expand(-1, -1, embedding_size))
|
||||
|
||||
scores = model.choice_model(user, selected_doc)
|
||||
choice_loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
# clicks.shape: [batch_size, slate_size]
|
||||
clicks = torch.stack(
|
||||
[resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
|
||||
no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
|
||||
# clicks.shape: [batch_size, slate_size+1]
|
||||
targets = torch.cat([clicks, no_clicks], dim=1)
|
||||
choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
|
||||
# print(model.choice_model.a.item(), model.choice_model.b.item())
|
||||
|
||||
# Step 2: Build qvalue loss
|
||||
# Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
|
||||
# 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
|
||||
# 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
|
||||
# 'batch_indexes']
|
||||
learning_strategy = policy.config["slateq_strategy"]
|
||||
|
||||
if learning_strategy == "SARSA":
|
||||
# next_doc.shape: [batch_size, num_docs, embedding_size]
|
||||
next_doc = torch.cat(
|
||||
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
|
||||
next_actions = train_batch["next_actions"]
|
||||
_, _, embedding_size = next_doc.shape
|
||||
# selected_doc.shape: [batch_size, slate_size, embedding_size]
|
||||
next_selected_doc = torch.gather(
|
||||
# input.shape: [batch_size, num_docs, embedding_size]
|
||||
input=next_doc,
|
||||
dim=1,
|
||||
# index.shape: [batch_size, slate_size, embedding_size]
|
||||
index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
|
||||
next_user = next_obs["user"]
|
||||
dones = train_batch["dones"]
|
||||
with torch.no_grad():
|
||||
# q_values.shape: [batch_size, slate_size+1]
|
||||
q_values = model.q_model(next_user, next_selected_doc)
|
||||
# raw_scores.shape: [batch_size, slate_size+1]
|
||||
raw_scores = model.choice_model(next_user, next_selected_doc)
|
||||
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
|
||||
scores = torch.exp(raw_scores - max_raw_scores)
|
||||
# next_q_values.shape: [batch_size]
|
||||
next_q_values = torch.sum(
|
||||
q_values * scores, dim=1) / torch.sum(
|
||||
scores, dim=1)
|
||||
next_q_values[dones] = 0.0
|
||||
elif learning_strategy == "MYOP":
|
||||
next_q_values = 0.
|
||||
elif learning_strategy == "QL":
|
||||
# next_doc.shape: [batch_size, num_docs, embedding_size]
|
||||
next_doc = torch.cat(
|
||||
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
|
||||
next_user = next_obs["user"]
|
||||
dones = train_batch["dones"]
|
||||
with torch.no_grad():
|
||||
_, next_q_values = model.choose_slate(next_user, next_doc)
|
||||
next_q_values[dones] = 0.0
|
||||
else:
|
||||
raise ValueError(learning_strategy)
|
||||
# target_q_values.shape: [batch_size]
|
||||
target_q_values = next_q_values + train_batch["rewards"]
|
||||
|
||||
q_values = model.q_model(user,
|
||||
selected_doc) # shape: [batch_size, slate_size+1]
|
||||
# raw_scores.shape: [batch_size, slate_size+1]
|
||||
raw_scores = model.choice_model(user, selected_doc)
|
||||
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
|
||||
scores = torch.exp(raw_scores - max_raw_scores)
|
||||
q_values = torch.sum(
|
||||
q_values * scores, dim=1) / torch.sum(
|
||||
scores, dim=1) # shape=[batch_size]
|
||||
|
||||
q_value_loss = nn.MSELoss()(q_values, target_q_values)
|
||||
return [choice_loss, q_value_loss]
|
||||
|
||||
|
||||
def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
|
||||
) -> List["torch.optim.Optimizer"]:
|
||||
optimizer_choice = torch.optim.Adam(
|
||||
policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
|
||||
optimizer_q_value = torch.optim.Adam(
|
||||
policy.model.q_model.parameters(),
|
||||
lr=config["lr_q_model"],
|
||||
eps=config["adam_epsilon"])
|
||||
return [optimizer_choice, optimizer_q_value]
|
||||
|
||||
|
||||
def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
|
||||
explore, timestep):
|
||||
"""Determine which action to take"""
|
||||
# First, we transform the observation into its unflattened form
|
||||
obs = restore_original_dimensions(
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
policy.observation_space,
|
||||
tensorlib=torch)
|
||||
|
||||
# user.shape: [batch_size(=1), embedding_size]
|
||||
user = obs["user"]
|
||||
# doc.shape: [batch_size(=1), num_docs, embedding_size]
|
||||
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
|
||||
|
||||
selected_slates, _ = model.choose_slate(user, doc)
|
||||
|
||||
action = selected_slates
|
||||
logp = None
|
||||
state_out = []
|
||||
return action, logp, state_out
|
||||
|
||||
|
||||
def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
|
||||
batch: SampleBatch,
|
||||
other_agent=None,
|
||||
episode=None) -> SampleBatch:
|
||||
"""Add next_actions to SampleBatch for SARSA training"""
|
||||
if policy.config["slateq_strategy"] == "SARSA":
|
||||
if not batch["dones"][-1]:
|
||||
raise RuntimeError(
|
||||
"Expected a complete episode in each sample batch. "
|
||||
f"But this batch is not: {batch}.")
|
||||
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
|
||||
return batch
|
||||
|
||||
|
||||
SlateQTorchPolicy = build_torch_policy(
|
||||
name="SlateQTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
|
||||
|
||||
# build model, loss functions, and optimizers
|
||||
make_model_and_action_dist=build_slateq_model_and_distribution,
|
||||
optimizer_fn=build_slateq_optimizers,
|
||||
loss_fn=build_slateq_losses,
|
||||
|
||||
# define how to act
|
||||
action_sampler_fn=action_sampler_fn,
|
||||
|
||||
# post processing batch sampled data
|
||||
postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
|
||||
)
|
10
rllib/env/wrappers/recsim_wrapper.py
vendored
10
rllib/env/wrappers/recsim_wrapper.py
vendored
|
@ -55,11 +55,14 @@ class RecSimObservationSpaceWrapper(gym.ObservationWrapper):
|
|||
|
||||
|
||||
class RecSimResetWrapper(gym.Wrapper):
|
||||
"""Fix RecSim environment's reset() function
|
||||
"""Fix RecSim environment's reset() and close() function
|
||||
|
||||
RecSim's reset() function returns an observation without the "response"
|
||||
field, breaking RLlib's check. This wrapper fixes that by assigning a
|
||||
random "response".
|
||||
|
||||
RecSim's close() function raises NotImplementedError. We change the
|
||||
behavior to doing nothing.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
|
@ -67,6 +70,9 @@ class RecSimResetWrapper(gym.Wrapper):
|
|||
obs["response"] = self.env.observation_space["response"].sample()
|
||||
return obs
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper):
|
||||
"""Convert the action space from MultiDiscrete to Discrete
|
||||
|
@ -108,7 +114,7 @@ def make_recsim_env(config):
|
|||
env = interest_evolution.create_environment(env_config)
|
||||
env = RecSimResetWrapper(env)
|
||||
env = RecSimObservationSpaceWrapper(env)
|
||||
if config and config["convert_to_discrete_action_space"]:
|
||||
if env_config and env_config["convert_to_discrete_action_space"]:
|
||||
env = MultiDiscreteToDiscreteActionWrapper(env)
|
||||
return env
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import ModelWeights
|
||||
|
||||
|
||||
class RandomPolicy(Policy):
|
||||
|
@ -50,3 +52,13 @@ class RandomPolicy(Policy):
|
|||
prev_action_batch=None,
|
||||
prev_reward_batch=None):
|
||||
return np.array([random.random()] * len(obs_batch))
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self) -> ModelWeights:
|
||||
"""No weights to save."""
|
||||
return {}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights: ModelWeights) -> None:
|
||||
"""No weights to set."""
|
||||
pass
|
||||
|
|
115
rllib/examples/slateq.py
Normal file
115
rllib/examples/slateq.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""The SlateQ algorithm for recommendation"""
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents import slateq
|
||||
from ray.rllib.agents import dqn
|
||||
from ray.rllib.agents.slateq.slateq import ALL_SLATEQ_STRATEGIES
|
||||
from ray.rllib.env.wrappers.recsim_wrapper import env_name as recsim_env_name
|
||||
from ray.tune.logger import pretty_print
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--agent",
|
||||
type=str,
|
||||
default="SlateQ",
|
||||
help=("Select agent policy. Choose from: DQN and SlateQ. "
|
||||
"Default value: SlateQ."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
type=str,
|
||||
default="QL",
|
||||
help=("Strategy for the SlateQ agent. Choose from: " +
|
||||
", ".join(ALL_SLATEQ_STRATEGIES) + ". "
|
||||
"Default value: QL. Ignored when using Tune."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-tune",
|
||||
action="store_true",
|
||||
help=("Run with Tune so that the results are logged into Tensorboard. "
|
||||
"For debugging, it's easier to run without Ray Tune."),
|
||||
)
|
||||
parser.add_argument("--tune-num-samples", type=int, default=10)
|
||||
parser.add_argument("--env-slate-size", type=int, default=2)
|
||||
parser.add_argument("--env-seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--num-gpus",
|
||||
type=float,
|
||||
default=0.,
|
||||
help="Only used if running with Tune.")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Only used if running with Tune.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.agent not in ["DQN", "SlateQ"]:
|
||||
raise ValueError(args.agent)
|
||||
|
||||
env_config = {
|
||||
"slate_size": args.env_slate_size,
|
||||
"seed": args.env_seed,
|
||||
"convert_to_discrete_action_space": args.agent == "DQN",
|
||||
}
|
||||
|
||||
ray.init()
|
||||
if args.use_tune:
|
||||
time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
|
||||
name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}"
|
||||
if args.agent == "DQN":
|
||||
tune.run(
|
||||
"DQN",
|
||||
stop={"timesteps_total": 4000000},
|
||||
name=name,
|
||||
config={
|
||||
"env": recsim_env_name,
|
||||
"num_gpus": args.num_gpus,
|
||||
"num_workers": args.num_workers,
|
||||
"env_config": env_config,
|
||||
},
|
||||
num_samples=args.tune_num_samples,
|
||||
verbose=1)
|
||||
else:
|
||||
tune.run(
|
||||
"SlateQ",
|
||||
stop={"timesteps_total": 4000000},
|
||||
name=name,
|
||||
config={
|
||||
"env": recsim_env_name,
|
||||
"num_gpus": args.num_gpus,
|
||||
"num_workers": args.num_workers,
|
||||
"slateq_strategy": tune.grid_search(ALL_SLATEQ_STRATEGIES),
|
||||
"env_config": env_config,
|
||||
},
|
||||
num_samples=args.tune_num_samples,
|
||||
verbose=1)
|
||||
else:
|
||||
# directly run using the trainer interface (good for debugging)
|
||||
if args.agent == "DQN":
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["num_workers"] = 0
|
||||
config["env_config"] = env_config
|
||||
trainer = dqn.DQNTrainer(config=config, env=recsim_env_name)
|
||||
else:
|
||||
config = slateq.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["num_workers"] = 0
|
||||
config["slateq_strategy"] = args.strategy
|
||||
config["env_config"] = env_config
|
||||
trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)
|
||||
for i in range(10):
|
||||
result = trainer.train()
|
||||
print(pretty_print(result))
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Reference in a new issue