mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
98 lines
3.8 KiB
Python
98 lines
3.8 KiB
Python
import unittest
|
|
import numpy as np
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer
|
|
|
|
|
|
class TestReservoirBuffer(unittest.TestCase):
|
|
def test_timesteps_unit(self):
|
|
"""Tests adding, sampling, get-/set state, and eviction with
|
|
experiences stored by timesteps."""
|
|
self.batch_id = 0
|
|
|
|
def _add_data_to_buffer(_buffer, batch_size, num_batches=5, **kwargs):
|
|
def _generate_data():
|
|
return SampleBatch(
|
|
{
|
|
SampleBatch.T: [np.random.random((4,))],
|
|
SampleBatch.ACTIONS: [np.random.choice([0, 1])],
|
|
SampleBatch.OBS: [np.random.random((4,))],
|
|
SampleBatch.NEXT_OBS: [np.random.random((4,))],
|
|
SampleBatch.REWARDS: [np.random.rand()],
|
|
SampleBatch.DONES: [np.random.choice([False, True])],
|
|
"batch_id": [self.batch_id],
|
|
}
|
|
)
|
|
|
|
for i in range(num_batches):
|
|
data = [_generate_data() for _ in range(batch_size)]
|
|
self.batch_id += 1
|
|
batch = SampleBatch.concat_samples(data)
|
|
_buffer.add(batch, **kwargs)
|
|
|
|
batch_size = 1
|
|
buffer_size = 100
|
|
|
|
buffer = ReservoirReplayBuffer(capacity=buffer_size)
|
|
# Put 1000 batches in a buffer with capacity 100
|
|
_add_data_to_buffer(buffer, batch_size=batch_size, num_batches=1000)
|
|
|
|
# Expect the batch id to be ~500 on average
|
|
batch_id_sum = 0
|
|
for i in range(200):
|
|
num_ts_sampled = np.random.randint(1, 10)
|
|
sample = buffer.sample(num_ts_sampled)
|
|
batch_id_sum += sum(sample["batch_id"]) / num_ts_sampled
|
|
|
|
self.assertAlmostEqual(batch_id_sum / 200, 500, delta=100)
|
|
|
|
def test_episodes_unit(self):
|
|
"""Tests adding, sampling, get-/set state, and eviction with
|
|
experiences stored by timesteps."""
|
|
self.batch_id = 0
|
|
|
|
def _add_data_to_buffer(_buffer, batch_size, num_batches=5, **kwargs):
|
|
def _generate_data():
|
|
return SampleBatch(
|
|
{
|
|
SampleBatch.T: [0, 1],
|
|
SampleBatch.ACTIONS: 2 * [np.random.choice([0, 1])],
|
|
SampleBatch.REWARDS: 2 * [np.random.rand()],
|
|
SampleBatch.OBS: 2 * [np.random.random((4,))],
|
|
SampleBatch.NEXT_OBS: 2 * [np.random.random((4,))],
|
|
SampleBatch.DONES: [False, True],
|
|
SampleBatch.AGENT_INDEX: 2 * [0],
|
|
"batch_id": 2 * [self.batch_id],
|
|
}
|
|
)
|
|
|
|
for i in range(num_batches):
|
|
data = [_generate_data() for _ in range(batch_size)]
|
|
self.batch_id += 1
|
|
batch = SampleBatch.concat_samples(data)
|
|
_buffer.add(batch, **kwargs)
|
|
|
|
batch_size = 1
|
|
buffer_size = 100
|
|
|
|
buffer = ReservoirReplayBuffer(capacity=buffer_size, storage_unit="fragments")
|
|
# Put 1000 batches in a buffer with capacity 100
|
|
_add_data_to_buffer(buffer, batch_size=batch_size, num_batches=1000)
|
|
|
|
# Expect the batch id to be ~500 on average
|
|
batch_id_sum = 0
|
|
for i in range(200):
|
|
num_episodes_sampled = np.random.randint(1, 10)
|
|
sample = buffer.sample(num_episodes_sampled)
|
|
num_ts_sampled = num_episodes_sampled * 2
|
|
batch_id_sum += sum(sample["batch_id"]) / num_ts_sampled
|
|
|
|
self.assertAlmostEqual(batch_id_sum / 200, 500, delta=100)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|