mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
195 lines
6.8 KiB
Python
195 lines
6.8 KiB
Python
import numpy as np
|
|
import unittest
|
|
|
|
import ray
|
|
from ray.rllib.policy.rnn_sequencing import (
|
|
pad_batch_to_sequences_of_same_size,
|
|
add_time_dimension,
|
|
chop_into_sequences,
|
|
)
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
from ray.rllib.utils.test_utils import check
|
|
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
class TestRNNSequencing(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_chop_into_sequences_long_seq(self):
|
|
"""Test pad_batch where episodes are longer than max_seq_len. The long
|
|
seq should be split into two smaller seqs that are less than max_seq_len"""
|
|
max_seq_len = 2
|
|
# Input seq lens, corresponding to ep_ids, unroll_ids, etc.
|
|
seq_lens = [2, 3, 1] # noqa: F841
|
|
ep_ids = [0, 0, 1, 1, 1, 2]
|
|
unroll_ids = [2, 2, 3, 3, 3, 4]
|
|
feats = [[1, 1, 2, 2, 2, 3]]
|
|
# Input states, ie states[3] is the input state at
|
|
# t = 3 and the output state at t = 2
|
|
states = [[1, 2, 3, 4, 5, 6]]
|
|
agent = [0, 0, 0, 0, 0, 0]
|
|
f_pad, s_init, s_lens = chop_into_sequences(
|
|
feature_columns=feats,
|
|
state_columns=states,
|
|
max_seq_len=max_seq_len,
|
|
episode_ids=ep_ids,
|
|
unroll_ids=unroll_ids,
|
|
agent_indices=agent,
|
|
dynamic_max=False,
|
|
)
|
|
expected_f_pad = [[1, 1, 2, 2, 2, 0, 3, 0]]
|
|
expected_seq_lens = [2, 2, 1, 1]
|
|
expected_states = [[1, 3, 5, 6]]
|
|
check(f_pad, expected_f_pad)
|
|
check(s_lens, expected_seq_lens)
|
|
check(s_init, expected_states)
|
|
|
|
# Try again with dynamic max
|
|
f_pad, s_init, s_lens = chop_into_sequences(
|
|
feature_columns=feats,
|
|
state_columns=states,
|
|
max_seq_len=max_seq_len,
|
|
episode_ids=ep_ids,
|
|
unroll_ids=unroll_ids,
|
|
agent_indices=agent,
|
|
dynamic_max=True,
|
|
)
|
|
check(f_pad, expected_f_pad)
|
|
check(s_lens, expected_seq_lens)
|
|
check(s_init, expected_states)
|
|
|
|
def test_pad_batch_dynamic_max(self):
|
|
"""Test pad_batch_to_sequences_of_same_size when dynamic_max = True"""
|
|
view_requirements = {
|
|
"state_in_0": ViewRequirement(
|
|
"state_out_0",
|
|
shift=[-1],
|
|
used_for_training=False,
|
|
used_for_compute_actions=True,
|
|
batch_repeat_value=1,
|
|
)
|
|
}
|
|
max_seq_len = 20
|
|
num_seqs = np.random.randint(1, 20)
|
|
seq_lens = np.random.randint(1, max_seq_len, size=(num_seqs))
|
|
max_len = np.max(seq_lens)
|
|
sum_seq_lens = np.sum(seq_lens)
|
|
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.arange(sum_seq_lens),
|
|
"b": np.arange(sum_seq_lens),
|
|
"seq_lens": seq_lens,
|
|
"state_in_0": [[0]] * num_seqs,
|
|
},
|
|
_max_seq_len=max_seq_len,
|
|
)
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
s1,
|
|
max_seq_len=max_seq_len,
|
|
feature_keys=["a", "b"],
|
|
view_requirements=view_requirements,
|
|
)
|
|
check(s1.max_seq_len, max_len)
|
|
check(s1["a"].shape[0], max_len * num_seqs)
|
|
check(s1["b"].shape[0], max_len * num_seqs)
|
|
|
|
def test_pad_batch_fixed_max(self):
|
|
"""Test pad_batch_to_sequences_of_same_size when dynamic_max = False"""
|
|
view_requirements = {
|
|
"state_in_0": ViewRequirement(
|
|
"state_out_0",
|
|
shift="-3:-1",
|
|
used_for_training=False,
|
|
used_for_compute_actions=True,
|
|
batch_repeat_value=1,
|
|
)
|
|
}
|
|
max_seq_len = 20
|
|
num_seqs = np.random.randint(1, 20)
|
|
seq_lens = np.random.randint(1, max_seq_len, size=(num_seqs))
|
|
sum_seq_lens = np.sum(seq_lens)
|
|
s1 = SampleBatch(
|
|
{
|
|
"a": np.arange(sum_seq_lens),
|
|
"b": np.arange(sum_seq_lens),
|
|
"seq_lens": seq_lens,
|
|
"state_in_0": [[0]] * num_seqs,
|
|
},
|
|
_max_seq_len=max_seq_len,
|
|
)
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
s1,
|
|
max_seq_len=max_seq_len,
|
|
feature_keys=["a", "b"],
|
|
view_requirements=view_requirements,
|
|
)
|
|
check(s1.max_seq_len, max_seq_len)
|
|
check(s1["a"].shape[0], max_seq_len * num_seqs)
|
|
check(s1["b"].shape[0], max_seq_len * num_seqs)
|
|
|
|
def test_add_time_dimension(self):
|
|
"""Test add_time_dimension gives sequential data along the time dimension"""
|
|
|
|
B, T, F = np.random.choice(
|
|
np.asarray(list(range(8, 32)), dtype=np.int32), # use int32 for seq_lens
|
|
size=3,
|
|
replace=False,
|
|
)
|
|
|
|
inputs_numpy = np.repeat(
|
|
np.arange(B * T)[:, np.newaxis], repeats=F, axis=-1
|
|
).astype(np.int32)
|
|
check(inputs_numpy.shape, (B * T, F))
|
|
|
|
time_shift_diff_batch_major = np.ones(shape=(B, T - 1, F), dtype=np.int32)
|
|
time_shift_diff_time_major = np.ones(shape=(T - 1, B, F), dtype=np.int32)
|
|
|
|
if tf is not None:
|
|
# Test tensorflow batch-major
|
|
padded_inputs = tf.constant(inputs_numpy)
|
|
batch_major_outputs = add_time_dimension(
|
|
padded_inputs, max_seq_len=T, framework="tf", time_major=False
|
|
)
|
|
check(batch_major_outputs.shape.as_list(), [B, T, F])
|
|
time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1]
|
|
check(time_shift_diff, time_shift_diff_batch_major)
|
|
|
|
if torch is not None:
|
|
# Test torch batch-major
|
|
padded_inputs = torch.from_numpy(inputs_numpy)
|
|
batch_major_outputs = add_time_dimension(
|
|
padded_inputs, max_seq_len=T, framework="torch", time_major=False
|
|
)
|
|
check(batch_major_outputs.shape, (B, T, F))
|
|
time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1]
|
|
check(time_shift_diff, time_shift_diff_batch_major)
|
|
|
|
# Test torch time-major
|
|
padded_inputs = torch.from_numpy(inputs_numpy)
|
|
time_major_outputs = add_time_dimension(
|
|
padded_inputs, max_seq_len=T, framework="torch", time_major=True
|
|
)
|
|
check(time_major_outputs.shape, (T, B, F))
|
|
time_shift_diff = time_major_outputs[1:] - time_major_outputs[:-1]
|
|
check(time_shift_diff, time_shift_diff_time_major)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|