mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Automate sequences in timeslice_along_seq_lens_with_overlap()
. (#24561)
This commit is contained in:
parent
bc8742792c
commit
bd2fdf4752
2 changed files with 31 additions and 8 deletions
|
@ -22,6 +22,7 @@ from ray.rllib.utils.debug import summarize
|
|||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType, ViewRequirementsDict
|
||||
from ray.util import log_once
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -360,16 +361,14 @@ def chop_into_sequences(
|
|||
|
||||
|
||||
def timeslice_along_seq_lens_with_overlap(
|
||||
sample_batch,
|
||||
seq_lens=None,
|
||||
zero_pad_max_seq_len=0,
|
||||
pre_overlap=0,
|
||||
zero_init_states=True,
|
||||
sample_batch: SampleBatchType,
|
||||
seq_lens: Optional[List[int]] = None,
|
||||
zero_pad_max_seq_len: int = 0,
|
||||
pre_overlap: int = 0,
|
||||
zero_init_states: bool = True,
|
||||
) -> List["SampleBatch"]:
|
||||
"""Slices batch along `seq_lens` (each seq-len item produces one batch).
|
||||
|
||||
Asserts that seq_lens is given or sample_batch["seq_lens"] is not None.
|
||||
|
||||
Args:
|
||||
sample_batch (SampleBatch): The SampleBatch to timeslice.
|
||||
seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
|
||||
|
@ -391,7 +390,8 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
assert seq_lens == [5, 5, 2]
|
||||
assert sample_batch.count == 12
|
||||
# self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
|
||||
slices = timeslices_along_seq_lens(
|
||||
slices = timeslice_along_seq_lens_with_overlap(
|
||||
sample_batch=sample_batch.
|
||||
zero_pad_max_seq_len=10,
|
||||
pre_overlap=3)
|
||||
# Z = zero padding (at beginning or end).
|
||||
|
@ -404,6 +404,21 @@ def timeslice_along_seq_lens_with_overlap(
|
|||
"""
|
||||
if seq_lens is None:
|
||||
seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
|
||||
if seq_lens is None:
|
||||
max_seq_len = zero_pad_max_seq_len - pre_overlap
|
||||
if log_once("no_sequence_lengths_available_for_time_slicing"):
|
||||
logger.warning(
|
||||
"Trying to slice a batch along sequences without "
|
||||
"sequence lengths being provided in the batch. Batch will "
|
||||
"be sliced into slices of size "
|
||||
"{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
|
||||
max_seq_len, zero_pad_max_seq_len, pre_overlap
|
||||
)
|
||||
)
|
||||
num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
|
||||
seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
|
||||
[last_seq_len] if last_seq_len else []
|
||||
)
|
||||
assert (
|
||||
seq_lens is not None and len(seq_lens) > 0
|
||||
), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
|
||||
|
|
|
@ -192,6 +192,14 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
batch : The batch to be added.
|
||||
**kwargs: Forward compatibility kwargs.
|
||||
"""
|
||||
if batch is None:
|
||||
if log_once("empty_batch_added_to_buffer"):
|
||||
logger.info(
|
||||
"A batch that is `None` was added to {}. This can be "
|
||||
"normal at the beginning of execution but might "
|
||||
"indicate an issue.".format(type(self).__name__)
|
||||
)
|
||||
return
|
||||
# Make a copy so the replay buffer doesn't pin plasma memory.
|
||||
batch = batch.copy()
|
||||
# Handle everything as if multi-agent.
|
||||
|
|
Loading…
Add table
Reference in a new issue