mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
147 lines
5.8 KiB
Python
147 lines
5.8 KiB
Python
import numpy as np
|
|
import unittest
|
|
|
|
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
|
|
class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|
"""Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer."""
|
|
|
|
capacity = 10
|
|
|
|
def _generate_data(self):
|
|
return SampleBatch(
|
|
{
|
|
"obs": [np.random.random((4,))],
|
|
"action": [np.random.choice([0, 1])],
|
|
"reward": [np.random.rand()],
|
|
"new_obs": [np.random.random((4,))],
|
|
"done": [np.random.choice([False, True])],
|
|
}
|
|
)
|
|
|
|
def test_mixin_sampling(self):
|
|
# 50% replay ratio.
|
|
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.5)
|
|
# Add a new batch.
|
|
batch = self._generate_data()
|
|
buffer.add_batch(batch)
|
|
# Expect at least 1 sample to be returned.
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) >= 1)
|
|
# If we insert and replay n times, expect roughly return batches of
|
|
# len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
|
|
# on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
buffer.add_batch(batch)
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 2.0)
|
|
|
|
# 33% replay ratio.
|
|
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.333)
|
|
# Expect exactly 0 samples to be returned (buffer empty).
|
|
sample = buffer.replay()
|
|
self.assertTrue(sample is None)
|
|
# Add a new batch.
|
|
batch = self._generate_data()
|
|
buffer.add_batch(batch)
|
|
# Expect at least 1 sample to be returned.
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) >= 1)
|
|
# If we insert-2x and replay n times, expect roughly return batches of
|
|
# len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
|
|
# on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
buffer.add_batch(batch)
|
|
buffer.add_batch(batch)
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)
|
|
|
|
# If we insert-1x and replay n times, expect roughly return batches of
|
|
# len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
|
|
# samples on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
buffer.add_batch(batch)
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)
|
|
|
|
# 90% replay ratio.
|
|
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.9)
|
|
# Expect exactly 0 samples to be returned (buffer empty).
|
|
sample = buffer.replay()
|
|
self.assertTrue(sample is None)
|
|
# Add a new batch.
|
|
batch = self._generate_data()
|
|
buffer.add_batch(batch)
|
|
# Expect at least 2 samples to be returned (new one plus at least one
|
|
# replay sample).
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) >= 2)
|
|
# If we insert and replay n times, expect roughly return batches of
|
|
# len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
|
|
# samples on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
buffer.add_batch(batch)
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)
|
|
|
|
# 0% replay ratio -> Only new samples.
|
|
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.0)
|
|
# Add a new batch.
|
|
batch = self._generate_data()
|
|
buffer.add_batch(batch)
|
|
# Expect exactly 1 sample to be returned.
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) == 1)
|
|
# Expect exactly 0 sample to be returned (nothing new to be returned;
|
|
# no replay allowed (replay_ratio=0.0)).
|
|
sample = buffer.replay()
|
|
self.assertTrue(sample is None)
|
|
# If we insert and replay n times, expect roughly return batches of
|
|
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
|
|
# on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
buffer.add_batch(batch)
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 1.0)
|
|
|
|
# 100% replay ratio -> Only new samples.
|
|
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=1.0)
|
|
# Expect exactly 0 samples to be returned (buffer empty).
|
|
sample = buffer.replay()
|
|
self.assertTrue(sample is None)
|
|
# Add a new batch.
|
|
batch = self._generate_data()
|
|
buffer.add_batch(batch)
|
|
# Expect exactly 1 sample to be returned (the new batch).
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) == 1)
|
|
# Another replay -> Expect exactly 1 sample to be returned.
|
|
sample = buffer.replay()
|
|
self.assertTrue(len(sample) == 1)
|
|
# If we replay n times, expect roughly return batches of
|
|
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
|
|
# on average in each returned value).
|
|
results = []
|
|
for _ in range(100):
|
|
sample = buffer.replay()
|
|
results.append(len(sample))
|
|
self.assertAlmostEqual(np.mean(results), 1.0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|