Added custom LSTM detection (#4087)

* Added autodetection of custom LSTM usage

* Reverted line separators

* Added check for LSTM

* Update vtrace_policy_graph.py

* Update appo_policy_graph.py
This commit is contained in:
Stefan Pantic 2019-02-22 06:07:48 +01:00 committed by Eric Liang
parent 692bb336a1
commit a54386e499
2 changed files with 2 additions and 2 deletions

View file

@ -148,7 +148,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
tf.get_variable_scope().name)
def to_batches(tensor):
if self.config["model"]["use_lstm"]:
if self.model.state_init:
B = tf.shape(self.model.seq_lens)[0]
T = tf.shape(tensor)[0] // B
else:

View file

@ -230,7 +230,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
tf.get_variable_scope().name)
def to_batches(tensor):
if self.config["model"]["use_lstm"]:
if self.model.state_init:
B = tf.shape(self.model.seq_lens)[0]
T = tf.shape(tensor)[0] // B
else: