[RLlib] Implement the SlateQ algorithm (#11450)

This commit is contained in:
desktable 2020-11-03 00:52:04 -08:00 committed by GitHub
parent e735add268
commit 5af745c90d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 803 additions and 4 deletions

View file

@ -110,6 +110,11 @@ def _import_simple_q():
return dqn.SimpleQTrainer
def _import_slate_q():
from ray.rllib.agents import slateq
return slateq.SlateQTrainer
def _import_td3():
from ray.rllib.agents import ddpg
return ddpg.TD3Trainer
@ -127,6 +132,7 @@ ALGORITHMS = {
"DDPG": _import_ddpg,
"DDPPO": _import_ddppo,
"DQN": _import_dqn,
"SlateQ": _import_slate_q,
"DREAMER": _import_dreamer,
"IMPALA": _import_impala,
"MAML": _import_maml,

View file

@ -0,0 +1,8 @@
from ray.rllib.agents.slateq.slateq import SlateQTrainer, DEFAULT_CONFIG
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
__all__ = [
"SlateQTrainer",
"SlateQTorchPolicy",
"DEFAULT_CONFIG",
]

View file

@ -0,0 +1,232 @@
"""
SlateQ (Reinforcement Learning for Recommendation)
==================================================
This file defines the trainer class for the SlateQ algorithm from the
`"Reinforcement Learning for Slate-based Recommender Systems: A Tractable
Decomposition and Practical Methodology" <https://arxiv.org/abs/1905.12767>`_
paper.
See `slateq_torch_policy.py` for the definition of the policy. Currently, only
PyTorch is supported. The algorithm is written and tested for Google's RecSim
environment (https://github.com/google-research/recsim).
"""
import logging
from typing import List, Type
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# Defines all SlateQ strategies implemented.
ALL_SLATEQ_STRATEGIES = [
# RANDOM: Randomly select documents for slates.
"RANDOM",
# MYOP: Select documents that maximize user click probabilities. This is
# a myopic strategy and ignores long term rewards. This is equivalent to
# setting a zero discount rate for future rewards.
"MYOP",
# SARSA: Use the SlateQ SARSA learning algorithm.
"SARSA",
# QL: Use the SlateQ Q-learning algorithm.
"QL",
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Model ===
# Dense-layer setup for each the advantage branch and the value branch
# in a dueling architecture.
"hiddens": [256, 64, 16],
# set batchmode
"batch_mode": "complete_episodes",
# === Deep Learning Framework Settings ===
# Currently, only PyTorch is supported
"framework": "torch",
# === Exploration Settings ===
"exploration_config": {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.02,
"epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon.
},
# Switch to greedy actions in evaluation workers.
"evaluation_config": {
"explore": False,
},
# Minimum env steps to optimize for per train call. This value does
# not affect learning, only the length of iterations.
"timesteps_per_iteration": 1000,
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": 50000,
# Whether to LZ4 compress observations
"compress_observations": False,
# If set, this will fix the ratio of replayed from a buffer and learned on
# timesteps to sampled from an environment and stored in the replay buffer
# timesteps. Otherwise, the replay will proceed at the native ratio
# determined by (train_batch_size / rollout_fragment_length).
"training_intensity": None,
# === Optimization ===
# Learning rate for adam optimizer for the user choice model
"lr_choice_model": 1e-2,
# Learning rate for adam optimizer for the q model
"lr_q_model": 1e-2,
# Adam epsilon hyper parameter
"adam_epsilon": 1e-8,
# If not None, clip gradients during optimization at this value
"grad_clip": 40,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 1000,
# Size of a batch sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# === SlateQ specific options ===
# Learning method used by the slateq policy. Choose from: RANDOM,
# MYOP (myopic), SARSA, QL (Q-Learning),
"slateq_strategy": "QL",
# user/doc embedding size for the recsim environment
"recsim_embedding_size": 20,
})
# __sphinx_doc_end__
# yapf: enable
def validate_config(config: TrainerConfigDict) -> None:
"""Checks the config based on settings"""
if config["framework"] != "torch":
raise ValueError("SlateQ only runs on PyTorch")
if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES:
raise ValueError("Unknown slateq_strategy: "
f"{config['slateq_strategy']}.")
if config["slateq_strategy"] == "SARSA":
if config["batch_mode"] != "complete_episodes":
raise ValueError(
"For SARSA strategy, batch_mode must be 'complete_episodes'")
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
"""Execution plan of the SlateQ algorithm. Defines the distributed dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: A local iterator over training metrics.
"""
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config["replay_sequence_length"],
)
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step.
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(TrainOneStep(workers))
if config["slateq_strategy"] != "RANDOM":
# Alternate deterministically between (1) and (2). Only return the
# output of (2) since training metrics are not available until (2)
# runs.
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_round_robin_weights(config))
else:
# No training is needed for the RANDOM strategy.
train_op = rollouts
return StandardMetricsReporting(train_op, workers, config)
def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
"""Calculate the round robin weights for the rollout and train steps"""
if not config["training_intensity"]:
return [1, 1]
# e.g., 32 / 4 -> native ratio of 8.0
native_ratio = (
config["train_batch_size"] / config["rollout_fragment_length"])
# Training intensity is specified in terms of
# (steps_replayed / steps_sampled), so adjust for the native ratio.
weights = [1, config["training_intensity"] / native_ratio]
return weights
def get_policy_class(config: TrainerConfigDict) -> Type[Policy]:
"""Policy class picker function.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Type[Policy]: The Policy class to use with SlateQTrainer.
"""
if config["slateq_strategy"] == "RANDOM":
return RandomPolicy
else:
return SlateQTorchPolicy
SlateQTrainer = build_trainer(
name="SlateQ",
get_policy_class=get_policy_class,
default_config=DEFAULT_CONFIG,
validate_config=validate_config,
execution_plan=execution_plan)

View file

@ -0,0 +1,420 @@
"""PyTorch policy class used for SlateQ"""
from typing import Dict, List, Sequence, Tuple
import gym
import numpy as np
import ray
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
TorchDistributionWrapper)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
TrainerConfigDict)
torch, nn = try_import_torch()
F = None
if nn:
F = nn.functional
class QValueModel(nn.Module):
"""The Q-value model for SlateQ"""
def __init__(self, embedding_size: int, q_hiddens: Sequence[int]):
super().__init__()
# construct hidden layers
layers = []
ins = 2 * embedding_size
for n in q_hiddens:
layers.append(nn.Linear(ins, n))
layers.append(nn.LeakyReLU())
ins = n
layers.append(nn.Linear(ins, 1))
self.layers = nn.Sequential(*layers)
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
"""Evaluate the user-doc Q model
Args:
user (TensorType): User embedding of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
embedding_size).
Returns:
score (TensorType): q_values of shape (batch_size, num_docs + 1).
"""
batch_size, num_docs, embedding_size = doc.shape
doc_flat = doc.view((batch_size * num_docs, embedding_size))
user_repeated = user.repeat(num_docs, 1)
x = torch.cat([user_repeated, doc_flat], dim=1)
x = self.layers(x)
# Similar to Google's SlateQ implementation in RecSim, we force the
# Q-values to zeros if there are no clicks.
x_no_click = torch.zeros((batch_size, 1), device=x.device)
return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
class UserChoiceModel(nn.Module):
r"""The user choice model for SlateQ
This class implements a multinomial logit model for predicting user clicks.
Under this model, the click probability of a document is proportional to:
.. math::
\exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click})
"""
def __init__(self):
super().__init__()
self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float))
self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float))
def forward(self, user: TensorType, doc: TensorType) -> TensorType:
"""Evaluate the user choice model
This function outputs user click scores for candidate documents. The
exponentials of these scores are proportional user click probabilities.
Here we return the scores unnormalized because because only some of the
documents will be selected and shown to the user.
Args:
user (TensorType): User embeddings of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
embedding_size).
Returns:
score (TensorType): logits of shape (batch_size, num_docs + 1),
where the last dimension represents no_click.
"""
batch_size = user.shape[0]
s = torch.einsum("be,bde->bd", user, doc)
s = s * self.beta
s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1)
return s
class SlateQModel(TorchModelV2, nn.Module):
"""The SlateQ model class
It includes both the user choice model and the Q-value model.
"""
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
model_config: ModelConfigDict,
name: str,
*,
embedding_size: int,
q_hiddens: Sequence[int],
):
nn.Module.__init__(self)
TorchModelV2.__init__(
self,
obs_space,
action_space,
# This required parameter (num_outputs) seems redundant: it has no
# real imact, and can be set arbitrarily. TODO: fix this.
num_outputs=0,
model_config=model_config,
name=name)
self.choice_model = UserChoiceModel()
self.q_model = QValueModel(embedding_size, q_hiddens)
self.slate_size = len(action_space.nvec)
def choose_slate(self, user: TensorType,
doc: TensorType) -> Tuple[TensorType, TensorType]:
"""Build a slate by selecting from candidate documents
Args:
user (TensorType): User embeddings of shape (batch_size,
embedding_size).
doc (TensorType): Doc embeddings of shape (batch_size,
num_docs, embedding_size).
Returns:
slate_selected (TensorType): Indices of documents selected for
the slate, with shape (batch_size, slate_size).
best_slate_q_value (TensorType): The Q-value of the selected slate,
with shape (batch_size).
"""
# Step 1: compute item scores (proportional to click probabilities)
# raw_scores.shape=[batch_size, num_docs+1]
raw_scores = self.choice_model(user, doc)
# max_raw_scores.shape=[batch_size, 1]
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
# deduct scores by max_scores to avoid value explosion
scores = torch.exp(raw_scores - max_raw_scores)
scores_doc = scores[:, :-1] # shape=[batch_size, num_docs]
scores_no_click = scores[:, [-1]] # shape=[batch_size, 1]
# Step 2: calculate the item-wise Q values
# q_values.shape=[batch_size, num_docs+1]
q_values = self.q_model(user, doc)
q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs]
q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1]
# Step 3: construct all possible slates
_, num_docs, _ = doc.shape
indices = torch.arange(num_docs, dtype=torch.long, device=doc.device)
# slates.shape = [num_slates, slate_size]
slates = torch.combinations(indices, r=self.slate_size)
num_slates, _ = slates.shape
# Step 4: calculate slate Q values
batch_size, _ = q_values_doc.shape
# slate_decomp_q_values.shape: [batch_size, num_slates, slate_size]
slate_decomp_q_values = torch.gather(
# input.shape: [batch_size, num_slates, num_docs]
input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1),
dim=2,
# index.shape: [batch_size, num_slates, slate_size]
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
# slate_scores.shape: [batch_size, num_slates, slate_size]
slate_scores = torch.gather(
# input.shape: [batch_size, num_slates, num_docs]
input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1),
dim=2,
# index.shape: [batch_size, num_slates, slate_size]
index=slates.unsqueeze(0).expand(batch_size, -1, -1))
# slate_q_values.shape: [batch_size, num_slates]
slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) +
(q_values_no_click * scores_no_click)) / (
slate_scores.sum(dim=2) + scores_no_click)
# Step 5: find the slate that maximizes q value
best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1)
# slates_selected.shape: [batch_size, slate_size]
slates_selected = slates[max_idx]
return slates_selected, best_slate_q_value
def forward(self, input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
# user.shape: [batch_size, embedding_size]
user = input_dict[SampleBatch.OBS]["user"]
# doc.shape: [batch_size, num_docs, embedding_size]
doc = torch.cat([
val.unsqueeze(1)
for val in input_dict[SampleBatch.OBS]["doc"].values()
], 1)
slates_selected, _ = self.choose_slate(user, doc)
state_out = []
return slates_selected, state_out
def build_slateq_model_and_distribution(
policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]:
"""Build models for SlateQ
Args:
policy (Policy): The policy, which will use the model for optimization.
obs_space (gym.spaces.Space): The policy's observation space.
action_space (gym.spaces.Space): The policy's action space.
config (TrainerConfigDict):
Returns:
(q_model, TorchCategorical)
"""
model = SlateQModel(
obs_space,
action_space,
model_config=config["model"],
name="slateq_model",
embedding_size=config["recsim_embedding_size"],
q_hiddens=config["hiddens"],
)
return model, TorchCategorical
def build_slateq_losses(policy: Policy, model: SlateQModel, _,
train_batch: SampleBatch) -> TensorType:
"""Constructs the losses for SlateQPolicy.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
train_batch (SampleBatch): The training data.
Returns:
TensorType: A single loss tensor.
"""
obs = restore_original_dimensions(
train_batch[SampleBatch.OBS],
policy.observation_space,
tensorlib=torch)
# user.shape: [batch_size, embedding_size]
user = obs["user"]
# doc.shape: [batch_size, num_docs, embedding_size]
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
# action.shape: [batch_size, slate_size]
actions = train_batch[SampleBatch.ACTIONS]
next_obs = restore_original_dimensions(
train_batch[SampleBatch.NEXT_OBS],
policy.observation_space,
tensorlib=torch)
# Step 1: Build user choice model loss
_, _, embedding_size = doc.shape
# selected_doc.shape: [batch_size, slate_size, embedding_size]
selected_doc = torch.gather(
# input.shape: [batch_size, num_docs, embedding_size]
input=doc,
dim=1,
# index.shape: [batch_size, slate_size, embedding_size]
index=actions.unsqueeze(2).expand(-1, -1, embedding_size))
scores = model.choice_model(user, selected_doc)
choice_loss_fn = nn.CrossEntropyLoss()
# clicks.shape: [batch_size, slate_size]
clicks = torch.stack(
[resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
# clicks.shape: [batch_size, slate_size+1]
targets = torch.cat([clicks, no_clicks], dim=1)
choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
# print(model.choice_model.a.item(), model.choice_model.b.item())
# Step 2: Build qvalue loss
# Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
# 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
# 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
# 'batch_indexes']
learning_strategy = policy.config["slateq_strategy"]
if learning_strategy == "SARSA":
# next_doc.shape: [batch_size, num_docs, embedding_size]
next_doc = torch.cat(
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
next_actions = train_batch["next_actions"]
_, _, embedding_size = next_doc.shape
# selected_doc.shape: [batch_size, slate_size, embedding_size]
next_selected_doc = torch.gather(
# input.shape: [batch_size, num_docs, embedding_size]
input=next_doc,
dim=1,
# index.shape: [batch_size, slate_size, embedding_size]
index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size))
next_user = next_obs["user"]
dones = train_batch["dones"]
with torch.no_grad():
# q_values.shape: [batch_size, slate_size+1]
q_values = model.q_model(next_user, next_selected_doc)
# raw_scores.shape: [batch_size, slate_size+1]
raw_scores = model.choice_model(next_user, next_selected_doc)
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
scores = torch.exp(raw_scores - max_raw_scores)
# next_q_values.shape: [batch_size]
next_q_values = torch.sum(
q_values * scores, dim=1) / torch.sum(
scores, dim=1)
next_q_values[dones] = 0.0
elif learning_strategy == "MYOP":
next_q_values = 0.
elif learning_strategy == "QL":
# next_doc.shape: [batch_size, num_docs, embedding_size]
next_doc = torch.cat(
[val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
next_user = next_obs["user"]
dones = train_batch["dones"]
with torch.no_grad():
_, next_q_values = model.choose_slate(next_user, next_doc)
next_q_values[dones] = 0.0
else:
raise ValueError(learning_strategy)
# target_q_values.shape: [batch_size]
target_q_values = next_q_values + train_batch["rewards"]
q_values = model.q_model(user,
selected_doc) # shape: [batch_size, slate_size+1]
# raw_scores.shape: [batch_size, slate_size+1]
raw_scores = model.choice_model(user, selected_doc)
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
scores = torch.exp(raw_scores - max_raw_scores)
q_values = torch.sum(
q_values * scores, dim=1) / torch.sum(
scores, dim=1) # shape=[batch_size]
q_value_loss = nn.MSELoss()(q_values, target_q_values)
return [choice_loss, q_value_loss]
def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict
) -> List["torch.optim.Optimizer"]:
optimizer_choice = torch.optim.Adam(
policy.model.choice_model.parameters(), lr=config["lr_choice_model"])
optimizer_q_value = torch.optim.Adam(
policy.model.q_model.parameters(),
lr=config["lr_q_model"],
eps=config["adam_epsilon"])
return [optimizer_choice, optimizer_q_value]
def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state,
explore, timestep):
"""Determine which action to take"""
# First, we transform the observation into its unflattened form
obs = restore_original_dimensions(
input_dict[SampleBatch.CUR_OBS],
policy.observation_space,
tensorlib=torch)
# user.shape: [batch_size(=1), embedding_size]
user = obs["user"]
# doc.shape: [batch_size(=1), num_docs, embedding_size]
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
selected_slates, _ = model.choose_slate(user, doc)
action = selected_slates
logp = None
state_out = []
return action, logp, state_out
def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
batch: SampleBatch,
other_agent=None,
episode=None) -> SampleBatch:
"""Add next_actions to SampleBatch for SARSA training"""
if policy.config["slateq_strategy"] == "SARSA":
if not batch["dones"][-1]:
raise RuntimeError(
"Expected a complete episode in each sample batch. "
f"But this batch is not: {batch}.")
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
return batch
SlateQTorchPolicy = build_torch_policy(
name="SlateQTorchPolicy",
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
# build model, loss functions, and optimizers
make_model_and_action_dist=build_slateq_model_and_distribution,
optimizer_fn=build_slateq_optimizers,
loss_fn=build_slateq_losses,
# define how to act
action_sampler_fn=action_sampler_fn,
# post processing batch sampled data
postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
)

View file

@ -55,11 +55,14 @@ class RecSimObservationSpaceWrapper(gym.ObservationWrapper):
class RecSimResetWrapper(gym.Wrapper):
"""Fix RecSim environment's reset() function
"""Fix RecSim environment's reset() and close() function
RecSim's reset() function returns an observation without the "response"
field, breaking RLlib's check. This wrapper fixes that by assigning a
random "response".
RecSim's close() function raises NotImplementedError. We change the
behavior to doing nothing.
"""
def reset(self):
@ -67,6 +70,9 @@ class RecSimResetWrapper(gym.Wrapper):
obs["response"] = self.env.observation_space["response"].sample()
return obs
def close(self):
pass
class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper):
"""Convert the action space from MultiDiscrete to Discrete
@ -108,7 +114,7 @@ def make_recsim_env(config):
env = interest_evolution.create_environment(env_config)
env = RecSimResetWrapper(env)
env = RecSimObservationSpaceWrapper(env)
if config and config["convert_to_discrete_action_space"]:
if env_config and env_config["convert_to_discrete_action_space"]:
env = MultiDiscreteToDiscreteActionWrapper(env)
return env

View file

@ -1,9 +1,11 @@
from gym.spaces import Box
import numpy as np
import random
import numpy as np
from gym.spaces import Box
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ModelWeights
class RandomPolicy(Policy):
@ -50,3 +52,13 @@ class RandomPolicy(Policy):
prev_action_batch=None,
prev_reward_batch=None):
return np.array([random.random()] * len(obs_batch))
@override(Policy)
def get_weights(self) -> ModelWeights:
"""No weights to save."""
return {}
@override(Policy)
def set_weights(self, weights: ModelWeights) -> None:
"""No weights to set."""
pass

115
rllib/examples/slateq.py Normal file
View file

@ -0,0 +1,115 @@
"""The SlateQ algorithm for recommendation"""
import argparse
from datetime import datetime
import ray
from ray import tune
from ray.rllib.agents import slateq
from ray.rllib.agents import dqn
from ray.rllib.agents.slateq.slateq import ALL_SLATEQ_STRATEGIES
from ray.rllib.env.wrappers.recsim_wrapper import env_name as recsim_env_name
from ray.tune.logger import pretty_print
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--agent",
type=str,
default="SlateQ",
help=("Select agent policy. Choose from: DQN and SlateQ. "
"Default value: SlateQ."),
)
parser.add_argument(
"--strategy",
type=str,
default="QL",
help=("Strategy for the SlateQ agent. Choose from: " +
", ".join(ALL_SLATEQ_STRATEGIES) + ". "
"Default value: QL. Ignored when using Tune."),
)
parser.add_argument(
"--use-tune",
action="store_true",
help=("Run with Tune so that the results are logged into Tensorboard. "
"For debugging, it's easier to run without Ray Tune."),
)
parser.add_argument("--tune-num-samples", type=int, default=10)
parser.add_argument("--env-slate-size", type=int, default=2)
parser.add_argument("--env-seed", type=int, default=0)
parser.add_argument(
"--num-gpus",
type=float,
default=0.,
help="Only used if running with Tune.")
parser.add_argument(
"--num-workers",
type=int,
default=0,
help="Only used if running with Tune.")
args = parser.parse_args()
if args.agent not in ["DQN", "SlateQ"]:
raise ValueError(args.agent)
env_config = {
"slate_size": args.env_slate_size,
"seed": args.env_seed,
"convert_to_discrete_action_space": args.agent == "DQN",
}
ray.init()
if args.use_tune:
time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}"
if args.agent == "DQN":
tune.run(
"DQN",
stop={"timesteps_total": 4000000},
name=name,
config={
"env": recsim_env_name,
"num_gpus": args.num_gpus,
"num_workers": args.num_workers,
"env_config": env_config,
},
num_samples=args.tune_num_samples,
verbose=1)
else:
tune.run(
"SlateQ",
stop={"timesteps_total": 4000000},
name=name,
config={
"env": recsim_env_name,
"num_gpus": args.num_gpus,
"num_workers": args.num_workers,
"slateq_strategy": tune.grid_search(ALL_SLATEQ_STRATEGIES),
"env_config": env_config,
},
num_samples=args.tune_num_samples,
verbose=1)
else:
# directly run using the trainer interface (good for debugging)
if args.agent == "DQN":
config = dqn.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 0
config["env_config"] = env_config
trainer = dqn.DQNTrainer(config=config, env=recsim_env_name)
else:
config = slateq.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 0
config["slateq_strategy"] = args.strategy
config["env_config"] = env_config
trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name)
for i in range(10):
result = trainer.train()
print(pretty_print(result))
ray.shutdown()
if __name__ == "__main__":
main()