diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index a7355b85b..8d98e1397 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -33,9 +33,9 @@ def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None: total_gb = psutil_mem.total / 1e9 mem_size = num_items * item_size / 1e9 msg = ("Estimated max memory usage for replay buffer is {} GB " - "({} batches of {} bytes each), " + "({} batches of size {}, {} bytes each), " "available system memory is {} GB".format( - mem_size, num_items, item_size, total_gb)) + mem_size, num_items, item.count, item_size, total_gb)) if mem_size > total_gb: raise ValueError(msg) elif mem_size > 0.2 * total_gb: @@ -51,15 +51,16 @@ class ReplayBuffer: """Create Prioritized Replay buffer. Args: - size (int): Max number of items to store in the FIFO buffer. + size (int): Max number of timesteps to store in the FIFO buffer. """ self._storage = [] self._maxsize = size self._next_idx = 0 self._hit_count = np.zeros(size) self._eviction_started = False - self._num_added = 0 - self._num_sampled = 0 + self._num_timesteps_added = 0 + self._num_timesteps_added_wrap = 0 + self._num_timesteps_sampled = 0 self._evicted_hit_stats = WindowStat("evicted_hit", 1000) self._est_size_bytes = 0 @@ -68,18 +69,26 @@ class ReplayBuffer: @DeveloperAPI def add(self, item: SampleBatchType, weight: float): - warn_replay_buffer_size(item=item, num_items=self._maxsize) + warn_replay_buffer_size( + item=item, num_items=self._maxsize / item.count) assert item.count > 0, item - self._num_added += 1 + self._num_timesteps_added += item.count + self._num_timesteps_added_wrap += item.count if self._next_idx >= len(self._storage): self._storage.append(item) self._est_size_bytes += item.size_bytes() else: self._storage[self._next_idx] = item - if self._next_idx + 1 >= self._maxsize: + + # Wrap around storage as a circular buffer once we hit maxsize. + if self._num_timesteps_added_wrap >= self._maxsize: self._eviction_started = True - self._next_idx = (self._next_idx + 1) % self._maxsize + self._num_timesteps_added_wrap = 0 + self._next_idx = 0 + else: + self._next_idx += 1 + if self._eviction_started: self._evicted_hit_stats.push(self._hit_count[self._next_idx]) self._hit_count[self._next_idx] = 0 @@ -109,8 +118,8 @@ class ReplayBuffer: @DeveloperAPI def stats(self, debug=False): data = { - "added_count": self._num_added, - "sampled_count": self._num_sampled, + "added_count": self._num_timesteps_added, + "sampled_count": self._num_timesteps_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } @@ -179,7 +188,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): transition and original idxes in buffer of sampled experiences. """ assert beta >= 0.0 - self._num_sampled += num_items idxes = self._sample_proportional(num_items) @@ -194,6 +202,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): count = self._storage[idx].count weights.extend([weight / max_weight] * count) batch_indexes.extend([idx] * count) + self._num_timesteps_sampled += count batch = self._encode_sample(idxes) # Note: prioritization is not supported in lockstep replay mode. @@ -251,10 +260,10 @@ class LocalReplayBuffer(ParallelIteratorWorker): may be created to increase parallelism.""" def __init__(self, - num_shards, - learning_starts, - buffer_size, - replay_batch_size, + num_shards=1, + learning_starts=1000, + buffer_size=10000, + replay_batch_size=1, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6, diff --git a/rllib/execution/tests/test_prioritized_replay_buffer.py b/rllib/execution/tests/test_prioritized_replay_buffer.py index 5f3f3084c..2fa114735 100644 --- a/rllib/execution/tests/test_prioritized_replay_buffer.py +++ b/rllib/execution/tests/test_prioritized_replay_buffer.py @@ -26,6 +26,24 @@ class TestPrioritizedReplayBuffer(unittest.TestCase): "done": [np.random.choice([False, True])], }) + def test_sequence_size(self): + # seq len 1 + memory = PrioritizedReplayBuffer(size=100, alpha=0.1) + for _ in range(200): + memory.add(self._generate_data(), weight=None) + assert len(memory._storage) == 100, len(memory._storage) + assert memory.stats()["added_count"] == 200, memory.stats() + + # seq len 5 + memory = PrioritizedReplayBuffer(size=100, alpha=0.1) + for _ in range(40): + memory.add( + SampleBatch.concat_samples( + [self._generate_data() for _ in range(5)]), + weight=None) + assert len(memory._storage) == 20, len(memory._storage) + assert memory.stats()["added_count"] == 200, memory.stats() + def test_add(self): memory = PrioritizedReplayBuffer( size=2,