[rllib] Add test case that we don't have a hard torch dependency (#7926)

This commit is contained in:
Eric Liang 2020-04-07 18:07:39 -07:00 committed by GitHub
parent 85481d635d
commit e8c19aba41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 8 deletions

View file

@ -1005,6 +1005,13 @@ py_test(
srcs = ["tests/test_dependency.py"]
)
py_test(
name = "tests/test_dependency_torch",
tags = ["tests_dir", "tests_dir_D"],
size = "small",
srcs = ["tests/test_dependency_torch.py"]
)
py_test(
name = "tests/test_eager_support",
tags = ["tests_dir", "tests_dir_E"],

View file

@ -1,8 +1,6 @@
from gym.spaces import Tuple, Discrete, Dict
import logging
import numpy as np
from torch.optim import RMSprop
from torch.distributions import Categorical
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
@ -244,6 +242,7 @@ class QMixTorchPolicy(Policy):
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
self.target_mixer, self.n_agents, self.n_actions,
self.config["double_q"], self.config["gamma"])
from torch.optim import RMSprop
self.optimiser = RMSprop(
params=self.params,
lr=config["lr"],
@ -283,6 +282,7 @@ class QMixTorchPolicy(Policy):
random_numbers = torch.rand_like(q_values[:, :, 0])
pick_random = (random_numbers < (self.cur_epsilon
if explore else 0.0)).long()
from torch.distributions import Categorical
random_actions = Categorical(avail).sample().long()
actions = (pick_random * random_actions +
(1 - pick_random) * masked_q_values.argmax(dim=2))

View file

@ -0,0 +1,22 @@
#!/usr/bin/env python
import os
import sys
if __name__ == "__main__":
# Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
from ray.rllib.agents.a3c import A2CTrainer
assert "torch" not in sys.modules, \
"Torch initially present, when it shouldn't."
# note: no ray.init(), to test it works without Ray
trainer = A2CTrainer(
env="CartPole-v0", config={
"use_pytorch": False,
"num_workers": 0
})
trainer.train()
assert "torch" not in sys.modules, "Torch should not be imported"

View file

@ -101,7 +101,10 @@ def try_import_tfp(error=False):
# Fake module for torch.nn.
class NNStub:
pass
def __init__(self, *a, **kw):
# Fake nn.functional module within torch.nn.
self.functional = None
self.Module = ModuleStub
# Fake class for torch.nn.Module to allow it to be inherited from.
@ -120,7 +123,7 @@ def try_import_torch(error=False):
"""
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
logger.warning("Not importing Torch for test purposes.")
return None, None
return _torch_stubs()
try:
import torch
@ -129,10 +132,12 @@ def try_import_torch(error=False):
except ImportError as e:
if error:
raise e
return _torch_stubs()
nn = NNStub()
nn.Module = ModuleStub
return None, nn
def _torch_stubs():
nn = NNStub()
return None, nn
def get_variable(value,
@ -165,7 +170,7 @@ def get_variable(value,
return tf.compat.v1.get_variable(
tf_name, initializer=value, dtype=dtype, trainable=trainable)
elif framework == "torch" and torch_tensor is True:
import torch
torch, _ = try_import_torch()
var_ = torch.from_numpy(value)
var_.requires_grad = trainable
return var_