[rllib] Shuffle RNN sequences in PPO as well (#5129)

* shuffle seq

* fix test
This commit is contained in:
Eric Liang 2019-07-06 20:40:49 -07:00 committed by GitHub
parent c04b69902c
commit c15ed3ac55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 12 deletions

View file

@ -27,6 +27,8 @@ DEFAULT_CONFIG = with_common_config({
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD
"sgd_minibatch_size": 128,
# Whether to shuffle sequences in the batch when training (recommended)
"shuffle_sequences": True,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
# Stepsize of SGD
@ -79,7 +81,8 @@ def choose_policy_optimizer(workers, config):
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
standardize_fields=["advantages"],
straggler_mitigation=config["straggler_mitigation"])
straggler_mitigation=config["straggler_mitigation"],
shuffle_sequences=config["shuffle_sequences"])
def update_kl(trainer, fetches):

View file

@ -128,6 +128,7 @@ def chop_into_sequences(episode_ids,
state_columns,
max_seq_len,
dynamic_max=True,
shuffle=False,
_extra_padding=0):
"""Truncate and pad experiences into fixed-length sequences.
@ -143,6 +144,7 @@ def chop_into_sequences(episode_ids,
dynamic_max (bool): Whether to dynamically shrink the max seq len.
For example, if max len is 20 and the actual max seq len in the
data is 7, it will be shrunk to 7.
shuffle (bool): Whether to shuffle the sequence outputs.
_extra_padding (int): Add extra padding to the end of sequences.
Returns:
@ -186,6 +188,7 @@ def chop_into_sequences(episode_ids,
if seq_len:
seq_lens.append(seq_len)
assert sum(seq_lens) == len(unique_ids)
seq_lens = np.array(seq_lens)
# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
@ -215,4 +218,17 @@ def chop_into_sequences(episode_ids,
i += l
initial_states.append(np.array(s_init))
return feature_sequences, initial_states, np.array(seq_lens)
if shuffle:
permutation = np.random.permutation(len(seq_lens))
for i, f in enumerate(feature_sequences):
orig_shape = f.shape
f = np.reshape(f, (len(seq_lens), -1) + f.shape[2:])
f = f[permutation]
f = np.reshape(f, orig_shape)
feature_sequences[i] = f
for i, s in enumerate(initial_states):
s = s[permutation]
initial_states[i] = s
seq_lens = seq_lens[permutation]
return feature_sequences, initial_states, seq_lens

View file

@ -50,7 +50,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
train_batch_size=1024,
num_gpus=0,
standardize_fields=[],
straggler_mitigation=False):
straggler_mitigation=False,
shuffle_sequences=True):
PolicyOptimizer.__init__(self, workers)
self.batch_size = sgd_batch_size
@ -59,6 +60,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.sample_batch_size = sample_batch_size
self.train_batch_size = train_batch_size
self.straggler_mitigation = straggler_mitigation
self.shuffle_sequences = shuffle_sequences
if not num_gpus:
self.devices = ["/cpu:0"]
else:
@ -157,10 +159,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
standardized = (value - value.mean()) / max(1e-4, value.std())
batch[field] = standardized
# Important: don't shuffle RNN sequence elements
if not policy._state_inputs:
batch.shuffle()
num_loaded_tuples = {}
with self.load_timer:
for policy_id, batch in samples.policy_batches.items():
@ -168,7 +166,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
continue
policy = self.policies[policy_id]
tuples = policy._get_loss_inputs_dict(batch)
tuples = policy._get_loss_inputs_dict(
batch, shuffle=self.shuffle_sequences)
data_keys = [ph for _, ph in policy._loss_inputs]
if policy._state_inputs:
state_keys = policy._state_inputs + [policy._seq_lens]

View file

@ -438,7 +438,8 @@ class TFPolicy(Policy):
def _build_compute_gradients(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict(
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
fetches = builder.add_fetches(
[self._grads, self._get_grad_and_stats_fetches()])
return fetches[0], fetches[1]
@ -455,7 +456,8 @@ class TFPolicy(Policy):
def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict(
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
@ -473,7 +475,19 @@ class TFPolicy(Policy):
**fetches[LEARNER_STATS_KEY])
return fetches
def _get_loss_inputs_dict(self, batch):
def _get_loss_inputs_dict(self, batch, shuffle):
"""Return a feed dict from a batch.
Arguments:
batch (SampleBatch): batch of data to derive inputs from
shuffle (bool): whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
Returns:
feed dict of data
"""
feed_dict = {}
if self._batch_divisibility_req > 1:
meets_divisibility_reqs = (
@ -485,6 +499,8 @@ class TFPolicy(Policy):
# Simple case: not RNN nor do we need to pad
if not self._state_inputs and meets_divisibility_reqs:
if shuffle:
batch.shuffle()
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
return feed_dict
@ -507,7 +523,8 @@ class TFPolicy(Policy):
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys],
max_seq_len,
dynamic_max=dynamic_max)
dynamic_max=dynamic_max,
shuffle=shuffle)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
for k, v in zip(state_keys, initial_states):

View file

@ -229,6 +229,7 @@ class RNNSequencing(unittest.TestCase):
ppo = PPOTrainer(
env="counter",
config={
"shuffle_sequences": False, # for deterministic testing
"num_workers": 0,
"sample_batch_size": 20,
"train_batch_size": 20,