mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
1470 lines
54 KiB
Python
1470 lines
54 KiB
Python
import collections
|
|
import numpy as np
|
|
import sys
|
|
import itertools
|
|
import tree # pip install dm_tree
|
|
from typing import Dict, Iterator, List, Optional, Set, Union
|
|
|
|
from ray.util import log_once
|
|
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
|
|
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
|
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
from ray.rllib.utils.numpy import concat_aligned
|
|
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
|
from ray.rllib.utils.typing import (
|
|
PolicyID,
|
|
TensorType,
|
|
SampleBatchType,
|
|
ViewRequirementsDict,
|
|
)
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, _ = try_import_torch()
|
|
|
|
# Default policy id for single agent environments
|
|
DEFAULT_POLICY_ID = "default_policy"
|
|
|
|
|
|
@PublicAPI
|
|
class SampleBatch(dict):
|
|
"""Wrapper around a dictionary with string keys and array-like values.
|
|
|
|
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
|
|
samples, each with an "obs" and "reward" attribute.
|
|
"""
|
|
|
|
# Outputs from interacting with the environment
|
|
OBS = "obs"
|
|
CUR_OBS = "obs"
|
|
NEXT_OBS = "new_obs"
|
|
ACTIONS = "actions"
|
|
REWARDS = "rewards"
|
|
PREV_ACTIONS = "prev_actions"
|
|
PREV_REWARDS = "prev_rewards"
|
|
DONES = "dones"
|
|
INFOS = "infos"
|
|
SEQ_LENS = "seq_lens"
|
|
# This is only computed and used when RE3 exploration strategy is enabled
|
|
OBS_EMBEDS = "obs_embeds"
|
|
T = "t"
|
|
|
|
# decision transformer
|
|
RETURNS_TO_GO = "returns_to_go"
|
|
ATTENTION_MASKS = "attention_masks"
|
|
|
|
# Extra action fetches keys.
|
|
ACTION_DIST_INPUTS = "action_dist_inputs"
|
|
ACTION_PROB = "action_prob"
|
|
ACTION_LOGP = "action_logp"
|
|
|
|
# Uniquely identifies an episode.
|
|
EPS_ID = "eps_id"
|
|
# An env ID (e.g. the index for a vectorized sub-env).
|
|
ENV_ID = "env_id"
|
|
|
|
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
|
# sequences from the same episode when multiple sample batches are
|
|
# concatenated (fusing sequences across batches can be unsafe).
|
|
UNROLL_ID = "unroll_id"
|
|
|
|
# Uniquely identifies an agent within an episode.
|
|
AGENT_INDEX = "agent_index"
|
|
|
|
# Value function predictions emitted by the behaviour policy.
|
|
VF_PREDS = "vf_preds"
|
|
|
|
@PublicAPI
|
|
def __init__(self, *args, **kwargs):
|
|
"""Constructs a sample batch (same params as dict constructor).
|
|
|
|
Note: All *args and those **kwargs not listed below will be passed
|
|
as-is to the parent dict constructor.
|
|
|
|
Keyword Args:
|
|
_time_major (Optional[bool]): Whether data in this sample batch
|
|
is time-major. This is False by default and only relevant
|
|
if the data contains sequences.
|
|
_max_seq_len (Optional[int]): The max sequence chunk length
|
|
if the data contains sequences.
|
|
_zero_padded (Optional[bool]): Whether the data in this batch
|
|
contains sequences AND these sequences are right-zero-padded
|
|
according to the `_max_seq_len` setting.
|
|
_is_training (Optional[bool]): Whether this batch is used for
|
|
training. If False, batch may be used for e.g. action
|
|
computations (inference).
|
|
"""
|
|
|
|
# Possible seq_lens (TxB or BxT) setup.
|
|
self.time_major = kwargs.pop("_time_major", None)
|
|
# Maximum seq len value.
|
|
self.max_seq_len = kwargs.pop("_max_seq_len", None)
|
|
# Is alredy right-zero-padded?
|
|
self.zero_padded = kwargs.pop("_zero_padded", False)
|
|
# Whether this batch is used for training (vs inference).
|
|
self._is_training = kwargs.pop("_is_training", None)
|
|
|
|
# Call super constructor. This will make the actual data accessible
|
|
# by column name (str) via e.g. self["some-col"].
|
|
dict.__init__(self, *args, **kwargs)
|
|
|
|
self.accessed_keys = set()
|
|
self.added_keys = set()
|
|
self.deleted_keys = set()
|
|
self.intercepted_values = {}
|
|
self.get_interceptor = None
|
|
|
|
# Clear out None seq-lens.
|
|
seq_lens_ = self.get(SampleBatch.SEQ_LENS)
|
|
if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
|
|
self.pop(SampleBatch.SEQ_LENS, None)
|
|
# Numpyfy seq_lens if list.
|
|
elif isinstance(seq_lens_, list):
|
|
self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
|
|
|
|
if (
|
|
self.max_seq_len is None
|
|
and seq_lens_ is not None
|
|
and not (tf and tf.is_tensor(seq_lens_))
|
|
and len(seq_lens_) > 0
|
|
):
|
|
self.max_seq_len = max(seq_lens_)
|
|
|
|
if self._is_training is None:
|
|
self._is_training = self.pop("is_training", False)
|
|
|
|
lengths = []
|
|
copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
|
|
for k, v in copy_.items():
|
|
assert isinstance(k, str), self
|
|
|
|
# TODO: Drop support for lists as values.
|
|
# Convert lists of int|float into numpy arrays make sure all data
|
|
# has same length.
|
|
if isinstance(v, list):
|
|
self[k] = np.array(v)
|
|
|
|
# Try to infer the "length" of the SampleBatch by finding the first
|
|
# value that is actually a ndarray/tensor. This would fail if
|
|
# all values are nested dicts/tuples of more complex underlying
|
|
# structures.
|
|
try:
|
|
len_ = len(v) if not isinstance(v, (dict, tuple)) else None
|
|
if len_:
|
|
lengths.append(len_)
|
|
except Exception:
|
|
pass
|
|
|
|
if (
|
|
self.get(SampleBatch.SEQ_LENS) is not None
|
|
and not (tf and tf.is_tensor(self[SampleBatch.SEQ_LENS]))
|
|
and len(self[SampleBatch.SEQ_LENS]) > 0
|
|
):
|
|
self.count = sum(self[SampleBatch.SEQ_LENS])
|
|
else:
|
|
self.count = lengths[0] if lengths else 0
|
|
|
|
# A convenience map for slicing this batch into sub-batches along
|
|
# the time axis. This helps reduce repeated iterations through the
|
|
# batch's seq_lens array to find good slicing points. Built lazily
|
|
# when needed.
|
|
self._slice_map = []
|
|
|
|
@PublicAPI
|
|
def __len__(self) -> int:
|
|
"""Returns the amount of samples in the sample batch."""
|
|
return self.count
|
|
|
|
@PublicAPI
|
|
def agent_steps(self) -> int:
|
|
"""Returns the same as len(self) (number of steps in this batch).
|
|
|
|
To make this compatible with `MultiAgentBatch.agent_steps()`.
|
|
"""
|
|
return len(self)
|
|
|
|
@PublicAPI
|
|
def env_steps(self) -> int:
|
|
"""Returns the same as len(self) (number of steps in this batch).
|
|
|
|
To make this compatible with `MultiAgentBatch.env_steps()`.
|
|
"""
|
|
return len(self)
|
|
|
|
@staticmethod
|
|
@PublicAPI
|
|
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
|
|
def concat_samples(
|
|
samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
|
|
) -> Union["SampleBatch", "MultiAgentBatch"]:
|
|
return concat_samples(samples)
|
|
|
|
@PublicAPI
|
|
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
|
"""Concatenates `other` to this one and returns a new SampleBatch.
|
|
|
|
Args:
|
|
other: The other SampleBatch object to concat to this one.
|
|
|
|
Returns:
|
|
The new SampleBatch, resulting from concating `other` to `self`.
|
|
|
|
Examples:
|
|
>>> import numpy as np
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> b1 = SampleBatch({"a": np.array([1, 2])}) # doctest: +SKIP
|
|
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])}) # doctest: +SKIP
|
|
>>> print(b1.concat(b2)) # doctest: +SKIP
|
|
{"a": np.array([1, 2, 3, 4, 5])}
|
|
"""
|
|
return self.concat_samples([self, other])
|
|
|
|
@PublicAPI
|
|
def copy(self, shallow: bool = False) -> "SampleBatch":
|
|
"""Creates a deep or shallow copy of this SampleBatch and returns it.
|
|
|
|
Args:
|
|
shallow: Whether the copying should be done shallowly.
|
|
|
|
Returns:
|
|
A deep or shallow copy of this SampleBatch object.
|
|
"""
|
|
copy_ = {k: v for k, v in self.items()}
|
|
data = tree.map_structure(
|
|
lambda v: (
|
|
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
|
|
),
|
|
copy_,
|
|
)
|
|
copy_ = SampleBatch(data)
|
|
copy_.set_get_interceptor(self.get_interceptor)
|
|
copy_.added_keys = self.added_keys
|
|
copy_.deleted_keys = self.deleted_keys
|
|
copy_.accessed_keys = self.accessed_keys
|
|
return copy_
|
|
|
|
@PublicAPI
|
|
def rows(self) -> Iterator[Dict[str, TensorType]]:
|
|
"""Returns an iterator over data rows, i.e. dicts with column values.
|
|
|
|
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
|
|
|
|
Yields:
|
|
The column values of the row in this iteration.
|
|
|
|
Examples:
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> batch = SampleBatch({ # doctest: +SKIP
|
|
... "a": [1, 2, 3],
|
|
... "b": [4, 5, 6],
|
|
... "seq_lens": [1, 2]
|
|
... })
|
|
>>> for row in batch.rows(): # doctest: +SKIP
|
|
... print(row) # doctest: +SKIP
|
|
{"a": 1, "b": 4, "seq_lens": 1}
|
|
{"a": 2, "b": 5, "seq_lens": 1}
|
|
{"a": 3, "b": 6, "seq_lens": 1}
|
|
"""
|
|
|
|
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
|
|
|
|
self_as_dict = {k: v for k, v in self.items()}
|
|
|
|
for i in range(self.count):
|
|
yield tree.map_structure_with_path(
|
|
lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
|
|
self_as_dict,
|
|
)
|
|
|
|
@PublicAPI
|
|
def columns(self, keys: List[str]) -> List[any]:
|
|
"""Returns a list of the batch-data in the specified columns.
|
|
|
|
Args:
|
|
keys: List of column names fo which to return the data.
|
|
|
|
Returns:
|
|
The list of data items ordered by the order of column
|
|
names in `keys`.
|
|
|
|
Examples:
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) # doctest: +SKIP
|
|
>>> print(batch.columns(["a", "b"])) # doctest: +SKIP
|
|
[[1], [2]]
|
|
"""
|
|
|
|
# TODO: (sven) Make this work for nested data as well.
|
|
out = []
|
|
for k in keys:
|
|
out.append(self[k])
|
|
return out
|
|
|
|
@PublicAPI
|
|
def shuffle(self) -> "SampleBatch":
|
|
"""Shuffles the rows of this batch in-place.
|
|
|
|
Returns:
|
|
This very (now shuffled) SampleBatch.
|
|
|
|
Raises:
|
|
ValueError: If self[SampleBatch.SEQ_LENS] is defined.
|
|
|
|
Examples:
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> batch = SampleBatch({"a": [1, 2, 3, 4]}) # doctest: +SKIP
|
|
>>> print(batch.shuffle()) # doctest: +SKIP
|
|
{"a": [4, 1, 3, 2]}
|
|
"""
|
|
|
|
# Shuffling the data when we have `seq_lens` defined is probably
|
|
# a bad idea!
|
|
if self.get(SampleBatch.SEQ_LENS) is not None:
|
|
raise ValueError(
|
|
"SampleBatch.shuffle not possible when your data has "
|
|
"`seq_lens` defined!"
|
|
)
|
|
|
|
# Get a permutation over the single items once and use the same
|
|
# permutation for all the data (otherwise, data would become
|
|
# meaningless).
|
|
permutation = np.random.permutation(self.count)
|
|
|
|
self_as_dict = {k: v for k, v in self.items()}
|
|
shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
|
|
self.update(shuffled)
|
|
# Flush cache such that intercepted values are recalculated after the
|
|
# shuffling.
|
|
self.intercepted_values = {}
|
|
return self
|
|
|
|
@PublicAPI
|
|
def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
|
|
"""Splits by `eps_id` column and returns list of new batches.
|
|
If `eps_id` is not present, splits by `dones` instead.
|
|
|
|
Args:
|
|
key: If specified, overwrite default and use key to split.
|
|
|
|
Returns:
|
|
List of batches, one per distinct episode.
|
|
|
|
Raises:
|
|
KeyError: If the `eps_id` AND `dones` columns are not present.
|
|
|
|
Examples:
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> # "eps_id" is present
|
|
>>> batch = SampleBatch( # doctest: +SKIP
|
|
... {"a": [1, 2, 3], "eps_id": [0, 0, 1]})
|
|
>>> print(batch.split_by_episode()) # doctest: +SKIP
|
|
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
|
|
>>>
|
|
>>> # "eps_id" not present, split by "dones" instead
|
|
>>> batch = SampleBatch( # doctest: +SKIP
|
|
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
|
|
>>> print(batch.split_by_episode()) # doctest: +SKIP
|
|
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
|
|
>>>
|
|
>>> # The last episode is appended even if it does not end with done
|
|
>>> batch = SampleBatch( # doctest: +SKIP
|
|
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
|
|
>>> print(batch.split_by_episode()) # doctest: +SKIP
|
|
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
|
|
>>> batch = SampleBatch( # doctest: +SKIP
|
|
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
|
|
>>> print(batch.split_by_episode()) # doctest: +SKIP
|
|
[{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
|
|
"""
|
|
|
|
def slice_by_eps_id():
|
|
slices = []
|
|
# Produce a new slice whenever we find a new episode ID.
|
|
cur_eps_id = self[SampleBatch.EPS_ID][0]
|
|
offset = 0
|
|
for i in range(self.count):
|
|
next_eps_id = self[SampleBatch.EPS_ID][i]
|
|
if next_eps_id != cur_eps_id:
|
|
slices.append(self[offset:i])
|
|
offset = i
|
|
cur_eps_id = next_eps_id
|
|
# Add final slice.
|
|
slices.append(self[offset : self.count])
|
|
return slices
|
|
|
|
def slice_by_dones():
|
|
slices = []
|
|
offset = 0
|
|
for i in range(self.count):
|
|
if self[SampleBatch.DONES][i]:
|
|
# Since self[i] is the last timestep of the episode,
|
|
# append it to the batch, then set offset to the start
|
|
# of the next batch
|
|
slices.append(self[offset : i + 1])
|
|
offset = i + 1
|
|
# Add final slice.
|
|
if offset != self.count:
|
|
slices.append(self[offset:])
|
|
return slices
|
|
|
|
key_to_method = {
|
|
SampleBatch.EPS_ID: slice_by_eps_id,
|
|
SampleBatch.DONES: slice_by_dones,
|
|
}
|
|
|
|
# If key not specified, default to this order.
|
|
key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
|
|
|
|
slices = None
|
|
if key is not None:
|
|
# If key specified, directly use it.
|
|
if key not in self:
|
|
raise KeyError(f"{self} does not have key `{key}`!")
|
|
slices = key_to_method[key]()
|
|
else:
|
|
# If key not specified, go in order.
|
|
for key in key_resolve_order:
|
|
if key in self:
|
|
slices = key_to_method[key]()
|
|
break
|
|
if slices is None:
|
|
raise KeyError(f"{self} does not have keys {key_resolve_order}!")
|
|
|
|
assert (
|
|
sum(s.count for s in slices) == self.count
|
|
), f"Calling split_by_episode on {self} returns {slices}"
|
|
f"which should both have {self.count} timesteps!"
|
|
return slices
|
|
|
|
def slice(
|
|
self, start: int, end: int, state_start=None, state_end=None
|
|
) -> "SampleBatch":
|
|
"""Returns a slice of the row data of this batch (w/o copying).
|
|
|
|
Args:
|
|
start: Starting index. If < 0, will left-zero-pad.
|
|
end: Ending index.
|
|
|
|
Returns:
|
|
A new SampleBatch, which has a slice of this batch's data.
|
|
"""
|
|
if (
|
|
self.get(SampleBatch.SEQ_LENS) is not None
|
|
and len(self[SampleBatch.SEQ_LENS]) > 0
|
|
):
|
|
if start < 0:
|
|
data = {
|
|
k: np.concatenate(
|
|
[
|
|
np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
|
|
v[0:end],
|
|
]
|
|
)
|
|
for k, v in self.items()
|
|
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
|
|
}
|
|
else:
|
|
data = {
|
|
k: tree.map_structure(lambda s: s[start:end], v)
|
|
for k, v in self.items()
|
|
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
|
|
}
|
|
if state_start is not None:
|
|
assert state_end is not None
|
|
state_idx = 0
|
|
state_key = "state_in_{}".format(state_idx)
|
|
while state_key in self:
|
|
data[state_key] = self[state_key][state_start:state_end]
|
|
state_idx += 1
|
|
state_key = "state_in_{}".format(state_idx)
|
|
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
|
|
# Adjust seq_lens if necessary.
|
|
data_len = len(data[next(iter(data))])
|
|
if sum(seq_lens) != data_len:
|
|
assert sum(seq_lens) > data_len
|
|
seq_lens[-1] = data_len - sum(seq_lens[:-1])
|
|
else:
|
|
# Fix state_in_x data.
|
|
count = 0
|
|
state_start = None
|
|
seq_lens = None
|
|
for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
|
|
count += seq_len
|
|
if count >= end:
|
|
state_idx = 0
|
|
state_key = "state_in_{}".format(state_idx)
|
|
if state_start is None:
|
|
state_start = i
|
|
while state_key in self:
|
|
data[state_key] = self[state_key][state_start : i + 1]
|
|
state_idx += 1
|
|
state_key = "state_in_{}".format(state_idx)
|
|
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
|
|
seq_len - (count - end)
|
|
]
|
|
if start < 0:
|
|
seq_lens[0] += -start
|
|
diff = sum(seq_lens) - (end - start)
|
|
if diff > 0:
|
|
seq_lens[0] -= diff
|
|
assert sum(seq_lens) == (end - start)
|
|
break
|
|
elif state_start is None and count > start:
|
|
state_start = i
|
|
|
|
return SampleBatch(
|
|
data,
|
|
seq_lens=seq_lens,
|
|
_is_training=self.is_training,
|
|
_time_major=self.time_major,
|
|
)
|
|
else:
|
|
return SampleBatch(
|
|
tree.map_structure(lambda value: value[start:end], self),
|
|
_is_training=self.is_training,
|
|
_time_major=self.time_major,
|
|
)
|
|
|
|
@PublicAPI
|
|
def timeslices(
|
|
self,
|
|
size: Optional[int] = None,
|
|
num_slices: Optional[int] = None,
|
|
k: Optional[int] = None,
|
|
) -> List["SampleBatch"]:
|
|
"""Returns SampleBatches, each one representing a k-slice of this one.
|
|
|
|
Will start from timestep 0 and produce slices of size=k.
|
|
|
|
Args:
|
|
size: The size (in timesteps) of each returned SampleBatch.
|
|
num_slices: The number of slices to produce.
|
|
k: Deprecated: Use size or num_slices instead. The size
|
|
(in timesteps) of each returned SampleBatch.
|
|
|
|
Returns:
|
|
The list of `num_slices` (new) SampleBatches or n (new)
|
|
SampleBatches each one of size `size`.
|
|
"""
|
|
if size is None and num_slices is None:
|
|
deprecation_warning("k", "size or num_slices")
|
|
assert k is not None
|
|
size = k
|
|
|
|
if size is None:
|
|
assert isinstance(num_slices, int)
|
|
|
|
slices = []
|
|
left = len(self)
|
|
start = 0
|
|
while left:
|
|
len_ = left // (num_slices - len(slices))
|
|
stop = start + len_
|
|
slices.append(self[start:stop])
|
|
left -= len_
|
|
start = stop
|
|
|
|
return slices
|
|
|
|
else:
|
|
assert isinstance(size, int)
|
|
|
|
slices = []
|
|
left = len(self)
|
|
start = 0
|
|
while left:
|
|
stop = start + size
|
|
slices.append(self[start:stop])
|
|
left -= size
|
|
start = stop
|
|
|
|
return slices
|
|
|
|
@Deprecated(new="SampleBatch.right_zero_pad", error=False)
|
|
def zero_pad(self, max_seq_len, exclude_states=True):
|
|
return self.right_zero_pad(max_seq_len, exclude_states)
|
|
|
|
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
|
|
"""Right (adding zeros at end) zero-pads this SampleBatch in-place.
|
|
|
|
This will set the `self.zero_padded` flag to True and
|
|
`self.max_seq_len` to the given `max_seq_len` value.
|
|
|
|
Args:
|
|
max_seq_len: The max (total) length to zero pad to.
|
|
exclude_states: If False, also right-zero-pad all
|
|
`state_in_x` data. If True, leave `state_in_x` keys
|
|
as-is.
|
|
|
|
Returns:
|
|
This very (now right-zero-padded) SampleBatch.
|
|
|
|
Raises:
|
|
ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
|
|
|
|
Examples:
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> batch = SampleBatch( # doctest: +SKIP
|
|
... {"a": [1, 2, 3], "seq_lens": [1, 2]})
|
|
>>> print(batch.right_zero_pad(max_seq_len=4)) # doctest: +SKIP
|
|
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
|
|
|
|
>>> batch = SampleBatch({"a": [1, 2, 3], # doctest: +SKIP
|
|
... "state_in_0": [1.0, 3.0],
|
|
... "seq_lens": [1, 2]})
|
|
>>> print(batch.right_zero_pad(max_seq_len=5)) # doctest: +SKIP
|
|
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
|
|
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
|
|
"seq_lens": [1, 2]}
|
|
"""
|
|
seq_lens = self.get(SampleBatch.SEQ_LENS)
|
|
if seq_lens is None:
|
|
raise ValueError(
|
|
"Cannot right-zero-pad SampleBatch if no `seq_lens` field "
|
|
f"present! SampleBatch={self}"
|
|
)
|
|
|
|
length = len(seq_lens) * max_seq_len
|
|
|
|
def _zero_pad_in_place(path, value):
|
|
# Skip "state_in_..." columns and "seq_lens".
|
|
if (exclude_states is True and path[0].startswith("state_in_")) or path[
|
|
0
|
|
] == SampleBatch.SEQ_LENS:
|
|
return
|
|
# Generate zero-filled primer of len=max_seq_len.
|
|
if value.dtype == object or value.dtype.type is np.str_:
|
|
f_pad = [None] * length
|
|
else:
|
|
# Make sure type doesn't change.
|
|
f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
|
|
# Fill primer with data.
|
|
f_pad_base = f_base = 0
|
|
for len_ in self[SampleBatch.SEQ_LENS]:
|
|
f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
|
|
f_pad_base += max_seq_len
|
|
f_base += len_
|
|
assert f_base == len(value), value
|
|
|
|
# Update our data in-place.
|
|
curr = self
|
|
for i, p in enumerate(path):
|
|
if i == len(path) - 1:
|
|
curr[p] = f_pad
|
|
curr = curr[p]
|
|
|
|
self_as_dict = {k: v for k, v in self.items()}
|
|
tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
|
|
|
|
# Set flags to indicate, we are now zero-padded (and to what extend).
|
|
self.zero_padded = True
|
|
self.max_seq_len = max_seq_len
|
|
|
|
return self
|
|
|
|
@ExperimentalAPI
|
|
def to_device(self, device, framework="torch"):
|
|
"""TODO: transfer batch to given device as framework tensor."""
|
|
if framework == "torch":
|
|
assert torch is not None
|
|
for k, v in self.items():
|
|
self[k] = convert_to_torch_tensor(v, device)
|
|
else:
|
|
raise NotImplementedError
|
|
return self
|
|
|
|
@PublicAPI
|
|
def size_bytes(self) -> int:
|
|
"""Returns sum over number of bytes of all data buffers.
|
|
|
|
For numpy arrays, we use `.nbytes`. For all other value types, we use
|
|
sys.getsizeof(...).
|
|
|
|
Returns:
|
|
The overall size in bytes of the data buffer (all columns).
|
|
"""
|
|
return sum(
|
|
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
|
|
for v in tree.flatten(self)
|
|
)
|
|
|
|
def get(self, key, default=None):
|
|
try:
|
|
return self.__getitem__(key)
|
|
except KeyError:
|
|
return default
|
|
|
|
@PublicAPI
|
|
def as_multi_agent(self) -> "MultiAgentBatch":
|
|
"""Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.
|
|
|
|
Returns:
|
|
The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
|
|
to this SampleBatch.
|
|
"""
|
|
return MultiAgentBatch({DEFAULT_POLICY_ID: self}, self.count)
|
|
|
|
@PublicAPI
|
|
def __getitem__(self, key: Union[str, slice]) -> TensorType:
|
|
"""Returns one column (by key) from the data or a sliced new batch.
|
|
|
|
Args:
|
|
key: The key (column name) to return or
|
|
a slice object for slicing this SampleBatch.
|
|
|
|
Returns:
|
|
The data under the given key or a sliced version of this batch.
|
|
"""
|
|
if isinstance(key, slice):
|
|
return self._slice(key)
|
|
|
|
# Backward compatibility for when "input-dicts" were used.
|
|
if key == "is_training":
|
|
if log_once("SampleBatch['is_training']"):
|
|
deprecation_warning(
|
|
old="SampleBatch['is_training']",
|
|
new="SampleBatch.is_training",
|
|
error=False,
|
|
)
|
|
return self.is_training
|
|
|
|
if not hasattr(self, key) and key in self:
|
|
self.accessed_keys.add(key)
|
|
|
|
value = dict.__getitem__(self, key)
|
|
if self.get_interceptor is not None:
|
|
if key not in self.intercepted_values:
|
|
self.intercepted_values[key] = self.get_interceptor(value)
|
|
value = self.intercepted_values[key]
|
|
return value
|
|
|
|
@PublicAPI
|
|
def __setitem__(self, key, item) -> None:
|
|
"""Inserts (overrides) an entire column (by key) in the data buffer.
|
|
|
|
Args:
|
|
key: The column name to set a value for.
|
|
item: The data to insert.
|
|
"""
|
|
# Defend against creating SampleBatch via pickle (no property
|
|
# `added_keys` and first item is already set).
|
|
if not hasattr(self, "added_keys"):
|
|
dict.__setitem__(self, key, item)
|
|
return
|
|
|
|
# Backward compatibility for when "input-dicts" were used.
|
|
if key == "is_training":
|
|
if log_once("SampleBatch['is_training']"):
|
|
deprecation_warning(
|
|
old="SampleBatch['is_training']",
|
|
new="SampleBatch.is_training",
|
|
error=False,
|
|
)
|
|
self._is_training = item
|
|
return
|
|
|
|
if key not in self:
|
|
self.added_keys.add(key)
|
|
|
|
dict.__setitem__(self, key, item)
|
|
if key in self.intercepted_values:
|
|
self.intercepted_values[key] = item
|
|
|
|
@property
|
|
def is_training(self):
|
|
if self.get_interceptor is not None and isinstance(self._is_training, bool):
|
|
if "_is_training" not in self.intercepted_values:
|
|
self.intercepted_values["_is_training"] = self.get_interceptor(
|
|
self._is_training
|
|
)
|
|
return self.intercepted_values["_is_training"]
|
|
return self._is_training
|
|
|
|
def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
|
|
self._is_training = training
|
|
self.intercepted_values.pop("_is_training", None)
|
|
|
|
@PublicAPI
|
|
def __delitem__(self, key):
|
|
self.deleted_keys.add(key)
|
|
dict.__delitem__(self, key)
|
|
|
|
@DeveloperAPI
|
|
def compress(
|
|
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
|
|
) -> "SampleBatch":
|
|
"""Compresses the data buffers (by column) in place.
|
|
|
|
Args:
|
|
bulk: Whether to compress across the batch dimension (0)
|
|
as well. If False will compress n separate list items, where n
|
|
is the batch size.
|
|
columns: The columns to compress. Default: Only
|
|
compress the obs and new_obs columns.
|
|
|
|
Returns:
|
|
This very (now compressed) SampleBatch.
|
|
"""
|
|
|
|
def _compress_in_place(path, value):
|
|
if path[0] not in columns:
|
|
return
|
|
curr = self
|
|
for i, p in enumerate(path):
|
|
if i == len(path) - 1:
|
|
if bulk:
|
|
curr[p] = pack(value)
|
|
else:
|
|
curr[p] = np.array([pack(o) for o in value])
|
|
curr = curr[p]
|
|
|
|
tree.map_structure_with_path(_compress_in_place, self)
|
|
|
|
return self
|
|
|
|
@DeveloperAPI
|
|
def decompress_if_needed(
|
|
self, columns: Set[str] = frozenset(["obs", "new_obs"])
|
|
) -> "SampleBatch":
|
|
"""Decompresses data buffers (per column if not compressed) in place.
|
|
|
|
Args:
|
|
columns: The columns to decompress. Default: Only
|
|
decompress the obs and new_obs columns.
|
|
|
|
Returns:
|
|
This very (now uncompressed) SampleBatch.
|
|
"""
|
|
|
|
def _decompress_in_place(path, value):
|
|
if path[0] not in columns:
|
|
return
|
|
curr = self
|
|
for p in path[:-1]:
|
|
curr = curr[p]
|
|
# Bulk compressed.
|
|
if is_compressed(value):
|
|
curr[path[-1]] = unpack(value)
|
|
# Non bulk compressed.
|
|
elif len(value) > 0 and is_compressed(value[0]):
|
|
curr[path[-1]] = np.array([unpack(o) for o in value])
|
|
|
|
tree.map_structure_with_path(_decompress_in_place, self)
|
|
|
|
return self
|
|
|
|
@DeveloperAPI
|
|
def set_get_interceptor(self, fn):
|
|
# If get-interceptor changes, must erase old intercepted values.
|
|
if fn is not self.get_interceptor:
|
|
self.intercepted_values = {}
|
|
self.get_interceptor = fn
|
|
|
|
def __repr__(self):
|
|
keys = list(self.keys())
|
|
if self.get(SampleBatch.SEQ_LENS) is None:
|
|
return f"SampleBatch({self.count}: {keys})"
|
|
else:
|
|
keys.remove(SampleBatch.SEQ_LENS)
|
|
return (
|
|
f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})"
|
|
)
|
|
|
|
def _slice(self, slice_: slice) -> "SampleBatch":
|
|
"""Helper method to handle SampleBatch slicing using a slice object.
|
|
|
|
The returned SampleBatch uses the same underlying data object as
|
|
`self`, so changing the slice will also change `self`.
|
|
|
|
Note that only zero or positive bounds are allowed for both start
|
|
and stop values. The slice step must be 1 (or None, which is the
|
|
same).
|
|
|
|
Args:
|
|
slice_: The python slice object to slice by.
|
|
|
|
Returns:
|
|
A new SampleBatch, however "linking" into the same data
|
|
(sliced) as self.
|
|
"""
|
|
start = slice_.start or 0
|
|
stop = slice_.stop or len(self)
|
|
# If stop goes beyond the length of this batch -> Make it go till the
|
|
# end only (including last item).
|
|
# Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
|
|
if stop > len(self):
|
|
stop = len(self)
|
|
assert start >= 0 and stop >= 0 and slice_.step in [1, None]
|
|
|
|
if (
|
|
self.get(SampleBatch.SEQ_LENS) is not None
|
|
and len(self[SampleBatch.SEQ_LENS]) > 0
|
|
):
|
|
# Build our slice-map, if not done already.
|
|
if not self._slice_map:
|
|
sum_ = 0
|
|
for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
|
|
self._slice_map.extend([(i, sum_)] * l)
|
|
sum_ = sum_ + l
|
|
# In case `stop` points to the very end (lengths of this
|
|
# batch), return the last sequence (the -1 here makes sure we
|
|
# never go beyond it; would result in an index error below).
|
|
self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))
|
|
|
|
start_seq_len, start_unpadded = self._slice_map[start]
|
|
stop_seq_len, stop_unpadded = self._slice_map[stop]
|
|
start_padded = start_unpadded
|
|
stop_padded = stop_unpadded
|
|
if self.zero_padded:
|
|
start_padded = start_seq_len * self.max_seq_len
|
|
stop_padded = stop_seq_len * self.max_seq_len
|
|
|
|
def map_(path, value):
|
|
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
|
|
"state_in_"
|
|
):
|
|
if path[0] != SampleBatch.INFOS:
|
|
return value[start_padded:stop_padded]
|
|
else:
|
|
return value[start_unpadded:stop_unpadded]
|
|
else:
|
|
return value[start_seq_len:stop_seq_len]
|
|
|
|
data = tree.map_structure_with_path(map_, self)
|
|
return SampleBatch(
|
|
data,
|
|
_is_training=self.is_training,
|
|
_time_major=self.time_major,
|
|
_zero_padded=self.zero_padded,
|
|
_max_seq_len=self.max_seq_len if self.zero_padded else None,
|
|
)
|
|
else:
|
|
data = tree.map_structure(lambda value: value[start:stop], self)
|
|
return SampleBatch(
|
|
data,
|
|
_is_training=self.is_training,
|
|
_time_major=self.time_major,
|
|
)
|
|
|
|
@Deprecated(error=False)
|
|
def _get_slice_indices(self, slice_size):
|
|
data_slices = []
|
|
data_slices_states = []
|
|
if (
|
|
self.get(SampleBatch.SEQ_LENS) is not None
|
|
and len(self[SampleBatch.SEQ_LENS]) > 0
|
|
):
|
|
assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), (
|
|
"ERROR: `slice_size` must be larger than the max. seq-len "
|
|
"in the batch!"
|
|
)
|
|
start_pos = 0
|
|
current_slize_size = 0
|
|
actual_slice_idx = 0
|
|
start_idx = 0
|
|
idx = 0
|
|
while idx < len(self[SampleBatch.SEQ_LENS]):
|
|
seq_len = self[SampleBatch.SEQ_LENS][idx]
|
|
current_slize_size += seq_len
|
|
actual_slice_idx += (
|
|
seq_len if not self.zero_padded else self.max_seq_len
|
|
)
|
|
# Complete minibatch -> Append to data_slices.
|
|
if current_slize_size >= slice_size:
|
|
end_idx = idx + 1
|
|
# We are not zero-padded yet; all sequences are
|
|
# back-to-back.
|
|
if not self.zero_padded:
|
|
data_slices.append((start_pos, start_pos + slice_size))
|
|
start_pos += slice_size
|
|
if current_slize_size > slice_size:
|
|
overhead = current_slize_size - slice_size
|
|
start_pos -= seq_len - overhead
|
|
idx -= 1
|
|
# We are already zero-padded: Cut in chunks of max_seq_len.
|
|
else:
|
|
data_slices.append((start_pos, actual_slice_idx))
|
|
start_pos = actual_slice_idx
|
|
|
|
data_slices_states.append((start_idx, end_idx))
|
|
current_slize_size = 0
|
|
start_idx = idx + 1
|
|
idx += 1
|
|
else:
|
|
i = 0
|
|
while i < self.count:
|
|
data_slices.append((i, i + slice_size))
|
|
i += slice_size
|
|
return data_slices, data_slices_states
|
|
|
|
@ExperimentalAPI
|
|
def get_single_step_input_dict(
|
|
self,
|
|
view_requirements: ViewRequirementsDict,
|
|
index: Union[str, int] = "last",
|
|
) -> "SampleBatch":
|
|
"""Creates single ts SampleBatch at given index from `self`.
|
|
|
|
For usage as input-dict for model (action or value function) calls.
|
|
|
|
Args:
|
|
view_requirements: A view requirements dict from the model for
|
|
which to produce the input_dict.
|
|
index: An integer index value indicating the
|
|
position in the trajectory for which to generate the
|
|
compute_actions input dict. Set to "last" to generate the dict
|
|
at the very end of the trajectory (e.g. for value estimation).
|
|
Note that "last" is different from -1, as "last" will use the
|
|
final NEXT_OBS as observation input.
|
|
|
|
Returns:
|
|
The (single-timestep) input dict for ModelV2 calls.
|
|
"""
|
|
last_mappings = {
|
|
SampleBatch.OBS: SampleBatch.NEXT_OBS,
|
|
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
|
|
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
|
|
}
|
|
|
|
input_dict = {}
|
|
for view_col, view_req in view_requirements.items():
|
|
if view_req.used_for_compute_actions is False:
|
|
continue
|
|
|
|
# Create batches of size 1 (single-agent input-dict).
|
|
data_col = view_req.data_col or view_col
|
|
if index == "last":
|
|
data_col = last_mappings.get(data_col, data_col)
|
|
# Range needed.
|
|
if view_req.shift_from is not None:
|
|
# Batch repeat value > 1: We have single frames in the
|
|
# batch at each timestep (for the `data_col`).
|
|
data = self[view_col][-1]
|
|
traj_len = len(self[data_col])
|
|
missing_at_end = traj_len % view_req.batch_repeat_value
|
|
# Index into the observations column must be shifted by
|
|
# -1 b/c index=0 for observations means the current (last
|
|
# seen) observation (after having taken an action).
|
|
obs_shift = (
|
|
-1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0
|
|
)
|
|
from_ = view_req.shift_from + obs_shift
|
|
to_ = view_req.shift_to + obs_shift + 1
|
|
if to_ == 0:
|
|
to_ = None
|
|
input_dict[view_col] = np.array(
|
|
[
|
|
np.concatenate([data, self[data_col][-missing_at_end:]])[
|
|
from_:to_
|
|
]
|
|
]
|
|
)
|
|
# Single index.
|
|
else:
|
|
input_dict[view_col] = tree.map_structure(
|
|
lambda v: v[-1:], # keep as array (w/ 1 element)
|
|
self[data_col],
|
|
)
|
|
# Single index somewhere inside the trajectory (non-last).
|
|
else:
|
|
input_dict[view_col] = self[data_col][
|
|
index : index + 1 if index != -1 else None
|
|
]
|
|
|
|
return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
|
|
|
|
|
|
@PublicAPI
|
|
class MultiAgentBatch:
|
|
"""A batch of experiences from multiple agents in the environment.
|
|
|
|
Attributes:
|
|
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
|
ids to SampleBatches of experiences.
|
|
count: The number of env steps in this batch.
|
|
"""
|
|
|
|
@PublicAPI
|
|
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
|
|
"""Initialize a MultiAgentBatch instance.
|
|
|
|
Args:
|
|
policy_batches: Mapping from policy
|
|
ids to SampleBatches of experiences.
|
|
env_steps: The number of environment steps in the environment
|
|
this batch contains. This will be less than the number of
|
|
transitions this batch contains across all policies in total.
|
|
"""
|
|
|
|
for v in policy_batches.values():
|
|
assert isinstance(v, SampleBatch)
|
|
self.policy_batches = policy_batches
|
|
# Called "count" for uniformity with SampleBatch.
|
|
# Prefer to access this via the `env_steps()` method when possible
|
|
# for clarity.
|
|
self.count = env_steps
|
|
|
|
@PublicAPI
|
|
def env_steps(self) -> int:
|
|
"""The number of env steps (there are >= 1 agent steps per env step).
|
|
|
|
Returns:
|
|
The number of environment steps contained in this batch.
|
|
"""
|
|
return self.count
|
|
|
|
@PublicAPI
|
|
def __len__(self) -> int:
|
|
"""Same as `self.env_steps()`."""
|
|
return self.count
|
|
|
|
@PublicAPI
|
|
def agent_steps(self) -> int:
|
|
"""The number of agent steps (there are >= 1 agent steps per env step).
|
|
|
|
Returns:
|
|
The number of agent steps total in this batch.
|
|
"""
|
|
ct = 0
|
|
for batch in self.policy_batches.values():
|
|
ct += batch.count
|
|
return ct
|
|
|
|
@PublicAPI
|
|
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
|
|
"""Returns k-step batches holding data for each agent at those steps.
|
|
|
|
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
|
|
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
|
|
|
|
Calling timeslices(1) would return three MultiAgentBatches containing
|
|
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
|
|
|
|
Calling timeslices(2) would return two MultiAgentBatches containing
|
|
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
|
|
|
|
This method is used to implement "lockstep" replay mode. Note that this
|
|
method does not guarantee each batch contains only data from a single
|
|
unroll. Batches might contain data from multiple different envs.
|
|
"""
|
|
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
|
|
|
|
# Build a sorted set of (eps_id, t, policy_id, data...)
|
|
steps = []
|
|
for policy_id, batch in self.policy_batches.items():
|
|
for row in batch.rows():
|
|
steps.append(
|
|
(
|
|
row[SampleBatch.EPS_ID],
|
|
row[SampleBatch.T],
|
|
row[SampleBatch.AGENT_INDEX],
|
|
policy_id,
|
|
row,
|
|
)
|
|
)
|
|
steps.sort()
|
|
|
|
finished_slices = []
|
|
cur_slice = collections.defaultdict(SampleBatchBuilder)
|
|
cur_slice_size = 0
|
|
|
|
def finish_slice():
|
|
nonlocal cur_slice_size
|
|
assert cur_slice_size > 0
|
|
batch = MultiAgentBatch(
|
|
{k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
|
|
)
|
|
cur_slice_size = 0
|
|
cur_slice.clear()
|
|
finished_slices.append(batch)
|
|
|
|
# For each unique env timestep.
|
|
for _, group in itertools.groupby(steps, lambda x: x[:2]):
|
|
# Accumulate into the current slice.
|
|
for _, _, _, policy_id, row in group:
|
|
cur_slice[policy_id].add_values(**row)
|
|
cur_slice_size += 1
|
|
# Slice has reached target number of env steps.
|
|
if cur_slice_size >= k:
|
|
finish_slice()
|
|
assert cur_slice_size == 0
|
|
|
|
if cur_slice_size > 0:
|
|
finish_slice()
|
|
|
|
assert len(finished_slices) > 0, finished_slices
|
|
return finished_slices
|
|
|
|
@staticmethod
|
|
@PublicAPI
|
|
def wrap_as_needed(
|
|
policy_batches: Dict[PolicyID, SampleBatch], env_steps: int
|
|
) -> Union[SampleBatch, "MultiAgentBatch"]:
|
|
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
|
|
If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
|
|
|
|
Args:
|
|
policy_batches: Mapping from policy ids to SampleBatch.
|
|
env_steps: Number of env steps in the batch.
|
|
|
|
Returns:
|
|
The single default policy's SampleBatch or a MultiAgentBatch
|
|
(more than one policy).
|
|
"""
|
|
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
|
|
return policy_batches[DEFAULT_POLICY_ID]
|
|
return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps)
|
|
|
|
@staticmethod
|
|
@PublicAPI
|
|
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
|
|
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
|
|
return concat_samples_into_ma_batch(samples)
|
|
|
|
@PublicAPI
|
|
def copy(self) -> "MultiAgentBatch":
|
|
"""Deep-copies self into a new MultiAgentBatch.
|
|
|
|
Returns:
|
|
The copy of self with deep-copied data.
|
|
"""
|
|
return MultiAgentBatch(
|
|
{k: v.copy() for (k, v) in self.policy_batches.items()}, self.count
|
|
)
|
|
|
|
@PublicAPI
|
|
def size_bytes(self) -> int:
|
|
"""
|
|
Returns:
|
|
The overall size in bytes of all policy batches (all columns).
|
|
"""
|
|
return sum(b.size_bytes() for b in self.policy_batches.values())
|
|
|
|
@DeveloperAPI
|
|
def compress(
|
|
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
|
|
) -> None:
|
|
"""Compresses each policy batch (per column) in place.
|
|
|
|
Args:
|
|
bulk: Whether to compress across the batch dimension (0)
|
|
as well. If False will compress n separate list items, where n
|
|
is the batch size.
|
|
columns: Set of column names to compress.
|
|
"""
|
|
for batch in self.policy_batches.values():
|
|
batch.compress(bulk=bulk, columns=columns)
|
|
|
|
@DeveloperAPI
|
|
def decompress_if_needed(
|
|
self, columns: Set[str] = frozenset(["obs", "new_obs"])
|
|
) -> "MultiAgentBatch":
|
|
"""Decompresses each policy batch (per column), if already compressed.
|
|
|
|
Args:
|
|
columns: Set of column names to decompress.
|
|
|
|
Returns:
|
|
Self.
|
|
"""
|
|
for batch in self.policy_batches.values():
|
|
batch.decompress_if_needed(columns)
|
|
return self
|
|
|
|
@DeveloperAPI
|
|
def as_multi_agent(self) -> "MultiAgentBatch":
|
|
"""Simply returns `self` (already a MultiAgentBatch).
|
|
|
|
Returns:
|
|
This very instance of MultiAgentBatch.
|
|
"""
|
|
return self
|
|
|
|
def __getitem__(self, key: str) -> SampleBatch:
|
|
"""Returns the SampleBatch for the given policy id."""
|
|
return self.policy_batches[key]
|
|
|
|
def __str__(self):
|
|
return "MultiAgentBatch({}, env_steps={})".format(
|
|
str(self.policy_batches), self.count
|
|
)
|
|
|
|
def __repr__(self):
|
|
return "MultiAgentBatch({}, env_steps={})".format(
|
|
str(self.policy_batches), self.count
|
|
)
|
|
|
|
|
|
@PublicAPI
|
|
def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
|
|
"""Concatenates a list of SampleBatches or MultiAgentBatches.
|
|
|
|
If all items in the list are or SampleBatch typ4, the output will be
|
|
a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type.
|
|
If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat
|
|
SampleBatch objects as MultiAgentBatch types with 'default_policy' key and
|
|
concatenate it with th rest of MultiAgentBatch objects.
|
|
Empty samples are simply ignored.
|
|
|
|
Args:
|
|
samples: List of SampleBatches or MultiAgentBatches to be
|
|
concatenated.
|
|
|
|
Returns:
|
|
A new (concatenated) SampleBatch or MultiAgentBatch.
|
|
|
|
Examples:
|
|
>>> import numpy as np
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> b1 = SampleBatch({"a": np.array([1, 2]), # doctest: +SKIP
|
|
... "b": np.array([10, 11])})
|
|
>>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
|
|
... "b": np.array([12])})
|
|
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
|
|
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
|
|
|
|
>>> c1 = MultiAgentBatch({'default_policy': { # doctest: +SKIP
|
|
... "a": np.array([1, 2]),
|
|
... "b": np.array([10, 11])
|
|
... }}, env_steps=2)
|
|
>>> c2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
|
|
... "b": np.array([12])})
|
|
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
|
|
MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
|
|
"b": np.array([10, 11, 12])}}
|
|
"""
|
|
|
|
if any([isinstance(s, MultiAgentBatch) for s in samples]):
|
|
return concat_samples_into_ma_batch(samples)
|
|
|
|
# the output is a SampleBatch type
|
|
concatd_seq_lens = []
|
|
concated_samples = []
|
|
# Make sure these settings are consistent amongst all batches.
|
|
zero_padded = max_seq_len = time_major = None
|
|
for s in samples:
|
|
if s.count > 0:
|
|
if max_seq_len is None:
|
|
zero_padded = s.zero_padded
|
|
max_seq_len = s.max_seq_len
|
|
time_major = s.time_major
|
|
|
|
# Make sure these settings are consistent amongst all batches.
|
|
if s.zero_padded != zero_padded or s.time_major != time_major:
|
|
raise ValueError(
|
|
"All SampleBatches' `zero_padded` and `time_major` settings "
|
|
"must be consistent!"
|
|
)
|
|
if (
|
|
s.max_seq_len is None or max_seq_len is None
|
|
) and s.max_seq_len != max_seq_len:
|
|
raise ValueError(
|
|
"Samples must consistently either provide or omit " "`max_seq_len`!"
|
|
)
|
|
elif zero_padded and s.max_seq_len != max_seq_len:
|
|
raise ValueError(
|
|
"For `zero_padded` SampleBatches, the values of `max_seq_len` "
|
|
"must be consistent!"
|
|
)
|
|
|
|
if max_seq_len is not None:
|
|
max_seq_len = max(max_seq_len, s.max_seq_len)
|
|
concated_samples.append(s)
|
|
if s.get(SampleBatch.SEQ_LENS) is not None:
|
|
concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
|
|
|
|
# If we don't have any samples (0 or only empty SampleBatches),
|
|
# return an empty SampleBatch here.
|
|
if len(concated_samples) == 0:
|
|
return SampleBatch()
|
|
|
|
# Collect the concat'd data.
|
|
concatd_data = {}
|
|
|
|
for k in concated_samples[0].keys():
|
|
try:
|
|
if k == "infos":
|
|
concatd_data[k] = concat_aligned(
|
|
[s[k] for s in concated_samples], time_major=time_major
|
|
)
|
|
else:
|
|
concatd_data[k] = tree.map_structure(
|
|
_concat_key, *[c[k] for c in concated_samples]
|
|
)
|
|
except Exception:
|
|
raise ValueError(
|
|
f"Cannot concat data under key '{k}', b/c "
|
|
"sub-structures under that key don't match. "
|
|
f"`samples`={samples}"
|
|
)
|
|
|
|
# Return a new (concat'd) SampleBatch.
|
|
return SampleBatch(
|
|
concatd_data,
|
|
seq_lens=concatd_seq_lens,
|
|
_time_major=time_major,
|
|
_zero_padded=zero_padded,
|
|
_max_seq_len=max_seq_len,
|
|
)
|
|
|
|
|
|
@PublicAPI
|
|
def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch":
|
|
"""Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type.
|
|
|
|
This function, as opposed to concat_samples() forces the output to always be
|
|
MultiAgentBatch which is more generic than SampleBatch.
|
|
|
|
Args:
|
|
samples: List of SampleBatches or MultiAgentBatches to be
|
|
concatenated.
|
|
|
|
Returns:
|
|
A new (concatenated) MultiAgentBatch.
|
|
|
|
Examples:
|
|
>>> import numpy as np
|
|
>>> from ray.rllib.policy.sample_batch import SampleBatch
|
|
>>> b1 = MultiAgentBatch({'default_policy': { # doctest: +SKIP
|
|
... "a": np.array([1, 2]),
|
|
... "b": np.array([10, 11])
|
|
... }}, env_steps=2)
|
|
>>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
|
|
... "b": np.array([12])})
|
|
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
|
|
MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
|
|
"b": np.array([10, 11, 12])}}
|
|
|
|
"""
|
|
|
|
policy_batches = collections.defaultdict(list)
|
|
env_steps = 0
|
|
for s in samples:
|
|
# Some batches in `samples` may be SampleBatch.
|
|
if isinstance(s, SampleBatch):
|
|
# If empty SampleBatch: ok (just ignore).
|
|
if len(s) <= 0:
|
|
continue
|
|
else:
|
|
# if non-empty: just convert to MA-batch and move forward
|
|
s = s.as_multi_agent()
|
|
elif not isinstance(s, MultiAgentBatch):
|
|
# Otherwise: Error.
|
|
raise ValueError(
|
|
"`concat_samples_into_ma_batch` can only concat "
|
|
"SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__)
|
|
)
|
|
|
|
for key, batch in s.policy_batches.items():
|
|
policy_batches[key].append(batch)
|
|
env_steps += s.env_steps()
|
|
|
|
out = {}
|
|
for key, batches in policy_batches.items():
|
|
out[key] = concat_samples(batches)
|
|
|
|
return MultiAgentBatch(out, env_steps)
|
|
|
|
|
|
def _concat_key(*values, time_major=None):
|
|
return concat_aligned(list(values), time_major)
|