2019-05-20 16:46:05 -07:00
|
|
|
import collections
|
|
|
|
import numpy as np
|
2020-06-12 20:17:27 -07:00
|
|
|
import sys
|
|
|
|
import itertools
|
2021-03-17 08:18:15 +01:00
|
|
|
from typing import Dict, List, Set, Union
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
from ray.util import log_once
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
|
|
|
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
2021-03-17 08:18:15 +01:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.utils.memory import concat_aligned
|
2020-12-27 09:46:03 -05:00
|
|
|
from ray.rllib.utils.typing import PolicyID, TensorType
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
# Default policy id for single agent environments
|
|
|
|
DEFAULT_POLICY_ID = "default_policy"
|
|
|
|
|
|
|
|
|
|
|
|
@PublicAPI
|
2021-03-17 08:18:15 +01:00
|
|
|
class SampleBatch(dict):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Wrapper around a dictionary with string keys and array-like values.
|
|
|
|
|
|
|
|
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
|
|
|
|
samples, each with an "obs" and "reward" attribute.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Outputs from interacting with the environment
|
2020-08-06 10:54:20 +02:00
|
|
|
OBS = "obs"
|
2019-05-20 16:46:05 -07:00
|
|
|
CUR_OBS = "obs"
|
|
|
|
NEXT_OBS = "new_obs"
|
|
|
|
ACTIONS = "actions"
|
|
|
|
REWARDS = "rewards"
|
|
|
|
PREV_ACTIONS = "prev_actions"
|
|
|
|
PREV_REWARDS = "prev_rewards"
|
|
|
|
DONES = "dones"
|
|
|
|
INFOS = "infos"
|
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Extra action fetches keys.
|
|
|
|
ACTION_DIST_INPUTS = "action_dist_inputs"
|
|
|
|
ACTION_PROB = "action_prob"
|
|
|
|
ACTION_LOGP = "action_logp"
|
|
|
|
|
2020-08-06 10:54:20 +02:00
|
|
|
# Uniquely identifies an episode.
|
2019-05-20 16:46:05 -07:00
|
|
|
EPS_ID = "eps_id"
|
|
|
|
|
|
|
|
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
|
|
|
# sequences from the same episode when multiple sample batches are
|
|
|
|
# concatenated (fusing sequences across batches can be unsafe).
|
|
|
|
UNROLL_ID = "unroll_id"
|
|
|
|
|
2020-08-06 10:54:20 +02:00
|
|
|
# Uniquely identifies an agent within an episode.
|
2019-05-20 16:46:05 -07:00
|
|
|
AGENT_INDEX = "agent_index"
|
|
|
|
|
2020-08-06 10:54:20 +02:00
|
|
|
# Value function predictions emitted by the behaviour policy.
|
2019-05-20 16:46:05 -07:00
|
|
|
VF_PREDS = "vf_preds"
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
"""Constructs a sample batch (same params as dict constructor)."""
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
# Possible seq_lens (TxB or BxT) setup.
|
|
|
|
self.time_major = kwargs.pop("_time_major", None)
|
2021-03-29 20:07:44 +02:00
|
|
|
|
|
|
|
self.seq_lens = kwargs.pop("_seq_lens", kwargs.pop("seq_lens", None))
|
|
|
|
if self.seq_lens is None and len(args) > 0 and isinstance(
|
|
|
|
args[0], dict):
|
|
|
|
self.seq_lens = args[0].pop("_seq_lens", args[0].pop(
|
|
|
|
"seq_lens", None))
|
2021-03-17 08:18:15 +01:00
|
|
|
if isinstance(self.seq_lens, list):
|
2021-03-29 20:07:44 +02:00
|
|
|
self.seq_lens = np.array(self.seq_lens, dtype=np.int32)
|
|
|
|
|
2020-12-21 02:22:32 +01:00
|
|
|
self.dont_check_lens = kwargs.pop("_dont_check_lens", False)
|
2021-02-25 12:18:11 +01:00
|
|
|
self.max_seq_len = kwargs.pop("_max_seq_len", None)
|
|
|
|
if self.max_seq_len is None and self.seq_lens is not None and \
|
2021-03-17 08:18:15 +01:00
|
|
|
not (tf and tf.is_tensor(self.seq_lens)) and \
|
2021-02-25 12:18:11 +01:00
|
|
|
len(self.seq_lens) > 0:
|
2020-08-21 12:35:16 +02:00
|
|
|
self.max_seq_len = max(self.seq_lens)
|
2021-02-25 12:18:11 +01:00
|
|
|
self.zero_padded = kwargs.pop("_zero_padded", False)
|
2021-03-17 08:18:15 +01:00
|
|
|
self.is_training = kwargs.pop("_is_training", None)
|
|
|
|
|
|
|
|
# Call super constructor. This will make the actual data accessible
|
|
|
|
# by column name (str) via e.g. self["some-col"].
|
|
|
|
dict.__init__(self, *args, **kwargs)
|
|
|
|
|
|
|
|
self.accessed_keys = set()
|
|
|
|
self.added_keys = set()
|
|
|
|
self.deleted_keys = set()
|
|
|
|
self.intercepted_values = {}
|
|
|
|
|
|
|
|
self.get_interceptor = None
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2021-03-29 20:07:44 +02:00
|
|
|
if self.is_training is None:
|
|
|
|
self.is_training = self.pop("is_training", False)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
lengths = []
|
2021-03-17 08:18:15 +01:00
|
|
|
copy_ = {k: v for k, v in self.items()}
|
|
|
|
for k, v in copy_.items():
|
2020-03-23 11:42:05 -07:00
|
|
|
assert isinstance(k, str), self
|
2021-03-17 08:18:15 +01:00
|
|
|
len_ = len(v) if isinstance(
|
|
|
|
v,
|
|
|
|
(list, np.ndarray)) or (torch and torch.is_tensor(v)) else None
|
|
|
|
lengths.append(len_)
|
2020-08-21 12:35:16 +02:00
|
|
|
if isinstance(v, list):
|
2021-03-17 08:18:15 +01:00
|
|
|
self[k] = np.array(v)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
if not lengths:
|
|
|
|
raise ValueError("Empty sample batch")
|
2021-03-17 08:18:15 +01:00
|
|
|
|
2020-12-21 02:22:32 +01:00
|
|
|
if not self.dont_check_lens:
|
|
|
|
assert len(set(lengths)) == 1, \
|
|
|
|
"Data columns must be same length, but lens are " \
|
|
|
|
"{}".format(lengths)
|
2021-03-17 08:18:15 +01:00
|
|
|
|
|
|
|
if self.seq_lens is not None and \
|
|
|
|
not (tf and tf.is_tensor(self.seq_lens)) and \
|
|
|
|
len(self.seq_lens) > 0:
|
2020-08-21 12:35:16 +02:00
|
|
|
self.count = sum(self.seq_lens)
|
|
|
|
else:
|
2021-03-17 08:18:15 +01:00
|
|
|
self.count = lengths[0]
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-01-12 20:15:23 +01:00
|
|
|
@PublicAPI
|
|
|
|
def __len__(self):
|
|
|
|
"""Returns the amount of samples in the sample batch."""
|
|
|
|
return self.count
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
2020-10-02 23:07:44 +02:00
|
|
|
def concat_samples(samples: List["SampleBatch"]) -> \
|
2020-07-05 13:09:51 +02:00
|
|
|
Union["SampleBatch", "MultiAgentBatch"]:
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Concatenates n data dicts or MultiAgentBatches.
|
|
|
|
|
|
|
|
Args:
|
2020-07-05 13:09:51 +02:00
|
|
|
samples (List[Dict[TensorType]]]): List of dicts of data (numpy).
|
2020-06-04 22:47:32 +02:00
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
Union[SampleBatch, MultiAgentBatch]: A new (compressed)
|
|
|
|
SampleBatch or MultiAgentBatch.
|
2020-06-04 22:47:32 +02:00
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
if isinstance(samples[0], MultiAgentBatch):
|
|
|
|
return MultiAgentBatch.concat_samples(samples)
|
2020-08-21 12:35:16 +02:00
|
|
|
seq_lens = []
|
|
|
|
concat_samples = []
|
2021-02-25 12:18:11 +01:00
|
|
|
zero_padded = samples[0].zero_padded
|
|
|
|
max_seq_len = samples[0].max_seq_len
|
2020-08-21 12:35:16 +02:00
|
|
|
for s in samples:
|
|
|
|
if s.count > 0:
|
2021-02-25 12:18:11 +01:00
|
|
|
assert s.zero_padded == zero_padded
|
|
|
|
if zero_padded:
|
|
|
|
assert s.max_seq_len == max_seq_len
|
2020-08-21 12:35:16 +02:00
|
|
|
concat_samples.append(s)
|
|
|
|
if s.seq_lens is not None:
|
|
|
|
seq_lens.extend(s.seq_lens)
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
out = {}
|
2020-08-21 12:35:16 +02:00
|
|
|
for k in concat_samples[0].keys():
|
|
|
|
out[k] = concat_aligned(
|
|
|
|
[s[k] for s in concat_samples],
|
|
|
|
time_major=concat_samples[0].time_major)
|
|
|
|
return SampleBatch(
|
2020-12-07 13:08:17 +01:00
|
|
|
out,
|
|
|
|
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
2020-12-21 02:22:32 +01:00
|
|
|
_time_major=concat_samples[0].time_major,
|
2021-02-25 12:18:11 +01:00
|
|
|
_dont_check_lens=True,
|
|
|
|
_zero_padded=zero_padded,
|
|
|
|
_max_seq_len=max_seq_len,
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def concat(self, other: "SampleBatch") -> "SampleBatch":
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Returns a new SampleBatch with each data column concatenated.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Args:
|
|
|
|
other (SampleBatch): The other SampleBatch object to concat to this
|
|
|
|
one.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatch: The new SampleBatch, resulting from concating `other`
|
|
|
|
to `self`.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Examples:
|
|
|
|
>>> b1 = SampleBatch({"a": [1, 2]})
|
|
|
|
>>> b2 = SampleBatch({"a": [3, 4, 5]})
|
|
|
|
>>> print(b1.concat(b2))
|
|
|
|
{"a": [1, 2, 3, 4, 5]}
|
|
|
|
"""
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
if self.keys() != other.keys():
|
|
|
|
raise ValueError(
|
|
|
|
"SampleBatches to concat must have same columns! {} vs {}".
|
|
|
|
format(list(self.keys()), list(other.keys())))
|
2019-05-20 16:46:05 -07:00
|
|
|
out = {}
|
|
|
|
for k in self.keys():
|
|
|
|
out[k] = concat_aligned([self[k], other[k]])
|
|
|
|
return SampleBatch(out)
|
|
|
|
|
|
|
|
@PublicAPI
|
2021-03-17 08:18:15 +01:00
|
|
|
def copy(self, shallow: bool = False) -> "SampleBatch":
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Creates a (deep) copy of this SampleBatch and returns it.
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
Args:
|
|
|
|
shallow (bool): Whether the copying should be done shallowly.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Returns:
|
|
|
|
SampleBatch: A (deep) copy of this SampleBatch object.
|
|
|
|
"""
|
2021-03-17 08:18:15 +01:00
|
|
|
copy_ = SampleBatch(
|
|
|
|
{
|
|
|
|
k: np.array(v, copy=not shallow)
|
|
|
|
if isinstance(v, np.ndarray) else v
|
|
|
|
for (k, v) in self.items()
|
|
|
|
},
|
2021-02-25 12:18:11 +01:00
|
|
|
_seq_lens=self.seq_lens,
|
2021-03-17 08:18:15 +01:00
|
|
|
_dont_check_lens=self.dont_check_lens)
|
|
|
|
copy_.set_get_interceptor(self.get_interceptor)
|
|
|
|
return copy_
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def rows(self) -> Dict[str, TensorType]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Returns an iterator over data rows, i.e. dicts with column values.
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
Yields:
|
|
|
|
Dict[str, TensorType]: The column values of the row in this
|
|
|
|
iteration.
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
Examples:
|
|
|
|
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
|
|
>>> for row in batch.rows():
|
|
|
|
print(row)
|
|
|
|
{"a": 1, "b": 4}
|
|
|
|
{"a": 2, "b": 5}
|
|
|
|
{"a": 3, "b": 6}
|
|
|
|
"""
|
|
|
|
|
|
|
|
for i in range(self.count):
|
|
|
|
row = {}
|
|
|
|
for k in self.keys():
|
|
|
|
row[k] = self[k][i]
|
|
|
|
yield row
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def columns(self, keys: List[str]) -> List[any]:
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Returns a list of the batch-data in the specified columns.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
keys (List[str]): List of column names fo which to return the data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[any]: The list of data items ordered by the order of column
|
|
|
|
names in `keys`.
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
|
|
|
|
>>> print(batch.columns(["a", "b"]))
|
|
|
|
[[1], [2]]
|
|
|
|
"""
|
|
|
|
|
|
|
|
out = []
|
|
|
|
for k in keys:
|
|
|
|
out.append(self[k])
|
|
|
|
return out
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def shuffle(self) -> None:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Shuffles the rows of this batch in-place."""
|
|
|
|
|
|
|
|
permutation = np.random.permutation(self.count)
|
|
|
|
for key, val in self.items():
|
|
|
|
self[key] = val[permutation]
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def split_by_episode(self) -> List["SampleBatch"]:
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Splits this batch's data by `eps_id`.
|
|
|
|
|
|
|
|
Returns:
|
2020-06-04 22:47:32 +02:00
|
|
|
List[SampleBatch]: List of batches, one per distinct episode.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
# No eps_id in data -> Make sure there are no "dones" in the middle
|
|
|
|
# and add eps_id automatically.
|
|
|
|
if SampleBatch.EPS_ID not in self:
|
|
|
|
if SampleBatch.DONES in self:
|
|
|
|
assert not any(self[SampleBatch.DONES][:-1])
|
|
|
|
self[SampleBatch.EPS_ID] = np.repeat(0, self.count)
|
|
|
|
return [self]
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
slices = []
|
2021-03-17 08:18:15 +01:00
|
|
|
cur_eps_id = self[SampleBatch.EPS_ID][0]
|
2019-05-20 16:46:05 -07:00
|
|
|
offset = 0
|
|
|
|
for i in range(self.count):
|
2021-03-17 08:18:15 +01:00
|
|
|
next_eps_id = self[SampleBatch.EPS_ID][i]
|
2019-05-20 16:46:05 -07:00
|
|
|
if next_eps_id != cur_eps_id:
|
|
|
|
slices.append(self.slice(offset, i))
|
|
|
|
offset = i
|
|
|
|
cur_eps_id = next_eps_id
|
|
|
|
slices.append(self.slice(offset, self.count))
|
|
|
|
for s in slices:
|
2021-03-17 08:18:15 +01:00
|
|
|
slen = len(set(s[SampleBatch.EPS_ID]))
|
2019-05-20 16:46:05 -07:00
|
|
|
assert slen == 1, (s, slen)
|
|
|
|
assert sum(s.count for s in slices) == self.count, (slices, self.count)
|
|
|
|
return slices
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def slice(self, start: int, end: int) -> "SampleBatch":
|
2020-07-29 21:15:09 +02:00
|
|
|
"""Returns a slice of the row data of this batch (w/o copying).
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
Args:
|
2021-02-25 12:18:11 +01:00
|
|
|
start (int): Starting index. If < 0, will zero-pad.
|
2019-05-20 16:46:05 -07:00
|
|
|
end (int): Ending index.
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
SampleBatch: A new SampleBatch, which has a slice of this batch's
|
|
|
|
data.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-12-21 02:22:32 +01:00
|
|
|
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
2021-02-25 12:18:11 +01:00
|
|
|
if start < 0:
|
|
|
|
data = {
|
|
|
|
k: np.concatenate([
|
|
|
|
np.zeros(
|
|
|
|
shape=(-start, ) + v.shape[1:], dtype=v.dtype),
|
|
|
|
v[0:end]
|
|
|
|
])
|
2021-03-17 08:18:15 +01:00
|
|
|
for k, v in self.items()
|
2021-02-25 12:18:11 +01:00
|
|
|
}
|
|
|
|
else:
|
2021-03-17 08:18:15 +01:00
|
|
|
data = {k: v[start:end] for k, v in self.items()}
|
2020-12-21 02:22:32 +01:00
|
|
|
# Fix state_in_x data.
|
|
|
|
count = 0
|
|
|
|
state_start = None
|
|
|
|
seq_lens = None
|
|
|
|
for i, seq_len in enumerate(self.seq_lens):
|
|
|
|
count += seq_len
|
|
|
|
if count >= end:
|
|
|
|
state_idx = 0
|
|
|
|
state_key = "state_in_{}".format(state_idx)
|
2021-02-25 12:18:11 +01:00
|
|
|
if state_start is None:
|
|
|
|
state_start = i
|
2021-03-17 08:18:15 +01:00
|
|
|
while state_key in self:
|
|
|
|
data[state_key] = self[state_key][state_start:i + 1]
|
2020-12-21 02:22:32 +01:00
|
|
|
state_idx += 1
|
|
|
|
state_key = "state_in_{}".format(state_idx)
|
|
|
|
seq_lens = list(self.seq_lens[state_start:i]) + [
|
|
|
|
seq_len - (count - end)
|
|
|
|
]
|
2021-02-25 12:18:11 +01:00
|
|
|
if start < 0:
|
|
|
|
seq_lens[0] += -start
|
2020-12-21 02:22:32 +01:00
|
|
|
assert sum(seq_lens) == (end - start)
|
|
|
|
break
|
|
|
|
elif state_start is None and count > start:
|
|
|
|
state_start = i
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
return SampleBatch(
|
2020-12-21 02:22:32 +01:00
|
|
|
data,
|
|
|
|
_seq_lens=np.array(seq_lens, dtype=np.int32),
|
|
|
|
_time_major=self.time_major,
|
|
|
|
_dont_check_lens=True)
|
2020-08-21 12:35:16 +02:00
|
|
|
else:
|
|
|
|
return SampleBatch(
|
|
|
|
{k: v[start:end]
|
2021-03-17 08:18:15 +01:00
|
|
|
for k, v in self.items()},
|
2020-08-21 12:35:16 +02:00
|
|
|
_seq_lens=None,
|
2021-04-16 09:16:24 +02:00
|
|
|
_is_training=self.is_training,
|
2020-08-21 12:35:16 +02:00
|
|
|
_time_major=self.time_major)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
@PublicAPI
|
|
|
|
def timeslices(self, k: int) -> List["SampleBatch"]:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Returns SampleBatches, each one representing a k-slice of this one.
|
|
|
|
|
|
|
|
Will start from timestep 0 and produce slices of size=k.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
k (int): The size (in timesteps) of each returned SampleBatch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[SampleBatch]: The list of (new) SampleBatches (each one of
|
|
|
|
size k).
|
|
|
|
"""
|
2021-02-25 12:18:11 +01:00
|
|
|
slices = self._get_slice_indices(k)
|
|
|
|
timeslices = [self.slice(i, j) for i, j in slices]
|
|
|
|
return timeslices
|
|
|
|
|
|
|
|
def zero_pad(self, max_seq_len: int, exclude_states: bool = True):
|
|
|
|
"""Left zero-pad the data in this SampleBatch in place.
|
|
|
|
|
|
|
|
This will set the `self.zero_padded` flag to True and
|
|
|
|
`self.max_seq_len` to the given `max_seq_len` value.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
max_len (int): The max (total) length to zero pad to.
|
|
|
|
exclude_states (bool): If False, also zero-pad all `state_in_x`
|
|
|
|
data. If False, leave `state_in_x` keys as-is.
|
|
|
|
"""
|
2021-03-17 08:18:15 +01:00
|
|
|
for col in self.keys():
|
2021-02-25 12:18:11 +01:00
|
|
|
# Skip state in columns.
|
|
|
|
if exclude_states is True and col.startswith("state_in_"):
|
|
|
|
continue
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
f = self[col]
|
2021-02-25 12:18:11 +01:00
|
|
|
# Save unnecessary copy.
|
|
|
|
if not isinstance(f, np.ndarray):
|
|
|
|
f = np.array(f)
|
|
|
|
# Already good length, can skip.
|
|
|
|
if f.shape[0] == max_seq_len:
|
|
|
|
continue
|
|
|
|
# Generate zero-filled primer of len=max_seq_len.
|
|
|
|
length = len(self.seq_lens) * max_seq_len
|
|
|
|
if f.dtype == np.object or f.dtype.type is np.str_:
|
|
|
|
f_pad = [None] * length
|
|
|
|
else:
|
|
|
|
# Make sure type doesn't change.
|
|
|
|
f_pad = np.zeros((length, ) + np.shape(f)[1:], dtype=f.dtype)
|
|
|
|
# Fill primer with data.
|
|
|
|
f_pad_base = f_base = 0
|
|
|
|
for len_ in self.seq_lens:
|
|
|
|
f_pad[f_pad_base:f_pad_base + len_] = f[f_base:f_base + len_]
|
|
|
|
f_pad_base += max_seq_len
|
|
|
|
f_base += len_
|
|
|
|
assert f_base == len(f), f
|
|
|
|
# Update our data.
|
2021-03-17 08:18:15 +01:00
|
|
|
self[col] = f_pad
|
2021-02-25 12:18:11 +01:00
|
|
|
|
|
|
|
# Set flags to indicate, we are now zero-padded (and to what extend).
|
|
|
|
self.zero_padded = True
|
|
|
|
self.max_seq_len = max_seq_len
|
2020-06-12 20:17:27 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def size_bytes(self) -> int:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
int: The overall size in bytes of the data buffer (all columns).
|
|
|
|
"""
|
2021-03-31 01:24:58 +08:00
|
|
|
return sum(
|
|
|
|
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
|
|
|
|
for v in self.values())
|
2020-06-12 20:17:27 -07:00
|
|
|
|
2021-03-29 20:07:44 +02:00
|
|
|
def get(self, key, default=None):
|
|
|
|
try:
|
|
|
|
return self.__getitem__(key)
|
|
|
|
except KeyError:
|
|
|
|
return default
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def __getitem__(self, key: str) -> TensorType:
|
|
|
|
"""Returns one column (by key) from the data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key (str): The key (column name) to return.
|
|
|
|
|
|
|
|
Returns:
|
2020-08-21 12:35:16 +02:00
|
|
|
TensorType: The data under the given key.
|
2020-07-05 13:09:51 +02:00
|
|
|
"""
|
2021-04-11 18:20:04 +02:00
|
|
|
if not hasattr(self, key):
|
|
|
|
self.accessed_keys.add(key)
|
2021-03-29 20:07:44 +02:00
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
# Backward compatibility for when "input-dicts" were used.
|
|
|
|
if key == "is_training":
|
|
|
|
if log_once("SampleBatch['is_training']"):
|
|
|
|
deprecation_warning(
|
|
|
|
old="SampleBatch['is_training']",
|
|
|
|
new="SampleBatch.is_training",
|
|
|
|
error=False)
|
|
|
|
return self.is_training
|
2021-03-29 20:07:44 +02:00
|
|
|
elif key == "seq_lens":
|
|
|
|
if self.get_interceptor is not None and self.seq_lens is not None:
|
|
|
|
if "seq_lens" not in self.intercepted_values:
|
|
|
|
self.intercepted_values["seq_lens"] = self.get_interceptor(
|
|
|
|
self.seq_lens)
|
|
|
|
return self.intercepted_values["seq_lens"]
|
|
|
|
return self.seq_lens
|
2021-03-17 08:18:15 +01:00
|
|
|
|
|
|
|
value = dict.__getitem__(self, key)
|
|
|
|
if self.get_interceptor is not None:
|
|
|
|
if key not in self.intercepted_values:
|
|
|
|
self.intercepted_values[key] = self.get_interceptor(value)
|
|
|
|
value = self.intercepted_values[key]
|
|
|
|
return value
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def __setitem__(self, key, item) -> None:
|
|
|
|
"""Inserts (overrides) an entire column (by key) in the data buffer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key (str): The column name to set a value for.
|
|
|
|
item (TensorType): The data to insert.
|
|
|
|
"""
|
2021-03-29 20:07:44 +02:00
|
|
|
if key == "seq_lens":
|
|
|
|
self.seq_lens = item
|
|
|
|
return
|
2021-03-17 08:18:15 +01:00
|
|
|
# Defend against creating SampleBatch via pickle (no property
|
|
|
|
# `added_keys` and first item is already set).
|
2021-03-29 20:07:44 +02:00
|
|
|
elif not hasattr(self, "added_keys"):
|
2021-03-17 08:18:15 +01:00
|
|
|
dict.__setitem__(self, key, item)
|
|
|
|
return
|
|
|
|
|
|
|
|
if key not in self:
|
|
|
|
self.added_keys.add(key)
|
|
|
|
|
|
|
|
dict.__setitem__(self, key, item)
|
|
|
|
if key in self.intercepted_values:
|
|
|
|
self.intercepted_values[key] = item
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def __delitem__(self, key):
|
|
|
|
self.deleted_keys.add(key)
|
|
|
|
dict.__delitem__(self, key)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def compress(self,
|
|
|
|
bulk: bool = False,
|
|
|
|
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Compresses the data buffers (by column) in place.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
bulk (bool): Whether to compress across the batch dimension (0)
|
|
|
|
as well. If False will compress n separate list items, where n
|
|
|
|
is the batch size.
|
|
|
|
columns (Set[str]): The columns to compress. Default: Only
|
|
|
|
compress the obs and new_obs columns.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
for key in columns:
|
2021-03-17 08:18:15 +01:00
|
|
|
if key in self.keys():
|
2019-05-20 16:46:05 -07:00
|
|
|
if bulk:
|
2021-03-17 08:18:15 +01:00
|
|
|
self[key] = pack(self[key])
|
2019-05-20 16:46:05 -07:00
|
|
|
else:
|
2021-03-17 08:18:15 +01:00
|
|
|
self[key] = np.array([pack(o) for o in self[key]])
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def decompress_if_needed(self,
|
|
|
|
columns: Set[str] = frozenset(
|
|
|
|
["obs", "new_obs"])) -> "SampleBatch":
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Decompresses data buffers (per column if not compressed) in place.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
columns (Set[str]): The columns to decompress. Default: Only
|
|
|
|
decompress the obs and new_obs columns.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatch: This very SampleBatch.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
for key in columns:
|
2021-03-17 08:18:15 +01:00
|
|
|
if key in self.keys():
|
|
|
|
arr = self[key]
|
2019-05-20 16:46:05 -07:00
|
|
|
if is_compressed(arr):
|
2021-03-17 08:18:15 +01:00
|
|
|
self[key] = unpack(arr)
|
2019-05-20 16:46:05 -07:00
|
|
|
elif len(arr) > 0 and is_compressed(arr[0]):
|
2021-03-17 08:18:15 +01:00
|
|
|
self[key] = np.array([unpack(o) for o in self[key]])
|
2020-05-11 20:24:43 -07:00
|
|
|
return self
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
@DeveloperAPI
|
|
|
|
def set_get_interceptor(self, fn):
|
|
|
|
self.get_interceptor = fn
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
def __repr__(self):
|
2021-03-17 08:18:15 +01:00
|
|
|
return "SampleBatch({})".format(list(self.keys()))
|
2019-08-12 17:39:02 -07:00
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
def _get_slice_indices(self, slice_size):
|
|
|
|
i = 0
|
|
|
|
slices = []
|
|
|
|
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
2021-04-22 19:21:03 +02:00
|
|
|
assert np.all(self.seq_lens < slice_size), \
|
|
|
|
"ERROR: `slice_size` must be larger than the max. seq-len " \
|
|
|
|
"in the batch!"
|
2021-02-25 12:18:11 +01:00
|
|
|
start_pos = 0
|
|
|
|
current_slize_size = 0
|
|
|
|
idx = 0
|
|
|
|
while idx < len(self.seq_lens):
|
|
|
|
seq_len = self.seq_lens[idx]
|
|
|
|
current_slize_size += seq_len
|
|
|
|
# Complete minibatch -> Append to slices.
|
|
|
|
if current_slize_size >= slice_size:
|
|
|
|
slices.append((start_pos, start_pos + slice_size))
|
|
|
|
start_pos += slice_size
|
|
|
|
if current_slize_size > slice_size:
|
|
|
|
overhead = current_slize_size - slice_size
|
|
|
|
start_pos -= (seq_len - overhead)
|
|
|
|
idx -= 1
|
|
|
|
current_slize_size = 0
|
|
|
|
idx += 1
|
|
|
|
else:
|
|
|
|
while i < self.count:
|
|
|
|
slices.append((i, i + slice_size))
|
|
|
|
i += slice_size
|
|
|
|
return slices
|
|
|
|
|
2021-03-17 08:18:15 +01:00
|
|
|
# TODO: deprecate
|
|
|
|
@property
|
|
|
|
def data(self):
|
|
|
|
if log_once("SampleBatch.data"):
|
|
|
|
deprecation_warning(
|
|
|
|
old="SampleBatch.data[..]", new="SampleBatch[..]", error=False)
|
|
|
|
return self
|
|
|
|
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class MultiAgentBatch:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""A batch of experiences from multiple agents in the environment.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
|
|
|
ids to SampleBatches of experiences.
|
|
|
|
count (int): The number of env steps in this batch.
|
|
|
|
"""
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def __init__(self, policy_batches: Dict[PolicyID, SampleBatch],
|
2020-06-12 20:17:27 -07:00
|
|
|
env_steps: int):
|
|
|
|
"""Initialize a MultiAgentBatch object.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
|
|
|
Args:
|
2020-06-12 20:17:27 -07:00
|
|
|
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
|
|
|
ids to SampleBatches of experiences.
|
2020-12-09 01:41:45 +01:00
|
|
|
env_steps (int): The number of environment steps in the environment
|
|
|
|
this batch contains. This will be less than the number of
|
2020-06-12 20:17:27 -07:00
|
|
|
transitions this batch contains across all policies in total.
|
2020-06-04 22:47:32 +02:00
|
|
|
"""
|
2020-07-05 13:09:51 +02:00
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
for v in policy_batches.values():
|
|
|
|
assert isinstance(v, SampleBatch)
|
2019-08-12 17:39:02 -07:00
|
|
|
self.policy_batches = policy_batches
|
2020-12-09 01:41:45 +01:00
|
|
|
# Called "count" for uniformity with SampleBatch.
|
|
|
|
# Prefer to access this via the `env_steps()` method when possible
|
|
|
|
# for clarity.
|
2020-06-12 20:17:27 -07:00
|
|
|
self.count = env_steps
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def env_steps(self) -> int:
|
|
|
|
"""The number of env steps (there are >= 1 agent steps per env step).
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
int: The number of environment steps contained in this batch.
|
2020-06-12 20:17:27 -07:00
|
|
|
"""
|
|
|
|
return self.count
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def agent_steps(self) -> int:
|
|
|
|
"""The number of agent steps (there are >= 1 agent steps per env step).
|
|
|
|
|
|
|
|
Returns:
|
2020-07-05 13:09:51 +02:00
|
|
|
int: The number of agent steps total in this batch.
|
2020-06-12 20:17:27 -07:00
|
|
|
"""
|
|
|
|
ct = 0
|
|
|
|
for batch in self.policy_batches.values():
|
|
|
|
ct += batch.count
|
|
|
|
return ct
|
|
|
|
|
|
|
|
@PublicAPI
|
|
|
|
def timeslices(self, k: int) -> List["MultiAgentBatch"]:
|
|
|
|
"""Returns k-step batches holding data for each agent at those steps.
|
|
|
|
|
|
|
|
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
|
|
|
|
for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
|
|
|
|
|
|
|
|
Calling timeslices(1) would return three MultiAgentBatches containing
|
|
|
|
[a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
|
|
|
|
|
|
|
|
Calling timeslices(2) would return two MultiAgentBatches containing
|
|
|
|
[a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
|
|
|
|
|
|
|
|
This method is used to implement "lockstep" replay mode. Note that this
|
|
|
|
method does not guarantee each batch contains only data from a single
|
|
|
|
unroll. Batches might contain data from multiple different envs.
|
|
|
|
"""
|
|
|
|
from ray.rllib.evaluation.sample_batch_builder import \
|
|
|
|
SampleBatchBuilder
|
|
|
|
|
|
|
|
# Build a sorted set of (eps_id, t, policy_id, data...)
|
|
|
|
steps = []
|
|
|
|
for policy_id, batch in self.policy_batches.items():
|
|
|
|
for row in batch.rows():
|
2020-07-20 08:03:12 +02:00
|
|
|
steps.append((row[SampleBatch.EPS_ID], row["t"],
|
|
|
|
row["agent_index"], policy_id, row))
|
2020-06-12 20:17:27 -07:00
|
|
|
steps.sort()
|
|
|
|
|
|
|
|
finished_slices = []
|
|
|
|
cur_slice = collections.defaultdict(SampleBatchBuilder)
|
|
|
|
cur_slice_size = 0
|
|
|
|
|
|
|
|
def finish_slice():
|
|
|
|
nonlocal cur_slice_size
|
|
|
|
assert cur_slice_size > 0
|
|
|
|
batch = MultiAgentBatch(
|
|
|
|
{k: v.build_and_reset()
|
|
|
|
for k, v in cur_slice.items()}, cur_slice_size)
|
|
|
|
cur_slice_size = 0
|
|
|
|
finished_slices.append(batch)
|
|
|
|
|
|
|
|
# For each unique env timestep.
|
|
|
|
for _, group in itertools.groupby(steps, lambda x: x[:2]):
|
|
|
|
# Accumulate into the current slice.
|
2020-07-20 08:03:12 +02:00
|
|
|
for _, _, _, policy_id, row in group:
|
2020-06-12 20:17:27 -07:00
|
|
|
cur_slice[policy_id].add_values(**row)
|
|
|
|
cur_slice_size += 1
|
|
|
|
# Slice has reached target number of env steps.
|
|
|
|
if cur_slice_size >= k:
|
|
|
|
finish_slice()
|
|
|
|
assert cur_slice_size == 0
|
|
|
|
|
|
|
|
if cur_slice_size > 0:
|
|
|
|
finish_slice()
|
|
|
|
|
|
|
|
assert len(finished_slices) > 0, finished_slices
|
|
|
|
return finished_slices
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
2020-07-05 13:09:51 +02:00
|
|
|
def wrap_as_needed(
|
|
|
|
policy_batches: Dict[PolicyID, SampleBatch],
|
|
|
|
env_steps: int) -> Union[SampleBatch, "MultiAgentBatch"]:
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Returns SampleBatch or MultiAgentBatch, depending on given policies.
|
|
|
|
|
|
|
|
Args:
|
2020-06-12 20:17:27 -07:00
|
|
|
policy_batches (Dict[PolicyID, SampleBatch]): Mapping from policy
|
|
|
|
ids to SampleBatch.
|
|
|
|
env_steps (int): Number of env steps in the batch.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
|
|
|
Returns:
|
2020-06-12 20:17:27 -07:00
|
|
|
Union[SampleBatch, MultiAgentBatch]: The single default policy's
|
2020-06-04 22:47:32 +02:00
|
|
|
SampleBatch or a MultiAgentBatch (more than one policy).
|
|
|
|
"""
|
2020-06-12 20:17:27 -07:00
|
|
|
if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
|
|
|
|
return policy_batches[DEFAULT_POLICY_ID]
|
2020-12-09 01:41:45 +01:00
|
|
|
return MultiAgentBatch(
|
|
|
|
policy_batches=policy_batches, env_steps=env_steps)
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@PublicAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
samples (List[MultiAgentBatch]): List of MultiagentBatch objects
|
|
|
|
to concatenate.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
MultiAgentBatch: A new MultiAgentBatch consisting of the
|
|
|
|
concatenated inputs.
|
|
|
|
"""
|
2019-08-12 17:39:02 -07:00
|
|
|
policy_batches = collections.defaultdict(list)
|
2020-06-12 20:17:27 -07:00
|
|
|
env_steps = 0
|
2019-08-12 17:39:02 -07:00
|
|
|
for s in samples:
|
2020-06-04 22:47:32 +02:00
|
|
|
if not isinstance(s, MultiAgentBatch):
|
|
|
|
raise ValueError(
|
|
|
|
"`MultiAgentBatch.concat_samples()` can only concat "
|
|
|
|
"MultiAgentBatch types, not {}!".format(type(s).__name__))
|
2020-06-12 20:17:27 -07:00
|
|
|
for key, batch in s.policy_batches.items():
|
|
|
|
policy_batches[key].append(batch)
|
|
|
|
env_steps += s.env_steps()
|
2019-08-12 17:39:02 -07:00
|
|
|
out = {}
|
2020-06-12 20:17:27 -07:00
|
|
|
for key, batches in policy_batches.items():
|
|
|
|
out[key] = SampleBatch.concat_samples(batches)
|
|
|
|
return MultiAgentBatch(out, env_steps)
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def copy(self) -> "MultiAgentBatch":
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Deep-copies self into a new MultiAgentBatch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
MultiAgentBatch: The copy of self with deep-copied data.
|
|
|
|
"""
|
2019-08-12 17:39:02 -07:00
|
|
|
return MultiAgentBatch(
|
|
|
|
{k: v.copy()
|
|
|
|
for (k, v) in self.policy_batches.items()}, self.count)
|
|
|
|
|
|
|
|
@PublicAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def size_bytes(self) -> int:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
int: The overall size in bytes of all policy batches (all columns).
|
|
|
|
"""
|
2020-06-12 20:17:27 -07:00
|
|
|
return sum(b.size_bytes() for b in self.policy_batches.values())
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def compress(self,
|
|
|
|
bulk: bool = False,
|
|
|
|
columns: Set[str] = frozenset(["obs", "new_obs"])) -> None:
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Compresses each policy batch (per column) in place.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
bulk (bool): Whether to compress across the batch dimension (0)
|
|
|
|
as well. If False will compress n separate list items, where n
|
|
|
|
is the batch size.
|
|
|
|
columns (Set[str]): Set of column names to compress.
|
|
|
|
"""
|
2019-08-12 17:39:02 -07:00
|
|
|
for batch in self.policy_batches.values():
|
|
|
|
batch.compress(bulk=bulk, columns=columns)
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-08-07 16:49:49 -07:00
|
|
|
def decompress_if_needed(self,
|
|
|
|
columns: Set[str] = frozenset(
|
|
|
|
["obs", "new_obs"])) -> "MultiAgentBatch":
|
2020-07-05 13:09:51 +02:00
|
|
|
"""Decompresses each policy batch (per column), if already compressed.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
columns (Set[str]): Set of column names to decompress.
|
2020-07-05 13:09:51 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
MultiAgentBatch: This very MultiAgentBatch.
|
2020-06-04 22:47:32 +02:00
|
|
|
"""
|
2019-08-12 17:39:02 -07:00
|
|
|
for batch in self.policy_batches.values():
|
|
|
|
batch.decompress_if_needed(columns)
|
2020-05-11 20:24:43 -07:00
|
|
|
return self
|
2019-08-12 17:39:02 -07:00
|
|
|
|
|
|
|
def __str__(self):
|
2020-06-12 20:17:27 -07:00
|
|
|
return "MultiAgentBatch({}, env_steps={})".format(
|
2019-08-12 17:39:02 -07:00
|
|
|
str(self.policy_batches), self.count)
|
|
|
|
|
|
|
|
def __repr__(self):
|
2020-06-12 20:17:27 -07:00
|
|
|
return "MultiAgentBatch({}, env_steps={})".format(
|
2019-08-12 17:39:02 -07:00
|
|
|
str(self.policy_batches), self.count)
|