[RLlib] Tf-eager policy bug fix: Duplicate model call in compute_gradients. (#12682)

This commit is contained in:
Sven Mika 2020-12-09 08:03:58 +01:00 committed by GitHub
parent cab46b7931
commit 28108c905b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -594,14 +594,6 @@ def build_eager_tf_policy(name,
self._is_training = True self._is_training = True
with tf.GradientTape(persistent=gradients_fn is not None) as tape: with tf.GradientTape(persistent=gradients_fn is not None) as tape:
# TODO: set seq len and state-in properly
state_in = []
for i in range(self.num_state_tensors()):
state_in.append(samples["state_in_{}".format(i)])
self._state_in = state_in
model_out, _ = self.model(samples, self._state_in,
samples.get("seq_lens"))
loss = loss_fn(self, self.model, self.dist_class, samples) loss = loss_fn(self, self.model, self.dist_class, samples)
variables = self.model.trainable_variables() variables = self.model.trainable_variables()