From 9de391b70e4531e60a560ccbf77681d520030856 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 19 Apr 2022 17:56:56 +0200 Subject: [PATCH] [RLlib] Issue 23897: `add_time_dimension()` causes returned shape to be completely unknown. (#24006) --- rllib/policy/rnn_sequencing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 575955caf..9bd83d669 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -197,7 +197,9 @@ def add_time_dimension( axis=0, ) ) - return tf.reshape(padded_inputs, new_shape) + ret = tf.reshape(padded_inputs, new_shape) + ret.set_shape([None, None] + padded_inputs.shape[1:].as_list()) + return ret else: assert framework == "torch", "`framework` must be either tf or torch!" padded_batch_size = padded_inputs.shape[0]