[RLlib] Issue 24530: Fix add_time_dimension (#24531)

Co-authored-by: Daewoo Lee <dwlee@rtst.co.kr>
This commit is contained in:
Daewoo Lee 2022-05-06 22:21:42 +09:00 committed by GitHub
parent f48f1b252c
commit fee35444ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -187,19 +187,15 @@ def add_time_dimension(
padded_batch_size = tf.shape(padded_inputs)[0]
# Dynamically reshape the padded batch to introduce a time dimension.
new_batch_size = padded_batch_size // max_seq_len
new_shape = tf.squeeze(
tf.stack(
[
tf.expand_dims(new_batch_size, axis=0),
tf.expand_dims(max_seq_len, axis=0),
tf.shape(padded_inputs)[1:],
],
axis=0,
)
new_shape = tf.concat(
[
tf.expand_dims(new_batch_size, axis=0),
tf.expand_dims(max_seq_len, axis=0),
tf.shape(padded_inputs)[1:],
],
axis=0,
)
ret = tf.reshape(padded_inputs, new_shape)
ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
return ret
return tf.reshape(padded_inputs, new_shape)
else:
assert framework == "torch", "`framework` must be either tf or torch!"
padded_batch_size = padded_inputs.shape[0]