mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Tf-eager policy bug fix: Duplicate model call in compute_gradients. (#12682)
This commit is contained in:
parent
cab46b7931
commit
28108c905b
1 changed files with 0 additions and 8 deletions
|
@ -594,14 +594,6 @@ def build_eager_tf_policy(name,
|
|||
self._is_training = True
|
||||
|
||||
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
||||
# TODO: set seq len and state-in properly
|
||||
state_in = []
|
||||
for i in range(self.num_state_tensors()):
|
||||
state_in.append(samples["state_in_{}".format(i)])
|
||||
self._state_in = state_in
|
||||
|
||||
model_out, _ = self.model(samples, self._state_in,
|
||||
samples.get("seq_lens"))
|
||||
loss = loss_fn(self, self.model, self.dist_class, samples)
|
||||
|
||||
variables = self.model.trainable_variables()
|
||||
|
|
Loading…
Add table
Reference in a new issue