[RLLib] Add missing .to() for MARWIL on PyTorch (#10685)

There was a missing .to() that caused a device mismatch error on PyTorch with MARWIL.
This commit is contained in:
Julius Frost 2020-09-09 21:52:55 -04:00 committed by GitHub
parent f7d5aa46a3
commit e72838c03d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -70,7 +70,7 @@ def stats(policy, train_batch):
def setup_mixins(policy, obs_space, action_space, config):
# Create a var.
policy.ma_adv_norm = torch.tensor(
[100.0], dtype=torch.float32, requires_grad=False)
[100.0], dtype=torch.float32, requires_grad=False).to(policy.device)
# Setup Value branch of our NN.
ValueNetworkMixin.__init__(policy)