mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] SlateQ (tf GPU + multi-GPU) + Bandit fixes (#23276)
This commit is contained in:
parent
da140a80e9
commit
b1cda46681
9 changed files with 133 additions and 78 deletions
10
rllib/BUILD
10
rllib/BUILD
|
@ -545,6 +545,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/slateq"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_interest_evolution_10_candidates_recsim_env_slateq_fake_gpus",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:ml", "learning_tests", "learning_tests_discrete", "fake_gpus"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/slateq/interest-evolution-10-candidates-recsim-env-slateq.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/slateq"]
|
||||
)
|
||||
|
||||
# TD3
|
||||
py_test(
|
||||
name = "learning_tests_pendulum_td3",
|
||||
|
|
|
@ -14,18 +14,25 @@ class OnlineLinearRegression(nn.Module):
|
|||
super(OnlineLinearRegression, self).__init__()
|
||||
self.d = feature_dim
|
||||
self.alpha = alpha
|
||||
# Diagonal matrix of size d (feature_dim).
|
||||
# If lambda=1.0, this will be an identity matrix.
|
||||
self.precision = nn.Parameter(
|
||||
data=lambda_ * torch.eye(self.d), requires_grad=False
|
||||
)
|
||||
# Inverse of the above diagnoal. If lambda=1.0, this is also an
|
||||
# identity matrix.
|
||||
self.covariance = nn.Parameter(
|
||||
data=torch.inverse(self.precision), requires_grad=False
|
||||
)
|
||||
# All-0s vector of size d (feature_dim).
|
||||
self.f = nn.Parameter(
|
||||
data=torch.zeros(
|
||||
self.d,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.covariance = nn.Parameter(
|
||||
data=torch.inverse(self.precision), requires_grad=False
|
||||
)
|
||||
# Dot product between f and covariance matrix
|
||||
# (batch dim stays intact; reduce dim 1).
|
||||
self.theta = nn.Parameter(
|
||||
data=self.covariance.matmul(self.f), requires_grad=False
|
||||
)
|
||||
|
|
|
@ -23,7 +23,11 @@ from ray.rllib.execution.concurrency_ops import Concurrently
|
|||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.train_ops import (
|
||||
MultiGPUTrainOneStep,
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
@ -150,14 +154,6 @@ class SlateQTrainer(Trainer):
|
|||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
|
||||
|
||||
@override(Trainer)
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
if config["framework"] == "torch":
|
||||
|
@ -183,12 +179,23 @@ class SlateQTrainer(Trainer):
|
|||
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])
|
||||
)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step.
|
||||
replay_op = (
|
||||
Replay(local_buffer=kwargs["local_replay_buffer"])
|
||||
.for_each(TrainOneStep(workers))
|
||||
.for_each(train_step_op)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
|
|
|
@ -549,7 +549,7 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
"output_config": {},
|
||||
# What sample batch columns to LZ4 compress in the output data.
|
||||
"output_compress_columns": ["obs", "new_obs"],
|
||||
# Max output file size before rolling over to a new file.
|
||||
# Max output file size (in bytes) before rolling over to a new file.
|
||||
"output_max_file_size": 64 * 1024 * 1024,
|
||||
|
||||
# === Settings for Multi-Agent Environments ===
|
||||
|
|
|
@ -2,6 +2,7 @@ from collections import namedtuple, OrderedDict
|
|||
import gym
|
||||
import logging
|
||||
import re
|
||||
import tree # pip install dm_tree
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
@ -458,18 +459,21 @@ class DynamicTFPolicy(TFPolicy):
|
|||
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
|
||||
flat_loss_inputs = tree.flatten(self._loss_input_dict)
|
||||
flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
|
||||
|
||||
# Note that there might be RNN state inputs at the end of the list
|
||||
if len(self._loss_input_dict) != len(existing_inputs):
|
||||
if len(flat_loss_inputs) != len(existing_inputs):
|
||||
raise ValueError(
|
||||
"Tensor list mismatch",
|
||||
self._loss_input_dict,
|
||||
self._state_inputs,
|
||||
existing_inputs,
|
||||
)
|
||||
for i, (k, v) in enumerate(self._loss_input_dict_no_rnn.items()):
|
||||
for i, v in enumerate(flat_loss_inputs_no_rnn):
|
||||
if v.shape.as_list() != existing_inputs[i].shape.as_list():
|
||||
raise ValueError(
|
||||
"Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape
|
||||
"Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
|
||||
)
|
||||
# By convention, the loss inputs are followed by state inputs and then
|
||||
# the seq len tensor.
|
||||
|
@ -478,15 +482,19 @@ class DynamicTFPolicy(TFPolicy):
|
|||
rnn_inputs.append(
|
||||
(
|
||||
"state_in_{}".format(i),
|
||||
existing_inputs[len(self._loss_input_dict_no_rnn) + i],
|
||||
existing_inputs[len(flat_loss_inputs_no_rnn) + i],
|
||||
)
|
||||
)
|
||||
if rnn_inputs:
|
||||
rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
|
||||
existing_inputs_unflattened = tree.unflatten_as(
|
||||
self._loss_input_dict_no_rnn,
|
||||
existing_inputs[: len(flat_loss_inputs_no_rnn)],
|
||||
)
|
||||
input_dict = OrderedDict(
|
||||
[("is_exploring", self._is_exploring), ("timestep", self._timestep)]
|
||||
+ [
|
||||
(k, existing_inputs[i])
|
||||
(k, existing_inputs_unflattened[k])
|
||||
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
||||
]
|
||||
+ rnn_inputs
|
||||
|
@ -509,7 +517,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
instance._loss_input_dict = input_dict
|
||||
losses = instance._do_loss_init(SampleBatch(input_dict))
|
||||
loss_inputs = [
|
||||
(k, existing_inputs[i])
|
||||
(k, existing_inputs_unflattened[k])
|
||||
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
|
||||
]
|
||||
|
||||
|
@ -546,7 +554,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
return len(batch)
|
||||
|
||||
input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
|
||||
data_keys = list(self._loss_input_dict_no_rnn.values())
|
||||
data_keys = tree.flatten(self._loss_input_dict_no_rnn)
|
||||
if self._state_inputs:
|
||||
state_keys = self._state_inputs + [self._seq_lens]
|
||||
else:
|
||||
|
@ -927,7 +935,7 @@ class TFMultiGPUTowerStack:
|
|||
"sgd_minibatch_size", policy.config.get("train_batch_size", 999999)
|
||||
)
|
||||
) // len(self.devices)
|
||||
input_placeholders = list(self.policy._loss_input_dict_no_rnn.values())
|
||||
input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn)
|
||||
rnn_inputs = []
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens]
|
||||
|
@ -954,19 +962,22 @@ class TFMultiGPUTowerStack:
|
|||
self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
|
||||
self._loaded_max_seq_len = 1
|
||||
|
||||
# Split on the CPU in case the data doesn't fit in GPU memory.
|
||||
with tf.device("/cpu:0"):
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(self.devices)) for ph in self.loss_inputs]
|
||||
)
|
||||
device_placeholders = [[] for _ in range(len(self.devices))]
|
||||
|
||||
for t in tree.flatten(self.loss_inputs):
|
||||
# Split on the CPU in case the data doesn't fit in GPU memory.
|
||||
with tf.device("/cpu:0"):
|
||||
splits = tf.split(t, len(self.devices))
|
||||
for i, d in enumerate(self.devices):
|
||||
device_placeholders[i].append(splits[i])
|
||||
|
||||
self._towers = []
|
||||
for tower_i, (device, device_placeholders) in enumerate(
|
||||
zip(self.devices, data_splits)
|
||||
for tower_i, (device, placeholders) in enumerate(
|
||||
zip(self.devices, device_placeholders)
|
||||
):
|
||||
self._towers.append(
|
||||
self._setup_device(
|
||||
tower_i, device, device_placeholders, len(input_placeholders)
|
||||
tower_i, device, placeholders, len(tree.flatten(input_placeholders))
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -2,13 +2,11 @@ interest-evolution-recsim-env-bandit-linucb:
|
|||
env: ray.rllib.examples.env.recommender_system_envs_with_recsim.InterestEvolutionRecSimEnv
|
||||
run: BanditLinUCB
|
||||
stop:
|
||||
episode_reward_mean: 170.0
|
||||
timesteps_total: 1000
|
||||
episode_reward_mean: 180.0
|
||||
timesteps_total: 50000
|
||||
config:
|
||||
framework: torch
|
||||
|
||||
metrics_num_episodes_for_smoothing: 100
|
||||
|
||||
# RLlib/RecSim wrapper specific settings:
|
||||
env_config:
|
||||
# Env class specified above takes one `config` arg in its c'tor:
|
||||
|
@ -17,7 +15,7 @@ interest-evolution-recsim-env-bandit-linucb:
|
|||
# document sampler model (a logic that creates n documents to select
|
||||
# the slate from).
|
||||
resample_documents: true
|
||||
num_candidates: 10
|
||||
num_candidates: 100
|
||||
# How many documents to recommend (out of `num_candidates`) each
|
||||
# timestep?
|
||||
slate_size: 2
|
||||
|
@ -27,3 +25,5 @@ interest-evolution-recsim-env-bandit-linucb:
|
|||
convert_to_discrete_action_space: true
|
||||
wrap_for_bandits: true
|
||||
seed: 0
|
||||
|
||||
metrics_num_episodes_for_smoothing: 500
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
interest-evolution-recsim-env-slateq:
|
||||
env: ray.rllib.examples.env.recommender_system_envs_with_recsim.InterestEvolutionRecSimEnv
|
||||
run: SlateQ
|
||||
stop:
|
||||
episode_reward_mean: 160.0
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
framework: tf
|
||||
|
||||
# RLlib/RecSim wrapper specific settings:
|
||||
env_config:
|
||||
# Env class specified above takes one `config` arg in its c'tor:
|
||||
config:
|
||||
# Each step, sample `num_candidates` documents using the env-internal
|
||||
# document sampler model (a logic that creates n documents to select
|
||||
# the slate from).
|
||||
resample_documents: true
|
||||
num_candidates: 10
|
||||
# How many documents to recommend (out of `num_candidates`) each
|
||||
# timestep?
|
||||
slate_size: 2
|
||||
# Should the action space be purely Discrete? Useful for algos that
|
||||
# don't support MultiDiscrete (e.g. DQN or Bandits).
|
||||
# SlateQ handles MultiDiscrete action spaces.
|
||||
convert_to_discrete_action_space: false
|
||||
seed: 0
|
||||
|
||||
# Fake 2 GPUs.
|
||||
num_gpus: 2
|
||||
_fake_gpus: true
|
||||
|
||||
exploration_config:
|
||||
warmup_timesteps: 10000
|
||||
epsilon_timesteps: 25000
|
||||
|
||||
replay_buffer_config:
|
||||
capacity: 100000
|
||||
|
||||
# Double learning rate and batch size.
|
||||
lr: 0.002
|
||||
train_batch_size: 64
|
||||
|
||||
learning_starts: 10000
|
||||
target_network_update_freq: 3200
|
||||
|
||||
metrics_num_episodes_for_smoothing: 200
|
|
@ -8,10 +8,10 @@ long-term-satisfaction-recsim-env-slateq:
|
|||
evaluation/episode_reward_mean: 1000.0
|
||||
timesteps_total: 200000
|
||||
config:
|
||||
# SlateQ only supported for torch so far.
|
||||
framework: torch
|
||||
# Works for both tf and torch.
|
||||
framework: tf
|
||||
|
||||
metrics_num_episodes_for_smoothing: 1000
|
||||
metrics_num_episodes_for_smoothing: 200
|
||||
|
||||
# RLlib/RecSim wrapper specific settings:
|
||||
env_config:
|
||||
|
@ -31,23 +31,7 @@ long-term-satisfaction-recsim-env-slateq:
|
|||
seed: 42
|
||||
|
||||
exploration_config:
|
||||
type: SlateSoftQ
|
||||
temperature: 0.7
|
||||
warmup_timesteps: 10000
|
||||
epsilon_timesteps: 60000
|
||||
|
||||
hiddens: [256, 256]
|
||||
|
||||
num_workers: 0
|
||||
num_gpus: 0
|
||||
|
||||
lr_choice_model: 0.002
|
||||
lr_q_model: 0.001
|
||||
|
||||
target_network_update_freq: 800
|
||||
tau: 1.0
|
||||
|
||||
# Evaluation settings.
|
||||
evaluation_interval: 1
|
||||
evaluation_num_workers: 4
|
||||
evaluation_duration: 200
|
||||
evaluation_duration_unit: episodes
|
||||
evaluation_parallel_to_training: true
|
||||
target_network_update_freq: 3200
|
||||
|
|
|
@ -5,7 +5,6 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
|
||||
from ray.rllib.utils.exploration.exploration import TensorType
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.torch_utils import FLOAT_MIN
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -25,23 +24,22 @@ class SlateEpsilonGreedy(EpsilonGreedy):
|
|||
|
||||
exploit_action = action_distribution.deterministic_sample()
|
||||
|
||||
batch_size = tf.shape(per_slate_q_values)[0]
|
||||
batch_size, num_slates = (
|
||||
tf.shape(per_slate_q_values)[0],
|
||||
tf.shape(per_slate_q_values)[1],
|
||||
)
|
||||
action_logp = tf.zeros(batch_size, dtype=tf.float32)
|
||||
|
||||
# Get the current epsilon.
|
||||
epsilon = self.epsilon_schedule(
|
||||
timestep if timestep is not None else self.last_timestep
|
||||
)
|
||||
# Mask out actions, whose Q-values are -inf, so that we don't
|
||||
# even consider them for exploration.
|
||||
random_valid_action_logits = tf.where(
|
||||
tf.equal(per_slate_q_values, tf.float32.min),
|
||||
tf.ones_like(per_slate_q_values) * tf.float32.min,
|
||||
tf.ones_like(per_slate_q_values),
|
||||
)
|
||||
# A random action.
|
||||
random_indices = tf.squeeze(
|
||||
tf.random.categorical(random_valid_action_logits, 1), axis=1
|
||||
random_indices = tf.random.uniform(
|
||||
(batch_size,),
|
||||
minval=0,
|
||||
maxval=num_slates,
|
||||
dtype=tf.dtypes.int32,
|
||||
)
|
||||
random_actions = tf.gather(all_slates, random_indices)
|
||||
|
||||
|
@ -92,16 +90,9 @@ class SlateEpsilonGreedy(EpsilonGreedy):
|
|||
if explore:
|
||||
# Get the current epsilon.
|
||||
epsilon = self.epsilon_schedule(self.last_timestep)
|
||||
# Mask out actions, whose Q-values are -inf, so that we don't
|
||||
# even consider them for exploration.
|
||||
random_valid_action_logits = torch.where(
|
||||
per_slate_q_values <= FLOAT_MIN,
|
||||
torch.ones_like(per_slate_q_values) * 0.0,
|
||||
torch.ones_like(per_slate_q_values),
|
||||
)
|
||||
# A random action.
|
||||
random_indices = torch.squeeze(
|
||||
torch.multinomial(random_valid_action_logits, 1), axis=1
|
||||
random_indices = torch.randint(
|
||||
0, per_slate_q_values.shape[1], (per_slate_q_values.shape[0],)
|
||||
)
|
||||
random_actions = all_slates[random_indices]
|
||||
|
||||
|
@ -111,7 +102,6 @@ class SlateEpsilonGreedy(EpsilonGreedy):
|
|||
random_actions,
|
||||
exploit_action,
|
||||
)
|
||||
|
||||
return action, action_logp
|
||||
# Return the deterministic "sample" (argmax) over the logits.
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue