mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Fix time dimension shaping for PyTorch RNN models. (#21735)
This commit is contained in:
parent
de0c6f6132
commit
377a522ce2
2 changed files with 61 additions and 5 deletions
|
@ -206,11 +206,13 @@ def add_time_dimension(
|
||||||
|
|
||||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||||
new_batch_size = padded_batch_size // max_seq_len
|
new_batch_size = padded_batch_size // max_seq_len
|
||||||
|
batch_major_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
|
||||||
|
padded_outputs = padded_inputs.view(batch_major_shape)
|
||||||
|
|
||||||
if time_major:
|
if time_major:
|
||||||
new_shape = (max_seq_len, new_batch_size) + padded_inputs.shape[1:]
|
# Swap the batch and time dimensions
|
||||||
else:
|
padded_outputs = padded_outputs.transpose(0, 1)
|
||||||
new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
|
return padded_outputs
|
||||||
return torch.reshape(padded_inputs, new_shape)
|
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
|
|
@ -2,12 +2,20 @@ import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
from ray.rllib.policy.rnn_sequencing import (
|
||||||
|
pad_batch_to_sequences_of_same_size,
|
||||||
|
add_time_dimension,
|
||||||
|
)
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||||
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||||
from ray.rllib.utils.test_utils import check
|
from ray.rllib.utils.test_utils import check
|
||||||
|
|
||||||
|
|
||||||
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
torch, nn = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
class TestRNNSequencing(unittest.TestCase):
|
class TestRNNSequencing(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls) -> None:
|
def setUpClass(cls) -> None:
|
||||||
|
@ -89,6 +97,52 @@ class TestRNNSequencing(unittest.TestCase):
|
||||||
check(s1["a"].shape[0], max_seq_len * num_seqs)
|
check(s1["a"].shape[0], max_seq_len * num_seqs)
|
||||||
check(s1["b"].shape[0], max_seq_len * num_seqs)
|
check(s1["b"].shape[0], max_seq_len * num_seqs)
|
||||||
|
|
||||||
|
def test_add_time_dimension(self):
|
||||||
|
"""Test add_time_dimension gives sequential data along the time dimension"""
|
||||||
|
|
||||||
|
B, T, F = np.random.choice(
|
||||||
|
np.asarray(list(range(8, 32)), dtype=np.int32), # use int32 for seq_lens
|
||||||
|
size=3,
|
||||||
|
replace=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_numpy = np.repeat(
|
||||||
|
np.arange(B * T)[:, np.newaxis], repeats=F, axis=-1
|
||||||
|
).astype(np.int32)
|
||||||
|
check(inputs_numpy.shape, (B * T, F))
|
||||||
|
|
||||||
|
time_shift_diff_batch_major = np.ones(shape=(B, T - 1, F), dtype=np.int32)
|
||||||
|
time_shift_diff_time_major = np.ones(shape=(T - 1, B, F), dtype=np.int32)
|
||||||
|
|
||||||
|
if tf is not None:
|
||||||
|
# Test tensorflow batch-major
|
||||||
|
padded_inputs = tf.constant(inputs_numpy)
|
||||||
|
batch_major_outputs = add_time_dimension(
|
||||||
|
padded_inputs, max_seq_len=T, framework="tf", time_major=False
|
||||||
|
)
|
||||||
|
check(batch_major_outputs.shape.as_list(), [B, T, F])
|
||||||
|
time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1]
|
||||||
|
check(time_shift_diff, time_shift_diff_batch_major)
|
||||||
|
|
||||||
|
if torch is not None:
|
||||||
|
# Test torch batch-major
|
||||||
|
padded_inputs = torch.from_numpy(inputs_numpy)
|
||||||
|
batch_major_outputs = add_time_dimension(
|
||||||
|
padded_inputs, max_seq_len=T, framework="torch", time_major=False
|
||||||
|
)
|
||||||
|
check(batch_major_outputs.shape, (B, T, F))
|
||||||
|
time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1]
|
||||||
|
check(time_shift_diff, time_shift_diff_batch_major)
|
||||||
|
|
||||||
|
# Test torch time-major
|
||||||
|
padded_inputs = torch.from_numpy(inputs_numpy)
|
||||||
|
time_major_outputs = add_time_dimension(
|
||||||
|
padded_inputs, max_seq_len=T, framework="torch", time_major=True
|
||||||
|
)
|
||||||
|
check(time_major_outputs.shape, (T, B, F))
|
||||||
|
time_shift_diff = time_major_outputs[1:] - time_major_outputs[:-1]
|
||||||
|
check(time_shift_diff, time_shift_diff_time_major)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
|
|
Loading…
Add table
Reference in a new issue