[rllib] Fix incorrect sequence length for rnn (#23830)

Update the torch policy to find the seq_lens using state_batches instead of input_dict. This helps handle the complex inputs to the model when the inbuilt preprocessing API is disabled.
This commit is contained in:
Kinal Mehta 2022-04-13 01:37:18 +05:30 committed by GitHub
parent 4cb6205726
commit 758e758c32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -317,9 +317,9 @@ class TorchPolicy(Policy):
# Calculate RNN sequence lengths.
seq_lens = (
torch.tensor(
[1] * len(input_dict["obs"]),
[1] * len(state_batches[0]),
dtype=torch.long,
device=input_dict["obs"].device,
device=state_batches[0].device,
)
if state_batches
else None