ray/rllib/policy/sample_batch.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1471 lines
54 KiB
Python
Raw Permalink Normal View History

import collections
import numpy as np
import sys
import itertools
import tree # pip install dm_tree
from typing import Dict, Iterator, List, Optional, Set, Union
from ray.util import log_once
from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import concat_aligned
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import (
PolicyID,
TensorType,
SampleBatchType,
ViewRequirementsDict,
)
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
@PublicAPI
class SampleBatch(dict):
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
"""
# Outputs from interacting with the environment
OBS = "obs"
CUR_OBS = "obs"
NEXT_OBS = "new_obs"
ACTIONS = "actions"
REWARDS = "rewards"
PREV_ACTIONS = "prev_actions"
PREV_REWARDS = "prev_rewards"
DONES = "dones"
INFOS = "infos"
SEQ_LENS = "seq_lens"
# This is only computed and used when RE3 exploration strategy is enabled
OBS_EMBEDS = "obs_embeds"
T = "t"
# decision transformer
RETURNS_TO_GO = "returns_to_go"
ATTENTION_MASKS = "attention_masks"
# Extra action fetches keys.
ACTION_DIST_INPUTS = "action_dist_inputs"
ACTION_PROB = "action_prob"
ACTION_LOGP = "action_logp"
# Uniquely identifies an episode.
EPS_ID = "eps_id"
# An env ID (e.g. the index for a vectorized sub-env).
ENV_ID = "env_id"
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
# Uniquely identifies an agent within an episode.
AGENT_INDEX = "agent_index"
# Value function predictions emitted by the behaviour policy.
VF_PREDS = "vf_preds"
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor).
Note: All *args and those **kwargs not listed below will be passed
as-is to the parent dict constructor.
Keyword Args:
_time_major (Optional[bool]): Whether data in this sample batch
is time-major. This is False by default and only relevant
if the data contains sequences.
_max_seq_len (Optional[int]): The max sequence chunk length
if the data contains sequences.
_zero_padded (Optional[bool]): Whether the data in this batch
contains sequences AND these sequences are right-zero-padded
according to the `_max_seq_len` setting.
_is_training (Optional[bool]): Whether this batch is used for
training. If False, batch may be used for e.g. action
computations (inference).
"""
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
# Maximum seq len value.
2021-02-25 12:18:11 +01:00
self.max_seq_len = kwargs.pop("_max_seq_len", None)
# Is alredy right-zero-padded?
2021-02-25 12:18:11 +01:00
self.zero_padded = kwargs.pop("_zero_padded", False)
# Whether this batch is used for training (vs inference).
self._is_training = kwargs.pop("_is_training", None)
# Call super constructor. This will make the actual data accessible
# by column name (str) via e.g. self["some-col"].
dict.__init__(self, *args, **kwargs)
self.accessed_keys = set()
self.added_keys = set()
self.deleted_keys = set()
self.intercepted_values = {}
self.get_interceptor = None
# Clear out None seq-lens.
seq_lens_ = self.get(SampleBatch.SEQ_LENS)
if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
self.pop(SampleBatch.SEQ_LENS, None)
# Numpyfy seq_lens if list.
elif isinstance(seq_lens_, list):
self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
if (
self.max_seq_len is None
and seq_lens_ is not None
and not (tf and tf.is_tensor(seq_lens_))
and len(seq_lens_) > 0
):
self.max_seq_len = max(seq_lens_)
if self._is_training is None:
self._is_training = self.pop("is_training", False)
lengths = []
copy_ = {k: v for k, v in self.items() if k != SampleBatch.SEQ_LENS}
for k, v in copy_.items():
assert isinstance(k, str), self
# TODO: Drop support for lists as values.
# Convert lists of int|float into numpy arrays make sure all data
# has same length.
if isinstance(v, list):
self[k] = np.array(v)
# Try to infer the "length" of the SampleBatch by finding the first
# value that is actually a ndarray/tensor. This would fail if
# all values are nested dicts/tuples of more complex underlying
# structures.
try:
len_ = len(v) if not isinstance(v, (dict, tuple)) else None
if len_:
lengths.append(len_)
except Exception:
pass
if (
self.get(SampleBatch.SEQ_LENS) is not None
and not (tf and tf.is_tensor(self[SampleBatch.SEQ_LENS]))
and len(self[SampleBatch.SEQ_LENS]) > 0
):
self.count = sum(self[SampleBatch.SEQ_LENS])
else:
self.count = lengths[0] if lengths else 0
# A convenience map for slicing this batch into sub-batches along
# the time axis. This helps reduce repeated iterations through the
# batch's seq_lens array to find good slicing points. Built lazily
# when needed.
self._slice_map = []
@PublicAPI
def __len__(self) -> int:
"""Returns the amount of samples in the sample batch."""
return self.count
@PublicAPI
def agent_steps(self) -> int:
"""Returns the same as len(self) (number of steps in this batch).
To make this compatible with `MultiAgentBatch.agent_steps()`.
"""
return len(self)
@PublicAPI
def env_steps(self) -> int:
"""Returns the same as len(self) (number of steps in this batch).
To make this compatible with `MultiAgentBatch.env_steps()`.
"""
return len(self)
@staticmethod
@PublicAPI
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
def concat_samples(
samples: Union[List["SampleBatch"], List["MultiAgentBatch"]],
) -> Union["SampleBatch", "MultiAgentBatch"]:
return concat_samples(samples)
@PublicAPI
def concat(self, other: "SampleBatch") -> "SampleBatch":
"""Concatenates `other` to this one and returns a new SampleBatch.
Args:
other: The other SampleBatch object to concat to this one.
Returns:
The new SampleBatch, resulting from concating `other` to `self`.
Examples:
>>> import numpy as np
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> b1 = SampleBatch({"a": np.array([1, 2])}) # doctest: +SKIP
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])}) # doctest: +SKIP
>>> print(b1.concat(b2)) # doctest: +SKIP
{"a": np.array([1, 2, 3, 4, 5])}
"""
return self.concat_samples([self, other])
@PublicAPI
def copy(self, shallow: bool = False) -> "SampleBatch":
"""Creates a deep or shallow copy of this SampleBatch and returns it.
Args:
shallow: Whether the copying should be done shallowly.
Returns:
A deep or shallow copy of this SampleBatch object.
"""
copy_ = {k: v for k, v in self.items()}
data = tree.map_structure(
lambda v: (
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
),
copy_,
)
copy_ = SampleBatch(data)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_
@PublicAPI
def rows(self) -> Iterator[Dict[str, TensorType]]:
"""Returns an iterator over data rows, i.e. dicts with column values.
Note that if `seq_lens` is set in self, we set it to 1 in the rows.
Yields:
The column values of the row in this iteration.
Examples:
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> batch = SampleBatch({ # doctest: +SKIP
... "a": [1, 2, 3],
... "b": [4, 5, 6],
... "seq_lens": [1, 2]
... })
>>> for row in batch.rows(): # doctest: +SKIP
... print(row) # doctest: +SKIP
{"a": 1, "b": 4, "seq_lens": 1}
{"a": 2, "b": 5, "seq_lens": 1}
{"a": 3, "b": 6, "seq_lens": 1}
"""
seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
self_as_dict = {k: v for k, v in self.items()}
for i in range(self.count):
yield tree.map_structure_with_path(
lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens,
self_as_dict,
)
@PublicAPI
def columns(self, keys: List[str]) -> List[any]:
"""Returns a list of the batch-data in the specified columns.
Args:
keys: List of column names fo which to return the data.
Returns:
The list of data items ordered by the order of column
names in `keys`.
Examples:
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) # doctest: +SKIP
>>> print(batch.columns(["a", "b"])) # doctest: +SKIP
[[1], [2]]
"""
# TODO: (sven) Make this work for nested data as well.
out = []
for k in keys:
out.append(self[k])
return out
@PublicAPI
def shuffle(self) -> "SampleBatch":
"""Shuffles the rows of this batch in-place.
Returns:
This very (now shuffled) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is defined.
Examples:
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> batch = SampleBatch({"a": [1, 2, 3, 4]}) # doctest: +SKIP
>>> print(batch.shuffle()) # doctest: +SKIP
{"a": [4, 1, 3, 2]}
"""
# Shuffling the data when we have `seq_lens` defined is probably
# a bad idea!
if self.get(SampleBatch.SEQ_LENS) is not None:
raise ValueError(
"SampleBatch.shuffle not possible when your data has "
"`seq_lens` defined!"
)
# Get a permutation over the single items once and use the same
# permutation for all the data (otherwise, data would become
# meaningless).
permutation = np.random.permutation(self.count)
self_as_dict = {k: v for k, v in self.items()}
shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
self.update(shuffled)
# Flush cache such that intercepted values are recalculated after the
# shuffling.
self.intercepted_values = {}
return self
@PublicAPI
def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
"""Splits by `eps_id` column and returns list of new batches.
If `eps_id` is not present, splits by `dones` instead.
Args:
key: If specified, overwrite default and use key to split.
Returns:
List of batches, one per distinct episode.
Raises:
KeyError: If the `eps_id` AND `dones` columns are not present.
Examples:
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> # "eps_id" is present
>>> batch = SampleBatch( # doctest: +SKIP
... {"a": [1, 2, 3], "eps_id": [0, 0, 1]})
>>> print(batch.split_by_episode()) # doctest: +SKIP
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
>>>
>>> # "eps_id" not present, split by "dones" instead
>>> batch = SampleBatch( # doctest: +SKIP
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
>>> print(batch.split_by_episode()) # doctest: +SKIP
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
>>>
>>> # The last episode is appended even if it does not end with done
>>> batch = SampleBatch( # doctest: +SKIP
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
>>> print(batch.split_by_episode()) # doctest: +SKIP
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
>>> batch = SampleBatch( # doctest: +SKIP
... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
>>> print(batch.split_by_episode()) # doctest: +SKIP
[{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
"""
def slice_by_eps_id():
slices = []
# Produce a new slice whenever we find a new episode ID.
cur_eps_id = self[SampleBatch.EPS_ID][0]
offset = 0
for i in range(self.count):
next_eps_id = self[SampleBatch.EPS_ID][i]
if next_eps_id != cur_eps_id:
slices.append(self[offset:i])
offset = i
cur_eps_id = next_eps_id
# Add final slice.
slices.append(self[offset : self.count])
return slices
def slice_by_dones():
slices = []
offset = 0
for i in range(self.count):
if self[SampleBatch.DONES][i]:
# Since self[i] is the last timestep of the episode,
# append it to the batch, then set offset to the start
# of the next batch
slices.append(self[offset : i + 1])
offset = i + 1
# Add final slice.
if offset != self.count:
slices.append(self[offset:])
return slices
key_to_method = {
SampleBatch.EPS_ID: slice_by_eps_id,
SampleBatch.DONES: slice_by_dones,
}
# If key not specified, default to this order.
key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
slices = None
if key is not None:
# If key specified, directly use it.
if key not in self:
raise KeyError(f"{self} does not have key `{key}`!")
slices = key_to_method[key]()
else:
# If key not specified, go in order.
for key in key_resolve_order:
if key in self:
slices = key_to_method[key]()
break
if slices is None:
raise KeyError(f"{self} does not have keys {key_resolve_order}!")
assert (
sum(s.count for s in slices) == self.count
), f"Calling split_by_episode on {self} returns {slices}"
f"which should both have {self.count} timesteps!"
return slices
def slice(
self, start: int, end: int, state_start=None, state_end=None
) -> "SampleBatch":
"""Returns a slice of the row data of this batch (w/o copying).
Args:
start: Starting index. If < 0, will left-zero-pad.
end: Ending index.
Returns:
A new SampleBatch, which has a slice of this batch's data.
"""
if (
self.get(SampleBatch.SEQ_LENS) is not None
and len(self[SampleBatch.SEQ_LENS]) > 0
):
2021-02-25 12:18:11 +01:00
if start < 0:
data = {
k: np.concatenate(
[
np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
v[0:end],
]
)
for k, v in self.items()
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
2021-02-25 12:18:11 +01:00
}
else:
data = {
k: tree.map_structure(lambda s: s[start:end], v)
for k, v in self.items()
if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
}
if state_start is not None:
assert state_end is not None
state_idx = 0
state_key = "state_in_{}".format(state_idx)
while state_key in self:
data[state_key] = self[state_key][state_start:state_end]
state_idx += 1
2020-12-21 02:22:32 +01:00
state_key = "state_in_{}".format(state_idx)
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
# Adjust seq_lens if necessary.
data_len = len(data[next(iter(data))])
if sum(seq_lens) != data_len:
assert sum(seq_lens) > data_len
seq_lens[-1] = data_len - sum(seq_lens[:-1])
else:
# Fix state_in_x data.
count = 0
state_start = None
seq_lens = None
for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
count += seq_len
if count >= end:
state_idx = 0
2020-12-21 02:22:32 +01:00
state_key = "state_in_{}".format(state_idx)
if state_start is None:
state_start = i
while state_key in self:
data[state_key] = self[state_key][state_start : i + 1]
state_idx += 1
state_key = "state_in_{}".format(state_idx)
seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
seq_len - (count - end)
]
if start < 0:
seq_lens[0] += -start
diff = sum(seq_lens) - (end - start)
if diff > 0:
seq_lens[0] -= diff
assert sum(seq_lens) == (end - start)
break
elif state_start is None and count > start:
state_start = i
2020-12-21 02:22:32 +01:00
return SampleBatch(
2020-12-21 02:22:32 +01:00
data,
seq_lens=seq_lens,
_is_training=self.is_training,
2020-12-21 02:22:32 +01:00
_time_major=self.time_major,
)
else:
return SampleBatch(
tree.map_structure(lambda value: value[start:end], self),
_is_training=self.is_training,
_time_major=self.time_major,
)
@PublicAPI
def timeslices(
self,
size: Optional[int] = None,
num_slices: Optional[int] = None,
k: Optional[int] = None,
) -> List["SampleBatch"]:
"""Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
Args:
size: The size (in timesteps) of each returned SampleBatch.
num_slices: The number of slices to produce.
k: Deprecated: Use size or num_slices instead. The size
(in timesteps) of each returned SampleBatch.
Returns:
The list of `num_slices` (new) SampleBatches or n (new)
SampleBatches each one of size `size`.
"""
if size is None and num_slices is None:
deprecation_warning("k", "size or num_slices")
assert k is not None
size = k
if size is None:
assert isinstance(num_slices, int)
slices = []
left = len(self)
start = 0
while left:
len_ = left // (num_slices - len(slices))
stop = start + len_
slices.append(self[start:stop])
left -= len_
start = stop
return slices
else:
assert isinstance(size, int)
slices = []
left = len(self)
start = 0
while left:
stop = start + size
slices.append(self[start:stop])
left -= size
start = stop
return slices
2021-02-25 12:18:11 +01:00
@Deprecated(new="SampleBatch.right_zero_pad", error=False)
def zero_pad(self, max_seq_len, exclude_states=True):
return self.right_zero_pad(max_seq_len, exclude_states)
def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
"""Right (adding zeros at end) zero-pads this SampleBatch in-place.
2021-02-25 12:18:11 +01:00
This will set the `self.zero_padded` flag to True and
`self.max_seq_len` to the given `max_seq_len` value.
Args:
max_seq_len: The max (total) length to zero pad to.
exclude_states: If False, also right-zero-pad all
`state_in_x` data. If True, leave `state_in_x` keys
as-is.
Returns:
This very (now right-zero-padded) SampleBatch.
Raises:
ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
Examples:
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> batch = SampleBatch( # doctest: +SKIP
... {"a": [1, 2, 3], "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=4)) # doctest: +SKIP
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3], # doctest: +SKIP
... "state_in_0": [1.0, 3.0],
... "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=5)) # doctest: +SKIP
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
"state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
"seq_lens": [1, 2]}
2021-02-25 12:18:11 +01:00
"""
seq_lens = self.get(SampleBatch.SEQ_LENS)
if seq_lens is None:
raise ValueError(
"Cannot right-zero-pad SampleBatch if no `seq_lens` field "
f"present! SampleBatch={self}"
)
length = len(seq_lens) * max_seq_len
def _zero_pad_in_place(path, value):
# Skip "state_in_..." columns and "seq_lens".
if (exclude_states is True and path[0].startswith("state_in_")) or path[
0
] == SampleBatch.SEQ_LENS:
return
2021-02-25 12:18:11 +01:00
# Generate zero-filled primer of len=max_seq_len.
if value.dtype == object or value.dtype.type is np.str_:
2021-02-25 12:18:11 +01:00
f_pad = [None] * length
else:
# Make sure type doesn't change.
f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
2021-02-25 12:18:11 +01:00
# Fill primer with data.
f_pad_base = f_base = 0
for len_ in self[SampleBatch.SEQ_LENS]:
f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
2021-02-25 12:18:11 +01:00
f_pad_base += max_seq_len
f_base += len_
assert f_base == len(value), value
# Update our data in-place.
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
curr[p] = f_pad
curr = curr[p]
self_as_dict = {k: v for k, v in self.items()}
tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
2021-02-25 12:18:11 +01:00
# Set flags to indicate, we are now zero-padded (and to what extend).
self.zero_padded = True
self.max_seq_len = max_seq_len
return self
@ExperimentalAPI
def to_device(self, device, framework="torch"):
"""TODO: transfer batch to given device as framework tensor."""
if framework == "torch":
assert torch is not None
for k, v in self.items():
self[k] = convert_to_torch_tensor(v, device)
else:
raise NotImplementedError
return self
@PublicAPI
def size_bytes(self) -> int:
"""Returns sum over number of bytes of all data buffers.
For numpy arrays, we use `.nbytes`. For all other value types, we use
sys.getsizeof(...).
Returns:
The overall size in bytes of the data buffer (all columns).
"""
return sum(
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
for v in tree.flatten(self)
)
def get(self, key, default=None):
try:
return self.__getitem__(key)
except KeyError:
return default
@PublicAPI
def as_multi_agent(self) -> "MultiAgentBatch":
"""Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.
Returns:
The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
to this SampleBatch.
"""
return MultiAgentBatch({DEFAULT_POLICY_ID: self}, self.count)
@PublicAPI
def __getitem__(self, key: Union[str, slice]) -> TensorType:
"""Returns one column (by key) from the data or a sliced new batch.
Args:
key: The key (column name) to return or
a slice object for slicing this SampleBatch.
Returns:
The data under the given key or a sliced version of this batch.
"""
if isinstance(key, slice):
return self._slice(key)
# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
deprecation_warning(
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False,
)
return self.is_training
if not hasattr(self, key) and key in self:
self.accessed_keys.add(key)
value = dict.__getitem__(self, key)
if self.get_interceptor is not None:
if key not in self.intercepted_values:
self.intercepted_values[key] = self.get_interceptor(value)
value = self.intercepted_values[key]
return value
@PublicAPI
def __setitem__(self, key, item) -> None:
"""Inserts (overrides) an entire column (by key) in the data buffer.
Args:
key: The column name to set a value for.
item: The data to insert.
"""
# Defend against creating SampleBatch via pickle (no property
# `added_keys` and first item is already set).
if not hasattr(self, "added_keys"):
dict.__setitem__(self, key, item)
return
# Backward compatibility for when "input-dicts" were used.
if key == "is_training":
if log_once("SampleBatch['is_training']"):
deprecation_warning(
old="SampleBatch['is_training']",
new="SampleBatch.is_training",
error=False,
)
self._is_training = item
return
if key not in self:
self.added_keys.add(key)
dict.__setitem__(self, key, item)
if key in self.intercepted_values:
self.intercepted_values[key] = item
@property
def is_training(self):
if self.get_interceptor is not None and isinstance(self._is_training, bool):
if "_is_training" not in self.intercepted_values:
self.intercepted_values["_is_training"] = self.get_interceptor(
self._is_training
)
return self.intercepted_values["_is_training"]
return self._is_training
def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
self._is_training = training
self.intercepted_values.pop("_is_training", None)
@PublicAPI
def __delitem__(self, key):
self.deleted_keys.add(key)
dict.__delitem__(self, key)
@DeveloperAPI
def compress(
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
) -> "SampleBatch":
"""Compresses the data buffers (by column) in place.
Args:
bulk: Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns: The columns to compress. Default: Only
compress the obs and new_obs columns.
Returns:
This very (now compressed) SampleBatch.
"""
def _compress_in_place(path, value):
if path[0] not in columns:
return
curr = self
for i, p in enumerate(path):
if i == len(path) - 1:
if bulk:
curr[p] = pack(value)
else:
curr[p] = np.array([pack(o) for o in value])
curr = curr[p]
tree.map_structure_with_path(_compress_in_place, self)
return self
@DeveloperAPI
def decompress_if_needed(
self, columns: Set[str] = frozenset(["obs", "new_obs"])
) -> "SampleBatch":
"""Decompresses data buffers (per column if not compressed) in place.
Args:
columns: The columns to decompress. Default: Only
decompress the obs and new_obs columns.
Returns:
This very (now uncompressed) SampleBatch.
"""
def _decompress_in_place(path, value):
if path[0] not in columns:
return
curr = self
for p in path[:-1]:
curr = curr[p]
# Bulk compressed.
if is_compressed(value):
curr[path[-1]] = unpack(value)
# Non bulk compressed.
elif len(value) > 0 and is_compressed(value[0]):
curr[path[-1]] = np.array([unpack(o) for o in value])
tree.map_structure_with_path(_decompress_in_place, self)
return self
@DeveloperAPI
def set_get_interceptor(self, fn):
# If get-interceptor changes, must erase old intercepted values.
if fn is not self.get_interceptor:
self.intercepted_values = {}
self.get_interceptor = fn
def __repr__(self):
keys = list(self.keys())
if self.get(SampleBatch.SEQ_LENS) is None:
return f"SampleBatch({self.count}: {keys})"
else:
keys.remove(SampleBatch.SEQ_LENS)
return (
f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})"
)
def _slice(self, slice_: slice) -> "SampleBatch":
"""Helper method to handle SampleBatch slicing using a slice object.
The returned SampleBatch uses the same underlying data object as
`self`, so changing the slice will also change `self`.
Note that only zero or positive bounds are allowed for both start
and stop values. The slice step must be 1 (or None, which is the
same).
Args:
slice_: The python slice object to slice by.
Returns:
A new SampleBatch, however "linking" into the same data
(sliced) as self.
"""
start = slice_.start or 0
stop = slice_.stop or len(self)
# If stop goes beyond the length of this batch -> Make it go till the
# end only (including last item).
# Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
if stop > len(self):
stop = len(self)
assert start >= 0 and stop >= 0 and slice_.step in [1, None]
if (
self.get(SampleBatch.SEQ_LENS) is not None
and len(self[SampleBatch.SEQ_LENS]) > 0
):
# Build our slice-map, if not done already.
if not self._slice_map:
sum_ = 0
for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
self._slice_map.extend([(i, sum_)] * l)
sum_ = sum_ + l
# In case `stop` points to the very end (lengths of this
# batch), return the last sequence (the -1 here makes sure we
# never go beyond it; would result in an index error below).
self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))
start_seq_len, start_unpadded = self._slice_map[start]
stop_seq_len, stop_unpadded = self._slice_map[stop]
start_padded = start_unpadded
stop_padded = stop_unpadded
if self.zero_padded:
start_padded = start_seq_len * self.max_seq_len
stop_padded = stop_seq_len * self.max_seq_len
def map_(path, value):
if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
"state_in_"
):
if path[0] != SampleBatch.INFOS:
return value[start_padded:stop_padded]
else:
return value[start_unpadded:stop_unpadded]
else:
return value[start_seq_len:stop_seq_len]
data = tree.map_structure_with_path(map_, self)
return SampleBatch(
data,
_is_training=self.is_training,
_time_major=self.time_major,
_zero_padded=self.zero_padded,
_max_seq_len=self.max_seq_len if self.zero_padded else None,
)
else:
data = tree.map_structure(lambda value: value[start:stop], self)
return SampleBatch(
data,
_is_training=self.is_training,
_time_major=self.time_major,
)
@Deprecated(error=False)
2021-02-25 12:18:11 +01:00
def _get_slice_indices(self, slice_size):
data_slices = []
data_slices_states = []
if (
self.get(SampleBatch.SEQ_LENS) is not None
and len(self[SampleBatch.SEQ_LENS]) > 0
):
assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), (
"ERROR: `slice_size` must be larger than the max. seq-len "
"in the batch!"
)
2021-02-25 12:18:11 +01:00
start_pos = 0
current_slize_size = 0
actual_slice_idx = 0
start_idx = 0
2021-02-25 12:18:11 +01:00
idx = 0
while idx < len(self[SampleBatch.SEQ_LENS]):
seq_len = self[SampleBatch.SEQ_LENS][idx]
2021-02-25 12:18:11 +01:00
current_slize_size += seq_len
actual_slice_idx += (
seq_len if not self.zero_padded else self.max_seq_len
)
# Complete minibatch -> Append to data_slices.
2021-02-25 12:18:11 +01:00
if current_slize_size >= slice_size:
end_idx = idx + 1
# We are not zero-padded yet; all sequences are
# back-to-back.
if not self.zero_padded:
data_slices.append((start_pos, start_pos + slice_size))
start_pos += slice_size
if current_slize_size > slice_size:
overhead = current_slize_size - slice_size
start_pos -= seq_len - overhead
idx -= 1
# We are already zero-padded: Cut in chunks of max_seq_len.
else:
data_slices.append((start_pos, actual_slice_idx))
start_pos = actual_slice_idx
data_slices_states.append((start_idx, end_idx))
2021-02-25 12:18:11 +01:00
current_slize_size = 0
start_idx = idx + 1
2021-02-25 12:18:11 +01:00
idx += 1
else:
i = 0
2021-02-25 12:18:11 +01:00
while i < self.count:
data_slices.append((i, i + slice_size))
2021-02-25 12:18:11 +01:00
i += slice_size
return data_slices, data_slices_states
2021-02-25 12:18:11 +01:00
@ExperimentalAPI
def get_single_step_input_dict(
self,
view_requirements: ViewRequirementsDict,
index: Union[str, int] = "last",
) -> "SampleBatch":
"""Creates single ts SampleBatch at given index from `self`.
For usage as input-dict for model (action or value function) calls.
Args:
view_requirements: A view requirements dict from the model for
which to produce the input_dict.
index: An integer index value indicating the
position in the trajectory for which to generate the
compute_actions input dict. Set to "last" to generate the dict
at the very end of the trajectory (e.g. for value estimation).
Note that "last" is different from -1, as "last" will use the
final NEXT_OBS as observation input.
Returns:
The (single-timestep) input dict for ModelV2 calls.
"""
last_mappings = {
SampleBatch.OBS: SampleBatch.NEXT_OBS,
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
}
input_dict = {}
for view_col, view_req in view_requirements.items():
if view_req.used_for_compute_actions is False:
continue
# Create batches of size 1 (single-agent input-dict).
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
# Range needed.
if view_req.shift_from is not None:
# Batch repeat value > 1: We have single frames in the
# batch at each timestep (for the `data_col`).
data = self[view_col][-1]
traj_len = len(self[data_col])
missing_at_end = traj_len % view_req.batch_repeat_value
# Index into the observations column must be shifted by
# -1 b/c index=0 for observations means the current (last
# seen) observation (after having taken an action).
obs_shift = (
-1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0
)
from_ = view_req.shift_from + obs_shift
to_ = view_req.shift_to + obs_shift + 1
if to_ == 0:
to_ = None
input_dict[view_col] = np.array(
[
np.concatenate([data, self[data_col][-missing_at_end:]])[
from_:to_
]
]
)
# Single index.
else:
input_dict[view_col] = tree.map_structure(
lambda v: v[-1:], # keep as array (w/ 1 element)
self[data_col],
)
# Single index somewhere inside the trajectory (non-last).
else:
input_dict[view_col] = self[data_col][
index : index + 1 if index != -1 else None
]
return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
@PublicAPI
class MultiAgentBatch:
"""A batch of experiences from multiple agents in the environment.
Attributes:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
count: The number of env steps in this batch.
"""
@PublicAPI
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
"""Initialize a MultiAgentBatch instance.
Args:
policy_batches: Mapping from policy
ids to SampleBatches of experiences.
env_steps: The number of environment steps in the environment
this batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
"""
for v in policy_batches.values():
assert isinstance(v, SampleBatch)
self.policy_batches = policy_batches
# Called "count" for uniformity with SampleBatch.
# Prefer to access this via the `env_steps()` method when possible
# for clarity.
self.count = env_steps
@PublicAPI
def env_steps(self) -> int:
"""The number of env steps (there are >= 1 agent steps per env step).
Returns:
The number of environment steps contained in this batch.
"""
return self.count
@PublicAPI
def __len__(self) -> int:
"""Same as `self.env_steps()`."""
return self.count
@PublicAPI
def agent_steps(self) -> int:
"""The number of agent steps (there are >= 1 agent steps per env step).
Returns:
The number of agent steps total in this batch.
"""
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@PublicAPI
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
"""Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement "lockstep" replay mode. Note that this
method does not guarantee each batch contains only data from a single
unroll. Batches might contain data from multiple different envs.
"""
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
# Build a sorted set of (eps_id, t, policy_id, data...)
steps = []
for policy_id, batch in self.policy_batches.items():
for row in batch.rows():
steps.append(
(
row[SampleBatch.EPS_ID],
row[SampleBatch.T],
row[SampleBatch.AGENT_INDEX],
policy_id,
row,
)
)
steps.sort()
finished_slices = []
cur_slice = collections.defaultdict(SampleBatchBuilder)
cur_slice_size = 0
def finish_slice():
nonlocal cur_slice_size
assert cur_slice_size > 0
batch = MultiAgentBatch(
{k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
)
cur_slice_size = 0
cur_slice.clear()
finished_slices.append(batch)
# For each unique env timestep.
for _, group in itertools.groupby(steps, lambda x: x[:2]):
# Accumulate into the current slice.
for _, _, _, policy_id, row in group:
cur_slice[policy_id].add_values(**row)
cur_slice_size += 1
# Slice has reached target number of env steps.
if cur_slice_size >= k:
finish_slice()
assert cur_slice_size == 0
if cur_slice_size > 0:
finish_slice()
assert len(finished_slices) > 0, finished_slices
return finished_slices
@staticmethod
@PublicAPI
def wrap_as_needed(
policy_batches: Dict[PolicyID, SampleBatch], env_steps: int
) -> Union[SampleBatch, "MultiAgentBatch"]:
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
Args:
policy_batches: Mapping from policy ids to SampleBatch.
env_steps: Number of env steps in the batch.
Returns:
The single default policy's SampleBatch or a MultiAgentBatch
(more than one policy).
"""
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
return policy_batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps)
@staticmethod
@PublicAPI
@Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=False)
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
return concat_samples_into_ma_batch(samples)
@PublicAPI
def copy(self) -> "MultiAgentBatch":
"""Deep-copies self into a new MultiAgentBatch.
Returns:
The copy of self with deep-copied data.
"""
return MultiAgentBatch(
{k: v.copy() for (k, v) in self.policy_batches.items()}, self.count
)
@PublicAPI
def size_bytes(self) -> int:
"""
Returns:
The overall size in bytes of all policy batches (all columns).
"""
return sum(b.size_bytes() for b in self.policy_batches.values())
@DeveloperAPI
def compress(
self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
) -> None:
"""Compresses each policy batch (per column) in place.
Args:
bulk: Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns: Set of column names to compress.
"""
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(
self, columns: Set[str] = frozenset(["obs", "new_obs"])
) -> "MultiAgentBatch":
"""Decompresses each policy batch (per column), if already compressed.
Args:
columns: Set of column names to decompress.
Returns:
Self.
"""
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
return self
@DeveloperAPI
def as_multi_agent(self) -> "MultiAgentBatch":
"""Simply returns `self` (already a MultiAgentBatch).
Returns:
This very instance of MultiAgentBatch.
"""
return self
def __getitem__(self, key: str) -> SampleBatch:
"""Returns the SampleBatch for the given policy id."""
return self.policy_batches[key]
def __str__(self):
return "MultiAgentBatch({}, env_steps={})".format(
str(self.policy_batches), self.count
)
def __repr__(self):
return "MultiAgentBatch({}, env_steps={})".format(
str(self.policy_batches), self.count
)
@PublicAPI
def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
"""Concatenates a list of SampleBatches or MultiAgentBatches.
If all items in the list are or SampleBatch typ4, the output will be
a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type.
If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat
SampleBatch objects as MultiAgentBatch types with 'default_policy' key and
concatenate it with th rest of MultiAgentBatch objects.
Empty samples are simply ignored.
Args:
samples: List of SampleBatches or MultiAgentBatches to be
concatenated.
Returns:
A new (concatenated) SampleBatch or MultiAgentBatch.
Examples:
>>> import numpy as np
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> b1 = SampleBatch({"a": np.array([1, 2]), # doctest: +SKIP
... "b": np.array([10, 11])})
>>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
... "b": np.array([12])})
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
>>> c1 = MultiAgentBatch({'default_policy': { # doctest: +SKIP
... "a": np.array([1, 2]),
... "b": np.array([10, 11])
... }}, env_steps=2)
>>> c2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
... "b": np.array([12])})
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
"b": np.array([10, 11, 12])}}
"""
if any([isinstance(s, MultiAgentBatch) for s in samples]):
return concat_samples_into_ma_batch(samples)
# the output is a SampleBatch type
concatd_seq_lens = []
concated_samples = []
# Make sure these settings are consistent amongst all batches.
zero_padded = max_seq_len = time_major = None
for s in samples:
if s.count > 0:
if max_seq_len is None:
zero_padded = s.zero_padded
max_seq_len = s.max_seq_len
time_major = s.time_major
# Make sure these settings are consistent amongst all batches.
if s.zero_padded != zero_padded or s.time_major != time_major:
raise ValueError(
"All SampleBatches' `zero_padded` and `time_major` settings "
"must be consistent!"
)
if (
s.max_seq_len is None or max_seq_len is None
) and s.max_seq_len != max_seq_len:
raise ValueError(
"Samples must consistently either provide or omit " "`max_seq_len`!"
)
elif zero_padded and s.max_seq_len != max_seq_len:
raise ValueError(
"For `zero_padded` SampleBatches, the values of `max_seq_len` "
"must be consistent!"
)
if max_seq_len is not None:
max_seq_len = max(max_seq_len, s.max_seq_len)
concated_samples.append(s)
if s.get(SampleBatch.SEQ_LENS) is not None:
concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
# If we don't have any samples (0 or only empty SampleBatches),
# return an empty SampleBatch here.
if len(concated_samples) == 0:
return SampleBatch()
# Collect the concat'd data.
concatd_data = {}
for k in concated_samples[0].keys():
try:
if k == "infos":
concatd_data[k] = concat_aligned(
[s[k] for s in concated_samples], time_major=time_major
)
else:
concatd_data[k] = tree.map_structure(
_concat_key, *[c[k] for c in concated_samples]
)
except Exception:
raise ValueError(
f"Cannot concat data under key '{k}', b/c "
"sub-structures under that key don't match. "
f"`samples`={samples}"
)
# Return a new (concat'd) SampleBatch.
return SampleBatch(
concatd_data,
seq_lens=concatd_seq_lens,
_time_major=time_major,
_zero_padded=zero_padded,
_max_seq_len=max_seq_len,
)
@PublicAPI
def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch":
"""Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type.
This function, as opposed to concat_samples() forces the output to always be
MultiAgentBatch which is more generic than SampleBatch.
Args:
samples: List of SampleBatches or MultiAgentBatches to be
concatenated.
Returns:
A new (concatenated) MultiAgentBatch.
Examples:
>>> import numpy as np
>>> from ray.rllib.policy.sample_batch import SampleBatch
>>> b1 = MultiAgentBatch({'default_policy': { # doctest: +SKIP
... "a": np.array([1, 2]),
... "b": np.array([10, 11])
... }}, env_steps=2)
>>> b2 = SampleBatch({"a": np.array([3]), # doctest: +SKIP
... "b": np.array([12])})
>>> print(concat_samples([b1, b2])) # doctest: +SKIP
MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
"b": np.array([10, 11, 12])}}
"""
policy_batches = collections.defaultdict(list)
env_steps = 0
for s in samples:
# Some batches in `samples` may be SampleBatch.
if isinstance(s, SampleBatch):
# If empty SampleBatch: ok (just ignore).
if len(s) <= 0:
continue
else:
# if non-empty: just convert to MA-batch and move forward
s = s.as_multi_agent()
elif not isinstance(s, MultiAgentBatch):
# Otherwise: Error.
raise ValueError(
"`concat_samples_into_ma_batch` can only concat "
"SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__)
)
for key, batch in s.policy_batches.items():
policy_batches[key].append(batch)
env_steps += s.env_steps()
out = {}
for key, batches in policy_batches.items():
out[key] = concat_samples(batches)
return MultiAgentBatch(out, env_steps)
def _concat_key(*values, time_major=None):
return concat_aligned(list(values), time_major)