mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
![]() The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ``` |
||
---|---|---|
.. | ||
examples | ||
api.rst | ||
architecture.rst | ||
examples.rst | ||
migration-guide.rst | ||
train-arch.svg | ||
train.rst | ||
user_guide.rst |