mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
![]() The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ``` |
||
---|---|---|
.. | ||
_includes | ||
_static | ||
_templates | ||
cluster | ||
data | ||
images | ||
ray-contribute | ||
ray-core | ||
ray-design-patterns | ||
ray-more-libs | ||
ray-observability | ||
ray-overview | ||
ray-references | ||
raysgd | ||
rllib | ||
serve | ||
train | ||
tune | ||
workflows | ||
_toc.yml | ||
conf.py | ||
custom_directives.py | ||
index.md |