2019-05-20 16:46:05 -07:00
import collections
import numpy as np
2020-06-12 20:17:27 -07:00
import sys
import itertools
2020-07-05 13:09:51 +02:00
from typing import Any, Dict, Iterable, List, Optional, Set, Union
2019-05-20 16:46:05 -07:00
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.memory import concat_aligned
2020-08-15 13:24:22 +02:00
from ray.rllib.utils.typing import TensorType
2019-05-20 16:46:05 -07:00
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
2020-06-12 20:17:27 -07:00
# TODO(ekl) reuse the other id def once we fix imports
PolicyID = Any
2019-05-20 16:46:05 -07:00
2020-01-02 17:42:13 -08:00
class SampleBatch:
2019-05-20 16:46:05 -07:00
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
# Outputs from interacting with the environment
2020-08-06 10:54:20 +02:00
OBS = "obs"
2019-05-20 16:46:05 -07:00
CUR_OBS = "obs"
NEXT_OBS = "new_obs"
ACTIONS = "actions"
REWARDS = "rewards"
PREV_ACTIONS = "prev_actions"
PREV_REWARDS = "prev_rewards"
DONES = "dones"
INFOS = "infos"
2020-04-01 09:43:21 +02:00
# Extra action fetches keys.
ACTION_DIST_INPUTS = "action_dist_inputs"
ACTION_PROB = "action_prob"
ACTION_LOGP = "action_logp"
2020-08-06 10:54:20 +02:00
# Uniquely identifies an episode.
2019-05-20 16:46:05 -07:00
EPS_ID = "eps_id"
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
2020-08-06 10:54:20 +02:00
# Uniquely identifies an agent within an episode.
2019-05-20 16:46:05 -07:00
AGENT_INDEX = "agent_index"
2020-08-06 10:54:20 +02:00
# Value function predictions emitted by the behaviour policy.
2019-05-20 16:46:05 -07:00
VF_PREDS = "vf_preds"
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
2020-08-21 12:35:16 +02:00
# Possible seq_lens (TxB or BxT) setup.
self.time_major = kwargs.pop("_time_major", None)
self.seq_lens = kwargs.pop("_seq_lens", None)
self.max_seq_len = None
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.max_seq_len = max(self.seq_lens)
# The actual data, accessible by column name (str).
2019-05-20 16:46:05 -07:00
self.data = dict(*args, **kwargs)
2020-08-21 12:35:16 +02:00
2019-05-20 16:46:05 -07:00
lengths = []
for k, v in self.data.copy().items():
2020-03-23 11:42:05 -07:00
assert isinstance(k, str), self
2019-05-20 16:46:05 -07:00
2020-08-21 12:35:16 +02:00
if isinstance(v, list):
self.data[k] = np.array(v)
2019-05-20 16:46:05 -07:00
if not lengths:
raise ValueError("Empty sample batch")
2020-08-21 12:35:16 +02:00
assert len(set(lengths)) == 1, \
"Data columns must be same length, but lens are {}".format(lengths)
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
self.count = len(self.data[k])
# Keeps track of new columns added after initial ones.
self.new_columns = []
2019-05-20 16:46:05 -07:00
2020-10-02 23:07:44 +02:00
def concat_samples(samples: List["SampleBatch"]) -> \
2020-07-05 13:09:51 +02:00
Union["SampleBatch", "MultiAgentBatch"]:
2020-06-04 22:47:32 +02:00
"""Concatenates n data dicts or MultiAgentBatches.
2020-07-05 13:09:51 +02:00
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
2020-06-04 22:47:32 +02:00
2020-07-05 13:09:51 +02:00
Union[SampleBatch, MultiAgentBatch]: A new (compressed)
SampleBatch or MultiAgentBatch.
2020-06-04 22:47:32 +02:00
2019-05-20 16:46:05 -07:00
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
2020-08-21 12:35:16 +02:00
seq_lens = []
concat_samples = []
for s in samples:
if s.count > 0:
if s.seq_lens is not None:
2019-05-20 16:46:05 -07:00
out = {}
2020-08-21 12:35:16 +02:00
for k in concat_samples[0].keys():
out[k] = concat_aligned(
[s[k] for s in concat_samples],
return SampleBatch(
out, _seq_lens=seq_lens, _time_major=concat_samples[0].time_major)
2019-05-20 16:46:05 -07:00
2020-07-05 13:09:51 +02:00
def concat(self, other: "SampleBatch") -> "SampleBatch":
2019-05-20 16:46:05 -07:00
"""Returns a new SampleBatch with each data column concatenated.
2020-07-05 13:09:51 +02:00
other (SampleBatch): The other SampleBatch object to concat to this
SampleBatch: The new SampleBatch, resulting from concating `other`
to `self`.
2019-05-20 16:46:05 -07:00
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
>>> print(b1.concat(b2))
{"a": [1, 2, 3, 4, 5]}
2020-06-04 22:47:32 +02:00
if self.keys() != other.keys():
raise ValueError(
"SampleBatches to concat must have same columns! {} vs {}".
format(list(self.keys()), list(other.keys())))
2019-05-20 16:46:05 -07:00
out = {}
for k in self.keys():
out[k] = concat_aligned([self[k], other[k]])
return SampleBatch(out)
2020-07-05 13:09:51 +02:00
def copy(self) -> "SampleBatch":
"""Creates a (deep) copy of this SampleBatch and returns it.
SampleBatch: A (deep) copy of this SampleBatch object.
2019-05-20 16:46:05 -07:00
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
2020-07-05 13:09:51 +02:00
def rows(self) -> Dict[str, TensorType]:
2019-05-20 16:46:05 -07:00
"""Returns an iterator over data rows, i.e. dicts with column values.
2020-07-05 13:09:51 +02:00
Dict[str, TensorType]: The column values of the row in this
2019-05-20 16:46:05 -07:00
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
{"a": 1, "b": 4}
{"a": 2, "b": 5}
{"a": 3, "b": 6}
for i in range(self.count):
row = {}
for k in self.keys():
row[k] = self[k][i]
yield row
2020-07-05 13:09:51 +02:00
def columns(self, keys: List[str]) -> List[any]:
2020-06-04 22:47:32 +02:00
"""Returns a list of the batch-data in the specified columns.
keys (List[str]): List of column names fo which to return the data.
List[any]: The list of data items ordered by the order of column
names in `keys`.
2019-05-20 16:46:05 -07:00
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
out = []
for k in keys:
return out
2020-07-05 13:09:51 +02:00
def shuffle(self) -> None:
2019-05-20 16:46:05 -07:00
"""Shuffles the rows of this batch in-place."""
permutation = np.random.permutation(self.count)
for key, val in self.items():
self[key] = val[permutation]
2020-07-05 13:09:51 +02:00
def split_by_episode(self) -> List["SampleBatch"]:
2019-05-20 16:46:05 -07:00
"""Splits this batch's data by `eps_id`.
2020-06-04 22:47:32 +02:00
List[SampleBatch]: List of batches, one per distinct episode.
2019-05-20 16:46:05 -07:00
slices = []
cur_eps_id = self.data["eps_id"][0]
offset = 0
for i in range(self.count):
next_eps_id = self.data["eps_id"][i]
if next_eps_id != cur_eps_id:
slices.append(self.slice(offset, i))
offset = i
cur_eps_id = next_eps_id
slices.append(self.slice(offset, self.count))
for s in slices:
slen = len(set(s["eps_id"]))
assert slen == 1, (s, slen)
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
2020-07-05 13:09:51 +02:00
def slice(self, start: int, end: int) -> "SampleBatch":
2020-07-29 21:15:09 +02:00
"""Returns a slice of the row data of this batch (w/o copying).
2019-05-20 16:46:05 -07:00
2020-06-04 22:47:32 +02:00
2019-05-20 16:46:05 -07:00
start (int): Starting index.
end (int): Ending index.
2020-07-05 13:09:51 +02:00
SampleBatch: A new SampleBatch, which has a slice of this batch's
2019-05-20 16:46:05 -07:00
2020-08-21 12:35:16 +02:00
if self.time_major is not None:
return SampleBatch(
{k: v[:, start:end]
for k, v in self.data.items()},
return SampleBatch(
{k: v[start:end]
for k, v in self.data.items()},
2019-05-20 16:46:05 -07:00
2020-06-12 20:17:27 -07:00
def timeslices(self, k: int) -> List["SampleBatch"]:
2020-07-05 13:09:51 +02:00
"""Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
k (int): The size (in timesteps) of each returned SampleBatch.
List[SampleBatch]: The list of (new) SampleBatches (each one of
size k).
2020-06-12 20:17:27 -07:00
out = []
i = 0
while i < self.count:
out.append(self.slice(i, i + k))
i += k
return out
2019-05-20 16:46:05 -07:00
2020-07-05 13:09:51 +02:00
def keys(self) -> Iterable[str]:
Iterable[str]: The keys() iterable over `self.data`.
2019-05-20 16:46:05 -07:00
return self.data.keys()
2020-07-05 13:09:51 +02:00
def items(self) -> Iterable[TensorType]:
Iterable[TensorType]: The values() iterable over `self.data`.
2019-05-20 16:46:05 -07:00
return self.data.items()
2019-09-17 04:44:20 -04:00
2020-07-05 13:09:51 +02:00
def get(self, key: str) -> Optional[TensorType]:
"""Returns one column (by key) from the data or None if key not found.
key (str): The key (column name) to return.
Optional[TensorType]: The data under the given key. None if key
not found in data.
2019-09-17 04:44:20 -04:00
return self.data.get(key)
2020-06-12 20:17:27 -07:00
def size_bytes(self) -> int:
2020-07-05 13:09:51 +02:00
int: The overall size in bytes of the data buffer (all columns).
2020-09-23 15:46:06 -07:00
return sum(sys.getsizeof(d) for d in self.data.values())
2020-06-12 20:17:27 -07:00
2019-05-20 16:46:05 -07:00
2020-07-05 13:09:51 +02:00
def __getitem__(self, key: str) -> TensorType:
"""Returns one column (by key) from the data.
key (str): The key (column name) to return.
2020-08-21 12:35:16 +02:00
TensorType: The data under the given key.
2020-07-05 13:09:51 +02:00
2019-05-20 16:46:05 -07:00
return self.data[key]
2020-07-05 13:09:51 +02:00
def __setitem__(self, key, item) -> None:
"""Inserts (overrides) an entire column (by key) in the data buffer.
key (str): The column name to set a value for.
item (TensorType): The data to insert.
2020-08-21 12:35:16 +02:00
if key not in self.data:
2019-05-20 16:46:05 -07:00
self.data[key] = item
2020-08-07 16:49:49 -07:00
def compress(self,
bulk: bool = False,
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
2020-07-05 13:09:51 +02:00
"""Compresses the data buffers (by column) in place.
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): The columns to compress. Default: Only
compress the obs and new_obs columns.
2019-05-20 16:46:05 -07:00
for key in columns:
if key in self.data:
if bulk:
self.data[key] = pack(self.data[key])
self.data[key] = np.array(
[pack(o) for o in self.data[key]])
2020-08-07 16:49:49 -07:00
def decompress_if_needed(self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "SampleBatch":
2020-07-05 13:09:51 +02:00
"""Decompresses data buffers (per column if not compressed) in place.
columns (Set[str]): The columns to decompress. Default: Only
decompress the obs and new_obs columns.
SampleBatch: This very SampleBatch.
2019-05-20 16:46:05 -07:00
for key in columns:
if key in self.data:
arr = self.data[key]
if is_compressed(arr):
self.data[key] = unpack(arr)
elif len(arr) > 0 and is_compressed(arr[0]):
self.data[key] = np.array(
[unpack(o) for o in self.data[key]])
2020-05-11 20:24:43 -07:00
return self
2019-05-20 16:46:05 -07:00
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data
2019-08-12 17:39:02 -07:00
2020-01-02 17:42:13 -08:00
class MultiAgentBatch:
2020-07-05 13:09:51 +02:00
"""A batch of experiences from multiple agents in the environment.
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
count (int): The number of env steps in this batch.
2019-08-12 17:39:02 -07:00
2020-08-07 16:49:49 -07:00
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
2020-06-12 20:17:27 -07:00
env_steps: int):
"""Initialize a MultiAgentBatch object.
2020-06-04 22:47:32 +02:00
2020-06-12 20:17:27 -07:00
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
env_steps (int): The number of timesteps in the environment this
batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
2020-06-04 22:47:32 +02:00
2020-07-05 13:09:51 +02:00
2020-06-12 20:17:27 -07:00
for v in policy_batches.values():
assert isinstance(v, SampleBatch)
2019-08-12 17:39:02 -07:00
self.policy_batches = policy_batches
2020-06-12 20:17:27 -07:00
# Called count for uniformity with SampleBatch. Prefer to access this
# via the env_steps() method when possible for clarity.
self.count = env_steps
def env_steps(self) -> int:
"""The number of env steps (there are >= 1 agent steps per env step).
2020-07-05 13:09:51 +02:00
int: The number of environment steps contained in this batch.
2020-06-12 20:17:27 -07:00
return self.count
def agent_steps(self) -> int:
"""The number of agent steps (there are >= 1 agent steps per env step).
2020-07-05 13:09:51 +02:00
int: The number of agent steps total in this batch.
2020-06-12 20:17:27 -07:00
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
"""Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement "lockstep" replay mode. Note that this
method does not guarantee each batch contains only data from a single
unroll. Batches might contain data from multiple different envs.
from ray.rllib.evaluation.sample_batch_builder import \
# Build a sorted set of (eps_id, t, policy_id, data...)
steps = []
for policy_id, batch in self.policy_batches.items():
for row in batch.rows():
2020-07-20 08:03:12 +02:00
steps.append((row[SampleBatch.EPS_ID], row["t"],
row["agent_index"], policy_id, row))
2020-06-12 20:17:27 -07:00
finished_slices = []
cur_slice = collections.defaultdict(SampleBatchBuilder)
cur_slice_size = 0
def finish_slice():
nonlocal cur_slice_size
assert cur_slice_size > 0
batch = MultiAgentBatch(
{k: v.build_and_reset()
for k, v in cur_slice.items()}, cur_slice_size)
cur_slice_size = 0
# For each unique env timestep.
for _, group in itertools.groupby(steps, lambda x: x[:2]):
# Accumulate into the current slice.
2020-07-20 08:03:12 +02:00
for _, _, _, policy_id, row in group:
2020-06-12 20:17:27 -07:00
cur_slice_size += 1
# Slice has reached target number of env steps.
if cur_slice_size >= k:
assert cur_slice_size == 0
if cur_slice_size > 0:
assert len(finished_slices) > 0, finished_slices
return finished_slices
2019-08-12 17:39:02 -07:00
2020-07-05 13:09:51 +02:00
def wrap_as_needed(
policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
2020-06-04 22:47:32 +02:00
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
2020-06-12 20:17:27 -07:00
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatch.
env_steps (int): Number of env steps in the batch.
2020-06-04 22:47:32 +02:00
2020-06-12 20:17:27 -07:00
Union[SampleBatch, MultiAgentBatch]: The single default policy's
2020-06-04 22:47:32 +02:00
SampleBatch or a MultiAgentBatch (more than one policy).
2020-06-12 20:17:27 -07:00
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
return policy_batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(policy_batches, env_steps)
2019-08-12 17:39:02 -07:00
2020-06-12 20:17:27 -07:00
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
2020-06-04 22:47:32 +02:00
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
to concatenate.
MultiAgentBatch: A new MultiAgentBatch consisting of the
concatenated inputs.
2019-08-12 17:39:02 -07:00
policy_batches = collections.defaultdict(list)
2020-06-12 20:17:27 -07:00
env_steps = 0
2019-08-12 17:39:02 -07:00
for s in samples:
2020-06-04 22:47:32 +02:00
if not isinstance(s, MultiAgentBatch):
raise ValueError(
"`MultiAgentBatch.concat_samples()` can only concat "
"MultiAgentBatch types, not {}!".format(type(s).__name__))
2020-06-12 20:17:27 -07:00
for key, batch in s.policy_batches.items():
env_steps += s.env_steps()
2019-08-12 17:39:02 -07:00
out = {}
2020-06-12 20:17:27 -07:00
for key, batches in policy_batches.items():
out[key] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, env_steps)
2019-08-12 17:39:02 -07:00
2020-06-12 20:17:27 -07:00
def copy(self) -> "MultiAgentBatch":
2020-06-04 22:47:32 +02:00
"""Deep-copies self into a new MultiAgentBatch.
MultiAgentBatch: The copy of self with deep-copied data.
2019-08-12 17:39:02 -07:00
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
2020-06-12 20:17:27 -07:00
def size_bytes(self) -> int:
2020-07-05 13:09:51 +02:00
int: The overall size in bytes of all policy batches (all columns).
2020-06-12 20:17:27 -07:00
return sum(b.size_bytes() for b in self.policy_batches.values())
2019-08-12 17:39:02 -07:00
2020-08-07 16:49:49 -07:00
def compress(self,
bulk: bool = False,
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
2020-07-05 13:09:51 +02:00
"""Compresses each policy batch (per column) in place.
2020-06-04 22:47:32 +02:00
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): Set of column names to compress.
2019-08-12 17:39:02 -07:00
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
2020-08-07 16:49:49 -07:00
def decompress_if_needed(self,
columns: Set[str] = frozenset(
["obs", "new_obs"])) -> "MultiAgentBatch":
2020-07-05 13:09:51 +02:00
"""Decompresses each policy batch (per column), if already compressed.
2020-06-04 22:47:32 +02:00
columns (Set[str]): Set of column names to decompress.
2020-07-05 13:09:51 +02:00
MultiAgentBatch: This very MultiAgentBatch.
2020-06-04 22:47:32 +02:00
2019-08-12 17:39:02 -07:00
for batch in self.policy_batches.values():
2020-05-11 20:24:43 -07:00
return self
2019-08-12 17:39:02 -07:00
def __str__(self):
2020-06-12 20:17:27 -07:00
return "MultiAgentBatch({}, env_steps={})".format(
2019-08-12 17:39:02 -07:00
str(self.policy_batches), self.count)
def __repr__(self):
2020-06-12 20:17:27 -07:00
return "MultiAgentBatch({}, env_steps={})".format(
2019-08-12 17:39:02 -07:00
str(self.policy_batches), self.count)