[RLlib] SlateQ (tf GPU + multi-GPU) + Bandit fixes (#23276)

This commit is contained in:
Sven Mika 2022-03-18 13:45:16 +01:00 committed by GitHub
parent da140a80e9
commit b1cda46681
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 133 additions and 78 deletions

View file

@ -545,6 +545,16 @@ py_test(
args = ["--yaml-dir=tuned_examples/slateq"]
)
py_test(
name = "learning_tests_interest_evolution_10_candidates_recsim_env_slateq_fake_gpus",
main = "tests/run_regression_tests.py",
tags = ["team:ml", "learning_tests", "learning_tests_discrete", "fake_gpus"],
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/slateq/interest-evolution-10-candidates-recsim-env-slateq.yaml"],
args = ["--yaml-dir=tuned_examples/slateq"]
)
# TD3
py_test(
name = "learning_tests_pendulum_td3",

View file

@ -14,18 +14,25 @@ class OnlineLinearRegression(nn.Module):
super(OnlineLinearRegression, self).__init__()
self.d = feature_dim
self.alpha = alpha
# Diagonal matrix of size d (feature_dim).
# If lambda=1.0, this will be an identity matrix.
self.precision = nn.Parameter(
data=lambda_ * torch.eye(self.d), requires_grad=False
)
# Inverse of the above diagnoal. If lambda=1.0, this is also an
# identity matrix.
self.covariance = nn.Parameter(
data=torch.inverse(self.precision), requires_grad=False
)
# All-0s vector of size d (feature_dim).
self.f = nn.Parameter(
data=torch.zeros(
self.d,
),
requires_grad=False,
)
self.covariance = nn.Parameter(
data=torch.inverse(self.precision), requires_grad=False
)
# Dot product between f and covariance matrix
# (batch dim stays intact; reduce dim 1).
self.theta = nn.Parameter(
data=self.covariance.matmul(self.f), requires_grad=False
)

View file

@ -23,7 +23,11 @@ from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.train_ops import (
MultiGPUTrainOneStep,
TrainOneStep,
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
@ -150,14 +154,6 @@ class SlateQTrainer(Trainer):
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
if config["framework"] == "torch":
@ -183,12 +179,23 @@ class SlateQTrainer(Trainer):
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])
)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step.
replay_op = (
Replay(local_buffer=kwargs["local_replay_buffer"])
.for_each(TrainOneStep(workers))
.for_each(train_step_op)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)

View file

@ -549,7 +549,7 @@ COMMON_CONFIG: TrainerConfigDict = {
"output_config": {},
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
# Max output file size (in bytes) before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# === Settings for Multi-Agent Environments ===

View file

@ -2,6 +2,7 @@ from collections import namedtuple, OrderedDict
import gym
import logging
import re
import tree # pip install dm_tree
from typing import Callable, Dict, List, Optional, Tuple, Type
from ray.util.debug import log_once
@ -458,18 +459,21 @@ class DynamicTFPolicy(TFPolicy):
def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> TFPolicy:
"""Creates a copy of self using existing input placeholders."""
flat_loss_inputs = tree.flatten(self._loss_input_dict)
flat_loss_inputs_no_rnn = tree.flatten(self._loss_input_dict_no_rnn)
# Note that there might be RNN state inputs at the end of the list
if len(self._loss_input_dict) != len(existing_inputs):
if len(flat_loss_inputs) != len(existing_inputs):
raise ValueError(
"Tensor list mismatch",
self._loss_input_dict,
self._state_inputs,
existing_inputs,
)
for i, (k, v) in enumerate(self._loss_input_dict_no_rnn.items()):
for i, v in enumerate(flat_loss_inputs_no_rnn):
if v.shape.as_list() != existing_inputs[i].shape.as_list():
raise ValueError(
"Tensor shape mismatch", i, k, v.shape, existing_inputs[i].shape
"Tensor shape mismatch", i, v.shape, existing_inputs[i].shape
)
# By convention, the loss inputs are followed by state inputs and then
# the seq len tensor.
@ -478,15 +482,19 @@ class DynamicTFPolicy(TFPolicy):
rnn_inputs.append(
(
"state_in_{}".format(i),
existing_inputs[len(self._loss_input_dict_no_rnn) + i],
existing_inputs[len(flat_loss_inputs_no_rnn) + i],
)
)
if rnn_inputs:
rnn_inputs.append((SampleBatch.SEQ_LENS, existing_inputs[-1]))
existing_inputs_unflattened = tree.unflatten_as(
self._loss_input_dict_no_rnn,
existing_inputs[: len(flat_loss_inputs_no_rnn)],
)
input_dict = OrderedDict(
[("is_exploring", self._is_exploring), ("timestep", self._timestep)]
+ [
(k, existing_inputs[i])
(k, existing_inputs_unflattened[k])
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
]
+ rnn_inputs
@ -509,7 +517,7 @@ class DynamicTFPolicy(TFPolicy):
instance._loss_input_dict = input_dict
losses = instance._do_loss_init(SampleBatch(input_dict))
loss_inputs = [
(k, existing_inputs[i])
(k, existing_inputs_unflattened[k])
for i, k in enumerate(self._loss_input_dict_no_rnn.keys())
]
@ -546,7 +554,7 @@ class DynamicTFPolicy(TFPolicy):
return len(batch)
input_dict = self._get_loss_inputs_dict(batch, shuffle=False)
data_keys = list(self._loss_input_dict_no_rnn.values())
data_keys = tree.flatten(self._loss_input_dict_no_rnn)
if self._state_inputs:
state_keys = self._state_inputs + [self._seq_lens]
else:
@ -927,7 +935,7 @@ class TFMultiGPUTowerStack:
"sgd_minibatch_size", policy.config.get("train_batch_size", 999999)
)
) // len(self.devices)
input_placeholders = list(self.policy._loss_input_dict_no_rnn.values())
input_placeholders = tree.flatten(self.policy._loss_input_dict_no_rnn)
rnn_inputs = []
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [self.policy._seq_lens]
@ -954,19 +962,22 @@ class TFMultiGPUTowerStack:
self._max_seq_len = tf1.placeholder(tf.int32, name="max_seq_len")
self._loaded_max_seq_len = 1
# Split on the CPU in case the data doesn't fit in GPU memory.
with tf.device("/cpu:0"):
data_splits = zip(
*[tf.split(ph, len(self.devices)) for ph in self.loss_inputs]
)
device_placeholders = [[] for _ in range(len(self.devices))]
for t in tree.flatten(self.loss_inputs):
# Split on the CPU in case the data doesn't fit in GPU memory.
with tf.device("/cpu:0"):
splits = tf.split(t, len(self.devices))
for i, d in enumerate(self.devices):
device_placeholders[i].append(splits[i])
self._towers = []
for tower_i, (device, device_placeholders) in enumerate(
zip(self.devices, data_splits)
for tower_i, (device, placeholders) in enumerate(
zip(self.devices, device_placeholders)
):
self._towers.append(
self._setup_device(
tower_i, device, device_placeholders, len(input_placeholders)
tower_i, device, placeholders, len(tree.flatten(input_placeholders))
)
)

View file

@ -2,13 +2,11 @@ interest-evolution-recsim-env-bandit-linucb:
env: ray.rllib.examples.env.recommender_system_envs_with_recsim.InterestEvolutionRecSimEnv
run: BanditLinUCB
stop:
episode_reward_mean: 170.0
timesteps_total: 1000
episode_reward_mean: 180.0
timesteps_total: 50000
config:
framework: torch
metrics_num_episodes_for_smoothing: 100
# RLlib/RecSim wrapper specific settings:
env_config:
# Env class specified above takes one `config` arg in its c'tor:
@ -17,7 +15,7 @@ interest-evolution-recsim-env-bandit-linucb:
# document sampler model (a logic that creates n documents to select
# the slate from).
resample_documents: true
num_candidates: 10
num_candidates: 100
# How many documents to recommend (out of `num_candidates`) each
# timestep?
slate_size: 2
@ -27,3 +25,5 @@ interest-evolution-recsim-env-bandit-linucb:
convert_to_discrete_action_space: true
wrap_for_bandits: true
seed: 0
metrics_num_episodes_for_smoothing: 500

View file

@ -0,0 +1,46 @@
interest-evolution-recsim-env-slateq:
env: ray.rllib.examples.env.recommender_system_envs_with_recsim.InterestEvolutionRecSimEnv
run: SlateQ
stop:
episode_reward_mean: 160.0
timesteps_total: 100000
config:
framework: tf
# RLlib/RecSim wrapper specific settings:
env_config:
# Env class specified above takes one `config` arg in its c'tor:
config:
# Each step, sample `num_candidates` documents using the env-internal
# document sampler model (a logic that creates n documents to select
# the slate from).
resample_documents: true
num_candidates: 10
# How many documents to recommend (out of `num_candidates`) each
# timestep?
slate_size: 2
# Should the action space be purely Discrete? Useful for algos that
# don't support MultiDiscrete (e.g. DQN or Bandits).
# SlateQ handles MultiDiscrete action spaces.
convert_to_discrete_action_space: false
seed: 0
# Fake 2 GPUs.
num_gpus: 2
_fake_gpus: true
exploration_config:
warmup_timesteps: 10000
epsilon_timesteps: 25000
replay_buffer_config:
capacity: 100000
# Double learning rate and batch size.
lr: 0.002
train_batch_size: 64
learning_starts: 10000
target_network_update_freq: 3200
metrics_num_episodes_for_smoothing: 200

View file

@ -8,10 +8,10 @@ long-term-satisfaction-recsim-env-slateq:
evaluation/episode_reward_mean: 1000.0
timesteps_total: 200000
config:
# SlateQ only supported for torch so far.
framework: torch
# Works for both tf and torch.
framework: tf
metrics_num_episodes_for_smoothing: 1000
metrics_num_episodes_for_smoothing: 200
# RLlib/RecSim wrapper specific settings:
env_config:
@ -31,23 +31,7 @@ long-term-satisfaction-recsim-env-slateq:
seed: 42
exploration_config:
type: SlateSoftQ
temperature: 0.7
warmup_timesteps: 10000
epsilon_timesteps: 60000
hiddens: [256, 256]
num_workers: 0
num_gpus: 0
lr_choice_model: 0.002
lr_q_model: 0.001
target_network_update_freq: 800
tau: 1.0
# Evaluation settings.
evaluation_interval: 1
evaluation_num_workers: 4
evaluation_duration: 200
evaluation_duration_unit: episodes
evaluation_parallel_to_training: true
target_network_update_freq: 3200

View file

@ -5,7 +5,6 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.epsilon_greedy import EpsilonGreedy
from ray.rllib.utils.exploration.exploration import TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -25,23 +24,22 @@ class SlateEpsilonGreedy(EpsilonGreedy):
exploit_action = action_distribution.deterministic_sample()
batch_size = tf.shape(per_slate_q_values)[0]
batch_size, num_slates = (
tf.shape(per_slate_q_values)[0],
tf.shape(per_slate_q_values)[1],
)
action_logp = tf.zeros(batch_size, dtype=tf.float32)
# Get the current epsilon.
epsilon = self.epsilon_schedule(
timestep if timestep is not None else self.last_timestep
)
# Mask out actions, whose Q-values are -inf, so that we don't
# even consider them for exploration.
random_valid_action_logits = tf.where(
tf.equal(per_slate_q_values, tf.float32.min),
tf.ones_like(per_slate_q_values) * tf.float32.min,
tf.ones_like(per_slate_q_values),
)
# A random action.
random_indices = tf.squeeze(
tf.random.categorical(random_valid_action_logits, 1), axis=1
random_indices = tf.random.uniform(
(batch_size,),
minval=0,
maxval=num_slates,
dtype=tf.dtypes.int32,
)
random_actions = tf.gather(all_slates, random_indices)
@ -92,16 +90,9 @@ class SlateEpsilonGreedy(EpsilonGreedy):
if explore:
# Get the current epsilon.
epsilon = self.epsilon_schedule(self.last_timestep)
# Mask out actions, whose Q-values are -inf, so that we don't
# even consider them for exploration.
random_valid_action_logits = torch.where(
per_slate_q_values <= FLOAT_MIN,
torch.ones_like(per_slate_q_values) * 0.0,
torch.ones_like(per_slate_q_values),
)
# A random action.
random_indices = torch.squeeze(
torch.multinomial(random_valid_action_logits, 1), axis=1
random_indices = torch.randint(
0, per_slate_q_values.shape[1], (per_slate_q_values.shape[0],)
)
random_actions = all_slates[random_indices]
@ -111,7 +102,6 @@ class SlateEpsilonGreedy(EpsilonGreedy):
random_actions,
exploit_action,
)
return action, action_logp
# Return the deterministic "sample" (argmax) over the logits.
else: