mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
349 lines
11 KiB
Python
349 lines
11 KiB
Python
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.compression import is_compressed
|
|
from ray.rllib.utils.test_utils import check
|
|
|
|
|
|
class TestSampleBatch(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_len_and_size_bytes(self):
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3]),
|
|
"b": {"c": np.array([4, 5, 6])},
|
|
SampleBatch.SEQ_LENS: [1, 2],
|
|
}
|
|
)
|
|
check(len(s1), 3)
|
|
check(
|
|
s1.size_bytes(),
|
|
s1["a"].nbytes + s1["b"]["c"].nbytes + s1[SampleBatch.SEQ_LENS].nbytes,
|
|
)
|
|
|
|
def test_dict_properties_of_sample_batches(self):
|
|
base_dict = {
|
|
"a": np.array([1, 2, 3]),
|
|
"b": np.array([[0.1, 0.2], [0.3, 0.4]]),
|
|
"c": True,
|
|
}
|
|
batch = SampleBatch(base_dict)
|
|
keys_ = list(base_dict.keys())
|
|
values_ = list(base_dict.values())
|
|
items_ = list(base_dict.items())
|
|
assert list(batch.keys()) == keys_
|
|
assert list(batch.values()) == values_
|
|
assert list(batch.items()) == items_
|
|
|
|
# Add an item and check, whether it's in the "added" list.
|
|
batch["d"] = np.array(1)
|
|
assert batch.added_keys == {"d"}, batch.added_keys
|
|
# Access two keys and check, whether they are in the
|
|
# "accessed" list.
|
|
print(batch["a"], batch["b"])
|
|
assert batch.accessed_keys == {"a", "b"}, batch.accessed_keys
|
|
# Delete a key and check, whether it's in the "deleted" list.
|
|
del batch["c"]
|
|
assert batch.deleted_keys == {"c"}, batch.deleted_keys
|
|
|
|
def test_right_zero_padding(self):
|
|
"""Tests, whether right-zero-padding work properly."""
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3]),
|
|
"b": {"c": np.array([4, 5, 6])},
|
|
SampleBatch.SEQ_LENS: [1, 2],
|
|
}
|
|
)
|
|
s1.right_zero_pad(max_seq_len=5)
|
|
check(
|
|
s1,
|
|
{
|
|
"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
|
|
"b": {"c": [4, 0, 0, 0, 0, 5, 6, 0, 0, 0]},
|
|
SampleBatch.SEQ_LENS: [1, 2],
|
|
},
|
|
)
|
|
|
|
def test_concat(self):
|
|
"""Tests, SampleBatches.concat() and ...concat_samples()."""
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3]),
|
|
"b": {"c": np.array([4, 5, 6])},
|
|
}
|
|
)
|
|
s2 = SampleBatch(
|
|
{
|
|
"a": np.array([2, 3, 4]),
|
|
"b": {"c": np.array([5, 6, 7])},
|
|
}
|
|
)
|
|
concatd = SampleBatch.concat_samples([s1, s2])
|
|
check(concatd["a"], [1, 2, 3, 2, 3, 4])
|
|
check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7])
|
|
check(next(concatd.rows()), {"a": 1, "b": {"c": 4}})
|
|
|
|
concatd_2 = s1.concat(s2)
|
|
check(concatd, concatd_2)
|
|
|
|
def test_concat_max_seq_len(self):
|
|
"""Tests, SampleBatches.concat_samples() max_seq_len."""
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3]),
|
|
"b": {"c": np.array([4, 5, 6])},
|
|
SampleBatch.SEQ_LENS: [1, 2],
|
|
}
|
|
)
|
|
s2 = SampleBatch(
|
|
{
|
|
"a": np.array([2, 3, 4]),
|
|
"b": {"c": np.array([5, 6, 7])},
|
|
SampleBatch.SEQ_LENS: [3],
|
|
}
|
|
)
|
|
|
|
s3 = SampleBatch(
|
|
{
|
|
"a": np.array([2, 3, 4]),
|
|
"b": {"c": np.array([5, 6, 7])},
|
|
}
|
|
)
|
|
|
|
concatd = SampleBatch.concat_samples([s1, s2])
|
|
check(concatd.max_seq_len, s2.max_seq_len)
|
|
|
|
with self.assertRaises(ValueError):
|
|
SampleBatch.concat_samples([s1, s2, s3])
|
|
|
|
def test_rows(self):
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([[1, 1], [2, 2], [3, 3]]),
|
|
"b": {"c": np.array([[4, 4], [5, 5], [6, 6]])},
|
|
SampleBatch.SEQ_LENS: np.array([1, 2]),
|
|
}
|
|
)
|
|
check(
|
|
next(s1.rows()),
|
|
{"a": [1, 1], "b": {"c": [4, 4]}, SampleBatch.SEQ_LENS: 1},
|
|
)
|
|
|
|
def test_compression(self):
|
|
"""Tests, whether compression and decompression work properly."""
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3, 2, 3, 4]),
|
|
"b": {"c": np.array([4, 5, 6, 5, 6, 7])},
|
|
}
|
|
)
|
|
# Test, whether compressing happens in-place.
|
|
s1.compress(columns={"a", "b"}, bulk=True)
|
|
self.assertTrue(is_compressed(s1["a"]))
|
|
self.assertTrue(is_compressed(s1["b"]["c"]))
|
|
self.assertTrue(isinstance(s1["b"], dict))
|
|
|
|
# Test, whether de-compressing happens in-place.
|
|
s1.decompress_if_needed(columns={"a", "b"})
|
|
check(s1["a"], [1, 2, 3, 2, 3, 4])
|
|
check(s1["b"]["c"], [4, 5, 6, 5, 6, 7])
|
|
it = s1.rows()
|
|
next(it)
|
|
check(next(it), {"a": 2, "b": {"c": 5}})
|
|
|
|
def test_slicing(self):
|
|
"""Tests, whether slicing can be done on SampleBatches."""
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3, 2, 3, 4]),
|
|
"b": {"c": np.array([4, 5, 6, 5, 6, 7])},
|
|
}
|
|
)
|
|
check(
|
|
s1[:3],
|
|
{
|
|
"a": [1, 2, 3],
|
|
"b": {"c": [4, 5, 6]},
|
|
},
|
|
)
|
|
check(
|
|
s1[0:3],
|
|
{
|
|
"a": [1, 2, 3],
|
|
"b": {"c": [4, 5, 6]},
|
|
},
|
|
)
|
|
check(
|
|
s1[1:4],
|
|
{
|
|
"a": [2, 3, 2],
|
|
"b": {"c": [5, 6, 5]},
|
|
},
|
|
)
|
|
check(
|
|
s1[1:],
|
|
{
|
|
"a": [2, 3, 2, 3, 4],
|
|
"b": {"c": [5, 6, 5, 6, 7]},
|
|
},
|
|
)
|
|
check(
|
|
s1[3:4],
|
|
{
|
|
"a": [2],
|
|
"b": {"c": [5]},
|
|
},
|
|
)
|
|
|
|
# When we change the slice, the original SampleBatch should also
|
|
# change (shared underlying data).
|
|
s1[:3]["a"][0] = 100
|
|
s1[1:2]["a"][0] = 200
|
|
check(s1["a"][0], 100)
|
|
check(s1["a"][1], 200)
|
|
|
|
# Seq-len batches should be auto-sliced along sequences,
|
|
# no matter what.
|
|
s2 = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3, 2, 3, 4]),
|
|
"b": {"c": np.array([4, 5, 6, 5, 6, 7])},
|
|
SampleBatch.SEQ_LENS: [2, 3, 1],
|
|
"state_in_0": [1.0, 3.0, 4.0],
|
|
}
|
|
)
|
|
# We would expect a=[1, 2, 3] now, but due to the sequence
|
|
# boundary, we stop earlier.
|
|
check(
|
|
s2[:3],
|
|
{
|
|
"a": [1, 2],
|
|
"b": {"c": [4, 5]},
|
|
SampleBatch.SEQ_LENS: [2],
|
|
"state_in_0": [1.0],
|
|
},
|
|
)
|
|
# Split exactly at a seq-len boundary.
|
|
check(
|
|
s2[:5],
|
|
{
|
|
"a": [1, 2, 3, 2, 3],
|
|
"b": {"c": [4, 5, 6, 5, 6]},
|
|
SampleBatch.SEQ_LENS: [2, 3],
|
|
"state_in_0": [1.0, 3.0],
|
|
},
|
|
)
|
|
# Split above seq-len boundary.
|
|
check(
|
|
s2[:50],
|
|
{
|
|
"a": [1, 2, 3, 2, 3, 4],
|
|
"b": {"c": [4, 5, 6, 5, 6, 7]},
|
|
SampleBatch.SEQ_LENS: [2, 3, 1],
|
|
"state_in_0": [1.0, 3.0, 4.0],
|
|
},
|
|
)
|
|
check(
|
|
s2[:],
|
|
{
|
|
"a": [1, 2, 3, 2, 3, 4],
|
|
"b": {"c": [4, 5, 6, 5, 6, 7]},
|
|
SampleBatch.SEQ_LENS: [2, 3, 1],
|
|
"state_in_0": [1.0, 3.0, 4.0],
|
|
},
|
|
)
|
|
|
|
def test_split_by_episode(self):
|
|
s = SampleBatch(
|
|
{
|
|
"a": np.array([0, 1, 2, 3, 4, 5]),
|
|
"eps_id": np.array([0, 0, 0, 0, 1, 1]),
|
|
"dones": np.array([0, 0, 0, 1, 0, 1]),
|
|
}
|
|
)
|
|
true_split = [np.array([0, 1, 2, 3]), np.array([4, 5])]
|
|
|
|
# Check that splitting by EPS_ID works correctly
|
|
eps_split = [b["a"] for b in s.split_by_episode()]
|
|
check(true_split, eps_split)
|
|
|
|
# Check that splitting by DONES works correctly
|
|
del s["eps_id"]
|
|
dones_split = [b["a"] for b in s.split_by_episode()]
|
|
check(true_split, dones_split)
|
|
|
|
# Check that splitting without the EPS_ID or DONES key raise an error
|
|
del s["dones"]
|
|
with self.assertRaises(KeyError):
|
|
s.split_by_episode()
|
|
|
|
# Check that splitting with DONES always False returns the whole batch
|
|
s["dones"] = np.array([0, 0, 0, 0, 0, 0])
|
|
batch_split = [b["a"] for b in s.split_by_episode()]
|
|
check(s["a"], batch_split[0])
|
|
|
|
def test_copy(self):
|
|
s = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3, 2, 3, 4]),
|
|
"b": {"c": np.array([4, 5, 6, 5, 6, 7])},
|
|
SampleBatch.SEQ_LENS: [2, 3, 1],
|
|
"state_in_0": [1.0, 3.0, 4.0],
|
|
}
|
|
)
|
|
s_copy = s.copy(shallow=False)
|
|
s_copy["a"][0] = 100
|
|
s_copy["b"]["c"][0] = 200
|
|
s_copy[SampleBatch.SEQ_LENS][0] = 3
|
|
s_copy[SampleBatch.SEQ_LENS][1] = 2
|
|
s_copy["state_in_0"][0] = 400.0
|
|
self.assertNotEqual(s["a"][0], s_copy["a"][0])
|
|
self.assertNotEqual(s["b"]["c"][0], s_copy["b"]["c"][0])
|
|
self.assertNotEqual(s[SampleBatch.SEQ_LENS][0], s_copy[SampleBatch.SEQ_LENS][0])
|
|
self.assertNotEqual(s[SampleBatch.SEQ_LENS][1], s_copy[SampleBatch.SEQ_LENS][1])
|
|
self.assertNotEqual(s["state_in_0"][0], s_copy["state_in_0"][0])
|
|
|
|
s_copy = s.copy(shallow=True)
|
|
s_copy["a"][0] = 100
|
|
s_copy["b"]["c"][0] = 200
|
|
s_copy[SampleBatch.SEQ_LENS][0] = 3
|
|
s_copy[SampleBatch.SEQ_LENS][1] = 2
|
|
s_copy["state_in_0"][0] = 400.0
|
|
self.assertEqual(s["a"][0], s_copy["a"][0])
|
|
self.assertEqual(s["b"]["c"][0], s_copy["b"]["c"][0])
|
|
self.assertEqual(s[SampleBatch.SEQ_LENS][0], s_copy[SampleBatch.SEQ_LENS][0])
|
|
self.assertEqual(s[SampleBatch.SEQ_LENS][1], s_copy[SampleBatch.SEQ_LENS][1])
|
|
self.assertEqual(s["state_in_0"][0], s_copy["state_in_0"][0])
|
|
|
|
def test_shuffle_with_interceptor(self):
|
|
"""Tests, whether `shuffle()` clears the `intercepted_values` cache."""
|
|
s = SampleBatch(
|
|
{
|
|
"a": np.array([1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7]),
|
|
}
|
|
)
|
|
# Set a summy get-interceptor (returning all values, but plus 1).
|
|
s.set_get_interceptor(lambda v: v + 1)
|
|
# Make sure, interceptor works.
|
|
check(s["a"], [2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8])
|
|
s.shuffle()
|
|
# Make sure, intercepted values are NOT the original ones (before the shuffle),
|
|
# but have also been shuffled.
|
|
check(s["a"], [2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8], false=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|