diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index 001f921d4..924711cf0 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -110,6 +110,11 @@ def _import_simple_q(): return dqn.SimpleQTrainer +def _import_slate_q(): + from ray.rllib.agents import slateq + return slateq.SlateQTrainer + + def _import_td3(): from ray.rllib.agents import ddpg return ddpg.TD3Trainer @@ -127,6 +132,7 @@ ALGORITHMS = { "DDPG": _import_ddpg, "DDPPO": _import_ddppo, "DQN": _import_dqn, + "SlateQ": _import_slate_q, "DREAMER": _import_dreamer, "IMPALA": _import_impala, "MAML": _import_maml, diff --git a/rllib/agents/slateq/__init__.py b/rllib/agents/slateq/__init__.py new file mode 100644 index 000000000..fb973990f --- /dev/null +++ b/rllib/agents/slateq/__init__.py @@ -0,0 +1,8 @@ +from ray.rllib.agents.slateq.slateq import SlateQTrainer, DEFAULT_CONFIG +from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy + +__all__ = [ + "SlateQTrainer", + "SlateQTorchPolicy", + "DEFAULT_CONFIG", +] diff --git a/rllib/agents/slateq/slateq.py b/rllib/agents/slateq/slateq.py new file mode 100644 index 000000000..f392a16b2 --- /dev/null +++ b/rllib/agents/slateq/slateq.py @@ -0,0 +1,232 @@ +""" +SlateQ (Reinforcement Learning for Recommendation) +================================================== + +This file defines the trainer class for the SlateQ algorithm from the +`"Reinforcement Learning for Slate-based Recommender Systems: A Tractable +Decomposition and Practical Methodology" `_ +paper. + +See `slateq_torch_policy.py` for the definition of the policy. Currently, only +PyTorch is supported. The algorithm is written and tested for Google's RecSim +environment (https://github.com/google-research/recsim). +""" + +import logging +from typing import List, Type + +from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.execution.concurrency_ops import Concurrently +from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.execution.replay_buffer import LocalReplayBuffer +from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer +from ray.rllib.execution.rollout_ops import ParallelRollouts +from ray.rllib.execution.train_ops import TrainOneStep +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.typing import TrainerConfigDict +from ray.util.iter import LocalIterator + +logger = logging.getLogger(__name__) + +# Defines all SlateQ strategies implemented. +ALL_SLATEQ_STRATEGIES = [ + # RANDOM: Randomly select documents for slates. + "RANDOM", + # MYOP: Select documents that maximize user click probabilities. This is + # a myopic strategy and ignores long term rewards. This is equivalent to + # setting a zero discount rate for future rewards. + "MYOP", + # SARSA: Use the SlateQ SARSA learning algorithm. + "SARSA", + # QL: Use the SlateQ Q-learning algorithm. + "QL", +] + +# yapf: disable +# __sphinx_doc_begin__ +DEFAULT_CONFIG = with_common_config({ + # === Model === + # Dense-layer setup for each the advantage branch and the value branch + # in a dueling architecture. + "hiddens": [256, 64, 16], + + # set batchmode + "batch_mode": "complete_episodes", + + # === Deep Learning Framework Settings === + # Currently, only PyTorch is supported + "framework": "torch", + + # === Exploration Settings === + "exploration_config": { + # The Exploration class to use. + "type": "EpsilonGreedy", + # Config for the Exploration class' constructor: + "initial_epsilon": 1.0, + "final_epsilon": 0.02, + "epsilon_timesteps": 10000, # Timesteps over which to anneal epsilon. + }, + # Switch to greedy actions in evaluation workers. + "evaluation_config": { + "explore": False, + }, + + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. + "timesteps_per_iteration": 1000, + # === Replay buffer === + # Size of the replay buffer. Note that if async_updates is set, then + # each worker will have a replay buffer of this size. + "buffer_size": 50000, + # Whether to LZ4 compress observations + "compress_observations": False, + # If set, this will fix the ratio of replayed from a buffer and learned on + # timesteps to sampled from an environment and stored in the replay buffer + # timesteps. Otherwise, the replay will proceed at the native ratio + # determined by (train_batch_size / rollout_fragment_length). + "training_intensity": None, + + # === Optimization === + # Learning rate for adam optimizer for the user choice model + "lr_choice_model": 1e-2, + # Learning rate for adam optimizer for the q model + "lr_q_model": 1e-2, + # Adam epsilon hyper parameter + "adam_epsilon": 1e-8, + # If not None, clip gradients during optimization at this value + "grad_clip": 40, + # How many steps of the model to sample before learning starts. + "learning_starts": 1000, + # Update the replay buffer with this many samples at once. Note that + # this setting applies per-worker if num_workers > 1. + "rollout_fragment_length": 1000, + # Size of a batch sampled from replay buffer for training. Note that + # if async_updates is set, then each worker returns gradients for a + # batch of this size. + "train_batch_size": 32, + + # === Parallelism === + # Number of workers for collecting samples with. This only makes sense + # to increase if your environment is particularly slow to sample, or if + # you"re using the Async or Ape-X optimizers. + "num_workers": 0, + # Whether to compute priorities on workers. + "worker_side_prioritization": False, + # Prevent iterations from going lower than this time span + "min_iter_time_s": 1, + + # === SlateQ specific options === + # Learning method used by the slateq policy. Choose from: RANDOM, + # MYOP (myopic), SARSA, QL (Q-Learning), + "slateq_strategy": "QL", + # user/doc embedding size for the recsim environment + "recsim_embedding_size": 20, +}) +# __sphinx_doc_end__ +# yapf: enable + + +def validate_config(config: TrainerConfigDict) -> None: + """Checks the config based on settings""" + if config["framework"] != "torch": + raise ValueError("SlateQ only runs on PyTorch") + + if config["slateq_strategy"] not in ALL_SLATEQ_STRATEGIES: + raise ValueError("Unknown slateq_strategy: " + f"{config['slateq_strategy']}.") + + if config["slateq_strategy"] == "SARSA": + if config["batch_mode"] != "complete_episodes": + raise ValueError( + "For SARSA strategy, batch_mode must be 'complete_episodes'") + + +def execution_plan(workers: WorkerSet, + config: TrainerConfigDict) -> LocalIterator[dict]: + """Execution plan of the SlateQ algorithm. Defines the distributed dataflow. + + Args: + workers (WorkerSet): The WorkerSet for training the Polic(y/ies) + of the Trainer. + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + LocalIterator[dict]: A local iterator over training metrics. + """ + local_replay_buffer = LocalReplayBuffer( + num_shards=1, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + replay_batch_size=config["train_batch_size"], + replay_mode=config["multiagent"]["replay_mode"], + replay_sequence_length=config["replay_sequence_length"], + ) + + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + # We execute the following steps concurrently: + # (1) Generate rollouts and store them in our local replay buffer. Calling + # next() on store_op drives this. + store_op = rollouts.for_each( + StoreToReplayBuffer(local_buffer=local_replay_buffer)) + + # (2) Read and train on experiences from the replay buffer. Every batch + # returned from the LocalReplay() iterator is passed to TrainOneStep to + # take a SGD step. + replay_op = Replay(local_buffer=local_replay_buffer) \ + .for_each(TrainOneStep(workers)) + + if config["slateq_strategy"] != "RANDOM": + # Alternate deterministically between (1) and (2). Only return the + # output of (2) since training metrics are not available until (2) + # runs. + train_op = Concurrently( + [store_op, replay_op], + mode="round_robin", + output_indexes=[1], + round_robin_weights=calculate_round_robin_weights(config)) + else: + # No training is needed for the RANDOM strategy. + train_op = rollouts + + return StandardMetricsReporting(train_op, workers, config) + + +def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]: + """Calculate the round robin weights for the rollout and train steps""" + if not config["training_intensity"]: + return [1, 1] + # e.g., 32 / 4 -> native ratio of 8.0 + native_ratio = ( + config["train_batch_size"] / config["rollout_fragment_length"]) + # Training intensity is specified in terms of + # (steps_replayed / steps_sampled), so adjust for the native ratio. + weights = [1, config["training_intensity"] / native_ratio] + return weights + + +def get_policy_class(config: TrainerConfigDict) -> Type[Policy]: + """Policy class picker function. + + Args: + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + Type[Policy]: The Policy class to use with SlateQTrainer. + """ + if config["slateq_strategy"] == "RANDOM": + return RandomPolicy + else: + return SlateQTorchPolicy + + +SlateQTrainer = build_trainer( + name="SlateQ", + get_policy_class=get_policy_class, + default_config=DEFAULT_CONFIG, + validate_config=validate_config, + execution_plan=execution_plan) diff --git a/rllib/agents/slateq/slateq_torch_policy.py b/rllib/agents/slateq/slateq_torch_policy.py new file mode 100644 index 000000000..0afb7cb12 --- /dev/null +++ b/rllib/agents/slateq/slateq_torch_policy.py @@ -0,0 +1,420 @@ +"""PyTorch policy class used for SlateQ""" + +from typing import Dict, List, Sequence, Tuple + +import gym +import numpy as np + +import ray +from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions +from ray.rllib.models.torch.torch_action_dist import (TorchCategorical, + TorchDistributionWrapper) +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import (ModelConfigDict, TensorType, + TrainerConfigDict) + +torch, nn = try_import_torch() +F = None +if nn: + F = nn.functional + + +class QValueModel(nn.Module): + """The Q-value model for SlateQ""" + + def __init__(self, embedding_size: int, q_hiddens: Sequence[int]): + super().__init__() + + # construct hidden layers + layers = [] + ins = 2 * embedding_size + for n in q_hiddens: + layers.append(nn.Linear(ins, n)) + layers.append(nn.LeakyReLU()) + ins = n + layers.append(nn.Linear(ins, 1)) + self.layers = nn.Sequential(*layers) + + def forward(self, user: TensorType, doc: TensorType) -> TensorType: + """Evaluate the user-doc Q model + + Args: + user (TensorType): User embedding of shape (batch_size, + embedding_size). + doc (TensorType): Doc embeddings of shape (batch_size, num_docs, + embedding_size). + + Returns: + score (TensorType): q_values of shape (batch_size, num_docs + 1). + """ + batch_size, num_docs, embedding_size = doc.shape + doc_flat = doc.view((batch_size * num_docs, embedding_size)) + user_repeated = user.repeat(num_docs, 1) + x = torch.cat([user_repeated, doc_flat], dim=1) + x = self.layers(x) + # Similar to Google's SlateQ implementation in RecSim, we force the + # Q-values to zeros if there are no clicks. + x_no_click = torch.zeros((batch_size, 1), device=x.device) + return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1) + + +class UserChoiceModel(nn.Module): + r"""The user choice model for SlateQ + + This class implements a multinomial logit model for predicting user clicks. + + Under this model, the click probability of a document is proportional to: + + .. math:: + \exp(\text{beta} * \text{doc_user_affinity} + \text{score_no_click}) + """ + + def __init__(self): + super().__init__() + self.beta = nn.Parameter(torch.tensor(0., dtype=torch.float)) + self.score_no_click = nn.Parameter(torch.tensor(0., dtype=torch.float)) + + def forward(self, user: TensorType, doc: TensorType) -> TensorType: + """Evaluate the user choice model + + This function outputs user click scores for candidate documents. The + exponentials of these scores are proportional user click probabilities. + Here we return the scores unnormalized because because only some of the + documents will be selected and shown to the user. + + Args: + user (TensorType): User embeddings of shape (batch_size, + embedding_size). + doc (TensorType): Doc embeddings of shape (batch_size, num_docs, + embedding_size). + + Returns: + score (TensorType): logits of shape (batch_size, num_docs + 1), + where the last dimension represents no_click. + """ + batch_size = user.shape[0] + s = torch.einsum("be,bde->bd", user, doc) + s = s * self.beta + s = torch.cat([s, self.score_no_click.expand((batch_size, 1))], dim=1) + return s + + +class SlateQModel(TorchModelV2, nn.Module): + """The SlateQ model class + + It includes both the user choice model and the Q-value model. + """ + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + model_config: ModelConfigDict, + name: str, + *, + embedding_size: int, + q_hiddens: Sequence[int], + ): + nn.Module.__init__(self) + TorchModelV2.__init__( + self, + obs_space, + action_space, + # This required parameter (num_outputs) seems redundant: it has no + # real imact, and can be set arbitrarily. TODO: fix this. + num_outputs=0, + model_config=model_config, + name=name) + self.choice_model = UserChoiceModel() + self.q_model = QValueModel(embedding_size, q_hiddens) + self.slate_size = len(action_space.nvec) + + def choose_slate(self, user: TensorType, + doc: TensorType) -> Tuple[TensorType, TensorType]: + """Build a slate by selecting from candidate documents + + Args: + user (TensorType): User embeddings of shape (batch_size, + embedding_size). + doc (TensorType): Doc embeddings of shape (batch_size, + num_docs, embedding_size). + + Returns: + slate_selected (TensorType): Indices of documents selected for + the slate, with shape (batch_size, slate_size). + best_slate_q_value (TensorType): The Q-value of the selected slate, + with shape (batch_size). + """ + # Step 1: compute item scores (proportional to click probabilities) + # raw_scores.shape=[batch_size, num_docs+1] + raw_scores = self.choice_model(user, doc) + # max_raw_scores.shape=[batch_size, 1] + max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) + # deduct scores by max_scores to avoid value explosion + scores = torch.exp(raw_scores - max_raw_scores) + scores_doc = scores[:, :-1] # shape=[batch_size, num_docs] + scores_no_click = scores[:, [-1]] # shape=[batch_size, 1] + + # Step 2: calculate the item-wise Q values + # q_values.shape=[batch_size, num_docs+1] + q_values = self.q_model(user, doc) + q_values_doc = q_values[:, :-1] # shape=[batch_size, num_docs] + q_values_no_click = q_values[:, [-1]] # shape=[batch_size, 1] + + # Step 3: construct all possible slates + _, num_docs, _ = doc.shape + indices = torch.arange(num_docs, dtype=torch.long, device=doc.device) + # slates.shape = [num_slates, slate_size] + slates = torch.combinations(indices, r=self.slate_size) + num_slates, _ = slates.shape + + # Step 4: calculate slate Q values + batch_size, _ = q_values_doc.shape + # slate_decomp_q_values.shape: [batch_size, num_slates, slate_size] + slate_decomp_q_values = torch.gather( + # input.shape: [batch_size, num_slates, num_docs] + input=q_values_doc.unsqueeze(1).expand(-1, num_slates, -1), + dim=2, + # index.shape: [batch_size, num_slates, slate_size] + index=slates.unsqueeze(0).expand(batch_size, -1, -1)) + # slate_scores.shape: [batch_size, num_slates, slate_size] + slate_scores = torch.gather( + # input.shape: [batch_size, num_slates, num_docs] + input=scores_doc.unsqueeze(1).expand(-1, num_slates, -1), + dim=2, + # index.shape: [batch_size, num_slates, slate_size] + index=slates.unsqueeze(0).expand(batch_size, -1, -1)) + # slate_q_values.shape: [batch_size, num_slates] + slate_q_values = ((slate_decomp_q_values * slate_scores).sum(dim=2) + + (q_values_no_click * scores_no_click)) / ( + slate_scores.sum(dim=2) + scores_no_click) + + # Step 5: find the slate that maximizes q value + best_slate_q_value, max_idx = torch.max(slate_q_values, dim=1) + # slates_selected.shape: [batch_size, slate_size] + slates_selected = slates[max_idx] + return slates_selected, best_slate_q_value + + def forward(self, input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]: + # user.shape: [batch_size, embedding_size] + user = input_dict[SampleBatch.OBS]["user"] + # doc.shape: [batch_size, num_docs, embedding_size] + doc = torch.cat([ + val.unsqueeze(1) + for val in input_dict[SampleBatch.OBS]["doc"].values() + ], 1) + + slates_selected, _ = self.choose_slate(user, doc) + + state_out = [] + return slates_selected, state_out + + +def build_slateq_model_and_distribution( + policy: Policy, obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]: + """Build models for SlateQ + + Args: + policy (Policy): The policy, which will use the model for optimization. + obs_space (gym.spaces.Space): The policy's observation space. + action_space (gym.spaces.Space): The policy's action space. + config (TrainerConfigDict): + + Returns: + (q_model, TorchCategorical) + """ + model = SlateQModel( + obs_space, + action_space, + model_config=config["model"], + name="slateq_model", + embedding_size=config["recsim_embedding_size"], + q_hiddens=config["hiddens"], + ) + return model, TorchCategorical + + +def build_slateq_losses(policy: Policy, model: SlateQModel, _, + train_batch: SampleBatch) -> TensorType: + """Constructs the losses for SlateQPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ + obs = restore_original_dimensions( + train_batch[SampleBatch.OBS], + policy.observation_space, + tensorlib=torch) + # user.shape: [batch_size, embedding_size] + user = obs["user"] + # doc.shape: [batch_size, num_docs, embedding_size] + doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1) + # action.shape: [batch_size, slate_size] + actions = train_batch[SampleBatch.ACTIONS] + + next_obs = restore_original_dimensions( + train_batch[SampleBatch.NEXT_OBS], + policy.observation_space, + tensorlib=torch) + + # Step 1: Build user choice model loss + _, _, embedding_size = doc.shape + # selected_doc.shape: [batch_size, slate_size, embedding_size] + selected_doc = torch.gather( + # input.shape: [batch_size, num_docs, embedding_size] + input=doc, + dim=1, + # index.shape: [batch_size, slate_size, embedding_size] + index=actions.unsqueeze(2).expand(-1, -1, embedding_size)) + + scores = model.choice_model(user, selected_doc) + choice_loss_fn = nn.CrossEntropyLoss() + + # clicks.shape: [batch_size, slate_size] + clicks = torch.stack( + [resp["click"][:, 1] for resp in next_obs["response"]], dim=1) + no_clicks = 1 - torch.sum(clicks, 1, keepdim=True) + # clicks.shape: [batch_size, slate_size+1] + targets = torch.cat([clicks, no_clicks], dim=1) + choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1)) + # print(model.choice_model.a.item(), model.choice_model.b.item()) + + # Step 2: Build qvalue loss + # Fields in available in train_batch: ['t', 'eps_id', 'agent_index', + # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions', + # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights', + # 'batch_indexes'] + learning_strategy = policy.config["slateq_strategy"] + + if learning_strategy == "SARSA": + # next_doc.shape: [batch_size, num_docs, embedding_size] + next_doc = torch.cat( + [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) + next_actions = train_batch["next_actions"] + _, _, embedding_size = next_doc.shape + # selected_doc.shape: [batch_size, slate_size, embedding_size] + next_selected_doc = torch.gather( + # input.shape: [batch_size, num_docs, embedding_size] + input=next_doc, + dim=1, + # index.shape: [batch_size, slate_size, embedding_size] + index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size)) + next_user = next_obs["user"] + dones = train_batch["dones"] + with torch.no_grad(): + # q_values.shape: [batch_size, slate_size+1] + q_values = model.q_model(next_user, next_selected_doc) + # raw_scores.shape: [batch_size, slate_size+1] + raw_scores = model.choice_model(next_user, next_selected_doc) + max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) + scores = torch.exp(raw_scores - max_raw_scores) + # next_q_values.shape: [batch_size] + next_q_values = torch.sum( + q_values * scores, dim=1) / torch.sum( + scores, dim=1) + next_q_values[dones] = 0.0 + elif learning_strategy == "MYOP": + next_q_values = 0. + elif learning_strategy == "QL": + # next_doc.shape: [batch_size, num_docs, embedding_size] + next_doc = torch.cat( + [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) + next_user = next_obs["user"] + dones = train_batch["dones"] + with torch.no_grad(): + _, next_q_values = model.choose_slate(next_user, next_doc) + next_q_values[dones] = 0.0 + else: + raise ValueError(learning_strategy) + # target_q_values.shape: [batch_size] + target_q_values = next_q_values + train_batch["rewards"] + + q_values = model.q_model(user, + selected_doc) # shape: [batch_size, slate_size+1] + # raw_scores.shape: [batch_size, slate_size+1] + raw_scores = model.choice_model(user, selected_doc) + max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) + scores = torch.exp(raw_scores - max_raw_scores) + q_values = torch.sum( + q_values * scores, dim=1) / torch.sum( + scores, dim=1) # shape=[batch_size] + + q_value_loss = nn.MSELoss()(q_values, target_q_values) + return [choice_loss, q_value_loss] + + +def build_slateq_optimizers(policy: Policy, config: TrainerConfigDict + ) -> List["torch.optim.Optimizer"]: + optimizer_choice = torch.optim.Adam( + policy.model.choice_model.parameters(), lr=config["lr_choice_model"]) + optimizer_q_value = torch.optim.Adam( + policy.model.q_model.parameters(), + lr=config["lr_q_model"], + eps=config["adam_epsilon"]) + return [optimizer_choice, optimizer_q_value] + + +def action_sampler_fn(policy: Policy, model: SlateQModel, input_dict, state, + explore, timestep): + """Determine which action to take""" + # First, we transform the observation into its unflattened form + obs = restore_original_dimensions( + input_dict[SampleBatch.CUR_OBS], + policy.observation_space, + tensorlib=torch) + + # user.shape: [batch_size(=1), embedding_size] + user = obs["user"] + # doc.shape: [batch_size(=1), num_docs, embedding_size] + doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1) + + selected_slates, _ = model.choose_slate(user, doc) + + action = selected_slates + logp = None + state_out = [] + return action, logp, state_out + + +def postprocess_fn_add_next_actions_for_sarsa(policy: Policy, + batch: SampleBatch, + other_agent=None, + episode=None) -> SampleBatch: + """Add next_actions to SampleBatch for SARSA training""" + if policy.config["slateq_strategy"] == "SARSA": + if not batch["dones"][-1]: + raise RuntimeError( + "Expected a complete episode in each sample batch. " + f"But this batch is not: {batch}.") + batch["next_actions"] = np.roll(batch["actions"], -1, axis=0) + return batch + + +SlateQTorchPolicy = build_torch_policy( + name="SlateQTorchPolicy", + get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG, + + # build model, loss functions, and optimizers + make_model_and_action_dist=build_slateq_model_and_distribution, + optimizer_fn=build_slateq_optimizers, + loss_fn=build_slateq_losses, + + # define how to act + action_sampler_fn=action_sampler_fn, + + # post processing batch sampled data + postprocess_fn=postprocess_fn_add_next_actions_for_sarsa, +) diff --git a/rllib/env/wrappers/recsim_wrapper.py b/rllib/env/wrappers/recsim_wrapper.py index cab1db7ea..ad846262a 100644 --- a/rllib/env/wrappers/recsim_wrapper.py +++ b/rllib/env/wrappers/recsim_wrapper.py @@ -55,11 +55,14 @@ class RecSimObservationSpaceWrapper(gym.ObservationWrapper): class RecSimResetWrapper(gym.Wrapper): - """Fix RecSim environment's reset() function + """Fix RecSim environment's reset() and close() function RecSim's reset() function returns an observation without the "response" field, breaking RLlib's check. This wrapper fixes that by assigning a random "response". + + RecSim's close() function raises NotImplementedError. We change the + behavior to doing nothing. """ def reset(self): @@ -67,6 +70,9 @@ class RecSimResetWrapper(gym.Wrapper): obs["response"] = self.env.observation_space["response"].sample() return obs + def close(self): + pass + class MultiDiscreteToDiscreteActionWrapper(gym.ActionWrapper): """Convert the action space from MultiDiscrete to Discrete @@ -108,7 +114,7 @@ def make_recsim_env(config): env = interest_evolution.create_environment(env_config) env = RecSimResetWrapper(env) env = RecSimObservationSpaceWrapper(env) - if config and config["convert_to_discrete_action_space"]: + if env_config and env_config["convert_to_discrete_action_space"]: env = MultiDiscreteToDiscreteActionWrapper(env) return env diff --git a/rllib/examples/policy/random_policy.py b/rllib/examples/policy/random_policy.py index 0ae04627c..5eaf48952 100644 --- a/rllib/examples/policy/random_policy.py +++ b/rllib/examples/policy/random_policy.py @@ -1,9 +1,11 @@ -from gym.spaces import Box -import numpy as np import random +import numpy as np +from gym.spaces import Box + from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import ModelWeights class RandomPolicy(Policy): @@ -50,3 +52,13 @@ class RandomPolicy(Policy): prev_action_batch=None, prev_reward_batch=None): return np.array([random.random()] * len(obs_batch)) + + @override(Policy) + def get_weights(self) -> ModelWeights: + """No weights to save.""" + return {} + + @override(Policy) + def set_weights(self, weights: ModelWeights) -> None: + """No weights to set.""" + pass diff --git a/rllib/examples/slateq.py b/rllib/examples/slateq.py new file mode 100644 index 000000000..ee744314e --- /dev/null +++ b/rllib/examples/slateq.py @@ -0,0 +1,115 @@ +"""The SlateQ algorithm for recommendation""" + +import argparse +from datetime import datetime + +import ray +from ray import tune +from ray.rllib.agents import slateq +from ray.rllib.agents import dqn +from ray.rllib.agents.slateq.slateq import ALL_SLATEQ_STRATEGIES +from ray.rllib.env.wrappers.recsim_wrapper import env_name as recsim_env_name +from ray.tune.logger import pretty_print + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--agent", + type=str, + default="SlateQ", + help=("Select agent policy. Choose from: DQN and SlateQ. " + "Default value: SlateQ."), + ) + parser.add_argument( + "--strategy", + type=str, + default="QL", + help=("Strategy for the SlateQ agent. Choose from: " + + ", ".join(ALL_SLATEQ_STRATEGIES) + ". " + "Default value: QL. Ignored when using Tune."), + ) + parser.add_argument( + "--use-tune", + action="store_true", + help=("Run with Tune so that the results are logged into Tensorboard. " + "For debugging, it's easier to run without Ray Tune."), + ) + parser.add_argument("--tune-num-samples", type=int, default=10) + parser.add_argument("--env-slate-size", type=int, default=2) + parser.add_argument("--env-seed", type=int, default=0) + parser.add_argument( + "--num-gpus", + type=float, + default=0., + help="Only used if running with Tune.") + parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Only used if running with Tune.") + args = parser.parse_args() + + if args.agent not in ["DQN", "SlateQ"]: + raise ValueError(args.agent) + + env_config = { + "slate_size": args.env_slate_size, + "seed": args.env_seed, + "convert_to_discrete_action_space": args.agent == "DQN", + } + + ray.init() + if args.use_tune: + time_signature = datetime.now().strftime("%Y-%m-%d_%H_%M_%S") + name = f"SlateQ/{args.agent}-seed{args.env_seed}-{time_signature}" + if args.agent == "DQN": + tune.run( + "DQN", + stop={"timesteps_total": 4000000}, + name=name, + config={ + "env": recsim_env_name, + "num_gpus": args.num_gpus, + "num_workers": args.num_workers, + "env_config": env_config, + }, + num_samples=args.tune_num_samples, + verbose=1) + else: + tune.run( + "SlateQ", + stop={"timesteps_total": 4000000}, + name=name, + config={ + "env": recsim_env_name, + "num_gpus": args.num_gpus, + "num_workers": args.num_workers, + "slateq_strategy": tune.grid_search(ALL_SLATEQ_STRATEGIES), + "env_config": env_config, + }, + num_samples=args.tune_num_samples, + verbose=1) + else: + # directly run using the trainer interface (good for debugging) + if args.agent == "DQN": + config = dqn.DEFAULT_CONFIG.copy() + config["num_gpus"] = 0 + config["num_workers"] = 0 + config["env_config"] = env_config + trainer = dqn.DQNTrainer(config=config, env=recsim_env_name) + else: + config = slateq.DEFAULT_CONFIG.copy() + config["num_gpus"] = 0 + config["num_workers"] = 0 + config["slateq_strategy"] = args.strategy + config["env_config"] = env_config + trainer = slateq.SlateQTrainer(config=config, env=recsim_env_name) + for i in range(10): + result = trainer.train() + print(pretty_print(result)) + ray.shutdown() + + +if __name__ == "__main__": + main()