ray/rllib/utils/replay_buffers/reservoir_buffer.py
2022-02-09 15:04:43 +01:00

109 lines
3.7 KiB
Python

from typing import Any, Dict
import random
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
@ExperimentalAPI
class ReservoirBuffer(ReplayBuffer):
"""This buffer implements reservoir sampling.
The algorithm has been described by Jeffrey S. Vitter in "Random sampling
with a reservoir". See https://www.cs.umd.edu/~samir/498/vitter.pdf for
the full paper.
"""
def __init__(self, capacity: int = 10000, storage_unit: str = "timesteps"):
"""Initializes a ReservoirBuffer instance.
Args:
capacity: Max number of timesteps to store in the FIFO
buffer. After reaching this number, older samples will be
dropped to make space for new ones.
storage_unit: Either 'sequences' or 'timesteps'. Specifies how
experiences are stored.
"""
ReplayBuffer.__init__(self, capacity, storage_unit)
self._num_add_calls = 0
self._num_evicted = 0
@ExperimentalAPI
@override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences.
Args:
batch: SampleBatch to add to this buffer's storage.
"""
# Update add counts.
self._num_add_calls += 1
# Update our timesteps counts.
self._num_timesteps_added += batch.count
self._num_timesteps_added_wrap += batch.count
if self._num_timesteps_added < self.capacity:
ReplayBuffer.add(self, batch)
else:
# Eviction of older samples has already started (buffer is "full")
self._eviction_started = True
idx = random.randint(0, self._num_add_calls - 1)
if idx < self.capacity:
self._num_evicted += 1
self._evicted_hit_stats.push(self._hit_count[idx])
self._hit_count[idx] = 0
self._storage[idx] = batch
assert batch.count > 0, batch
warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)
@ExperimentalAPI
@override(ReplayBuffer)
def stats(self, debug: bool = False) -> dict:
"""Returns the stats of this buffer.
Args:
debug: If True, adds sample eviction statistics to the returned
stats dict.
Returns:
A dictionary of stats about this buffer.
"""
data = {
"num_evicted": self._num_evicted,
"num_add_calls": self._num_add_calls,
}
parent = ReplayBuffer.stats(self, debug)
parent.update(data)
return parent
@ExperimentalAPI
@override(ReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
Returns:
The serializable local state.
"""
parent = ReplayBuffer.get_state(self)
parent.update(self.stats())
return parent
@ExperimentalAPI
@override(ReplayBuffer)
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
Args:
state: The new state to set this buffer. Can be
obtained by calling `self.get_state()`.
"""
self._num_evicted = state["num_evicted"]
self._num_add_calls = state["num_add_calls"]
ReplayBuffer.set_state(self, state)