[RLlib] Update alpha_zero_policy.py (#15042)

This commit is contained in:
Yeachan-Heo 2021-05-04 20:20:24 +09:00 committed by GitHub
parent 40fdedd3de
commit 0552f6e886
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -138,9 +138,9 @@ class AlphaZeroPolicy(TorchPolicy):
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
grad_info.update({
"total_loss": loss_out.detach().numpy(),
"policy_loss": policy_loss.detach().numpy(),
"value_loss": value_loss.detach().numpy()
"total_loss": loss_out.detach().cpu().numpy(),
"policy_loss": policy_loss.detach().cpu().numpy(),
"value_loss": value_loss.detach().cpu().numpy()
})
return {LEARNER_STATS_KEY: grad_info}