diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index fa2452d92..f10e71c2d 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -70,7 +70,7 @@ def stats(policy, train_batch): def setup_mixins(policy, obs_space, action_space, config): # Create a var. policy.ma_adv_norm = torch.tensor( - [100.0], dtype=torch.float32, requires_grad=False) + [100.0], dtype=torch.float32, requires_grad=False).to(policy.device) # Setup Value branch of our NN. ValueNetworkMixin.__init__(policy)