ray/rllib/utils/replay_buffers/replay_buffer.py
2022-06-10 16:47:51 +02:00

445 lines
17 KiB
Python

import logging
import platform
from typing import Any, Dict, List, Optional, Callable, Union
import numpy as np
import random
from enum import Enum
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType, T
from ray.util.annotations import DeveloperAPI
from ray.util.iter import ParallelIteratorWorker
# Constant that represents all policies in lockstep replay mode.
_ALL_POLICIES = "__all__"
logger = logging.getLogger(__name__)
@DeveloperAPI
class StorageUnit(Enum):
"""Specifies how batches are structured in a ReplayBuffer.
timesteps: One buffer slot per timestep.
sequences: One buffer slot per sequence.
episodes: One buffer slot per episode.
fragemts: One buffer slot per incoming batch.
"""
TIMESTEPS = "timesteps"
SEQUENCES = "sequences"
EPISODES = "episodes"
FRAGMENTS = "fragments"
@DeveloperAPI
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
"""Warn if the configured replay buffer capacity is too large."""
if log_once("replay_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
# TODO (artur): Remove ParallelIteratorWorker once we no longer support executionplans
@DeveloperAPI
class ReplayBuffer(ParallelIteratorWorker):
"""The lowest-level replay buffer interface used by RLlib.
This class implements a basic ring-type of buffer with random sampling.
ReplayBuffer is the base class for advanced types that add functionality while
retaining compatibility through inheritance.
The following examples show how buffers behave with different storage_units
and capacities. This behaviour is generally similar for other buffers, although
they might not implement all storage_units.
Examples:
>>> from ray.rllib.utils.replay_buffers import ReplayBuffer, # doctest: +SKIP
... StorageUnit # doctest: +SKIP
>>> from ray.rllib.policy.sample_batch import SampleBatch # doctest: +SKIP
>>> # Store any batch as a whole
>>> buffer = ReplayBuffer(capacity=10,
... storage_unit=StorageUnit.FRAGMENTS) # doctest: +SKIP
>>> buffer.add(SampleBatch({"a": [1], "b": [2, 3, 4]})) # doctest: +SKIP
>>> print(buffer.sample(1)) # doctest: +SKIP
>>> # SampleBatch(1: ['a', 'b'])
>>> # Store only complete episodes
>>> buffer = ReplayBuffer(capacity=10,
... storage_unit=StorageUnit.EPISODES) # doctest: +SKIP
>>> buffer.add(SampleBatch({"c": [1, 2, 3, 4], # doctest: +SKIP
... SampleBatch.T: [0, 1, 0, 1],
... SampleBatch.DONES: [False, True, False, True],
... SampleBatch.EPS_ID: [0, 0, 1, 1]})) # doctest: +SKIP
>>> eps_n = buffer.sample(1) # doctest: +SKIP
>>> print(eps_n[SampleBatch.EPS_ID]) # doctest: +SKIP
>>> # [1 1]
>>> # Store single timesteps
>>> buffer = ReplayBuffer(capacity=2, # doctest: +SKIP
... storage_unit=StorageUnit.TIMESTEPS) # doctest: +SKIP
>>> buffer.add(SampleBatch({"a": [1, 2],
... SampleBatch.T: [0, 1]})) # doctest: +SKIP
>>> t_n = buffer.sample(1) # doctest: +SKIP
>>> print(t_n["a"]) # doctest: +SKIP
>>> # [2]
>>> buffer.add(SampleBatch({"a": [3], SampleBatch.T: [2]})) # doctest: +SKIP
>>> print(buffer._eviction_started) # doctest: +SKIP
>>> # True
>>> t_n = buffer.sample(1) # doctest: +SKIP
>>> print(t_n["a"]) # doctest: +SKIP
>>> # [3] # doctest: +SKIP
>>> buffer = ReplayBuffer(capacity=10, # doctest: +SKIP
... storage_unit=StorageUnit.SEQUENCES) # doctest: +SKIP
>>> buffer.add(SampleBatch({"c": [1, 2, 3], # doctest: +SKIP
... SampleBatch.SEQ_LENS: [1, 2]})) # doctest: +SKIP
>>> seq_n = buffer.sample(1) # doctest: +SKIP
>>> print(seq_n["c"]) # doctest: +SKIP
>>> # [1]
"""
def __init__(
self,
capacity: int = 10000,
storage_unit: Union[str, StorageUnit] = "timesteps",
**kwargs,
):
"""Initializes a (FIFO) ReplayBuffer instance.
Args:
capacity: Max number of timesteps to store in this FIFO
buffer. After reaching this number, older samples will be
dropped to make space for new ones.
storage_unit: If not a StorageUnit, either 'timesteps', 'sequences' or
'episodes'. Specifies how experiences are stored.
``**kwargs``: Forward compatibility kwargs.
"""
if storage_unit in ["timesteps", StorageUnit.TIMESTEPS]:
self.storage_unit = StorageUnit.TIMESTEPS
elif storage_unit in ["sequences", StorageUnit.SEQUENCES]:
self.storage_unit = StorageUnit.SEQUENCES
elif storage_unit in ["episodes", StorageUnit.EPISODES]:
self.storage_unit = StorageUnit.EPISODES
elif storage_unit in ["fragments", StorageUnit.FRAGMENTS]:
self.storage_unit = StorageUnit.FRAGMENTS
else:
raise ValueError(
"storage_unit must be either 'timesteps', 'sequences' or 'episodes' "
"or 'fragments', but is {}".format(storage_unit)
)
# The actual storage (list of SampleBatches or MultiAgentBatches).
self._storage = []
# Caps the number of timesteps stored in this buffer
if capacity <= 0:
raise ValueError(
"Capacity of replay buffer has to be greater than zero "
"but was set to {}.".format(capacity)
)
self.capacity = capacity
# The next index to override in the buffer.
self._next_idx = 0
# len(self._hit_count) must always be less than len(capacity)
self._hit_count = np.zeros(self.capacity)
# Whether we have already hit our capacity (and have therefore
# started to evict older samples).
self._eviction_started = False
# Number of (single) timesteps that have been added to the buffer
# over its lifetime. Note that each added item (batch) may contain
# more than one timestep.
self._num_timesteps_added = 0
self._num_timesteps_added_wrap = 0
# Number of (single) timesteps that have been sampled from the buffer
# over its lifetime.
self._num_timesteps_sampled = 0
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
self._est_size_bytes = 0
self.batch_size = None
def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@DeveloperAPI
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences to this buffer.
Splits batch into chunks of timesteps, sequences or episodes, depending on
`self._storage_unit`. Calls `self._add_single_batch` to add resulting slices
to the buffer storage.
Args:
batch: Batch to add.
``**kwargs``: Forward compatibility kwargs.
"""
if not batch.count > 0:
return
warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)
if self.storage_unit == StorageUnit.TIMESTEPS:
timeslices = batch.timeslices(1)
for t in timeslices:
self._add_single_batch(t, **kwargs)
elif self.storage_unit == StorageUnit.SEQUENCES:
timestep_count = 0
for seq_len in batch.get(SampleBatch.SEQ_LENS):
start_seq = timestep_count
end_seq = timestep_count + seq_len
self._add_single_batch(batch[start_seq:end_seq], **kwargs)
timestep_count = end_seq
elif self.storage_unit == StorageUnit.EPISODES:
for eps in batch.split_by_episode():
if (
eps.get(SampleBatch.T)[0] == 0
and eps.get(SampleBatch.DONES)[-1] == True # noqa E712
):
# Only add full episodes to the buffer
self._add_single_batch(eps, **kwargs)
else:
if log_once("only_full_episodes"):
logger.info(
"This buffer uses episodes as a storage "
"unit and thus allows only full episodes "
"to be added to it. Some samples may be "
"dropped."
)
elif self.storage_unit == StorageUnit.FRAGMENTS:
self._add_single_batch(batch, **kwargs)
@DeveloperAPI
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
"""Add a SampleBatch of experiences to self._storage.
An item consists of either one or more timesteps, a sequence or an
episode. Differs from add() in that it does not consider the storage
unit or type of batch and simply stores it.
Args:
item: The batch to be added.
``**kwargs``: Forward compatibility kwargs.
"""
self._num_timesteps_added += item.count
self._num_timesteps_added_wrap += item.count
if self._next_idx >= len(self._storage):
self._storage.append(item)
self._est_size_bytes += item.size_bytes()
else:
item_to_be_removed = self._storage[self._next_idx]
self._est_size_bytes -= item_to_be_removed.size_bytes()
self._storage[self._next_idx] = item
self._est_size_bytes += item.size_bytes()
# Eviction of older samples has already started (buffer is "full").
if self._eviction_started:
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
self._hit_count[self._next_idx] = 0
# Wrap around storage as a circular buffer once we hit capacity.
if self._num_timesteps_added_wrap >= self.capacity:
self._eviction_started = True
self._num_timesteps_added_wrap = 0
self._next_idx = 0
else:
self._next_idx += 1
@DeveloperAPI
def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]:
"""Samples `num_items` items from this buffer.
The items depend on the buffer's storage_unit.
Samples in the results may be repeated.
Examples for sampling results:
1) If storage unit 'timesteps' has been chosen and batches of
size 5 have been added, sample(5) will yield a concatenated batch of
15 timesteps.
2) If storage unit 'sequences' has been chosen and sequences of
different lengths have been added, sample(5) will yield a concatenated
batch with a number of timesteps equal to the sum of timesteps in
the 5 sampled sequences.
3) If storage unit 'episodes' has been chosen and episodes of
different lengths have been added, sample(5) will yield a concatenated
batch with a number of timesteps equal to the sum of timesteps in
the 5 sampled episodes.
Args:
num_items: Number of items to sample from this buffer.
``**kwargs``: Forward compatibility kwargs.
Returns:
Concatenated batch of items.
"""
if len(self) == 0:
raise ValueError("Trying to sample from an empty buffer.")
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
sample = self._encode_sample(idxes)
self._num_timesteps_sampled += sample.count
return sample
@DeveloperAPI
def stats(self, debug: bool = False) -> dict:
"""Returns the stats of this buffer.
Args:
debug: If True, adds sample eviction statistics to the returned
stats dict.
Returns:
A dictionary of stats about this buffer.
"""
data = {
"added_count": self._num_timesteps_added,
"added_count_wrapped": self._num_timesteps_added_wrap,
"eviction_started": self._eviction_started,
"sampled_count": self._num_timesteps_sampled,
"est_size_bytes": self._est_size_bytes,
"num_entries": len(self._storage),
}
if debug:
data.update(self._evicted_hit_stats.stats())
return data
@DeveloperAPI
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
Returns:
The serializable local state.
"""
state = {"_storage": self._storage, "_next_idx": self._next_idx}
state.update(self.stats(debug=False))
return state
@DeveloperAPI
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
Args:
state: The new state to set this buffer. Can be
obtained by calling `self.get_state()`.
"""
# The actual storage.
self._storage = state["_storage"]
self._next_idx = state["_next_idx"]
# Stats and counts.
self._num_timesteps_added = state["added_count"]
self._num_timesteps_added_wrap = state["added_count_wrapped"]
self._eviction_started = state["eviction_started"]
self._num_timesteps_sampled = state["sampled_count"]
self._est_size_bytes = state["est_size_bytes"]
@DeveloperAPI
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
"""Fetches concatenated samples at given indices from the storage."""
samples = []
for i in idxes:
self._hit_count[i] += 1
samples.append(self._storage[i])
if samples:
# We assume all samples are of same type
sample_type = type(samples[0])
out = sample_type.concat_samples(samples)
else:
out = SampleBatch()
out.decompress_if_needed()
return out
@DeveloperAPI
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
@DeveloperAPI
def apply(
self,
func: Callable[["ReplayBuffer", Optional[Any], Optional[Any]], T],
*args,
**kwargs,
) -> T:
"""Calls the given function with this ReplayBuffer instance.
This is useful if we want to apply a function to a set of remote actors.
Args:
func: A callable that accepts the replay buffer itself, args and kwargs
``*args``: Any args to pass to func
``**kwargs``: Any kwargs to pass to func
Returns:
Return value of the induced function call
"""
return func(self, *args, **kwargs)
@Deprecated(old="ReplayBuffer.add_batch()", new="ReplayBuffer.add()", error=False)
def add_batch(self, *args, **kwargs):
return self.add(*args, **kwargs)
@Deprecated(
old="ReplayBuffer.replay(num_items)",
new="ReplayBuffer.sample(num_items)",
error=False,
)
def replay(self, num_items):
return self.sample(num_items)
@Deprecated(
help="ReplayBuffers could be iterated over by default before. "
"Making a buffer an iterator will soon "
"be deprecated altogether. Consider switching to the training "
"iteration API to resolve this.",
error=False,
)
def make_iterator(self, num_items_to_replay: int):
"""Make this buffer a ParallelIteratorWorker to retain compatibility.
Execution plans have made heavy use of buffers as ParallelIteratorWorkers.
This method provides an easy way to support this for now.
"""
def gen_replay():
while True:
yield self.sample(num_items_to_replay)
ParallelIteratorWorker.__init__(self, gen_replay, False)