mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix incorrect sequence length for rnn (#23830)
Update the torch policy to find the seq_lens using state_batches instead of input_dict. This helps handle the complex inputs to the model when the inbuilt preprocessing API is disabled.
This commit is contained in:
parent
4cb6205726
commit
758e758c32
1 changed files with 2 additions and 2 deletions
|
@ -317,9 +317,9 @@ class TorchPolicy(Policy):
|
|||
# Calculate RNN sequence lengths.
|
||||
seq_lens = (
|
||||
torch.tensor(
|
||||
[1] * len(input_dict["obs"]),
|
||||
[1] * len(state_batches[0]),
|
||||
dtype=torch.long,
|
||||
device=input_dict["obs"].device,
|
||||
device=state_batches[0].device,
|
||||
)
|
||||
if state_batches
|
||||
else None
|
||||
|
|
Loading…
Add table
Reference in a new issue