ray/rllib/policy/sample_batch.py

478 lines
15 KiB
Python
Raw Normal View History

import collections
import numpy as np
import sys
import itertools
from typing import Dict, List, Any
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.memory import concat_aligned
from ray.rllib.utils.deprecation import deprecation_warning
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
# TODO(ekl) reuse the other id def once we fix imports
PolicyID = Any
@PublicAPI
class SampleBatch:
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
"""
# Outputs from interacting with the environment
CUR_OBS = "obs"
NEXT_OBS = "new_obs"
ACTIONS = "actions"
REWARDS = "rewards"
PREV_ACTIONS = "prev_actions"
PREV_REWARDS = "prev_rewards"
DONES = "dones"
INFOS = "infos"
# Extra action fetches keys.
ACTION_DIST_INPUTS = "action_dist_inputs"
ACTION_PROB = "action_prob"
ACTION_LOGP = "action_logp"
# Uniquely identifies an episode
EPS_ID = "eps_id"
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
# Uniquely identifies an agent within an episode
AGENT_INDEX = "agent_index"
# Value function predictions emitted by the behaviour policy
VF_PREDS = "vf_preds"
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
self.data = dict(*args, **kwargs)
lengths = []
for k, v in self.data.copy().items():
assert isinstance(k, str), self
lengths.append(len(v))
self.data[k] = np.array(v, copy=False)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, ("data columns must be same length",
self.data, lengths)
self.count = lengths[0]
@staticmethod
@PublicAPI
def concat_samples(samples):
"""Concatenates n data dicts or MultiAgentBatches.
Args:
samples (List[Dict[np.ndarray]]]): List of dicts of data (numpy).
Returns:
Union[SampleBatch,MultiAgentBatch]: A new (compressed) SampleBatch/
MultiAgentBatch.
"""
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = concat_aligned([s[k] for s in samples])
return SampleBatch(out)
@PublicAPI
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.
Examples:
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
>>> print(b1.concat(b2))
{"a": [1, 2, 3, 4, 5]}
"""
if self.keys() != other.keys():
raise ValueError(
"SampleBatches to concat must have same columns! {} vs {}".
format(list(self.keys()), list(other.keys())))
out = {}
for k in self.keys():
out[k] = concat_aligned([self[k], other[k]])
return SampleBatch(out)
@PublicAPI
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
@PublicAPI
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4}
{"a": 2, "b": 5}
{"a": 3, "b": 6}
"""
for i in range(self.count):
row = {}
for k in self.keys():
row[k] = self[k][i]
yield row
@PublicAPI
def columns(self, keys):
"""Returns a list of the batch-data in the specified columns.
Args:
keys (List[str]): List of column names fo which to return the data.
Returns:
List[any]: The list of data items ordered by the order of column
names in `keys`.
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
"""
out = []
for k in keys:
out.append(self[k])
return out
@PublicAPI
def shuffle(self):
"""Shuffles the rows of this batch in-place."""
permutation = np.random.permutation(self.count)
for key, val in self.items():
self[key] = val[permutation]
@PublicAPI
def split_by_episode(self):
"""Splits this batch's data by `eps_id`.
Returns:
List[SampleBatch]: List of batches, one per distinct episode.
"""
slices = []
cur_eps_id = self.data["eps_id"][0]
offset = 0
for i in range(self.count):
next_eps_id = self.data["eps_id"][i]
if next_eps_id != cur_eps_id:
slices.append(self.slice(offset, i))
offset = i
cur_eps_id = next_eps_id
slices.append(self.slice(offset, self.count))
for s in slices:
slen = len(set(s["eps_id"]))
assert slen == 1, (s, slen)
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
@PublicAPI
def slice(self, start, end):
"""Returns a slice of the row data of this batch.
Args:
start (int): Starting index.
end (int): Ending index.
Returns:
SampleBatch which has a slice of this batch's data.
"""
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
@PublicAPI
def timeslices(self, k: int) -> List["SampleBatch"]:
out = []
i = 0
while i < self.count:
out.append(self.slice(i, i + k))
i += k
return out
@PublicAPI
def keys(self):
return self.data.keys()
@PublicAPI
def items(self):
return self.data.items()
@PublicAPI
def get(self, key):
return self.data.get(key)
@PublicAPI
def size_bytes(self) -> int:
return sum(sys.getsizeof(d) for d in self.data)
@PublicAPI
def __getitem__(self, key):
return self.data[key]
@PublicAPI
def __setitem__(self, key, item):
self.data[key] = item
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
if bulk:
self.data[key] = pack(self.data[key])
else:
self.data[key] = np.array(
[pack(o) for o in self.data[key]])
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
arr = self.data[key]
if is_compressed(arr):
self.data[key] = unpack(arr)
elif len(arr) > 0 and is_compressed(arr[0]):
self.data[key] = np.array(
[unpack(o) for o in self.data[key]])
return self
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data
@PublicAPI
class MultiAgentBatch:
"""A batch of experiences from multiple agents in the environment."""
@PublicAPI
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int):
"""Initialize a MultiAgentBatch object.
Args:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
env_steps (int): The number of timesteps in the environment this
batch contains. This will be less than the number of
transitions this batch contains across all policies in total.
Attributes:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatches of experiences.
count (int): the number of env steps in this batch.
"""
for v in policy_batches.values():
assert isinstance(v, SampleBatch)
self.policy_batches = policy_batches
# Called count for uniformity with SampleBatch. Prefer to access this
# via the env_steps() method when possible for clarity.
self.count = env_steps
@PublicAPI
def env_steps(self) -> int:
"""The number of env steps (there are >= 1 agent steps per env step).
Returns:
int: the number of environment steps contained in this batch.
"""
return self.count
@PublicAPI
def agent_steps(self) -> int:
"""The number of agent steps (there are >= 1 agent steps per env step).
Returns:
int: the number of agent steps total in this batch.
"""
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@PublicAPI
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
"""Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement "lockstep" replay mode. Note that this
method does not guarantee each batch contains only data from a single
unroll. Batches might contain data from multiple different envs.
"""
from ray.rllib.evaluation.sample_batch_builder import \
SampleBatchBuilder
# Build a sorted set of (eps_id, t, policy_id, data...)
steps = []
for policy_id, batch in self.policy_batches.items():
for row in batch.rows():
steps.append((row[SampleBatch.EPS_ID], row["t"], policy_id,
row))
steps.sort()
finished_slices = []
cur_slice = collections.defaultdict(SampleBatchBuilder)
cur_slice_size = 0
def finish_slice():
nonlocal cur_slice_size
assert cur_slice_size > 0
batch = MultiAgentBatch(
{k: v.build_and_reset()
for k, v in cur_slice.items()}, cur_slice_size)
cur_slice_size = 0
finished_slices.append(batch)
# For each unique env timestep.
for _, group in itertools.groupby(steps, lambda x: x[:2]):
# Accumulate into the current slice.
for _, _, policy_id, row in group:
cur_slice[policy_id].add_values(**row)
cur_slice_size += 1
# Slice has reached target number of env steps.
if cur_slice_size >= k:
finish_slice()
assert cur_slice_size == 0
if cur_slice_size > 0:
finish_slice()
assert len(finished_slices) > 0, finished_slices
return finished_slices
@staticmethod
@PublicAPI
def wrap_as_needed(policy_batches: Dict[PolicyID, SampleBatch],
env_steps: int) -> Any:
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
Args:
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
ids to SampleBatch.
env_steps (int): Number of env steps in the batch.
Returns:
Union[SampleBatch, MultiAgentBatch]: The single default policy's
SampleBatch or a MultiAgentBatch (more than one policy).
"""
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
return policy_batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(policy_batches, env_steps)
@staticmethod
@PublicAPI
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
Args:
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
to concatenate.
Returns:
MultiAgentBatch: A new MultiAgentBatch consisting of the
concatenated inputs.
"""
policy_batches = collections.defaultdict(list)
env_steps = 0
for s in samples:
if not isinstance(s, MultiAgentBatch):
raise ValueError(
"`MultiAgentBatch.concat_samples()` can only concat "
"MultiAgentBatch types, not {}!".format(type(s).__name__))
for key, batch in s.policy_batches.items():
policy_batches[key].append(batch)
env_steps += s.env_steps()
out = {}
for key, batches in policy_batches.items():
out[key] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, env_steps)
@PublicAPI
def copy(self) -> "MultiAgentBatch":
"""Deep-copies self into a new MultiAgentBatch.
Returns:
MultiAgentBatch: The copy of self with deep-copied data.
"""
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def size_bytes(self) -> int:
return sum(b.size_bytes() for b in self.policy_batches.values())
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
"""Compresses each policy batch.
Args:
bulk (bool): Whether to compress across the batch dimension (0)
as well. If False will compress n separate list items, where n
is the batch size.
columns (Set[str]): Set of column names to compress.
"""
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
"""Decompresses each policy batch, if already compressed.
Args:
columns (Set[str]): Set of column names to decompress.
"""
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
return self
def __str__(self):
return "MultiAgentBatch({}, env_steps={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, env_steps={})".format(
str(self.policy_batches), self.count)
# Deprecated.
def total(self):
deprecation_warning("batch.total()", "batch.agent_steps()")
return self.agent_steps()