diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index af9f0397f..d8ce1be6e 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -148,7 +148,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): tf.get_variable_scope().name) def to_batches(tensor): - if self.config["model"]["use_lstm"]: + if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B else: diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index ace8f39be..e0716f274 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -230,7 +230,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): tf.get_variable_scope().name) def to_batches(tensor): - if self.config["model"]["use_lstm"]: + if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B else: