mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
195 lines
6.8 KiB
Python
195 lines
6.8 KiB
Python
import logging
|
|
import platform
|
|
from typing import Any, Dict, List, Optional
|
|
import numpy as np
|
|
import random
|
|
|
|
# Import ray before psutil will make sure we use psutil's bundled version
|
|
import ray # noqa F401
|
|
import psutil # noqa E402
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import ExperimentalAPI
|
|
from ray.rllib.utils.metrics.window_stat import WindowStat
|
|
from ray.rllib.utils.typing import SampleBatchType
|
|
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@ExperimentalAPI
|
|
class ReplayBuffer:
|
|
def __init__(
|
|
self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs
|
|
):
|
|
"""Initializes a ReplayBuffer instance.
|
|
|
|
Args:
|
|
capacity: Max number of timesteps to store in the FIFO
|
|
buffer. After reaching this number, older samples will be
|
|
dropped to make space for new ones.
|
|
storage_unit: Either 'sequences' or 'timesteps'. Specifies how
|
|
experiences are stored.
|
|
**kwargs: Forward compatibility kwargs.
|
|
"""
|
|
|
|
if storage_unit == "timesteps":
|
|
self._store_as_sequences = False
|
|
elif storage_unit == "sequences":
|
|
self._store_as_sequences = True
|
|
else:
|
|
raise ValueError("storage_unit must be either 'sequences' or 'timestamps'")
|
|
|
|
# The actual storage (list of SampleBatches).
|
|
self._storage = []
|
|
|
|
self.capacity = capacity
|
|
# The next index to override in the buffer.
|
|
self._next_idx = 0
|
|
self._hit_count = np.zeros(self.capacity)
|
|
|
|
# Whether we have already hit our capacity (and have therefore
|
|
# started to evict older samples).
|
|
self._eviction_started = False
|
|
|
|
# Number of (single) timesteps that have been added to the buffer
|
|
# over its lifetime. Note that each added item (batch) may contain
|
|
# more than one timestep.
|
|
self._num_timesteps_added = 0
|
|
self._num_timesteps_added_wrap = 0
|
|
|
|
# Number of (single) timesteps that have been sampled from the buffer
|
|
# over its lifetime.
|
|
self._num_timesteps_sampled = 0
|
|
|
|
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
|
|
self._est_size_bytes = 0
|
|
|
|
def __len__(self) -> int:
|
|
"""Returns the number of items currently stored in this buffer."""
|
|
return len(self._storage)
|
|
|
|
@ExperimentalAPI
|
|
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
|
"""Adds a batch of experiences.
|
|
|
|
Args:
|
|
batch: SampleBatch to add to this buffer's storage.
|
|
**kwargs: Forward compatibility kwargs.
|
|
"""
|
|
assert batch.count > 0, batch
|
|
warn_replay_capacity(item=batch, num_items=self.capacity / batch.count)
|
|
|
|
# Update our timesteps counts.
|
|
self._num_timesteps_added += batch.count
|
|
self._num_timesteps_added_wrap += batch.count
|
|
|
|
if self._next_idx >= len(self._storage):
|
|
self._storage.append(batch)
|
|
self._est_size_bytes += batch.size_bytes()
|
|
else:
|
|
self._storage[self._next_idx] = batch
|
|
|
|
# Wrap around storage as a circular buffer once we hit capacity.
|
|
if self._num_timesteps_added_wrap >= self.capacity:
|
|
self._eviction_started = True
|
|
self._num_timesteps_added_wrap = 0
|
|
self._next_idx = 0
|
|
else:
|
|
self._next_idx += 1
|
|
|
|
# Eviction of older samples has already started (buffer is "full").
|
|
if self._eviction_started:
|
|
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
|
|
self._hit_count[self._next_idx] = 0
|
|
|
|
@ExperimentalAPI
|
|
def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]:
|
|
"""Samples a batch of size `num_items` from this buffer.
|
|
|
|
If less than `num_items` records are in this buffer, some samples in
|
|
the results may be repeated to fulfil the batch size (`num_items`)
|
|
request.
|
|
|
|
Args:
|
|
num_items: Number of items to sample from this buffer.
|
|
**kwargs: Forward compatibility kwargs.
|
|
|
|
Returns:
|
|
Concatenated batch of items. None if buffer is empty.
|
|
"""
|
|
# If we don't have any samples yet in this buffer, return None.
|
|
if len(self) == 0:
|
|
return None
|
|
|
|
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
|
|
sample = self._encode_sample(idxes)
|
|
# Update our timesteps counters.
|
|
self._num_timesteps_sampled += len(sample)
|
|
return sample
|
|
|
|
@ExperimentalAPI
|
|
def stats(self, debug: bool = False) -> dict:
|
|
"""Returns the stats of this buffer.
|
|
|
|
Args:
|
|
debug: If True, adds sample eviction statistics to the returned
|
|
stats dict.
|
|
|
|
Returns:
|
|
A dictionary of stats about this buffer.
|
|
"""
|
|
data = {
|
|
"added_count": self._num_timesteps_added,
|
|
"added_count_wrapped": self._num_timesteps_added_wrap,
|
|
"eviction_started": self._eviction_started,
|
|
"sampled_count": self._num_timesteps_sampled,
|
|
"est_size_bytes": self._est_size_bytes,
|
|
"num_entries": len(self._storage),
|
|
}
|
|
if debug:
|
|
data.update(self._evicted_hit_stats.stats())
|
|
return data
|
|
|
|
@ExperimentalAPI
|
|
def get_state(self) -> Dict[str, Any]:
|
|
"""Returns all local state.
|
|
|
|
Returns:
|
|
The serializable local state.
|
|
"""
|
|
state = {"_storage": self._storage, "_next_idx": self._next_idx}
|
|
state.update(self.stats(debug=False))
|
|
return state
|
|
|
|
@ExperimentalAPI
|
|
def set_state(self, state: Dict[str, Any]) -> None:
|
|
"""Restores all local state to the provided `state`.
|
|
|
|
Args:
|
|
state: The new state to set this buffer. Can be
|
|
obtained by calling `self.get_state()`.
|
|
"""
|
|
# The actual storage.
|
|
self._storage = state["_storage"]
|
|
self._next_idx = state["_next_idx"]
|
|
# Stats and counts.
|
|
self._num_timesteps_added = state["added_count"]
|
|
self._num_timesteps_added_wrap = state["added_count_wrapped"]
|
|
self._eviction_started = state["eviction_started"]
|
|
self._num_timesteps_sampled = state["sampled_count"]
|
|
self._est_size_bytes = state["est_size_bytes"]
|
|
|
|
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
|
|
out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
|
|
out.decompress_if_needed()
|
|
return out
|
|
|
|
def get_host(self) -> str:
|
|
"""Returns the computer's network name.
|
|
|
|
Returns:
|
|
The computer's networks name or an empty string, if the network
|
|
name could not be determined.
|
|
"""
|
|
return platform.node()
|