2021-10-18 22:27:46 -07:00
|
|
|
|
|
|
|
.. _train-api:
|
|
|
|
|
|
|
|
Ray Train API
|
|
|
|
=============
|
|
|
|
|
|
|
|
.. _train-api-trainer:
|
|
|
|
|
|
|
|
Trainer
|
|
|
|
-------
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.Trainer
|
|
|
|
:members:
|
|
|
|
|
|
|
|
.. _train-api-iterator:
|
|
|
|
|
|
|
|
TrainingIterator
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.TrainingIterator
|
|
|
|
:members:
|
|
|
|
|
|
|
|
.. _train-api-backend-config:
|
|
|
|
|
|
|
|
Backend Configurations
|
|
|
|
----------------------
|
|
|
|
|
|
|
|
.. _train-api-torch-config:
|
|
|
|
|
|
|
|
TorchConfig
|
|
|
|
~~~~~~~~~~~
|
|
|
|
|
2021-11-13 11:05:53 -08:00
|
|
|
.. autoclass:: ray.train.torch.TorchConfig
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2021-10-18 22:27:46 -07:00
|
|
|
|
|
|
|
.. _train-api-tensorflow-config:
|
|
|
|
|
|
|
|
TensorflowConfig
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
2021-11-13 11:05:53 -08:00
|
|
|
.. autoclass:: ray.train.tensorflow.TensorflowConfig
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2021-10-18 22:27:46 -07:00
|
|
|
|
|
|
|
.. _train-api-horovod-config:
|
|
|
|
|
|
|
|
HorovodConfig
|
|
|
|
~~~~~~~~~~~~~
|
|
|
|
|
2021-11-13 11:05:53 -08:00
|
|
|
.. autoclass:: ray.train.horovod.HorovodConfig
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2021-10-18 22:27:46 -07:00
|
|
|
|
2022-03-15 08:11:05 -07:00
|
|
|
.. _train-api-backend-interfaces:
|
|
|
|
|
|
|
|
Backend interfaces (for developers only)
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
Backend
|
|
|
|
+++++++
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.backend.Backend
|
|
|
|
|
|
|
|
BackendConfig
|
|
|
|
+++++++++++++
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.backend.BackendConfig
|
|
|
|
|
2021-10-18 22:27:46 -07:00
|
|
|
|
|
|
|
Callbacks
|
|
|
|
---------
|
|
|
|
|
|
|
|
.. _train-api-callback:
|
|
|
|
|
|
|
|
TrainingCallback
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.TrainingCallback
|
|
|
|
:members:
|
|
|
|
|
2022-01-03 16:03:04 -06:00
|
|
|
.. _train-api-print-callback:
|
|
|
|
|
|
|
|
PrintCallback
|
|
|
|
~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.callbacks.PrintCallback
|
|
|
|
|
2021-10-18 22:27:46 -07:00
|
|
|
.. _train-api-json-logger-callback:
|
|
|
|
|
|
|
|
JsonLoggerCallback
|
|
|
|
~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.callbacks.JsonLoggerCallback
|
|
|
|
|
|
|
|
.. _train-api-tbx-logger-callback:
|
|
|
|
|
|
|
|
TBXLoggerCallback
|
|
|
|
~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.callbacks.TBXLoggerCallback
|
|
|
|
|
2021-12-21 17:17:52 -08:00
|
|
|
.. _train-api-mlflow-logger-callback:
|
|
|
|
|
|
|
|
MLflowLoggerCallback
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback
|
|
|
|
|
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`.
```
| File "ray_sgd_training.py", line 18, in <module>
| from ray import train
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module>
| from ray.train.callbacks import TrainingCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module>
| from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module>
| from torch.profiler import profile
| ModuleNotFoundError: No module named 'torch.profiler'
```
A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes:
1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized.
2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed:
```
>>> import ray
>>> import ray.train
>>> import ray.train.torch
>>> from ray.train.torch import TorchWorkerProfiler
>>> twp = TorchWorkerProfiler()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__
"Torch Profiler requires torch>=1.8.1. "
ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler.
```
2022-02-14 16:16:55 -08:00
|
|
|
|
|
|
|
.. _train-api-torch-tensorboard-profiler-callback:
|
|
|
|
|
|
|
|
TorchTensorboardProfilerCallback
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.callbacks.TorchTensorboardProfilerCallback
|
|
|
|
|
2022-03-15 08:11:05 -07:00
|
|
|
.. _train-api-func-utils:
|
|
|
|
|
2021-10-18 22:27:46 -07:00
|
|
|
Training Function Utilities
|
|
|
|
---------------------------
|
|
|
|
|
|
|
|
train.report
|
|
|
|
~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.report
|
|
|
|
|
|
|
|
train.load_checkpoint
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.load_checkpoint
|
|
|
|
|
|
|
|
train.save_checkpoint
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.save_checkpoint
|
|
|
|
|
2022-03-15 08:11:05 -07:00
|
|
|
train.get_dataset_shard
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.get_dataset_shard
|
|
|
|
|
2021-10-18 22:27:46 -07:00
|
|
|
train.world_rank
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.world_rank
|
|
|
|
|
|
|
|
train.local_rank
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
2021-11-15 07:34:17 -08:00
|
|
|
.. autofunction:: ray.train.local_rank
|
|
|
|
|
|
|
|
train.world_size
|
|
|
|
~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.world_size
|
|
|
|
|
2021-11-29 12:42:28 -08:00
|
|
|
.. _train-api-torch-utils:
|
|
|
|
|
2021-11-15 07:34:17 -08:00
|
|
|
PyTorch Training Function Utilities
|
|
|
|
-----------------------------------
|
|
|
|
|
2022-04-04 16:14:35 -07:00
|
|
|
.. _train-api-torch-prepare-model:
|
|
|
|
|
2021-11-15 07:34:17 -08:00
|
|
|
train.torch.prepare_model
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.prepare_model
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2021-11-15 07:34:17 -08:00
|
|
|
|
2022-03-18 13:27:26 -07:00
|
|
|
.. _train-api-torch-prepare-data-loader:
|
|
|
|
|
2021-11-15 07:34:17 -08:00
|
|
|
train.torch.prepare_data_loader
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.prepare_data_loader
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2021-11-15 07:34:17 -08:00
|
|
|
|
2022-03-16 22:53:02 -05:00
|
|
|
train.torch.prepare_optimizer
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.prepare_optimizer
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2022-03-16 22:53:02 -05:00
|
|
|
|
|
|
|
|
|
|
|
train.torch.backward
|
|
|
|
~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.backward
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2022-03-16 22:53:02 -05:00
|
|
|
|
2022-04-04 16:14:35 -07:00
|
|
|
.. _train-api-torch-get-device:
|
2022-03-16 22:53:02 -05:00
|
|
|
|
2021-11-15 07:34:17 -08:00
|
|
|
train.torch.get_device
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2022-01-28 16:09:06 -08:00
|
|
|
.. autofunction:: ray.train.torch.get_device
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2022-01-28 16:09:06 -08:00
|
|
|
|
2022-03-15 13:07:34 -05:00
|
|
|
train.torch.enable_reproducibility
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.enable_reproducibility
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2022-03-15 13:07:34 -05:00
|
|
|
|
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`.
```
| File "ray_sgd_training.py", line 18, in <module>
| from ray import train
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module>
| from ray.train.callbacks import TrainingCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module>
| from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module>
| from torch.profiler import profile
| ModuleNotFoundError: No module named 'torch.profiler'
```
A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes:
1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized.
2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed:
```
>>> import ray
>>> import ray.train
>>> import ray.train.torch
>>> from ray.train.torch import TorchWorkerProfiler
>>> twp = TorchWorkerProfiler()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__
"Torch Profiler requires torch>=1.8.1. "
ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler.
```
2022-02-14 16:16:55 -08:00
|
|
|
.. _train-api-torch-worker-profiler:
|
|
|
|
|
2022-03-16 22:53:02 -05:00
|
|
|
train.torch.accelerate
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autofunction:: ray.train.torch.accelerate
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
2022-03-16 22:53:02 -05:00
|
|
|
|
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`.
```
| File "ray_sgd_training.py", line 18, in <module>
| from ray import train
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module>
| from ray.train.callbacks import TrainingCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module>
| from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module>
| from torch.profiler import profile
| ModuleNotFoundError: No module named 'torch.profiler'
```
A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes:
1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized.
2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed:
```
>>> import ray
>>> import ray.train
>>> import ray.train.torch
>>> from ray.train.torch import TorchWorkerProfiler
>>> twp = TorchWorkerProfiler()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__
"Torch Profiler requires torch>=1.8.1. "
ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler.
```
2022-02-14 16:16:55 -08:00
|
|
|
train.torch.TorchWorkerProfiler
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
|
|
|
.. autoclass:: ray.train.torch.TorchWorkerProfiler
|
|
|
|
:members:
|
2022-06-08 21:34:18 -07:00
|
|
|
:noindex:
|
[train] add TorchTensorboardProfilerCallback (#22345)
The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`.
```
| File "ray_sgd_training.py", line 18, in <module>
| from ray import train
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module>
| from ray.train.callbacks import TrainingCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module>
| from ray.train.callbacks.profile import TorchTensorboardProfilerCallback
| File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module>
| from torch.profiler import profile
| ModuleNotFoundError: No module named 'torch.profiler'
```
A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes:
1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized.
2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed:
```
>>> import ray
>>> import ray.train
>>> import ray.train.torch
>>> from ray.train.torch import TorchWorkerProfiler
>>> twp = TorchWorkerProfiler()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__
"Torch Profiler requires torch>=1.8.1. "
ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler.
```
2022-02-14 16:16:55 -08:00
|
|
|
|
2022-03-15 22:02:17 +01:00
|
|
|
.. _train-api-tensorflow-utils:
|
|
|
|
|
2022-01-28 16:09:06 -08:00
|
|
|
TensorFlow Training Function Utilities
|
|
|
|
--------------------------------------
|
|
|
|
|
|
|
|
train.tensorflow.prepare_dataset_shard
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
|
2022-06-08 21:34:18 -07:00
|
|
|
.. autofunction:: ray.train.tensorflow.prepare_dataset_shard
|
|
|
|
:noindex:
|