ray/rllib/utils/replay_buffers/tests/test_reservoir_buffer.py
2022-06-10 16:47:51 +02:00

98 lines
3.8 KiB
Python

import unittest
import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer
class TestReservoirBuffer(unittest.TestCase):
def test_timesteps_unit(self):
"""Tests adding, sampling, get-/set state, and eviction with
experiences stored by timesteps."""
self.batch_id = 0
def _add_data_to_buffer(_buffer, batch_size, num_batches=5, **kwargs):
def _generate_data():
return SampleBatch(
{
SampleBatch.T: [np.random.random((4,))],
SampleBatch.ACTIONS: [np.random.choice([0, 1])],
SampleBatch.OBS: [np.random.random((4,))],
SampleBatch.NEXT_OBS: [np.random.random((4,))],
SampleBatch.REWARDS: [np.random.rand()],
SampleBatch.DONES: [np.random.choice([False, True])],
"batch_id": [self.batch_id],
}
)
for i in range(num_batches):
data = [_generate_data() for _ in range(batch_size)]
self.batch_id += 1
batch = SampleBatch.concat_samples(data)
_buffer.add(batch, **kwargs)
batch_size = 1
buffer_size = 100
buffer = ReservoirReplayBuffer(capacity=buffer_size)
# Put 1000 batches in a buffer with capacity 100
_add_data_to_buffer(buffer, batch_size=batch_size, num_batches=1000)
# Expect the batch id to be ~500 on average
batch_id_sum = 0
for i in range(200):
num_ts_sampled = np.random.randint(1, 10)
sample = buffer.sample(num_ts_sampled)
batch_id_sum += sum(sample["batch_id"]) / num_ts_sampled
self.assertAlmostEqual(batch_id_sum / 200, 500, delta=100)
def test_episodes_unit(self):
"""Tests adding, sampling, get-/set state, and eviction with
experiences stored by timesteps."""
self.batch_id = 0
def _add_data_to_buffer(_buffer, batch_size, num_batches=5, **kwargs):
def _generate_data():
return SampleBatch(
{
SampleBatch.T: [0, 1],
SampleBatch.ACTIONS: 2 * [np.random.choice([0, 1])],
SampleBatch.REWARDS: 2 * [np.random.rand()],
SampleBatch.OBS: 2 * [np.random.random((4,))],
SampleBatch.NEXT_OBS: 2 * [np.random.random((4,))],
SampleBatch.DONES: [False, True],
SampleBatch.AGENT_INDEX: 2 * [0],
"batch_id": 2 * [self.batch_id],
}
)
for i in range(num_batches):
data = [_generate_data() for _ in range(batch_size)]
self.batch_id += 1
batch = SampleBatch.concat_samples(data)
_buffer.add(batch, **kwargs)
batch_size = 1
buffer_size = 100
buffer = ReservoirReplayBuffer(capacity=buffer_size, storage_unit="fragments")
# Put 1000 batches in a buffer with capacity 100
_add_data_to_buffer(buffer, batch_size=batch_size, num_batches=1000)
# Expect the batch id to be ~500 on average
batch_id_sum = 0
for i in range(200):
num_episodes_sampled = np.random.randint(1, 10)
sample = buffer.sample(num_episodes_sampled)
num_ts_sampled = num_episodes_sampled * 2
batch_id_sum += sum(sample["batch_id"]) / num_ts_sampled
self.assertAlmostEqual(batch_id_sum / 200, 500, delta=100)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))