mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Replay buffer size inaccurate with replay_seq_len option (#10988)
* support replay seq len * update * fix warn * add test * test
This commit is contained in:
parent
8abe13023f
commit
8f79b4e45e
2 changed files with 43 additions and 16 deletions
|
@ -33,9 +33,9 @@ def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
|
|||
total_gb = psutil_mem.total / 1e9
|
||||
mem_size = num_items * item_size / 1e9
|
||||
msg = ("Estimated max memory usage for replay buffer is {} GB "
|
||||
"({} batches of {} bytes each), "
|
||||
"({} batches of size {}, {} bytes each), "
|
||||
"available system memory is {} GB".format(
|
||||
mem_size, num_items, item_size, total_gb))
|
||||
mem_size, num_items, item.count, item_size, total_gb))
|
||||
if mem_size > total_gb:
|
||||
raise ValueError(msg)
|
||||
elif mem_size > 0.2 * total_gb:
|
||||
|
@ -51,15 +51,16 @@ class ReplayBuffer:
|
|||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Args:
|
||||
size (int): Max number of items to store in the FIFO buffer.
|
||||
size (int): Max number of timesteps to store in the FIFO buffer.
|
||||
"""
|
||||
self._storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
self._hit_count = np.zeros(size)
|
||||
self._eviction_started = False
|
||||
self._num_added = 0
|
||||
self._num_sampled = 0
|
||||
self._num_timesteps_added = 0
|
||||
self._num_timesteps_added_wrap = 0
|
||||
self._num_timesteps_sampled = 0
|
||||
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
|
||||
self._est_size_bytes = 0
|
||||
|
||||
|
@ -68,18 +69,26 @@ class ReplayBuffer:
|
|||
|
||||
@DeveloperAPI
|
||||
def add(self, item: SampleBatchType, weight: float):
|
||||
warn_replay_buffer_size(item=item, num_items=self._maxsize)
|
||||
warn_replay_buffer_size(
|
||||
item=item, num_items=self._maxsize / item.count)
|
||||
assert item.count > 0, item
|
||||
self._num_added += 1
|
||||
self._num_timesteps_added += item.count
|
||||
self._num_timesteps_added_wrap += item.count
|
||||
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(item)
|
||||
self._est_size_bytes += item.size_bytes()
|
||||
else:
|
||||
self._storage[self._next_idx] = item
|
||||
if self._next_idx + 1 >= self._maxsize:
|
||||
|
||||
# Wrap around storage as a circular buffer once we hit maxsize.
|
||||
if self._num_timesteps_added_wrap >= self._maxsize:
|
||||
self._eviction_started = True
|
||||
self._next_idx = (self._next_idx + 1) % self._maxsize
|
||||
self._num_timesteps_added_wrap = 0
|
||||
self._next_idx = 0
|
||||
else:
|
||||
self._next_idx += 1
|
||||
|
||||
if self._eviction_started:
|
||||
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
|
||||
self._hit_count[self._next_idx] = 0
|
||||
|
@ -109,8 +118,8 @@ class ReplayBuffer:
|
|||
@DeveloperAPI
|
||||
def stats(self, debug=False):
|
||||
data = {
|
||||
"added_count": self._num_added,
|
||||
"sampled_count": self._num_sampled,
|
||||
"added_count": self._num_timesteps_added,
|
||||
"sampled_count": self._num_timesteps_sampled,
|
||||
"est_size_bytes": self._est_size_bytes,
|
||||
"num_entries": len(self._storage),
|
||||
}
|
||||
|
@ -179,7 +188,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
transition and original idxes in buffer of sampled experiences.
|
||||
"""
|
||||
assert beta >= 0.0
|
||||
self._num_sampled += num_items
|
||||
|
||||
idxes = self._sample_proportional(num_items)
|
||||
|
||||
|
@ -194,6 +202,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
count = self._storage[idx].count
|
||||
weights.extend([weight / max_weight] * count)
|
||||
batch_indexes.extend([idx] * count)
|
||||
self._num_timesteps_sampled += count
|
||||
batch = self._encode_sample(idxes)
|
||||
|
||||
# Note: prioritization is not supported in lockstep replay mode.
|
||||
|
@ -251,10 +260,10 @@ class LocalReplayBuffer(ParallelIteratorWorker):
|
|||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self,
|
||||
num_shards,
|
||||
learning_starts,
|
||||
buffer_size,
|
||||
replay_batch_size,
|
||||
num_shards=1,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
replay_batch_size=1,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
|
|
|
@ -26,6 +26,24 @@ class TestPrioritizedReplayBuffer(unittest.TestCase):
|
|||
"done": [np.random.choice([False, True])],
|
||||
})
|
||||
|
||||
def test_sequence_size(self):
|
||||
# seq len 1
|
||||
memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
|
||||
for _ in range(200):
|
||||
memory.add(self._generate_data(), weight=None)
|
||||
assert len(memory._storage) == 100, len(memory._storage)
|
||||
assert memory.stats()["added_count"] == 200, memory.stats()
|
||||
|
||||
# seq len 5
|
||||
memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
|
||||
for _ in range(40):
|
||||
memory.add(
|
||||
SampleBatch.concat_samples(
|
||||
[self._generate_data() for _ in range(5)]),
|
||||
weight=None)
|
||||
assert len(memory._storage) == 20, len(memory._storage)
|
||||
assert memory.stats()["added_count"] == 200, memory.stats()
|
||||
|
||||
def test_add(self):
|
||||
memory = PrioritizedReplayBuffer(
|
||||
size=2,
|
||||
|
|
Loading…
Add table
Reference in a new issue