mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
Added custom LSTM detection (#4087)
* Added autodetection of custom LSTM usage * Reverted line separators * Added check for LSTM * Update vtrace_policy_graph.py * Update appo_policy_graph.py
This commit is contained in:
parent
692bb336a1
commit
a54386e499
2 changed files with 2 additions and 2 deletions
|
@ -148,7 +148,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
if self.config["model"]["use_lstm"]:
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
|
|
|
@ -230,7 +230,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
|||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
if self.config["model"]["use_lstm"]:
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue