mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLLib] Add missing .to() for MARWIL on PyTorch (#10685)
There was a missing .to() that caused a device mismatch error on PyTorch with MARWIL.
This commit is contained in:
parent
f7d5aa46a3
commit
e72838c03d
1 changed files with 1 additions and 1 deletions
|
@ -70,7 +70,7 @@ def stats(policy, train_batch):
|
|||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
# Create a var.
|
||||
policy.ma_adv_norm = torch.tensor(
|
||||
[100.0], dtype=torch.float32, requires_grad=False)
|
||||
[100.0], dtype=torch.float32, requires_grad=False).to(policy.device)
|
||||
# Setup Value branch of our NN.
|
||||
ValueNetworkMixin.__init__(policy)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue