[RLlib] Automate sequences in timeslice_along_seq_lens_with_overlap(). (#24561)

This commit is contained in:
Artur Niederfahrenhorst 2022-05-09 11:55:06 +02:00 committed by GitHub
parent bc8742792c
commit bd2fdf4752
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 8 deletions

View file

@ -22,6 +22,7 @@ from ray.rllib.utils.debug import summarize
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
from ray.util import log_once
from ray.rllib.utils.typing import SampleBatchType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -360,16 +361,14 @@ def chop_into_sequences(
def timeslice_along_seq_lens_with_overlap(
sample_batch,
seq_lens=None,
zero_pad_max_seq_len=0,
pre_overlap=0,
zero_init_states=True,
sample_batch: SampleBatchType,
seq_lens: Optional[List[int]] = None,
zero_pad_max_seq_len: int = 0,
pre_overlap: int = 0,
zero_init_states: bool = True,
) -> List["SampleBatch"]:
"""Slices batch along `seq_lens` (each seq-len item produces one batch).
Asserts that seq_lens is given or sample_batch["seq_lens"] is not None.
Args:
sample_batch (SampleBatch): The SampleBatch to timeslice.
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
@ -391,7 +390,8 @@ def timeslice_along_seq_lens_with_overlap(
assert seq_lens == [5, 5, 2]
assert sample_batch.count == 12
# self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
slices = timeslices_along_seq_lens(
slices = timeslice_along_seq_lens_with_overlap(
sample_batch=sample_batch.
zero_pad_max_seq_len=10,
pre_overlap=3)
# Z = zero padding (at beginning or end).
@ -404,6 +404,21 @@ def timeslice_along_seq_lens_with_overlap(
"""
if seq_lens is None:
seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
if seq_lens is None:
max_seq_len = zero_pad_max_seq_len - pre_overlap
if log_once("no_sequence_lengths_available_for_time_slicing"):
logger.warning(
"Trying to slice a batch along sequences without "
"sequence lengths being provided in the batch. Batch will "
"be sliced into slices of size "
"{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
max_seq_len, zero_pad_max_seq_len, pre_overlap
)
)
num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
[last_seq_len] if last_seq_len else []
)
assert (
seq_lens is not None and len(seq_lens) > 0
), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"

View file

@ -192,6 +192,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
batch : The batch to be added.
**kwargs: Forward compatibility kwargs.
"""
if batch is None:
if log_once("empty_batch_added_to_buffer"):
logger.info(
"A batch that is `None` was added to {}. This can be "
"normal at the beginning of execution but might "
"indicate an issue.".format(type(self).__name__)
)
return
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
# Handle everything as if multi-agent.