mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
97 lines
3 KiB
Python
97 lines
3 KiB
Python
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.test_utils import check
|
|
|
|
|
|
class TestRNNSequencing(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_pad_batch_dynamic_max(self):
|
|
"""Test pad_batch_to_sequences_of_same_size when dynamic_max = True"""
|
|
view_requirements = {
|
|
"state_in_0": ViewRequirement(
|
|
"state_out_0",
|
|
shift=[-1],
|
|
used_for_training=False,
|
|
used_for_compute_actions=True,
|
|
batch_repeat_value=1,
|
|
)
|
|
}
|
|
max_seq_len = 20
|
|
num_seqs = np.random.randint(1, 20)
|
|
seq_lens = np.random.randint(1, max_seq_len, size=(num_seqs))
|
|
max_len = np.max(seq_lens)
|
|
sum_seq_lens = np.sum(seq_lens)
|
|
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.arange(sum_seq_lens),
|
|
"b": np.arange(sum_seq_lens),
|
|
"seq_lens": seq_lens,
|
|
"state_in_0": [[0]] * num_seqs,
|
|
},
|
|
_max_seq_len=max_seq_len,
|
|
)
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
s1,
|
|
max_seq_len=max_seq_len,
|
|
feature_keys=["a", "b"],
|
|
view_requirements=view_requirements,
|
|
)
|
|
check(s1.max_seq_len, max_len)
|
|
check(s1["a"].shape[0], max_len * num_seqs)
|
|
check(s1["b"].shape[0], max_len * num_seqs)
|
|
|
|
def test_pad_batch_fixed_max(self):
|
|
"""Test pad_batch_to_sequences_of_same_size when dynamic_max = False"""
|
|
view_requirements = {
|
|
"state_in_0": ViewRequirement(
|
|
"state_out_0",
|
|
shift="-3:-1",
|
|
used_for_training=False,
|
|
used_for_compute_actions=True,
|
|
batch_repeat_value=1,
|
|
)
|
|
}
|
|
max_seq_len = 20
|
|
num_seqs = np.random.randint(1, 20)
|
|
seq_lens = np.random.randint(1, max_seq_len, size=(num_seqs))
|
|
sum_seq_lens = np.sum(seq_lens)
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.arange(sum_seq_lens),
|
|
"b": np.arange(sum_seq_lens),
|
|
"seq_lens": seq_lens,
|
|
"state_in_0": [[0]] * num_seqs,
|
|
},
|
|
_max_seq_len=max_seq_len,
|
|
)
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
s1,
|
|
max_seq_len=max_seq_len,
|
|
feature_keys=["a", "b"],
|
|
view_requirements=view_requirements,
|
|
)
|
|
check(s1.max_seq_len, max_seq_len)
|
|
check(s1["a"].shape[0], max_seq_len * num_seqs)
|
|
check(s1["b"].shape[0], max_seq_len * num_seqs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|