ray/rllib/policy/rnn_sequencing.py

436 lines
17 KiB
Python
Raw Normal View History

"""RNN utils for RLlib.
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
The main trick here is that we add the time dimension at the last moment.
The non-LSTM layers of the model see their inputs as one flat batch. Before
the LSTM cell, we reshape the input to add the expected time dimension. During
postprocessing, we dynamically pad the experience batches so that this
reshaping is possible.
Note that this padding strategy only works out if we assume zero inputs don't
meaningfully affect the loss function. This happens to be true for all the
current algorithms: https://github.com/ray-project/ray/issues/2992
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
"""
import logging
import numpy as np
from typing import List, Optional
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
2020-12-21 02:22:32 +01:00
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
from ray.util import log_once
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
@DeveloperAPI
def pad_batch_to_sequences_of_same_size(
batch: SampleBatch,
max_seq_len: int,
shuffle: bool = False,
batch_divisibility_req: int = 1,
feature_keys: Optional[List[str]] = None,
2020-12-21 02:22:32 +01:00
view_requirements: Optional[ViewRequirementsDict] = None,
):
"""Applies padding to `batch` so it's choppable into same-size sequences.
Shuffles `batch` (if desired), makes sure divisibility requirement is met,
then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
adding a time dimension (yet).
Padding depends on episodes found in batch and `max_seq_len`.
Args:
batch (SampleBatch): The SampleBatch object. All values in here have
the shape [B, ...].
max_seq_len (int): The max. sequence length to use for chopping.
shuffle (bool): Whether to shuffle batch sequences. Shuffle may
be done in-place. This only makes sense if you're further
applying minibatch SGD after getting the outputs.
batch_divisibility_req (int): The int by which the batch dimension
must be dividable.
feature_keys (Optional[List[str]]): An optional list of keys to apply
sequence-chopping to. If None, use all keys in batch that are not
"state_in/out_"-type keys.
2020-12-21 02:22:32 +01:00
view_requirements (Optional[ViewRequirementsDict]): An optional
Policy ViewRequirements dict to be able to infer whether
e.g. dynamic max'ing should be applied over the seq_lens.
"""
if batch.zero_padded:
return
batch.zero_padded = True
if batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
# not multiagent
and max(batch[SampleBatch.AGENT_INDEX]) == 0)
else:
meets_divisibility_reqs = True
2020-12-21 02:22:32 +01:00
states_already_reduced_to_init = False
# RNN/attention net case. Figure out whether we should apply dynamic
# max'ing over the list of sequence lengths.
if "state_in_0" in batch or "state_out_0" in batch:
2020-12-21 02:22:32 +01:00
# Check, whether the state inputs have already been reduced to their
# init values at the beginning of each max_seq_len chunk.
if batch.get("seq_lens") is not None and \
len(batch["state_in_0"]) == len(batch["seq_lens"]):
2020-12-21 02:22:32 +01:00
states_already_reduced_to_init = True
# RNN (or single timestep state-in): Set the max dynamically.
if view_requirements["state_in_0"].shift_from is None:
dynamic_max = True
# Attention Nets (state inputs are over some range): No dynamic maxing
# possible.
else:
dynamic_max = False
# Multi-agent case.
elif not meets_divisibility_reqs:
max_seq_len = batch_divisibility_req
dynamic_max = False
2020-12-21 02:22:32 +01:00
# Simple case: No RNN/attention net, nor do we need to pad.
else:
if shuffle:
batch.shuffle()
return
2020-12-21 02:22:32 +01:00
# RNN, attention net, or multi-agent case.
state_keys = []
feature_keys_ = feature_keys or []
2020-12-21 02:22:32 +01:00
for k, v in batch.items():
if k.startswith("state_in_"):
state_keys.append(k)
2020-12-21 02:22:32 +01:00
elif not feature_keys and not k.startswith("state_out_") and \
k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
feature_keys_.append(k)
feature_sequences, initial_states, seq_lens = \
chop_into_sequences(
2020-12-21 02:22:32 +01:00
feature_columns=[batch[k] for k in feature_keys_],
state_columns=[batch[k] for k in state_keys],
episode_ids=batch.get(SampleBatch.EPS_ID),
unroll_ids=batch.get(SampleBatch.UNROLL_ID),
agent_indices=batch.get(SampleBatch.AGENT_INDEX),
seq_lens=batch.get("seq_lens"),
2020-12-21 02:22:32 +01:00
max_seq_len=max_seq_len,
dynamic_max=dynamic_max,
2020-12-21 02:22:32 +01:00
states_already_reduced_to_init=states_already_reduced_to_init,
shuffle=shuffle)
2020-12-21 02:22:32 +01:00
for i, k in enumerate(feature_keys_):
batch[k] = feature_sequences[i]
for i, k in enumerate(state_keys):
batch[k] = initial_states[i]
2020-12-21 02:22:32 +01:00
batch["seq_lens"] = np.array(seq_lens)
if log_once("rnn_ma_feed_dict"):
2020-12-21 02:22:32 +01:00
logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
@DeveloperAPI
def add_time_dimension(padded_inputs: TensorType,
*,
max_seq_len: int,
framework: str = "tf",
time_major: bool = False):
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
"""Adds a time dimension to padded inputs.
Args:
padded_inputs (TensorType): a padded batch of sequences. That is,
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
A, B, C are sequence elements and * denotes padding.
max_seq_len (int): The max. sequence length in padded_inputs.
framework (str): The framework string ("tf2", "tf", "tfe", "torch").
time_major (bool): Whether data should be returned in time-major (TxB)
format or not (BxT).
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
Returns:
TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
"""
# Sequence lengths have to be specified for LSTM batch inputs. The
# input batch must be padded to the max seq length given here. That is,
# batch_size == len(seq_lens) * max(seq_lens)
if framework in ["tf2", "tf", "tfe"]:
assert time_major is False, "time-major not supported yet for tf!"
padded_batch_size = tf.shape(padded_inputs)[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = ([new_batch_size, max_seq_len] +
padded_inputs.get_shape().as_list()[1:])
return tf.reshape(padded_inputs, new_shape)
else:
assert framework == "torch", "`framework` must be either tf or torch!"
padded_batch_size = padded_inputs.shape[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
if time_major:
new_shape = (max_seq_len, new_batch_size) + padded_inputs.shape[1:]
else:
new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
return torch.reshape(padded_inputs, new_shape)
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
@DeveloperAPI
2020-12-21 02:22:32 +01:00
def chop_into_sequences(*,
feature_columns,
state_columns,
max_seq_len,
2020-12-21 02:22:32 +01:00
episode_ids=None,
unroll_ids=None,
agent_indices=None,
dynamic_max=True,
shuffle=False,
2020-12-21 02:22:32 +01:00
seq_lens=None,
states_already_reduced_to_init=False,
_extra_padding=0):
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
"""Truncate and pad experiences into fixed-length sequences.
Args:
2021-02-25 12:18:11 +01:00
feature_columns (list): List of arrays containing features.
state_columns (list): List of arrays containing LSTM state values.
max_seq_len (int): Max length of sequences before truncation.
episode_ids (List[EpisodeID]): List of episode ids for each step.
unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
This is used to make sure sequences are cut between sample batches.
agent_indices (List[AgentID]): List of agent ids for each step. Note
that this has to be combined with episode_ids for uniqueness.
dynamic_max (bool): Whether to dynamically shrink the max seq len.
For example, if max len is 20 and the actual max seq len in the
data is 7, it will be shrunk to 7.
shuffle (bool): Whether to shuffle the sequence outputs.
_extra_padding (int): Add extra padding to the end of sequences.
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
Returns:
f_pad (list): Padded feature columns. These will be of shape
[NUM_SEQUENCES * MAX_SEQ_LEN, ...].
s_init (list): Initial states for each sequence, of shape
[NUM_SEQUENCES, ...].
seq_lens (list): List of sequence lengths, of shape [NUM_SEQUENCES].
Examples:
>>> f_pad, s_init, seq_lens = chop_into_sequences(
episode_ids=[1, 1, 5, 5, 5, 5],
unroll_ids=[4, 4, 4, 4, 4, 4],
agent_indices=[0, 0, 0, 0, 0, 0],
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
feature_columns=[[4, 4, 8, 8, 8, 8],
[1, 1, 0, 1, 1, 0]],
state_columns=[[4, 5, 4, 5, 5, 5]],
max_seq_len=3)
>>> print(f_pad)
[[4, 4, 0, 8, 8, 8, 8, 0, 0],
[1, 1, 0, 0, 1, 1, 0, 0, 0]]
>>> print(s_init)
[[4, 4, 5]]
>>> print(seq_lens)
[2, 3, 1]
"""
2020-12-21 02:22:32 +01:00
if seq_lens is None or len(seq_lens) == 0:
prev_id = None
seq_lens = []
seq_len = 0
unique_ids = np.add(
np.add(episode_ids, agent_indices),
np.array(unroll_ids, dtype=np.int64) << 32)
for uid in unique_ids:
if (prev_id is not None and uid != prev_id) or \
seq_len >= max_seq_len:
seq_lens.append(seq_len)
seq_len = 0
seq_len += 1
prev_id = uid
if seq_len:
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
seq_lens.append(seq_len)
2020-12-21 02:22:32 +01:00
seq_lens = np.array(seq_lens, dtype=np.int32)
2020-12-21 02:22:32 +01:00
assert sum(seq_lens) == len(feature_columns[0])
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
max_seq_len = max(seq_lens) + _extra_padding
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
feature_sequences = []
for f in feature_columns:
# Save unnecessary copy.
if not isinstance(f, np.ndarray):
f = np.array(f)
length = len(seq_lens) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_:
f_pad = [None] * length
else:
# Make sure type doesn't change.
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
seq_base = 0
i = 0
for len_ in seq_lens:
for seq_offset in range(len_):
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
f_pad[seq_base + seq_offset] = f[i]
i += 1
seq_base += max_seq_len
2020-12-21 02:22:32 +01:00
assert i == len(f), f
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
feature_sequences.append(f_pad)
2020-12-21 02:22:32 +01:00
if states_already_reduced_to_init:
initial_states = state_columns
else:
initial_states = []
for s in state_columns:
# Skip unnecessary copy.
if not isinstance(s, np.ndarray):
s = np.array(s)
s_init = []
i = 0
for len_ in seq_lens:
s_init.append(s[i])
i += len_
initial_states.append(np.array(s_init))
[rllib] General RNN support (#2299) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * wip * tf * update * fix * cleanup * cleanup * spacing * model * fix * dqn * fix ddpg * doc * keep names * update * fix * com * docs * clarify model outputs * Update torch_policy_graph.py * fix obs filter * pass thru worker index * fix * rename * vlad torch comments * fix log action * debug name * fix lstm * remove unused ddpg net * remove conv net * revert lstm * wip * wip * cast * wip * works * fix a3c * works * lstm util test * doc * clean up * update * fix lstm check * move to end * fix sphinx * fix cmd * remove bad doc * clarify * copy * async sa * fix * comments * fix a3c conf * tune lstm * fix reshape * fix * back to 16 * tuned a3c update * update * tuned * optional * fix catalog * remove prep
2018-06-27 22:51:04 -07:00
if shuffle:
permutation = np.random.permutation(len(seq_lens))
for i, f in enumerate(feature_sequences):
orig_shape = f.shape
f = np.reshape(f, (len(seq_lens), -1) + f.shape[1:])
f = f[permutation]
f = np.reshape(f, orig_shape)
feature_sequences[i] = f
for i, s in enumerate(initial_states):
s = s[permutation]
initial_states[i] = s
seq_lens = seq_lens[permutation]
return feature_sequences, initial_states, seq_lens
2021-02-25 12:18:11 +01:00
def timeslice_along_seq_lens_with_overlap(
sample_batch,
seq_lens=None,
zero_pad_max_seq_len=0,
pre_overlap=0,
zero_init_states=True) -> List["SampleBatch"]:
"""Slices batch along `seq_lens` (each seq-len item produces one batch).
Asserts that seq_lens is given or sample_batch["seq_lens"] is not None.
2021-02-25 12:18:11 +01:00
Args:
sample_batch (SampleBatch): The SampleBatch to timeslice.
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
at. If None, use `sample_batch["seq_lens"]`.
2021-02-25 12:18:11 +01:00
zero_pad_max_seq_len (int): If >0, already zero-pad the resulting
slices up to this length. NOTE: This max-len will include the
additional timesteps gained via setting pre_overlap (see Example).
2021-02-25 12:18:11 +01:00
pre_overlap (int): If >0, will overlap each two consecutive slices by
this many timesteps (toward the left side). This will cause
zero-padding at the very beginning of the batch.
zero_init_states (bool): Whether initial states should always be
zero'd. If False, will use the state_outs of the batch to
populate state_in values.
Returns:
List[SampleBatch]: The list of (new) SampleBatches.
Examples:
assert seq_lens == [5, 5, 2]
assert sample_batch.count == 12
# self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
slices = timeslices_along_seq_lens(
zero_pad_max_seq_len=10,
pre_overlap=3)
# Z = zero padding (at beginning or end).
# |pre (3)| seq | max-seq-len (up to 10)
# slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z
# slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z
# slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
# Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
# count (makes sure each slice has exactly length 10).
"""
if seq_lens is None:
seq_lens = sample_batch.get("seq_lens")
assert seq_lens is not None and len(seq_lens) > 0, \
2021-02-25 12:18:11 +01:00
"Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
# Generate n slices based on seq_lens.
2021-02-25 12:18:11 +01:00
start = 0
slices = []
for seq_len in seq_lens:
pre_begin = start - pre_overlap
slice_begin = start
end = start + seq_len
slices.append((pre_begin, slice_begin, end))
2021-02-25 12:18:11 +01:00
start += seq_len
timeslices = []
for begin, slice_begin, end in slices:
2021-02-25 12:18:11 +01:00
zero_length = None
data_begin = 0
zero_init_states_ = zero_init_states
if begin < 0:
zero_length = pre_overlap
data_begin = slice_begin
2021-02-25 12:18:11 +01:00
zero_init_states_ = True
else:
eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else
0:end]
is_last_episode_ids = eps_ids == eps_ids[-1]
if is_last_episode_ids[0] is not True:
zero_length = int(sum(1.0 - is_last_episode_ids))
data_begin = begin + zero_length
zero_init_states_ = True
if zero_length is not None:
data = {
k: np.concatenate([
np.zeros(
shape=(zero_length, ) + v.shape[1:], dtype=v.dtype),
v[data_begin:end]
])
for k, v in sample_batch.items() if k != "seq_lens"
2021-02-25 12:18:11 +01:00
}
else:
data = {
k: v[begin:end]
for k, v in sample_batch.items() if k != "seq_lens"
}
2021-02-25 12:18:11 +01:00
if zero_init_states_:
i = 0
key = "state_in_{}".format(i)
while key in data:
data[key] = np.zeros_like(sample_batch[key][0:1])
2021-02-25 12:18:11 +01:00
del data["state_out_{}".format(i)]
i += 1
key = "state_in_{}".format(i)
# TODO: This will not work with attention nets as their state_outs are
# not compatible with state_ins.
else:
i = 0
key = "state_in_{}".format(i)
while key in data:
data[key] = sample_batch["state_out_{}".format(i)][begin -
1:begin]
2021-02-25 12:18:11 +01:00
del data["state_out_{}".format(i)]
i += 1
key = "state_in_{}".format(i)
timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
2021-02-25 12:18:11 +01:00
# Zero-pad each slice if necessary.
if zero_pad_max_seq_len > 0:
for ts in timeslices:
ts.zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)
return timeslices