[tune] Only use TBXLoggerCallback when torch is installed (#16695)

* [tune] Only use TBXLoggerCallback when torch is installed

* Fix lint

* fix

* Update python/ray/tune/utils/callback.py

Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Travis Addair 2021-06-28 16:34:20 -07:00 committed by GitHub
parent 0811ae4231
commit e5dfa4cfb9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View file

@ -11,6 +11,15 @@ from ray.tune.logger import CSVLoggerCallback, CSVLogger, LoggerCallback, \
TBXLoggerCallback, TBXLogger
from ray.tune.syncer import SyncerCallback
try:
if "TUNE_TEST_NO_TORCH_IMPORT" in os.environ:
_HAS_TORCH = False
else:
import torch # noqa: F401
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
logger = logging.getLogger(__name__)
@ -108,7 +117,7 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
if not has_json_logger:
callbacks.append(JsonLoggerCallback())
last_logger_index = len(callbacks) - 1
if not has_tbx_logger:
if not has_tbx_logger and _HAS_TORCH:
callbacks.append(TBXLoggerCallback())
last_logger_index = len(callbacks) - 1

View file

@ -6,6 +6,7 @@ import sys
if __name__ == "__main__":
# Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
os.environ["TUNE_TEST_NO_TORCH_IMPORT"] = "1"
from ray.rllib.agents.a3c import A2CTrainer
assert "torch" not in sys.modules, \
@ -31,5 +32,6 @@ if __name__ == "__main__":
# Clean up.
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
del os.environ["TUNE_TEST_NO_TORCH_IMPORT"]
print("ok")