2022-02-04 17:01:12 +01:00
|
|
|
"""PyTorch policy class used for SlateQ."""
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
import gym
|
2022-02-04 17:01:12 +01:00
|
|
|
import logging
|
2020-11-03 00:52:04 -08:00
|
|
|
import numpy as np
|
2022-02-04 17:01:12 +01:00
|
|
|
import time
|
|
|
|
from typing import Dict, List, Tuple, Type
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
import ray
|
2022-02-04 17:01:12 +01:00
|
|
|
from ray.rllib.agents.sac.sac_torch_policy import TargetNetworkMixin
|
|
|
|
from ray.rllib.agents.slateq.slateq_torch_model import SlateQModel
|
2020-11-03 00:52:04 -08:00
|
|
|
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.models.torch.torch_action_dist import (
|
|
|
|
TorchCategorical,
|
|
|
|
TorchDistributionWrapper,
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-12-26 20:14:18 -05:00
|
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
2020-11-03 00:52:04 -08:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2022-02-04 17:01:12 +01:00
|
|
|
from ray.rllib.utils.torch_utils import apply_grad_clipping, huber_loss
|
|
|
|
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
torch, nn = try_import_torch()
|
2022-02-04 17:01:12 +01:00
|
|
|
logger = logging.getLogger(__name__)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
|
|
|
|
def build_slateq_model_and_distribution(
|
2022-01-29 18:41:57 -08:00
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
2022-02-04 17:01:12 +01:00
|
|
|
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
2020-11-03 00:52:04 -08:00
|
|
|
"""Build models for SlateQ
|
|
|
|
|
|
|
|
Args:
|
2022-02-04 17:01:12 +01:00
|
|
|
policy: The policy, which will use the model for optimization.
|
|
|
|
obs_space: The policy's observation space.
|
|
|
|
action_space: The policy's action space.
|
|
|
|
config: The Trainer's config dict.
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
Returns:
|
2022-02-04 17:01:12 +01:00
|
|
|
Tuple consisting of 1) Q-model and 2) an action distribution class.
|
2020-11-03 00:52:04 -08:00
|
|
|
"""
|
|
|
|
model = SlateQModel(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
model_config=config["model"],
|
|
|
|
name="slateq_model",
|
2022-02-04 17:01:12 +01:00
|
|
|
user_embedding_size=obs_space.original_space["user"].shape[0],
|
|
|
|
doc_embedding_size=obs_space.original_space["doc"]["0"].shape[0],
|
|
|
|
num_docs=len(obs_space.original_space["doc"].spaces),
|
2020-11-03 00:52:04 -08:00
|
|
|
q_hiddens=config["hiddens"],
|
2022-02-04 17:01:12 +01:00
|
|
|
double_q=config["double_q"],
|
2020-11-03 00:52:04 -08:00
|
|
|
)
|
2022-02-04 17:01:12 +01:00
|
|
|
|
|
|
|
policy.target_model = SlateQModel(
|
|
|
|
obs_space,
|
|
|
|
action_space,
|
|
|
|
model_config=config["model"],
|
|
|
|
name="target_slateq_model",
|
|
|
|
user_embedding_size=obs_space.original_space["user"].shape[0],
|
|
|
|
doc_embedding_size=obs_space.original_space["doc"]["0"].shape[0],
|
|
|
|
num_docs=len(obs_space.original_space["doc"].spaces),
|
|
|
|
q_hiddens=config["hiddens"],
|
|
|
|
double_q=config["double_q"],
|
|
|
|
)
|
|
|
|
|
2020-11-03 00:52:04 -08:00
|
|
|
return model, TorchCategorical
|
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def build_slateq_losses(
|
2022-02-04 17:01:12 +01:00
|
|
|
policy: Policy,
|
|
|
|
model: ModelV2,
|
|
|
|
_: Type[TorchDistributionWrapper],
|
|
|
|
train_batch: SampleBatch,
|
2022-01-29 18:41:57 -08:00
|
|
|
) -> TensorType:
|
2022-02-04 17:01:12 +01:00
|
|
|
"""Constructs the choice- and Q-value losses for the SlateQTorchPolicy.
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
Args:
|
2022-02-04 17:01:12 +01:00
|
|
|
policy: The Policy to calculate the loss for.
|
|
|
|
model: The Model to calculate the loss for.
|
|
|
|
train_batch: The training data.
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
Returns:
|
2022-02-04 17:01:12 +01:00
|
|
|
Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors.
|
2020-11-03 00:52:04 -08:00
|
|
|
"""
|
2022-02-04 17:01:12 +01:00
|
|
|
start = time.time()
|
2020-11-03 00:52:04 -08:00
|
|
|
obs = restore_original_dimensions(
|
2022-01-29 18:41:57 -08:00
|
|
|
train_batch[SampleBatch.OBS], policy.observation_space, tensorlib=torch
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
# user.shape: [batch_size, embedding_size]
|
|
|
|
user = obs["user"]
|
|
|
|
# doc.shape: [batch_size, num_docs, embedding_size]
|
|
|
|
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
|
|
|
|
# action.shape: [batch_size, slate_size]
|
|
|
|
actions = train_batch[SampleBatch.ACTIONS]
|
|
|
|
|
|
|
|
next_obs = restore_original_dimensions(
|
2022-01-29 18:41:57 -08:00
|
|
|
train_batch[SampleBatch.NEXT_OBS], policy.observation_space, tensorlib=torch
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
# Step 1: Build user choice model loss
|
|
|
|
_, _, embedding_size = doc.shape
|
|
|
|
# selected_doc.shape: [batch_size, slate_size, embedding_size]
|
|
|
|
selected_doc = torch.gather(
|
|
|
|
# input.shape: [batch_size, num_docs, embedding_size]
|
|
|
|
input=doc,
|
|
|
|
dim=1,
|
|
|
|
# index.shape: [batch_size, slate_size, embedding_size]
|
2022-02-04 17:01:12 +01:00
|
|
|
index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
scores = model.choice_model(user, selected_doc)
|
|
|
|
choice_loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
# clicks.shape: [batch_size, slate_size]
|
2022-01-29 18:41:57 -08:00
|
|
|
clicks = torch.stack([resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
|
2020-11-03 00:52:04 -08:00
|
|
|
no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
|
|
|
|
# clicks.shape: [batch_size, slate_size+1]
|
|
|
|
targets = torch.cat([clicks, no_clicks], dim=1)
|
|
|
|
choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
|
|
|
|
# print(model.choice_model.a.item(), model.choice_model.b.item())
|
|
|
|
|
|
|
|
# Step 2: Build qvalue loss
|
|
|
|
# Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
|
|
|
|
# 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
|
|
|
|
# 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
|
|
|
|
# 'batch_indexes']
|
|
|
|
learning_strategy = policy.config["slateq_strategy"]
|
|
|
|
|
|
|
|
if learning_strategy == "SARSA":
|
|
|
|
# next_doc.shape: [batch_size, num_docs, embedding_size]
|
2022-01-29 18:41:57 -08:00
|
|
|
next_doc = torch.cat([val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
|
2020-11-03 00:52:04 -08:00
|
|
|
next_actions = train_batch["next_actions"]
|
|
|
|
_, _, embedding_size = next_doc.shape
|
|
|
|
# selected_doc.shape: [batch_size, slate_size, embedding_size]
|
|
|
|
next_selected_doc = torch.gather(
|
|
|
|
# input.shape: [batch_size, num_docs, embedding_size]
|
|
|
|
input=next_doc,
|
|
|
|
dim=1,
|
|
|
|
# index.shape: [batch_size, slate_size, embedding_size]
|
2022-02-04 17:01:12 +01:00
|
|
|
index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
next_user = next_obs["user"]
|
2022-02-04 17:01:12 +01:00
|
|
|
dones = train_batch[SampleBatch.DONES]
|
2020-11-03 00:52:04 -08:00
|
|
|
with torch.no_grad():
|
|
|
|
# q_values.shape: [batch_size, slate_size+1]
|
|
|
|
q_values = model.q_model(next_user, next_selected_doc)
|
|
|
|
# raw_scores.shape: [batch_size, slate_size+1]
|
|
|
|
raw_scores = model.choice_model(next_user, next_selected_doc)
|
|
|
|
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
|
|
|
|
scores = torch.exp(raw_scores - max_raw_scores)
|
|
|
|
# next_q_values.shape: [batch_size]
|
2022-01-29 18:41:57 -08:00
|
|
|
next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
|
|
|
|
scores, dim=1
|
|
|
|
)
|
2022-02-04 17:01:12 +01:00
|
|
|
next_q_values[dones.bool()] = 0.0
|
2020-11-03 00:52:04 -08:00
|
|
|
elif learning_strategy == "MYOP":
|
2022-01-29 18:41:57 -08:00
|
|
|
next_q_values = 0.0
|
2020-11-03 00:52:04 -08:00
|
|
|
elif learning_strategy == "QL":
|
|
|
|
# next_doc.shape: [batch_size, num_docs, embedding_size]
|
2022-01-29 18:41:57 -08:00
|
|
|
next_doc = torch.cat([val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
|
2020-11-03 00:52:04 -08:00
|
|
|
next_user = next_obs["user"]
|
2022-02-04 17:01:12 +01:00
|
|
|
dones = train_batch[SampleBatch.DONES]
|
2020-11-03 00:52:04 -08:00
|
|
|
with torch.no_grad():
|
2022-02-04 17:01:12 +01:00
|
|
|
if policy.config["double_q"]:
|
|
|
|
next_target_per_slate_q_values = policy.target_models[
|
|
|
|
model
|
|
|
|
].get_per_slate_q_values(next_user, next_doc)
|
|
|
|
_, next_q_values, _ = model.choose_slate(
|
|
|
|
next_user, next_doc, next_target_per_slate_q_values
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
_, next_q_values, _ = policy.target_models[model].choose_slate(
|
|
|
|
next_user, next_doc
|
|
|
|
)
|
|
|
|
next_q_values = next_q_values.detach()
|
|
|
|
next_q_values[dones.bool()] = 0.0
|
2020-11-03 00:52:04 -08:00
|
|
|
else:
|
|
|
|
raise ValueError(learning_strategy)
|
|
|
|
# target_q_values.shape: [batch_size]
|
2022-02-04 17:01:12 +01:00
|
|
|
target_q_values = (
|
|
|
|
train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * next_q_values
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
2022-02-04 17:01:12 +01:00
|
|
|
# q_values.shape: [batch_size, slate_size+1].
|
|
|
|
q_values = model.q_model(user, selected_doc)
|
|
|
|
# raw_scores.shape: [batch_size, slate_size+1].
|
2020-11-03 00:52:04 -08:00
|
|
|
raw_scores = model.choice_model(user, selected_doc)
|
|
|
|
max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
|
|
|
|
scores = torch.exp(raw_scores - max_raw_scores)
|
2022-01-29 18:41:57 -08:00
|
|
|
q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
|
|
|
|
scores, dim=1
|
|
|
|
) # shape=[batch_size]
|
2022-02-04 17:01:12 +01:00
|
|
|
td_error = torch.abs(q_values - target_q_values)
|
|
|
|
q_value_loss = torch.mean(huber_loss(td_error))
|
|
|
|
|
|
|
|
# Store values for stats function in model (tower), such that for
|
|
|
|
# multi-GPU, we do not override them during the parallel loss phase.
|
|
|
|
model.tower_stats["choice_loss"] = choice_loss
|
|
|
|
model.tower_stats["q_loss"] = q_value_loss
|
|
|
|
model.tower_stats["q_values"] = q_values
|
|
|
|
model.tower_stats["next_q_values"] = next_q_values
|
|
|
|
model.tower_stats["next_q_minus_q"] = next_q_values - q_values
|
|
|
|
model.tower_stats["td_error"] = td_error
|
|
|
|
model.tower_stats["target_q_values"] = target_q_values
|
|
|
|
model.tower_stats["raw_scores"] = raw_scores
|
|
|
|
|
|
|
|
logger.debug(f"loss calculation took {time.time()-start}s")
|
|
|
|
return choice_loss, q_value_loss
|
|
|
|
|
|
|
|
|
|
|
|
def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]:
|
|
|
|
stats = {
|
|
|
|
"q_loss": torch.mean(torch.stack(policy.get_tower_stats("q_loss"))),
|
|
|
|
"choice_loss": torch.mean(torch.stack(policy.get_tower_stats("choice_loss"))),
|
|
|
|
"raw_scores": torch.mean(torch.stack(policy.get_tower_stats("raw_scores"))),
|
|
|
|
"q_values": torch.mean(torch.stack(policy.get_tower_stats("q_values"))),
|
|
|
|
"next_q_values": torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("next_q_values"))
|
|
|
|
),
|
|
|
|
"next_q_minus_q": torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("next_q_minus_q"))
|
|
|
|
),
|
|
|
|
"target_q_values": torch.mean(
|
|
|
|
torch.stack(policy.get_tower_stats("target_q_values"))
|
|
|
|
),
|
|
|
|
"td_error": torch.mean(torch.stack(policy.get_tower_stats("td_error"))),
|
|
|
|
}
|
|
|
|
model_stats = {
|
|
|
|
k: torch.mean(var)
|
|
|
|
for k, var in policy.model.trainable_variables(as_dict=True).items()
|
|
|
|
}
|
|
|
|
stats.update(model_stats)
|
|
|
|
return stats
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def build_slateq_optimizers(
|
|
|
|
policy: Policy, config: TrainerConfigDict
|
|
|
|
) -> List["torch.optim.Optimizer"]:
|
2020-11-03 00:52:04 -08:00
|
|
|
optimizer_choice = torch.optim.Adam(
|
2022-01-29 18:41:57 -08:00
|
|
|
policy.model.choice_model.parameters(), lr=config["lr_choice_model"]
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
optimizer_q_value = torch.optim.Adam(
|
|
|
|
policy.model.q_model.parameters(),
|
|
|
|
lr=config["lr_q_model"],
|
2022-01-29 18:41:57 -08:00
|
|
|
eps=config["adam_epsilon"],
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
return [optimizer_choice, optimizer_q_value]
|
|
|
|
|
|
|
|
|
2022-02-04 17:01:12 +01:00
|
|
|
def action_distribution_fn(
|
|
|
|
policy: Policy, model: SlateQModel, input_dict, *, explore, is_training, **kwargs
|
2022-01-29 18:41:57 -08:00
|
|
|
):
|
2020-11-03 00:52:04 -08:00
|
|
|
"""Determine which action to take"""
|
2022-02-04 17:01:12 +01:00
|
|
|
# First, we transform the observation into its unflattened form.
|
|
|
|
|
|
|
|
# start = time.time()
|
2020-11-03 00:52:04 -08:00
|
|
|
obs = restore_original_dimensions(
|
2022-01-29 18:41:57 -08:00
|
|
|
input_dict[SampleBatch.CUR_OBS], policy.observation_space, tensorlib=torch
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
# user.shape: [batch_size(=1), embedding_size]
|
|
|
|
user = obs["user"]
|
|
|
|
# doc.shape: [batch_size(=1), num_docs, embedding_size]
|
|
|
|
doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
|
|
|
|
|
2022-02-04 17:01:12 +01:00
|
|
|
_, _, per_slate_q_values = model.choose_slate(user, doc)
|
2020-11-03 00:52:04 -08:00
|
|
|
|
2022-02-04 17:01:12 +01:00
|
|
|
return per_slate_q_values, TorchCategorical, []
|
2020-11-03 00:52:04 -08:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def postprocess_fn_add_next_actions_for_sarsa(
|
|
|
|
policy: Policy, batch: SampleBatch, other_agent=None, episode=None
|
|
|
|
) -> SampleBatch:
|
2020-11-03 00:52:04 -08:00
|
|
|
"""Add next_actions to SampleBatch for SARSA training"""
|
|
|
|
if policy.config["slateq_strategy"] == "SARSA":
|
2022-02-04 17:01:12 +01:00
|
|
|
if not batch["dones"][-1] and policy._no_tracing is False:
|
2020-11-03 00:52:04 -08:00
|
|
|
raise RuntimeError(
|
|
|
|
"Expected a complete episode in each sample batch. "
|
2022-01-29 18:41:57 -08:00
|
|
|
f"But this batch is not: {batch}."
|
|
|
|
)
|
2020-11-03 00:52:04 -08:00
|
|
|
batch["next_actions"] = np.roll(batch["actions"], -1, axis=0)
|
2022-02-04 17:01:12 +01:00
|
|
|
|
2020-11-03 00:52:04 -08:00
|
|
|
return batch
|
|
|
|
|
|
|
|
|
2022-02-04 17:01:12 +01:00
|
|
|
def setup_late_mixins(
|
|
|
|
policy: Policy,
|
|
|
|
obs_space: gym.spaces.Space,
|
|
|
|
action_space: gym.spaces.Space,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
) -> None:
|
|
|
|
"""Call all mixin classes' constructors before SlateQTorchPolicy initialization.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
policy: The Policy object.
|
|
|
|
obs_space: The Policy's observation space.
|
|
|
|
action_space: The Policy's action space.
|
|
|
|
config: The Policy's config.
|
|
|
|
"""
|
|
|
|
TargetNetworkMixin.__init__(policy)
|
|
|
|
|
|
|
|
|
2020-12-26 20:14:18 -05:00
|
|
|
SlateQTorchPolicy = build_policy_class(
|
2020-11-03 00:52:04 -08:00
|
|
|
name="SlateQTorchPolicy",
|
2020-12-26 20:14:18 -05:00
|
|
|
framework="torch",
|
2020-11-03 00:52:04 -08:00
|
|
|
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,
|
2022-02-04 17:01:12 +01:00
|
|
|
after_init=setup_late_mixins,
|
|
|
|
loss_fn=build_slateq_losses,
|
|
|
|
stats_fn=build_slateq_stats,
|
|
|
|
# Build model, loss functions, and optimizers
|
2020-11-03 00:52:04 -08:00
|
|
|
make_model_and_action_dist=build_slateq_model_and_distribution,
|
|
|
|
optimizer_fn=build_slateq_optimizers,
|
2022-02-04 17:01:12 +01:00
|
|
|
# Define how to act.
|
|
|
|
action_distribution_fn=action_distribution_fn,
|
|
|
|
# Post processing sampled trajectory data.
|
2020-11-03 00:52:04 -08:00
|
|
|
postprocess_fn=postprocess_fn_add_next_actions_for_sarsa,
|
2022-02-04 17:01:12 +01:00
|
|
|
extra_grad_process_fn=apply_grad_clipping,
|
|
|
|
mixins=[TargetNetworkMixin],
|
2020-11-03 00:52:04 -08:00
|
|
|
)
|