2020-05-21 10:16:18 -07:00
|
|
|
import collections
|
2020-06-12 20:17:27 -07:00
|
|
|
import logging
|
|
|
|
import numpy as np
|
2020-06-08 21:29:46 -07:00
|
|
|
import platform
|
2020-06-12 20:17:27 -07:00
|
|
|
import random
|
2020-12-24 06:30:33 -08:00
|
|
|
from typing import List, Dict
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-09-23 15:46:06 -07:00
|
|
|
# Import ray before psutil will make sure we use psutil's bundled version
|
|
|
|
import ray # noqa F401
|
|
|
|
import psutil # noqa E402
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
2021-02-25 12:18:11 +01:00
|
|
|
from ray.rllib.policy.rnn_sequencing import \
|
|
|
|
timeslice_along_seq_lens_with_overlap
|
2020-06-12 20:17:27 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
|
|
|
|
DEFAULT_POLICY_ID
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
|
|
from ray.util.iter import ParallelIteratorWorker
|
2020-09-23 15:46:06 -07:00
|
|
|
from ray.util.debug import log_once
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.utils.timer import TimerStat
|
|
|
|
from ray.rllib.utils.window_stat import WindowStat
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import SampleBatchType
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
# Constant that represents all policies in lockstep replay mode.
|
|
|
|
_ALL_POLICIES = "__all__"
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-09-23 15:46:06 -07:00
|
|
|
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
|
|
|
|
"""Warn if the configured replay buffer size is too large."""
|
|
|
|
if log_once("replay_buffer_size"):
|
|
|
|
item_size = item.size_bytes()
|
|
|
|
psutil_mem = psutil.virtual_memory()
|
|
|
|
total_gb = psutil_mem.total / 1e9
|
|
|
|
mem_size = num_items * item_size / 1e9
|
|
|
|
msg = ("Estimated max memory usage for replay buffer is {} GB "
|
2020-09-25 13:47:23 -07:00
|
|
|
"({} batches of size {}, {} bytes each), "
|
2020-09-23 15:46:06 -07:00
|
|
|
"available system memory is {} GB".format(
|
2020-09-25 13:47:23 -07:00
|
|
|
mem_size, num_items, item.count, item_size, total_gb))
|
2020-09-23 15:46:06 -07:00
|
|
|
if mem_size > total_gb:
|
|
|
|
raise ValueError(msg)
|
|
|
|
elif mem_size > 0.2 * total_gb:
|
|
|
|
logger.warning(msg)
|
|
|
|
else:
|
|
|
|
logger.info(msg)
|
|
|
|
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
@DeveloperAPI
|
|
|
|
class ReplayBuffer:
|
|
|
|
@DeveloperAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def __init__(self, size: int):
|
2020-05-21 10:16:18 -07:00
|
|
|
"""Create Prioritized Replay buffer.
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
Args:
|
2020-09-25 13:47:23 -07:00
|
|
|
size (int): Max number of timesteps to store in the FIFO buffer.
|
2020-05-21 10:16:18 -07:00
|
|
|
"""
|
|
|
|
self._storage = []
|
|
|
|
self._maxsize = size
|
|
|
|
self._next_idx = 0
|
|
|
|
self._hit_count = np.zeros(size)
|
|
|
|
self._eviction_started = False
|
2020-09-25 13:47:23 -07:00
|
|
|
self._num_timesteps_added = 0
|
|
|
|
self._num_timesteps_added_wrap = 0
|
|
|
|
self._num_timesteps_sampled = 0
|
2020-05-21 10:16:18 -07:00
|
|
|
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
|
|
|
|
self._est_size_bytes = 0
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def __len__(self) -> int:
|
2020-05-21 10:16:18 -07:00
|
|
|
return len(self._storage)
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-12-24 06:30:33 -08:00
|
|
|
def add(self, item: SampleBatchType, weight: float) -> None:
|
2020-09-25 13:47:23 -07:00
|
|
|
warn_replay_buffer_size(
|
|
|
|
item=item, num_items=self._maxsize / item.count)
|
2020-06-12 20:17:27 -07:00
|
|
|
assert item.count > 0, item
|
2021-02-25 12:18:11 +01:00
|
|
|
|
2020-09-25 13:47:23 -07:00
|
|
|
self._num_timesteps_added += item.count
|
|
|
|
self._num_timesteps_added_wrap += item.count
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
if self._next_idx >= len(self._storage):
|
2020-06-12 20:17:27 -07:00
|
|
|
self._storage.append(item)
|
|
|
|
self._est_size_bytes += item.size_bytes()
|
2020-05-21 10:16:18 -07:00
|
|
|
else:
|
2020-06-12 20:17:27 -07:00
|
|
|
self._storage[self._next_idx] = item
|
2020-09-25 13:47:23 -07:00
|
|
|
|
|
|
|
# Wrap around storage as a circular buffer once we hit maxsize.
|
|
|
|
if self._num_timesteps_added_wrap >= self._maxsize:
|
2020-05-21 10:16:18 -07:00
|
|
|
self._eviction_started = True
|
2020-09-25 13:47:23 -07:00
|
|
|
self._num_timesteps_added_wrap = 0
|
|
|
|
self._next_idx = 0
|
|
|
|
else:
|
|
|
|
self._next_idx += 1
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
if self._eviction_started:
|
|
|
|
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
|
|
|
|
self._hit_count[self._next_idx] = 0
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
|
|
|
|
out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
|
|
|
|
out.decompress_if_needed()
|
|
|
|
return out
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def sample(self, num_items: int) -> SampleBatchType:
|
2020-05-21 10:16:18 -07:00
|
|
|
"""Sample a batch of experiences.
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
Args:
|
|
|
|
num_items (int): Number of items to sample from this buffer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatchType: concatenated batch of items.
|
2020-05-21 10:16:18 -07:00
|
|
|
"""
|
|
|
|
idxes = [
|
|
|
|
random.randint(0,
|
2020-06-12 20:17:27 -07:00
|
|
|
len(self._storage) - 1) for _ in range(num_items)
|
2020-05-21 10:16:18 -07:00
|
|
|
]
|
2020-06-12 20:17:27 -07:00
|
|
|
self._num_sampled += num_items
|
2020-05-21 10:16:18 -07:00
|
|
|
return self._encode_sample(idxes)
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-12-24 06:30:33 -08:00
|
|
|
def stats(self, debug=False) -> dict:
|
2020-05-21 10:16:18 -07:00
|
|
|
data = {
|
2020-09-25 13:47:23 -07:00
|
|
|
"added_count": self._num_timesteps_added,
|
|
|
|
"sampled_count": self._num_timesteps_sampled,
|
2020-05-21 10:16:18 -07:00
|
|
|
"est_size_bytes": self._est_size_bytes,
|
|
|
|
"num_entries": len(self._storage),
|
|
|
|
}
|
|
|
|
if debug:
|
|
|
|
data.update(self._evicted_hit_stats.stats())
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
|
|
|
@DeveloperAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def __init__(self, size: int, alpha: float):
|
2020-05-21 10:16:18 -07:00
|
|
|
"""Create Prioritized Replay buffer.
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
Args:
|
|
|
|
size (int): Max number of items to store in the FIFO buffer.
|
|
|
|
alpha (float): how much prioritization is used
|
|
|
|
(0 - no prioritization, 1 - full prioritization).
|
|
|
|
|
|
|
|
See also:
|
|
|
|
ReplayBuffer.__init__()
|
2020-05-21 10:16:18 -07:00
|
|
|
"""
|
|
|
|
super(PrioritizedReplayBuffer, self).__init__(size)
|
|
|
|
assert alpha > 0
|
|
|
|
self._alpha = alpha
|
|
|
|
|
|
|
|
it_capacity = 1
|
|
|
|
while it_capacity < size:
|
|
|
|
it_capacity *= 2
|
|
|
|
|
|
|
|
self._it_sum = SumSegmentTree(it_capacity)
|
|
|
|
self._it_min = MinSegmentTree(it_capacity)
|
|
|
|
self._max_priority = 1.0
|
|
|
|
self._prio_change_stats = WindowStat("reprio", 1000)
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-12-24 06:30:33 -08:00
|
|
|
def add(self, item: SampleBatchType, weight: float) -> None:
|
2020-05-21 10:16:18 -07:00
|
|
|
idx = self._next_idx
|
2020-06-12 20:17:27 -07:00
|
|
|
super(PrioritizedReplayBuffer, self).add(item, weight)
|
2020-05-21 10:16:18 -07:00
|
|
|
if weight is None:
|
|
|
|
weight = self._max_priority
|
|
|
|
self._it_sum[idx] = weight**self._alpha
|
|
|
|
self._it_min[idx] = weight**self._alpha
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def _sample_proportional(self, num_items: int) -> List[int]:
|
2020-05-21 10:16:18 -07:00
|
|
|
res = []
|
2020-06-12 20:17:27 -07:00
|
|
|
for _ in range(num_items):
|
2020-05-21 10:16:18 -07:00
|
|
|
# TODO(szymon): should we ensure no repeats?
|
|
|
|
mass = random.random() * self._it_sum.sum(0, len(self._storage))
|
|
|
|
idx = self._it_sum.find_prefixsum_idx(mass)
|
|
|
|
res.append(idx)
|
|
|
|
return res
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-06-12 20:17:27 -07:00
|
|
|
def sample(self, num_items: int, beta: float) -> SampleBatchType:
|
|
|
|
"""Sample a batch of experiences and return priority weights, indices.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_items (int): Number of items to sample from this buffer.
|
|
|
|
beta (float): To what degree to use importance weights
|
|
|
|
(0 - no corrections, 1 - full correction).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
SampleBatchType: Concatenated batch of items including "weights"
|
|
|
|
and "batch_indexes" fields denoting IS of each sampled
|
|
|
|
transition and original idxes in buffer of sampled experiences.
|
2020-05-21 10:16:18 -07:00
|
|
|
"""
|
|
|
|
assert beta >= 0.0
|
|
|
|
|
2020-06-12 20:17:27 -07:00
|
|
|
idxes = self._sample_proportional(num_items)
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
weights = []
|
2020-06-12 20:17:27 -07:00
|
|
|
batch_indexes = []
|
2020-05-21 10:16:18 -07:00
|
|
|
p_min = self._it_min.min() / self._it_sum.sum()
|
|
|
|
max_weight = (p_min * len(self._storage))**(-beta)
|
|
|
|
|
|
|
|
for idx in idxes:
|
|
|
|
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
|
|
|
weight = (p_sample * len(self._storage))**(-beta)
|
2020-06-12 20:17:27 -07:00
|
|
|
count = self._storage[idx].count
|
2021-02-25 12:18:11 +01:00
|
|
|
# If zero-padded, count will not be the actual batch size of the
|
|
|
|
# data.
|
|
|
|
if isinstance(self._storage[idx], SampleBatch) and \
|
|
|
|
self._storage[idx].zero_padded:
|
|
|
|
actual_size = self._storage[idx].max_seq_len
|
|
|
|
else:
|
|
|
|
actual_size = count
|
|
|
|
weights.extend([weight / max_weight] * actual_size)
|
|
|
|
batch_indexes.extend([idx] * actual_size)
|
2020-09-25 13:47:23 -07:00
|
|
|
self._num_timesteps_sampled += count
|
2020-06-12 20:17:27 -07:00
|
|
|
batch = self._encode_sample(idxes)
|
|
|
|
|
|
|
|
# Note: prioritization is not supported in lockstep replay mode.
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
|
|
batch["weights"] = np.array(weights)
|
|
|
|
batch["batch_indexes"] = np.array(batch_indexes)
|
|
|
|
|
|
|
|
return batch
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-12-24 06:30:33 -08:00
|
|
|
def update_priorities(self, idxes: List[int],
|
|
|
|
priorities: List[float]) -> None:
|
2020-05-21 10:16:18 -07:00
|
|
|
"""Update priorities of sampled transitions.
|
|
|
|
|
|
|
|
sets priority of transition at index idxes[i] in buffer
|
|
|
|
to priorities[i].
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
idxes: [int]
|
|
|
|
List of idxes of sampled transitions
|
|
|
|
priorities: [float]
|
|
|
|
List of updated priorities corresponding to
|
|
|
|
transitions at the sampled idxes denoted by
|
|
|
|
variable `idxes`.
|
|
|
|
"""
|
|
|
|
assert len(idxes) == len(priorities)
|
|
|
|
for idx, priority in zip(idxes, priorities):
|
|
|
|
assert priority > 0
|
|
|
|
assert 0 <= idx < len(self._storage)
|
|
|
|
delta = priority**self._alpha - self._it_sum[idx]
|
|
|
|
self._prio_change_stats.push(delta)
|
|
|
|
self._it_sum[idx] = priority**self._alpha
|
|
|
|
self._it_min[idx] = priority**self._alpha
|
|
|
|
|
|
|
|
self._max_priority = max(self._max_priority, priority)
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-12-24 06:30:33 -08:00
|
|
|
def stats(self, debug: bool = False) -> Dict:
|
2020-05-21 10:16:18 -07:00
|
|
|
parent = ReplayBuffer.stats(self, debug)
|
|
|
|
if debug:
|
|
|
|
parent.update(self._prio_change_stats.stats())
|
|
|
|
return parent
|
|
|
|
|
|
|
|
|
|
|
|
# Visible for testing.
|
|
|
|
_local_replay_buffer = None
|
|
|
|
|
|
|
|
|
|
|
|
class LocalReplayBuffer(ParallelIteratorWorker):
|
2021-02-25 12:18:11 +01:00
|
|
|
"""A replay buffer shard storing data for all policies (in multiagent setup).
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2021-02-25 12:18:11 +01:00
|
|
|
Ray actors are single-threaded, so for scalability, multiple replay actors
|
2020-05-21 10:16:18 -07:00
|
|
|
may be created to increase parallelism."""
|
|
|
|
|
|
|
|
def __init__(self,
|
2020-12-24 06:30:33 -08:00
|
|
|
num_shards: int = 1,
|
|
|
|
learning_starts: int = 1000,
|
|
|
|
buffer_size: int = 10000,
|
|
|
|
replay_batch_size: int = 1,
|
|
|
|
prioritized_replay_alpha: float = 0.6,
|
|
|
|
prioritized_replay_beta: float = 0.4,
|
|
|
|
prioritized_replay_eps: float = 1e-6,
|
|
|
|
replay_mode: str = "independent",
|
2021-02-25 12:18:11 +01:00
|
|
|
replay_sequence_length: int = 1,
|
|
|
|
replay_burn_in: int = 0,
|
|
|
|
replay_zero_init_states: bool = True):
|
|
|
|
"""Initializes a LocalReplayBuffer instance.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_shards (int): The number of buffer shards that exist in total
|
|
|
|
(including this one).
|
|
|
|
learning_starts (int): Number of timesteps after which a call to
|
|
|
|
`replay()` will yield samples (before that, `replay()` will
|
|
|
|
return None).
|
|
|
|
buffer_size (int): The size of the buffer. Note that when
|
|
|
|
`replay_sequence_length` > 1, this is the number of sequences
|
|
|
|
(not single timesteps) stored.
|
|
|
|
replay_batch_size (int): The batch size to be sampled (in
|
|
|
|
timesteps). Note that if `replay_sequence_length` > 1,
|
|
|
|
`self.replay_batch_size` will be set to the number of
|
|
|
|
sequences sampled (B).
|
|
|
|
prioritized_replay_alpha (float): Alpha parameter for a prioritized
|
|
|
|
replay buffer.
|
|
|
|
prioritized_replay_beta (float): Beta parameter for a prioritized
|
|
|
|
replay buffer.
|
|
|
|
prioritized_replay_eps (float): Epsilon parameter for a prioritized
|
|
|
|
replay buffer.
|
|
|
|
replay_mode (str): One of "independent" or "lockstep". Determined,
|
|
|
|
whether in the multiagent case, sampling is done across all
|
|
|
|
agents/policies equally.
|
|
|
|
replay_sequence_length (int): The sequence length (T) of a single
|
|
|
|
sample. If > 1, we will sample B x T from this buffer.
|
|
|
|
replay_burn_in (int): The burn-in length in case
|
|
|
|
`replay_sequence_length` > 0. This is the number of timesteps
|
|
|
|
each sequence overlaps with the previous one to generate a
|
|
|
|
better internal state (=state after the burn-in), instead of
|
|
|
|
starting from 0.0 each RNN rollout.
|
|
|
|
replay_zero_init_states (bool): Whether the initial states in the
|
|
|
|
buffer (if replay_sequence_length > 0) are alwayas 0.0 or
|
|
|
|
should be updated with the previous train_batch state outputs.
|
|
|
|
"""
|
2020-05-21 10:16:18 -07:00
|
|
|
self.replay_starts = learning_starts // num_shards
|
|
|
|
self.buffer_size = buffer_size // num_shards
|
|
|
|
self.replay_batch_size = replay_batch_size
|
|
|
|
self.prioritized_replay_beta = prioritized_replay_beta
|
|
|
|
self.prioritized_replay_eps = prioritized_replay_eps
|
2020-06-12 20:17:27 -07:00
|
|
|
self.replay_mode = replay_mode
|
|
|
|
self.replay_sequence_length = replay_sequence_length
|
2021-02-25 12:18:11 +01:00
|
|
|
self.replay_burn_in = replay_burn_in
|
|
|
|
self.replay_zero_init_states = replay_zero_init_states
|
2020-06-12 20:17:27 -07:00
|
|
|
|
|
|
|
if replay_sequence_length > 1:
|
|
|
|
self.replay_batch_size = int(
|
|
|
|
max(1, replay_batch_size // replay_sequence_length))
|
|
|
|
logger.info(
|
|
|
|
"Since replay_sequence_length={} and replay_batch_size={}, "
|
|
|
|
"we will replay {} sequences at a time.".format(
|
|
|
|
replay_sequence_length, replay_batch_size,
|
|
|
|
self.replay_batch_size))
|
|
|
|
|
|
|
|
if replay_mode not in ["lockstep", "independent"]:
|
|
|
|
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
def gen_replay():
|
|
|
|
while True:
|
|
|
|
yield self.replay()
|
|
|
|
|
|
|
|
ParallelIteratorWorker.__init__(self, gen_replay, False)
|
|
|
|
|
|
|
|
def new_buffer():
|
|
|
|
return PrioritizedReplayBuffer(
|
|
|
|
self.buffer_size, alpha=prioritized_replay_alpha)
|
|
|
|
|
|
|
|
self.replay_buffers = collections.defaultdict(new_buffer)
|
|
|
|
|
|
|
|
# Metrics
|
|
|
|
self.add_batch_timer = TimerStat()
|
|
|
|
self.replay_timer = TimerStat()
|
|
|
|
self.update_priorities_timer = TimerStat()
|
|
|
|
self.num_added = 0
|
|
|
|
|
|
|
|
# Make externally accessible for testing.
|
|
|
|
global _local_replay_buffer
|
|
|
|
_local_replay_buffer = self
|
|
|
|
# If set, return this instead of the usual data for testing.
|
|
|
|
self._fake_batch = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_instance_for_testing():
|
|
|
|
global _local_replay_buffer
|
|
|
|
return _local_replay_buffer
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def get_host(self) -> str:
|
2020-06-08 21:29:46 -07:00
|
|
|
return platform.node()
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def add_batch(self, batch: SampleBatchType) -> None:
|
2020-05-21 10:16:18 -07:00
|
|
|
# Make a copy so the replay buffer doesn't pin plasma memory.
|
|
|
|
batch = batch.copy()
|
|
|
|
# Handle everything as if multiagent
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
|
|
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
|
|
|
with self.add_batch_timer:
|
2021-02-25 12:18:11 +01:00
|
|
|
# Lockstep mode: Store under _ALL_POLICIES key (we will always
|
|
|
|
# only sample from all policies at the same time).
|
2020-06-12 20:17:27 -07:00
|
|
|
if self.replay_mode == "lockstep":
|
|
|
|
# Note that prioritization is not supported in this mode.
|
|
|
|
for s in batch.timeslices(self.replay_sequence_length):
|
|
|
|
self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
|
|
|
|
else:
|
2021-02-25 12:18:11 +01:00
|
|
|
for policy_id, sample_batch in batch.policy_batches.items():
|
|
|
|
if self.replay_sequence_length == 1:
|
|
|
|
timeslices = sample_batch.timeslices(1)
|
|
|
|
else:
|
|
|
|
timeslices = timeslice_along_seq_lens_with_overlap(
|
|
|
|
sample_batch=sample_batch,
|
|
|
|
zero_pad_max_seq_len=self.replay_sequence_length,
|
|
|
|
pre_overlap=self.replay_burn_in,
|
|
|
|
zero_init_states=self.replay_zero_init_states,
|
|
|
|
)
|
|
|
|
for time_slice in timeslices:
|
|
|
|
# If SampleBatch has prio-replay weights, average
|
|
|
|
# over these to use as a weight for the entire
|
|
|
|
# sequence.
|
|
|
|
if "weights" in time_slice:
|
|
|
|
weight = np.mean(time_slice["weights"])
|
2020-06-12 20:17:27 -07:00
|
|
|
else:
|
|
|
|
weight = None
|
2021-02-25 12:18:11 +01:00
|
|
|
self.replay_buffers[policy_id].add(
|
|
|
|
time_slice, weight=weight)
|
2020-05-21 10:16:18 -07:00
|
|
|
self.num_added += batch.count
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def replay(self) -> SampleBatchType:
|
2020-05-21 10:16:18 -07:00
|
|
|
if self._fake_batch:
|
|
|
|
fake_batch = SampleBatch(self._fake_batch)
|
|
|
|
return MultiAgentBatch({
|
|
|
|
DEFAULT_POLICY_ID: fake_batch
|
|
|
|
}, fake_batch.count)
|
|
|
|
|
|
|
|
if self.num_added < self.replay_starts:
|
|
|
|
return None
|
|
|
|
|
|
|
|
with self.replay_timer:
|
2021-02-25 12:18:11 +01:00
|
|
|
# Lockstep mode: Sample from all policies at the same time an
|
|
|
|
# equal amount of steps.
|
2020-06-12 20:17:27 -07:00
|
|
|
if self.replay_mode == "lockstep":
|
|
|
|
return self.replay_buffers[_ALL_POLICIES].sample(
|
|
|
|
self.replay_batch_size, beta=self.prioritized_replay_beta)
|
|
|
|
else:
|
|
|
|
samples = {}
|
|
|
|
for policy_id, replay_buffer in self.replay_buffers.items():
|
|
|
|
samples[policy_id] = replay_buffer.sample(
|
|
|
|
self.replay_batch_size,
|
|
|
|
beta=self.prioritized_replay_beta)
|
|
|
|
return MultiAgentBatch(samples, self.replay_batch_size)
|
2020-05-21 10:16:18 -07:00
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def update_priorities(self, prio_dict: Dict) -> None:
|
2020-05-21 10:16:18 -07:00
|
|
|
with self.update_priorities_timer:
|
|
|
|
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
|
|
|
|
new_priorities = (
|
|
|
|
np.abs(td_errors) + self.prioritized_replay_eps)
|
|
|
|
self.replay_buffers[policy_id].update_priorities(
|
|
|
|
batch_indexes, new_priorities)
|
|
|
|
|
2020-12-24 06:30:33 -08:00
|
|
|
def stats(self, debug: bool = False) -> Dict:
|
2020-05-21 10:16:18 -07:00
|
|
|
stat = {
|
|
|
|
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
|
|
|
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
|
|
|
"update_priorities_time_ms": round(
|
|
|
|
1000 * self.update_priorities_timer.mean, 3),
|
|
|
|
}
|
|
|
|
for policy_id, replay_buffer in self.replay_buffers.items():
|
|
|
|
stat.update({
|
|
|
|
"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)
|
|
|
|
})
|
|
|
|
return stat
|
|
|
|
|
|
|
|
|
|
|
|
ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)
|