ray/doc/source/train/api.rst

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

227 lines
3.8 KiB
ReStructuredText
Raw Normal View History

.. _train-api:
Ray Train API
=============
.. _train-api-trainer:
Trainer
-------
.. autoclass:: ray.train.Trainer
:members:
.. _train-api-iterator:
TrainingIterator
~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.TrainingIterator
:members:
.. _train-api-backend-config:
Backend Configurations
----------------------
.. _train-api-torch-config:
TorchConfig
~~~~~~~~~~~
.. autoclass:: ray.train.torch.TorchConfig
:noindex:
.. _train-api-tensorflow-config:
TensorflowConfig
~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.tensorflow.TensorflowConfig
:noindex:
.. _train-api-horovod-config:
HorovodConfig
~~~~~~~~~~~~~
.. autoclass:: ray.train.horovod.HorovodConfig
:noindex:
.. _train-api-backend-interfaces:
Backend interfaces (for developers only)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Backend
+++++++
.. autoclass:: ray.train.backend.Backend
BackendConfig
+++++++++++++
.. autoclass:: ray.train.backend.BackendConfig
Callbacks
---------
.. _train-api-callback:
TrainingCallback
~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.TrainingCallback
:members:
.. _train-api-print-callback:
PrintCallback
~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.PrintCallback
.. _train-api-json-logger-callback:
JsonLoggerCallback
~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.JsonLoggerCallback
.. _train-api-tbx-logger-callback:
TBXLoggerCallback
~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.TBXLoggerCallback
.. _train-api-mlflow-logger-callback:
MLflowLoggerCallback
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.MLflowLoggerCallback
[train] add TorchTensorboardProfilerCallback (#22345) The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
2022-02-14 16:16:55 -08:00
.. _train-api-torch-tensorboard-profiler-callback:
TorchTensorboardProfilerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.callbacks.TorchTensorboardProfilerCallback
.. _train-api-func-utils:
Training Function Utilities
---------------------------
train.report
~~~~~~~~~~~~
.. autofunction:: ray.train.report
train.load_checkpoint
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.load_checkpoint
train.save_checkpoint
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.save_checkpoint
train.get_dataset_shard
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.get_dataset_shard
train.world_rank
~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.world_rank
train.local_rank
~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.local_rank
train.world_size
~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.world_size
.. _train-api-torch-utils:
PyTorch Training Function Utilities
-----------------------------------
.. _train-api-torch-prepare-model:
train.torch.prepare_model
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.prepare_model
:noindex:
.. _train-api-torch-prepare-data-loader:
train.torch.prepare_data_loader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.prepare_data_loader
:noindex:
train.torch.prepare_optimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.prepare_optimizer
:noindex:
train.torch.backward
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.backward
:noindex:
.. _train-api-torch-get-device:
train.torch.get_device
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.get_device
:noindex:
train.torch.enable_reproducibility
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.enable_reproducibility
:noindex:
[train] add TorchTensorboardProfilerCallback (#22345) The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
2022-02-14 16:16:55 -08:00
.. _train-api-torch-worker-profiler:
train.torch.accelerate
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.torch.accelerate
:noindex:
[train] add TorchTensorboardProfilerCallback (#22345) The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
2022-02-14 16:16:55 -08:00
train.torch.TorchWorkerProfiler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.train.torch.TorchWorkerProfiler
:members:
:noindex:
[train] add TorchTensorboardProfilerCallback (#22345) The [original PR](https://github.com/ray-project/ray/pull/21864) was [reverted](https://github.com/ray-project/ray/pull/22117) because it caused `torch` (more specifically, `torch>=1.8.1`) to be required to use `ray.train`. ``` | File "ray_sgd_training.py", line 18, in <module> | from ray import train | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/__init__.py", line 2, in <module> | from ray.train.callbacks import TrainingCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/__init__.py", line 8, in <module> | from ray.train.callbacks.profile import TorchTensorboardProfilerCallback | File "/home/ray/anaconda3/lib/python3.7/site-packages/ray/train/callbacks/profile.py", line 6, in <module> | from torch.profiler import profile | ModuleNotFoundError: No module named 'torch.profiler' ``` A [minimal installation test suite](https://github.com/ray-project/ray/pull/22300) was added to detect this. Further, in this PR we make the following changes: 1. Move `TorchWorkerProfiler` to `ray.train.torch` so all torch imports are centralized. 2. Add import validation logic to `TorchWorkerProfiler.__init__` so an exception will only be raised if the user tries to initialize a `TorchWorkerProfiler` without having a valid version of `torch` installed: ``` >>> import ray >>> import ray.train >>> import ray.train.torch >>> from ray.train.torch import TorchWorkerProfiler >>> twp = TorchWorkerProfiler() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/matt/workspace/ray/python/ray/train/torch.py", line 365, in __init__ "Torch Profiler requires torch>=1.8.1. " ImportError: Torch Profiler requires torch>=1.8.1. Run `pip install 'torch>=1.8.1'` to use TorchWorkerProfiler. ```
2022-02-14 16:16:55 -08:00
.. _train-api-tensorflow-utils:
TensorFlow Training Function Utilities
--------------------------------------
train.tensorflow.prepare_dataset_shard
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: ray.train.tensorflow.prepare_dataset_shard
:noindex: