[rllib] Replay buffer size inaccurate with replay_seq_len option (#10988)

* support replay seq len

* update

* fix warn

* add test

* test
This commit is contained in:
Eric Liang 2020-09-25 13:47:23 -07:00 committed by GitHub
parent 8abe13023f
commit 8f79b4e45e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 16 deletions

View file

@ -33,9 +33,9 @@ def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = ("Estimated max memory usage for replay buffer is {} GB "
"({} batches of {} bytes each), "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item_size, total_gb))
mem_size, num_items, item.count, item_size, total_gb))
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
@ -51,15 +51,16 @@ class ReplayBuffer:
"""Create Prioritized Replay buffer.
Args:
size (int): Max number of items to store in the FIFO buffer.
size (int): Max number of timesteps to store in the FIFO buffer.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
self._hit_count = np.zeros(size)
self._eviction_started = False
self._num_added = 0
self._num_sampled = 0
self._num_timesteps_added = 0
self._num_timesteps_added_wrap = 0
self._num_timesteps_sampled = 0
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
self._est_size_bytes = 0
@ -68,18 +69,26 @@ class ReplayBuffer:
@DeveloperAPI
def add(self, item: SampleBatchType, weight: float):
warn_replay_buffer_size(item=item, num_items=self._maxsize)
warn_replay_buffer_size(
item=item, num_items=self._maxsize / item.count)
assert item.count > 0, item
self._num_added += 1
self._num_timesteps_added += item.count
self._num_timesteps_added_wrap += item.count
if self._next_idx >= len(self._storage):
self._storage.append(item)
self._est_size_bytes += item.size_bytes()
else:
self._storage[self._next_idx] = item
if self._next_idx + 1 >= self._maxsize:
# Wrap around storage as a circular buffer once we hit maxsize.
if self._num_timesteps_added_wrap >= self._maxsize:
self._eviction_started = True
self._next_idx = (self._next_idx + 1) % self._maxsize
self._num_timesteps_added_wrap = 0
self._next_idx = 0
else:
self._next_idx += 1
if self._eviction_started:
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
self._hit_count[self._next_idx] = 0
@ -109,8 +118,8 @@ class ReplayBuffer:
@DeveloperAPI
def stats(self, debug=False):
data = {
"added_count": self._num_added,
"sampled_count": self._num_sampled,
"added_count": self._num_timesteps_added,
"sampled_count": self._num_timesteps_sampled,
"est_size_bytes": self._est_size_bytes,
"num_entries": len(self._storage),
}
@ -179,7 +188,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
transition and original idxes in buffer of sampled experiences.
"""
assert beta >= 0.0
self._num_sampled += num_items
idxes = self._sample_proportional(num_items)
@ -194,6 +202,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
count = self._storage[idx].count
weights.extend([weight / max_weight] * count)
batch_indexes.extend([idx] * count)
self._num_timesteps_sampled += count
batch = self._encode_sample(idxes)
# Note: prioritization is not supported in lockstep replay mode.
@ -251,10 +260,10 @@ class LocalReplayBuffer(ParallelIteratorWorker):
may be created to increase parallelism."""
def __init__(self,
num_shards,
learning_starts,
buffer_size,
replay_batch_size,
num_shards=1,
learning_starts=1000,
buffer_size=10000,
replay_batch_size=1,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,

View file

@ -26,6 +26,24 @@ class TestPrioritizedReplayBuffer(unittest.TestCase):
"done": [np.random.choice([False, True])],
})
def test_sequence_size(self):
# seq len 1
memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
for _ in range(200):
memory.add(self._generate_data(), weight=None)
assert len(memory._storage) == 100, len(memory._storage)
assert memory.stats()["added_count"] == 200, memory.stats()
# seq len 5
memory = PrioritizedReplayBuffer(size=100, alpha=0.1)
for _ in range(40):
memory.add(
SampleBatch.concat_samples(
[self._generate_data() for _ in range(5)]),
weight=None)
assert len(memory._storage) == 20, len(memory._storage)
assert memory.stats()["added_count"] == 200, memory.stats()
def test_add(self):
memory = PrioritizedReplayBuffer(
size=2,