mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Tf-eager policy bug fix: Duplicate model call in compute_gradients. (#12682)
This commit is contained in:
parent
cab46b7931
commit
28108c905b
1 changed files with 0 additions and 8 deletions
|
@ -594,14 +594,6 @@ def build_eager_tf_policy(name,
|
||||||
self._is_training = True
|
self._is_training = True
|
||||||
|
|
||||||
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
with tf.GradientTape(persistent=gradients_fn is not None) as tape:
|
||||||
# TODO: set seq len and state-in properly
|
|
||||||
state_in = []
|
|
||||||
for i in range(self.num_state_tensors()):
|
|
||||||
state_in.append(samples["state_in_{}".format(i)])
|
|
||||||
self._state_in = state_in
|
|
||||||
|
|
||||||
model_out, _ = self.model(samples, self._state_in,
|
|
||||||
samples.get("seq_lens"))
|
|
||||||
loss = loss_fn(self, self.model, self.dist_class, samples)
|
loss = loss_fn(self, self.model, self.dist_class, samples)
|
||||||
|
|
||||||
variables = self.model.trainable_variables()
|
variables = self.model.trainable_variables()
|
||||||
|
|
Loading…
Add table
Reference in a new issue