mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] Add test case that we don't have a hard torch dependency (#7926)
This commit is contained in:
parent
85481d635d
commit
e8c19aba41
4 changed files with 42 additions and 8 deletions
|
@ -1005,6 +1005,13 @@ py_test(
|
|||
srcs = ["tests/test_dependency.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_dependency_torch",
|
||||
tags = ["tests_dir", "tests_dir_D"],
|
||||
size = "small",
|
||||
srcs = ["tests/test_dependency_torch.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_eager_support",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import logging
|
||||
import numpy as np
|
||||
from torch.optim import RMSprop
|
||||
from torch.distributions import Categorical
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
|
@ -244,6 +242,7 @@ class QMixTorchPolicy(Policy):
|
|||
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
|
||||
self.target_mixer, self.n_agents, self.n_actions,
|
||||
self.config["double_q"], self.config["gamma"])
|
||||
from torch.optim import RMSprop
|
||||
self.optimiser = RMSprop(
|
||||
params=self.params,
|
||||
lr=config["lr"],
|
||||
|
@ -283,6 +282,7 @@ class QMixTorchPolicy(Policy):
|
|||
random_numbers = torch.rand_like(q_values[:, :, 0])
|
||||
pick_random = (random_numbers < (self.cur_epsilon
|
||||
if explore else 0.0)).long()
|
||||
from torch.distributions import Categorical
|
||||
random_actions = Categorical(avail).sample().long()
|
||||
actions = (pick_random * random_actions +
|
||||
(1 - pick_random) * masked_q_values.argmax(dim=2))
|
||||
|
|
22
rllib/tests/test_dependency_torch.py
Executable file
22
rllib/tests/test_dependency_torch.py
Executable file
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Do not import torch for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "torch" not in sys.modules, \
|
||||
"Torch initially present, when it shouldn't."
|
||||
|
||||
# note: no ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"use_pytorch": False,
|
||||
"num_workers": 0
|
||||
})
|
||||
trainer.train()
|
||||
|
||||
assert "torch" not in sys.modules, "Torch should not be imported"
|
|
@ -101,7 +101,10 @@ def try_import_tfp(error=False):
|
|||
|
||||
# Fake module for torch.nn.
|
||||
class NNStub:
|
||||
pass
|
||||
def __init__(self, *a, **kw):
|
||||
# Fake nn.functional module within torch.nn.
|
||||
self.functional = None
|
||||
self.Module = ModuleStub
|
||||
|
||||
|
||||
# Fake class for torch.nn.Module to allow it to be inherited from.
|
||||
|
@ -120,7 +123,7 @@ def try_import_torch(error=False):
|
|||
"""
|
||||
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
|
||||
logger.warning("Not importing Torch for test purposes.")
|
||||
return None, None
|
||||
return _torch_stubs()
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
@ -129,10 +132,12 @@ def try_import_torch(error=False):
|
|||
except ImportError as e:
|
||||
if error:
|
||||
raise e
|
||||
return _torch_stubs()
|
||||
|
||||
nn = NNStub()
|
||||
nn.Module = ModuleStub
|
||||
return None, nn
|
||||
|
||||
def _torch_stubs():
|
||||
nn = NNStub()
|
||||
return None, nn
|
||||
|
||||
|
||||
def get_variable(value,
|
||||
|
@ -165,7 +170,7 @@ def get_variable(value,
|
|||
return tf.compat.v1.get_variable(
|
||||
tf_name, initializer=value, dtype=dtype, trainable=trainable)
|
||||
elif framework == "torch" and torch_tensor is True:
|
||||
import torch
|
||||
torch, _ = try_import_torch()
|
||||
var_ = torch.from_numpy(value)
|
||||
var_.requires_grad = trainable
|
||||
return var_
|
||||
|
|
Loading…
Add table
Reference in a new issue