mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Only use TBXLoggerCallback when torch is installed (#16695)
* [tune] Only use TBXLoggerCallback when torch is installed * Fix lint * fix * Update python/ray/tune/utils/callback.py Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com> Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
parent
0811ae4231
commit
e5dfa4cfb9
2 changed files with 12 additions and 1 deletions
|
@ -11,6 +11,15 @@ from ray.tune.logger import CSVLoggerCallback, CSVLogger, LoggerCallback, \
|
|||
TBXLoggerCallback, TBXLogger
|
||||
from ray.tune.syncer import SyncerCallback
|
||||
|
||||
try:
|
||||
if "TUNE_TEST_NO_TORCH_IMPORT" in os.environ:
|
||||
_HAS_TORCH = False
|
||||
else:
|
||||
import torch # noqa: F401
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
_HAS_TORCH = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -108,7 +117,7 @@ def create_default_callbacks(callbacks: Optional[List[Callback]],
|
|||
if not has_json_logger:
|
||||
callbacks.append(JsonLoggerCallback())
|
||||
last_logger_index = len(callbacks) - 1
|
||||
if not has_tbx_logger:
|
||||
if not has_tbx_logger and _HAS_TORCH:
|
||||
callbacks.append(TBXLoggerCallback())
|
||||
last_logger_index = len(callbacks) - 1
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import sys
|
|||
if __name__ == "__main__":
|
||||
# Do not import torch for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
os.environ["TUNE_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "torch" not in sys.modules, \
|
||||
|
@ -31,5 +32,6 @@ if __name__ == "__main__":
|
|||
|
||||
# Clean up.
|
||||
del os.environ["RLLIB_TEST_NO_TORCH_IMPORT"]
|
||||
del os.environ["TUNE_TEST_NO_TORCH_IMPORT"]
|
||||
|
||||
print("ok")
|
||||
|
|
Loading…
Add table
Reference in a new issue