[RLlib] Issue 23897: add_time_dimension() causes returned shape to be completely unknown. (#24006)

This commit is contained in:
Sven Mika 2022-04-19 17:56:56 +02:00 committed by GitHub
parent de9e143938
commit 9de391b70e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -197,7 +197,9 @@ def add_time_dimension(
axis=0,
)
)
return tf.reshape(padded_inputs, new_shape)
ret = tf.reshape(padded_inputs, new_shape)
ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
return ret
else:
assert framework == "torch", "`framework` must be either tf or torch!"
padded_batch_size = padded_inputs.shape[0]