mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 24530: Fix add_time_dimension
(#24531)
Co-authored-by: Daewoo Lee <dwlee@rtst.co.kr>
This commit is contained in:
parent
f48f1b252c
commit
fee35444ab
1 changed files with 8 additions and 12 deletions
|
@ -187,19 +187,15 @@ def add_time_dimension(
|
|||
padded_batch_size = tf.shape(padded_inputs)[0]
|
||||
# Dynamically reshape the padded batch to introduce a time dimension.
|
||||
new_batch_size = padded_batch_size // max_seq_len
|
||||
new_shape = tf.squeeze(
|
||||
tf.stack(
|
||||
[
|
||||
tf.expand_dims(new_batch_size, axis=0),
|
||||
tf.expand_dims(max_seq_len, axis=0),
|
||||
tf.shape(padded_inputs)[1:],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
new_shape = tf.concat(
|
||||
[
|
||||
tf.expand_dims(new_batch_size, axis=0),
|
||||
tf.expand_dims(max_seq_len, axis=0),
|
||||
tf.shape(padded_inputs)[1:],
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret = tf.reshape(padded_inputs, new_shape)
|
||||
ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
|
||||
return ret
|
||||
return tf.reshape(padded_inputs, new_shape)
|
||||
else:
|
||||
assert framework == "torch", "`framework` must be either tf or torch!"
|
||||
padded_batch_size = padded_inputs.shape[0]
|
||||
|
|
Loading…
Add table
Reference in a new issue