mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Issue 23897: add_time_dimension()
causes returned shape to be completely unknown. (#24006)
This commit is contained in:
parent
de9e143938
commit
9de391b70e
1 changed files with 3 additions and 1 deletions
|
@ -197,7 +197,9 @@ def add_time_dimension(
|
|||
axis=0,
|
||||
)
|
||||
)
|
||||
return tf.reshape(padded_inputs, new_shape)
|
||||
ret = tf.reshape(padded_inputs, new_shape)
|
||||
ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
|
||||
return ret
|
||||
else:
|
||||
assert framework == "torch", "`framework` must be either tf or torch!"
|
||||
padded_batch_size = padded_inputs.shape[0]
|
||||
|
|
Loading…
Add table
Reference in a new issue