mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Shuffle RNN sequences in PPO as well (#5129)
* shuffle seq * fix test
This commit is contained in:
parent
c04b69902c
commit
c15ed3ac55
5 changed files with 48 additions and 12 deletions
|
@ -27,6 +27,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"train_batch_size": 4000,
|
||||
# Total SGD batch size across all devices for SGD
|
||||
"sgd_minibatch_size": 128,
|
||||
# Whether to shuffle sequences in the batch when training (recommended)
|
||||
"shuffle_sequences": True,
|
||||
# Number of SGD iterations in each outer loop
|
||||
"num_sgd_iter": 30,
|
||||
# Stepsize of SGD
|
||||
|
@ -79,7 +81,8 @@ def choose_policy_optimizer(workers, config):
|
|||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
straggler_mitigation=config["straggler_mitigation"])
|
||||
straggler_mitigation=config["straggler_mitigation"],
|
||||
shuffle_sequences=config["shuffle_sequences"])
|
||||
|
||||
|
||||
def update_kl(trainer, fetches):
|
||||
|
|
|
@ -128,6 +128,7 @@ def chop_into_sequences(episode_ids,
|
|||
state_columns,
|
||||
max_seq_len,
|
||||
dynamic_max=True,
|
||||
shuffle=False,
|
||||
_extra_padding=0):
|
||||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
|
@ -143,6 +144,7 @@ def chop_into_sequences(episode_ids,
|
|||
dynamic_max (bool): Whether to dynamically shrink the max seq len.
|
||||
For example, if max len is 20 and the actual max seq len in the
|
||||
data is 7, it will be shrunk to 7.
|
||||
shuffle (bool): Whether to shuffle the sequence outputs.
|
||||
_extra_padding (int): Add extra padding to the end of sequences.
|
||||
|
||||
Returns:
|
||||
|
@ -186,6 +188,7 @@ def chop_into_sequences(episode_ids,
|
|||
if seq_len:
|
||||
seq_lens.append(seq_len)
|
||||
assert sum(seq_lens) == len(unique_ids)
|
||||
seq_lens = np.array(seq_lens)
|
||||
|
||||
# Dynamically shrink max len as needed to optimize memory usage
|
||||
if dynamic_max:
|
||||
|
@ -215,4 +218,17 @@ def chop_into_sequences(episode_ids,
|
|||
i += l
|
||||
initial_states.append(np.array(s_init))
|
||||
|
||||
return feature_sequences, initial_states, np.array(seq_lens)
|
||||
if shuffle:
|
||||
permutation = np.random.permutation(len(seq_lens))
|
||||
for i, f in enumerate(feature_sequences):
|
||||
orig_shape = f.shape
|
||||
f = np.reshape(f, (len(seq_lens), -1) + f.shape[2:])
|
||||
f = f[permutation]
|
||||
f = np.reshape(f, orig_shape)
|
||||
feature_sequences[i] = f
|
||||
for i, s in enumerate(initial_states):
|
||||
s = s[permutation]
|
||||
initial_states[i] = s
|
||||
seq_lens = seq_lens[permutation]
|
||||
|
||||
return feature_sequences, initial_states, seq_lens
|
||||
|
|
|
@ -50,7 +50,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
train_batch_size=1024,
|
||||
num_gpus=0,
|
||||
standardize_fields=[],
|
||||
straggler_mitigation=False):
|
||||
straggler_mitigation=False,
|
||||
shuffle_sequences=True):
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.batch_size = sgd_batch_size
|
||||
|
@ -59,6 +60,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
self.sample_batch_size = sample_batch_size
|
||||
self.train_batch_size = train_batch_size
|
||||
self.straggler_mitigation = straggler_mitigation
|
||||
self.shuffle_sequences = shuffle_sequences
|
||||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
else:
|
||||
|
@ -157,10 +159,6 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
standardized = (value - value.mean()) / max(1e-4, value.std())
|
||||
batch[field] = standardized
|
||||
|
||||
# Important: don't shuffle RNN sequence elements
|
||||
if not policy._state_inputs:
|
||||
batch.shuffle()
|
||||
|
||||
num_loaded_tuples = {}
|
||||
with self.load_timer:
|
||||
for policy_id, batch in samples.policy_batches.items():
|
||||
|
@ -168,7 +166,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
|||
continue
|
||||
|
||||
policy = self.policies[policy_id]
|
||||
tuples = policy._get_loss_inputs_dict(batch)
|
||||
tuples = policy._get_loss_inputs_dict(
|
||||
batch, shuffle=self.shuffle_sequences)
|
||||
data_keys = [ph for _, ph in policy._loss_inputs]
|
||||
if policy._state_inputs:
|
||||
state_keys = policy._state_inputs + [policy._seq_lens]
|
||||
|
|
|
@ -438,7 +438,8 @@ class TFPolicy(Policy):
|
|||
def _build_compute_gradients(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict(
|
||||
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
|
||||
fetches = builder.add_fetches(
|
||||
[self._grads, self._get_grad_and_stats_fetches()])
|
||||
return fetches[0], fetches[1]
|
||||
|
@ -455,7 +456,8 @@ class TFPolicy(Policy):
|
|||
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict(
|
||||
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches([
|
||||
self._apply_op,
|
||||
|
@ -473,7 +475,19 @@ class TFPolicy(Policy):
|
|||
**fetches[LEARNER_STATS_KEY])
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
def _get_loss_inputs_dict(self, batch, shuffle):
|
||||
"""Return a feed dict from a batch.
|
||||
|
||||
Arguments:
|
||||
batch (SampleBatch): batch of data to derive inputs from
|
||||
shuffle (bool): whether to shuffle batch sequences. Shuffle may
|
||||
be done in-place. This only makes sense if you're further
|
||||
applying minibatch SGD after getting the outputs.
|
||||
|
||||
Returns:
|
||||
feed dict of data
|
||||
"""
|
||||
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
|
@ -485,6 +499,8 @@ class TFPolicy(Policy):
|
|||
|
||||
# Simple case: not RNN nor do we need to pad
|
||||
if not self._state_inputs and meets_divisibility_reqs:
|
||||
if shuffle:
|
||||
batch.shuffle()
|
||||
for k, ph in self._loss_inputs:
|
||||
feed_dict[ph] = batch[k]
|
||||
return feed_dict
|
||||
|
@ -507,7 +523,8 @@ class TFPolicy(Policy):
|
|||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max)
|
||||
dynamic_max=dynamic_max,
|
||||
shuffle=shuffle)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
|
|
|
@ -229,6 +229,7 @@ class RNNSequencing(unittest.TestCase):
|
|||
ppo = PPOTrainer(
|
||||
env="counter",
|
||||
config={
|
||||
"shuffle_sequences": False, # for deterministic testing
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 20,
|
||||
"train_batch_size": 20,
|
||||
|
|
Loading…
Add table
Reference in a new issue