ray/doc/source/train/dl_guide.rst

1177 lines
41 KiB
ReStructuredText

.. _train-dl-guide:
Deep Learning User Guide
========================
This guide explains how to use Train to scale PyTorch, TensorFlow and Horovod.
In this guide, we cover examples for the following use cases:
* How do I :ref:`port my code <train-porting-code>` to use Ray Train?
* How do I use Ray Train to :ref:`train with a large dataset <train-datasets>`?
* How do I :ref:`monitor <train-monitoring>` my training?
* How do I run my training on pre-emptible instances
(:ref:`fault tolerance <train-fault-tolerance>`)?
* How do I :ref:`tune <train-tune>` my Ray Train model?
.. _train-backends:
Backends
--------
Ray Train provides a thin API around different backend frameworks for
distributed deep learning. At the moment, Ray Train allows you to perform
training with:
* **PyTorch:** Ray Train initializes your distributed process group, allowing
you to run your ``DistributedDataParallel`` training script. See `PyTorch
Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`_
for more information.
* **TensorFlow:** Ray Train configures ``TF_CONFIG`` for you, allowing you to run
your ``MultiWorkerMirroredStrategy`` training script. See `Distributed
training with TensorFlow <https://www.tensorflow.org/guide/distributed_training>`_
for more information.
* **Horovod:** Ray Train configures the Horovod environment and Rendezvous
server for you, allowing you to run your ``DistributedOptimizer`` training
script. See `Horovod documentation <https://horovod.readthedocs.io/en/stable/index.html>`_
for more information.
.. _train-porting-code:
Porting code to Ray Train
-------------------------
The following instructions assume you have a training function
that can already be run on a single worker for one of the supported
:ref:`backend <train-backends>` frameworks.
Update training function
~~~~~~~~~~~~~~~~~~~~~~~~
First, you'll want to update your training function to support distributed
training.
.. tabbed:: PyTorch
Ray Train will set up your distributed process group for you and also provides utility methods
to automatically prepare your model and data for distributed training.
.. note::
Ray Train will still work even if you don't use the ``prepare_model`` and ``prepare_data_loader`` utilities below,
and instead handle the logic directly inside your training function.
First, use the ``prepare_model`` function to automatically move your model to the right device and wrap it in
``DistributedDataParallel``
.. code-block:: diff
import torch
from torch.nn.parallel import DistributedDataParallel
+from ray import train
+import ray.train.torch
def train_func():
- device = torch.device(f"cuda:{train.local_rank()}" if
- torch.cuda.is_available() else "cpu")
- torch.cuda.set_device(device)
# Create model.
model = NeuralNetwork()
- model = model.to(device)
- model = DistributedDataParallel(model,
- device_ids=[train.local_rank()] if torch.cuda.is_available() else None)
+ model = train.torch.prepare_model(model)
...
Then, use the ``prepare_data_loader`` function to automatically add a ``DistributedSampler`` to your ``DataLoader``
and move the batches to the right device.
.. code-block:: diff
import torch
from torch.utils.data import DataLoader, DistributedSampler
+from ray import train
+import ray.train.torch
def train_func():
- device = torch.device(f"cuda:{train.local_rank()}" if
- torch.cuda.is_available() else "cpu")
- torch.cuda.set_device(device)
...
- data_loader = DataLoader(my_dataset, batch_size=worker_batch_size, sampler=DistributedSampler(dataset))
+ data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
+ data_loader = train.torch.prepare_data_loader(data_loader)
for X, y in data_loader:
- X = X.to_device(device)
- y = y.to_device(device)
.. tip::
Keep in mind that ``DataLoader`` takes in a ``batch_size`` which is the batch size for each worker.
The global batch size can be calculated from the worker batch size (and vice-versa) with the following equation:
.. code-block::
global_batch_size = worker_batch_size * train.world_size()
.. tabbed:: TensorFlow
.. note::
The current TensorFlow implementation supports
``MultiWorkerMirroredStrategy`` (and ``MirroredStrategy``). If there are
other strategies you wish to see supported by Ray Train, please let us know
by submitting a `feature request on GitHub <https://github.com/ray-project/ray/issues>`_.
These instructions closely follow TensorFlow's `Multi-worker training
with Keras <https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras>`_
tutorial. One key difference is that Ray Train will handle the environment
variable set up for you.
**Step 1:** Wrap your model in ``MultiWorkerMirroredStrategy``.
The `MultiWorkerMirroredStrategy <https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy>`_
enables synchronous distributed training. The ``Model`` *must* be built and
compiled within the scope of the strategy.
.. code-block:: python
with tf.distribute.MultiWorkerMirroredStrategy().scope():
model = ... # build model
model.compile()
**Step 2:** Update your ``Dataset`` batch size to the *global* batch
size.
The `batch <https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch>`_
will be split evenly across worker processes, so ``batch_size`` should be
set appropriately.
.. code-block:: diff
-batch_size = worker_batch_size
+batch_size = worker_batch_size * train.world_size()
.. tabbed:: Horovod
If you have a training function that already runs with the `Horovod Ray
Executor <https://horovod.readthedocs.io/en/stable/ray_include.html#horovod-ray-executor>`_,
you should not need to make any additional changes!
To onboard onto Horovod, please visit the `Horovod guide
<https://horovod.readthedocs.io/en/stable/index.html#get-started>`_.
Create Ray Train Trainer
~~~~~~~~~~~~~~~~~~~~~~~~
``Trainer``\s are the primary Ray Train classes that are used to manage state and
execute training. You can create a simple ``Trainer`` for the backend of choice
with one of the following:
.. tabbed:: PyTorch
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.torch import TorchTrainer
# For GPU Training, set `use_gpu` to True.
use_gpu = False
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
)
.. tabbed:: TensorFlow
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
# For GPU Training, set `use_gpu` to True.
use_gpu = False
trainer = TensorflowTrainer(
train_func,
scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
)
.. tabbed:: Horovod
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.horovod import HorovodTrainer
# For GPU Training, set `use_gpu` to True.
use_gpu = False
trainer = HorovodTrainer(
train_func,
scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
)
To customize the backend setup, you can use a :ref:`train-api-backend-config` object.
.. tabbed:: PyTorch
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.torch import TorchTrainer, TorchConfig
trainer = TorchTrainer(
train_func,
torch_backend=TorchConfig(...),
scaling_config=ScalingConfig(num_workers=2),
)
.. tabbed:: TensorFlow
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer, TensorflowConfig
trainer = TensorflowTrainer(
train_func,
tensorflow_backend=TensorflowConfig(...),
scaling_config=ScalingConfig(num_workers=2),
)
.. tabbed:: Horovod
.. code-block:: python
from ray.air import ScalingConfig
from ray.train.horovod import HorovodTrainer, HorovodConfig
trainer = HorovodTrainer(
train_func,
tensorflow_backend=HorovodConfig(...),
scaling_config=ScalingConfig(num_workers=2),
)
For more configurability, please reference the :class:`BaseTrainer` API.
Run training function
~~~~~~~~~~~~~~~~~~~~~
With a distributed training function and a Ray Train ``Trainer``, you are now
ready to start training!
.. code-block:: python
trainer.fit()
Configuring Training
--------------------
With Ray Train, you can execute a training function (``train_func``) in a
distributed manner by calling ``Trainer.fit``. To pass arguments
into the training function, you can expose a single ``config`` dictionary parameter:
.. code-block:: diff
-def train_func():
+def train_func(config):
Then, you can pass in the config dictionary as an argument to ``Trainer``:
.. code-block:: diff
+config = {} # This should be populated.
trainer = TorchTrainer(
train_func,
+ train_loop_config=config,
scaling_config=ScalingConfig(num_workers=2)
)
Putting this all together, you can run your training function with different
configurations. As an example:
.. code-block:: python
from ray.air import session, ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
for i in range(config["num_epochs"]):
session.report({"epoch": i})
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 2},
scaling_config=ScalingConfig(num_workers=2)
)
result = trainer.fit()
print(result.metrics["num_epochs"])
# 1
A primary use-case for ``config`` is to try different hyperparameters. To
perform hyperparameter tuning with Ray Train, please refer to the
:ref:`Ray Tune integration <train-tune>`.
.. TODO add support for with_parameters
.. _train-result-object:
Accessing Training Results
--------------------------
.. TODO(ml-team) Flesh this section out.
The return of a ``Trainer.fit`` is a :class:`Result` object, containing
information about the training run. You can access it to obtain saved checkpoints,
metrics and other relevant data.
For example, you can:
* Print the metrics for the last training iteration:
.. code-block:: python
from pprint import pprint
pprint(result.metrics)
# {'_time_this_iter_s': 0.001016855239868164,
# '_timestamp': 1657829125,
# '_training_iteration': 2,
# 'config': {},
# 'date': '2022-07-14_20-05-25',
# 'done': True,
# 'episodes_total': None,
# 'epoch': 1,
# 'experiment_id': '5a3f8b9bf875437881a8ddc7e4dd3340',
# 'experiment_tag': '0',
# 'hostname': 'ip-172-31-43-110',
# 'iterations_since_restore': 2,
# 'node_ip': '172.31.43.110',
# 'pid': 654068,
# 'time_since_restore': 3.4353830814361572,
# 'time_this_iter_s': 0.00809168815612793,
# 'time_total_s': 3.4353830814361572,
# 'timestamp': 1657829125,
# 'timesteps_since_restore': 0,
# 'timesteps_total': None,
# 'training_iteration': 2,
# 'trial_id': '4913f_00000',
# 'warmup_time': 0.003167867660522461}
* View the dataframe containing the metrics from all iterations:
.. code-block:: python
print(result.metrics_dataframe)
* Obtain the :class:`Checkpoint`, used for resuming training, prediction and serving.
.. code-block:: python
result.checkpoint # last saved checkpoint
result.best_checkpoints # N best saved checkpoints, as configured in run_config
.. _train-log-dir:
Log Directory Structure
~~~~~~~~~~~~~~~~~~~~~~~
Each ``Trainer`` will have a local directory created for logs and checkpoints.
You can obtain the path to the directory by accessing the ``log_dir`` attribute
of the :class:`Result` object returned by ``Trainer.fit``.
.. code-block:: python
print(result.log_dir)
# '/home/ubuntu/ray_results/TorchTrainer_2022-06-13_20-31-06/checkpoint_000003'
.. _train-datasets:
Distributed Data Ingest with Ray Datasets
-----------------------------------------
:ref:`Ray Datasets <datasets>` are the recommended way to work with large datasets in Ray Train. Datasets provides automatic loading, sharding, and pipelined ingest (optional) of Data across multiple Train workers.
To get started, pass in one or more datasets under the ``datasets`` keyword argument for Trainer (e.g., ``Trainer(datasets={...})``).
Here's a simple code overview of the Datasets integration:
.. code-block:: python
from ray.air import session
# Datasets can be accessed in your train_func via ``get_dataset_shard``.
def train_func(config):
train_data_shard = session.get_dataset_shard("train")
validation_data_shard = session.get_dataset_shard("validation")
...
# Random split the dataset into 80% training data and 20% validation data.
dataset = ray.data.read_csv("...")
train_dataset, validation_dataset = dataset.train_test_split(
test_size=0.2, shuffle=True,
)
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset, "validation": validation_dataset},
scaling_config=ScalingConfig(num_workers=8),
)
trainer.fit()
For more details on how to configure data ingest for Train, please refer to :ref:`air-ingest`.
.. TODO link to Training Run Iterator API as a 3rd option for logging.
.. _train-monitoring:
Logging, Checkpointing and Callbacks
------------------------------------
Ray Train has mechanisms to easily collect intermediate results from the training workers during the training run
and also has a :ref:`Callback interface <train-callbacks>` to perform actions on these intermediate results (such as logging, aggregations, etc.).
You can use either the :ref:`built-in callbacks <air-builtin-callbacks>` that Ray AIR provides,
or implement a :ref:`custom callback <train-custom-callbacks>` for your use case. The callback API
is shared with Ray Tune.
.. _train-checkpointing:
Ray Train also provides a way to save :ref:`Checkpoints <air-checkpoints-doc>` during the training process. This is
useful for:
1. :ref:`Integration with Ray Tune <train-tune>` to use certain Ray Tune
schedulers.
2. Running a long-running training job on a cluster of pre-emptible machines/pods.
3. Persisting trained model state to later use for serving/inference.
4. In general, storing any model artifacts.
Reporting intermediate results and handling checkpoints
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Ray AIR provides a *Session* API for reporting intermediate
results and checkpoints from the training function (run on distributed workers) up to the
``Trainer`` (where your python script is executed) by calling ``session.report(metrics)``.
The results will be collected from the distributed workers and passed to the driver to
be logged and displayed.
.. warning::
Only the results from rank 0 worker will be used. However, in order to ensure
consistency, ``session.report()`` has to be called on each worker.
The primary use-case for reporting is for metrics (accuracy, loss, etc.) at
the end of each training epoch.
.. code-block:: python
from ray.air import session
def train_func():
...
for i in range(num_epochs):
result = model.train(...)
session.report({"result": result})
The session concept exists on several levels: The execution layer (called `Tune Session`) and the Data Parallel training layer
(called `Train Session`).
The following figure shows how these two sessions look like in a Data Parallel training scenario.
.. image:: ../ray-air/images/session.svg
:width: 650px
:align: center
..
https://docs.google.com/drawings/d/1g0pv8gqgG29aPEPTcd4BC0LaRNbW1sAkv3H6W1TCp0c/edit
Saving checkpoints
++++++++++++++++++
:ref:`Checkpoints <air-checkpoints-doc>` can be saved by calling ``session.report(metrics, checkpoint=Checkpoint(...))`` in the
training function. This will cause the checkpoint state from the distributed
workers to be saved on the ``Trainer`` (where your python script is executed).
The latest saved checkpoint can be accessed through the ``checkpoint`` attribute of
the :class:`Result`, and the best saved checkpoints can be accessed by the ``best_checkpoints``
attribute.
Concrete examples are provided to demonstrate how checkpoints (model weights but not models) are saved
appropriately in distributed training.
.. tabbed:: PyTorch
.. code-block:: python
:emphasize-lines: 36, 37, 38, 39, 40, 41
import ray.train.torch
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np
def train_func(config):
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
# toy neural network : 1-layer
# wrap the model in DDP
model = ray.train.torch.prepare_model(nn.Linear(4, 1))
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=3e-4)
for epoch in range(config["num_epochs"]):
y = model.forward(X)
# compute loss
loss = criterion(y, Y)
# back-propagate loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# To fetch non-DDP state_dict
# w/o DDP: model.state_dict()
# w/ DDP: model.module.state_dict()
# See: https://github.com/ray-project/ray/issues/20915
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
session.report({}, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 5},
scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.checkpoint.to_dict())
# {'epoch': 4, 'model_weights': OrderedDict([('bias', tensor([-0.1215])), ('weight', tensor([[0.3253, 0.1979, 0.4525, 0.2850]]))]), '_timestamp': 1656107095, '_preprocessor': None, '_current_checkpoint_id': 4}
.. tabbed:: TensorFlow
.. code-block:: python
:emphasize-lines: 23
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
import numpy as np
def train_func(config):
import tensorflow as tf
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = np.random.normal(0, 1, size=(n, 4))
Y = np.random.uniform(0, 1, size=(n, 1))
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
# toy neural network : 1-layer
model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
for epoch in range(config["num_epochs"]):
model.fit(X, Y, batch_size=20)
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=model.get_weights())
)
session.report({}, checkpoint=checkpoint)
trainer = TensorflowTrainer(
train_func,
train_loop_config={"num_epochs": 5},
scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.checkpoint.to_dict())
# {'epoch': 4, 'model_weights': [array([[-0.31858477],
# [ 0.03747174],
# [ 0.28266194],
# [ 0.8626015 ]], dtype=float32), array([0.02230084], dtype=float32)], '_timestamp': 1656107383, '_preprocessor': None, '_current_checkpoint_id': 4}
By default, checkpoints will be persisted to local disk in the :ref:`log
directory <train-log-dir>` of each run.
.. code-block:: python
print(result.checkpoint.get_internal_representation())
# ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000003')
Configuring checkpoints
+++++++++++++++++++++++
For more configurability of checkpointing behavior (specifically saving
checkpoints to disk), a :class:`CheckpointConfig` can be passed into
``Trainer``.
As an example, to completely disable writing checkpoints to disk:
.. code-block:: python
:emphasize-lines: 9,14
from ray import train
from ray.air import RunConfig, CheckpointConfig, ScalingConfig
from ray.train.torch import TorchTrainer
def train_func():
for epoch in range(3):
train.save_checkpoint(epoch=epoch)
checkpoint_config = CheckpointConfig(num_to_keep=0)
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(checkpoint_config=checkpoint_config)
)
trainer.fit()
You may also config ``CheckpointConfig`` to keep the "N best" checkpoints persisted to disk. The following example shows how you could keep the 2 checkpoints with the lowest "loss" value:
.. code-block:: python
from ray.air import session, Checkpoint, RunConfig, CheckpointConfig, ScalingConfig
from ray.train.torch import TorchTrainer
def train_func():
# first checkpoint
session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=2)))
# second checkpoint
session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=4)))
# third checkpoint
session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=1)))
# fourth checkpoint
session.report(dict(loss=2), checkpoint=Checkpoint.from_dict(dict(loss=3)))
# Keep the 2 checkpoints with the smallest "loss" value.
checkpoint_config = CheckpointConfig(
num_to_keep=2, checkpoint_score_attribute="loss", checkpoint_score_order="min"
)
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=2),
run_config=RunConfig(checkpoint_config=checkpoint_config),
)
result = trainer.fit()
print(result.best_checkpoints[0][0].get_internal_representation())
# ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000000')
print(result.best_checkpoints[1][0].get_internal_representation())
# ('local_path', '/home/ubuntu/ray_results/TorchTrainer_2022-06-24_21-34-49/TorchTrainer_7988b_00000_0_2022-06-24_21-34-49/checkpoint_000002')
Loading checkpoints
+++++++++++++++++++
Checkpoints can be loaded into the training function in 2 steps:
1. From the training function, ``session.get_checkpoint`` can be used to access
the most recently saved :class:`Checkpoint`. This is useful to continue training even
if there's a worker failure.
2. The checkpoint to start training with can be bootstrapped by passing in a
:class:`Checkpoint` to ``Trainer`` as the ``resume_from_checkpoint`` argument.
.. tabbed:: PyTorch
.. code-block:: python
:emphasize-lines: 23, 25, 26, 29, 30, 31, 35
import ray.train.torch
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from torch.optim import Adam
import numpy as np
def train_func(config):
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
# toy neural network : 1-layer
model = nn.Linear(4, 1)
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=3e-4)
start_epoch = 0
checkpoint = session.get_checkpoint()
if checkpoint:
# assume that we have run the session.report() example
# and successfully save some model weights
checkpoint_dict = checkpoint.to_dict()
model.load_state_dict(checkpoint_dict.get("model_weights"))
start_epoch = checkpoint_dict.get("epoch", -1) + 1
# wrap the model in DDP
model = ray.train.torch.prepare_model(model)
for epoch in range(start_epoch, config["num_epochs"]):
y = model.forward(X)
# compute loss
loss = criterion(y, Y)
# back-propagate loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=state_dict)
)
session.report({}, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 2},
scaling_config=ScalingConfig(num_workers=2),
)
# save a checkpoint
result = trainer.fit()
# load checkpoint
trainer = TorchTrainer(
train_func,
train_loop_config={"num_epochs": 4},
scaling_config=ScalingConfig(num_workers=2),
resume_from_checkpoint=result.checkpoint,
)
result = trainer.fit()
print(result.checkpoint.to_dict())
# {'epoch': 3, 'model_weights': OrderedDict([('bias', tensor([0.0902])), ('weight', tensor([[-0.1549, -0.0861, 0.4353, -0.4116]]))]), '_timestamp': 1656108265, '_preprocessor': None, '_current_checkpoint_id': 2}
.. tabbed:: TensorFlow
.. code-block:: python
:emphasize-lines: 15, 21, 22, 25, 26, 27, 30
from ray.air import session, Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
import numpy as np
def train_func(config):
import tensorflow as tf
n = 100
# create a toy dataset
# data : X - dim = (n, 4)
# target : Y - dim = (n, 1)
X = np.random.normal(0, 1, size=(n, 4))
Y = np.random.uniform(0, 1, size=(n, 1))
start_epoch = 0
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
# toy neural network : 1-layer
model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
checkpoint = session.get_checkpoint()
if checkpoint:
# assume that we have run the session.report() example
# and successfully save some model weights
checkpoint_dict = checkpoint.to_dict()
model.set_weights(checkpoint_dict.get("model_weights"))
start_epoch = checkpoint_dict.get("epoch", -1) + 1
model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
for epoch in range(start_epoch, config["num_epochs"]):
model.fit(X, Y, batch_size=20)
checkpoint = Checkpoint.from_dict(
dict(epoch=epoch, model_weights=model.get_weights())
)
session.report({}, checkpoint=checkpoint)
trainer = TensorflowTrainer(
train_func,
train_loop_config={"num_epochs": 2},
scaling_config=ScalingConfig(num_workers=2),
)
# save a checkpoint
result = trainer.fit()
# load a checkpoint
trainer = TensorflowTrainer(
train_func,
train_loop_config={"num_epochs": 5},
scaling_config=ScalingConfig(num_workers=2),
resume_from_checkpoint=result.checkpoint,
)
result = trainer.fit()
print(result.checkpoint.to_dict())
# {'epoch': 4, 'model_weights': [array([[-0.70056134],
# [-0.8839263 ],
# [-1.0043601 ],
# [-0.61634773]], dtype=float32), array([0.01889327], dtype=float32)], '_timestamp': 1656108446, '_preprocessor': None, '_current_checkpoint_id': 3}
.. _train-callbacks:
Callbacks
~~~~~~~~~
You may want to plug in your training code with your favorite experiment management framework.
Ray AIR provides an interface to fetch intermediate results and callbacks to process/log your intermediate results
(the values passed into ``session.report(...)``).
Ray AIR contains :ref:`built-in callbacks <air-builtin-callbacks>` for popular tracking frameworks, or you can implement your own callback via the :ref:`Callback <tune-callbacks-docs>` interface.
Example: Logging to MLflow and TensorBoard
++++++++++++++++++++++++++++++++++++++++++
**Step 1: Install the necessary packages**
.. code-block:: bash
$ pip install mlflow
$ pip install tensorboardX
**Step 2: Run the following training script**
.. literalinclude:: /../../python/ray/train/examples/mlflow_simple_example.py
:language: python
.. _train-custom-callbacks:
Custom Callbacks
++++++++++++++++
If the provided callbacks do not cover your desired integrations or use-cases,
you may always implement a custom callback by subclassing ``Callback``. If
the callback is general enough, please feel welcome to :ref:`add it <getting-involved>`
to the ``ray`` `repository <https://github.com/ray-project/ray>`_.
A simple example for creating a callback that will print out results:
.. code-block:: python
from typing import List, Dict
from ray.air import session, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.tune.logger import LoggerCallback
# LoggerCallback is a higher level API of Callback.
class LoggingCallback(LoggerCallback):
def __init__(self) -> None:
self.results = []
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
self.results.append(trial.last_result)
def train_func():
for i in range(3):
session.report({"epoch": i})
callback = LoggingCallback()
trainer = TorchTrainer(
train_func,
run_config=RunConfig(callbacks=[callback]),
scaling_config=ScalingConfig(num_workers=2),
)
trainer.fit()
print("\n".join([str(x) for x in callback.results]))
# {'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-28', 'timestamp': 1656349408, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}}
# {'epoch': 0, '_timestamp': 1656349412, '_time_this_iter_s': 0.0026497840881347656, '_training_iteration': 1, 'time_this_iter_s': 3.433483362197876, 'done': False, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 1, 'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-32', 'timestamp': 1656349412, 'time_total_s': 3.433483362197876, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 3.433483362197876, 'timesteps_since_restore': 0, 'iterations_since_restore': 1, 'warmup_time': 0.003779172897338867, 'experiment_tag': '0'}
# {'epoch': 1, '_timestamp': 1656349412, '_time_this_iter_s': 0.0013833045959472656, '_training_iteration': 2, 'time_this_iter_s': 0.016670703887939453, 'done': False, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 2, 'trial_id': '0f1d0_00000', 'experiment_id': '494a1d050b4a4d11aeabd87ba475fcd3', 'date': '2022-06-27_17-03-32', 'timestamp': 1656349412, 'time_total_s': 3.4501540660858154, 'pid': 23018, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 3.4501540660858154, 'timesteps_since_restore': 0, 'iterations_since_restore': 2, 'warmup_time': 0.003779172897338867, 'experiment_tag': '0'}
Example: PyTorch Distributed metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In real applications, you may want to calculate optimization metrics besides
accuracy and loss: recall, precision, Fbeta, etc.
Ray Train natively supports `TorchMetrics <https://torchmetrics.readthedocs.io/en/latest/>`_, which provides a collection of machine learning metrics for distributed, scalable PyTorch models.
Here is an example:
.. code-block:: python
from typing import List, Dict
from ray.air import session, ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torchmetrics
def train_func(config):
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
accuracy = torchmetrics.functional.accuracy(preds, target).item()
session.report({"accuracy": accuracy})
trainer = TorchTrainer(train_func, scaling_config=ScalingConfig(num_workers=2))
result = trainer.fit()
print(result.metrics["accuracy"])
# 0.20000000298023224
.. Running on the cloud
.. --------------------
.. Use Ray Train with the Ray cluster launcher by changing the following:
.. .. code-block:: bash
.. ray up cluster.yaml
.. TODO.
.. _train-fault-tolerance:
Fault Tolerance & Elastic Training
----------------------------------
Ray Train has built-in fault tolerance to recover from worker failures (i.e.
``RayActorError``\s). When a failure is detected, the workers will be shut
down and new workers will be added in. The training function will be
restarted, but progress from the previous execution can be resumed through
checkpointing.
.. warning:: In order to retain progress when recovery, your training function
**must** implement logic for both saving *and* loading :ref:`checkpoints
<train-checkpointing>`.
Each instance of recovery from a worker failure is considered a retry. The
number of retries is configurable through the ``max_failures`` attribute of the
``failure_config`` argument set in the ``run_config`` argument passed to the
``Trainer``.
.. note:: Elastic Training is not yet supported.
.. Running on pre-emptible machines
.. --------------------------------
.. You may want to
.. TODO.
.. We do not have a profiling callback in AIR as the execution engine has changed to Tune. The behavior of the callback can be replicated with checkpoints (do a trace, save it to checkpoint, it gets downloaded to driver every iteration).
.. .. _train-profiling:
.. Profiling
.. ---------
.. Ray Train comes with an integration with `PyTorch Profiler <https://pytorch.org/blog/introducing-pytorch-profiler-the-new-and-improved-performance-tool/>`_.
.. Specifically, it comes with a :ref:`TorchWorkerProfiler <train-api-torch-worker-profiler>` utility class and :ref:`train-api-torch-tensorboard-profiler-callback` callback
.. that allow you to use the PyTorch Profiler as you would in a non-distributed PyTorch script, and synchronize the generated Tensorboard traces onto
.. the disk that from which your script was executed from.
.. **Step 1: Update training function with** ``TorchWorkerProfiler``
.. .. code-block:: bash
.. from ray.train.torch import TorchWorkerProfiler
.. def train_func():
.. twp = TorchWorkerProfiler()
.. with profile(..., on_trace_ready=twp.trace_handler) as p:
.. ...
.. profile_results = twp.get_and_clear_profile_traces()
.. train.report(..., **profile_results)
.. ...
.. **Step 2: Run training function with** ``TorchTensorboardProfilerCallback``
.. .. code-block:: python
.. from ray.train import Trainer
.. from ray.train.callbacks import TorchTensorboardProfilerCallback
.. trainer = Trainer(backend="torch", num_workers=2)
.. trainer.start()
.. trainer.run(train_func, callbacks=[TorchTensorboardProfilerCallback()])
.. trainer.shutdown()
.. **Step 3: Visualize the logs**
.. .. code-block:: bash
.. # Navigate to the run directory of the trainer.
.. # For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001/pytorch_profiler`
.. $ cd <TRAINER_RUN_DIR>/pytorch_profiler
.. # Install the PyTorch Profiler TensorBoard Plugin.
.. $ pip install torch_tb_profiler
.. # Star the TensorBoard UI.
.. $ tensorboard --logdir .
.. # View the PyTorch Profiler traces.
.. $ open http://localhost:6006/#pytorch_profiler
.. _train-tune:
Hyperparameter tuning (Ray Tune)
--------------------------------
Hyperparameter tuning with :ref:`Ray Tune <tune-main>` is natively supported
with Ray Train. Specifically, you can take an existing ``Trainer`` and simply
pass it into a :class:`Tuner`.
.. code-block:: python
from ray import tune
from ray.air import session, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.tune.tuner import Tuner, TuneConfig
def train_func(config):
# In this example, nothing is expected to change over epochs,
# and the output metric is equivalent to the input value.
for _ in range(config["num_epochs"]):
session.report(dict(output=config["input"]))
trainer = TorchTrainer(train_func, scaling_config=ScalingConfig(num_workers=2))
tuner = Tuner(
trainer,
param_space={
"train_loop_config": {
"num_epochs": 2,
"input": tune.grid_search([1, 2, 3]),
}
},
tune_config=TuneConfig(num_samples=5, metric="output", mode="max"),
)
result_grid = tuner.fit()
print(result_grid.get_best_result().metrics["output"])
# 3
.. _torch-amp:
Automatic Mixed Precision
-------------------------
Automatic mixed precision (AMP) lets you train your models faster by using a lower
precision datatype for operations like linear layers and convolutions.
.. tabbed:: PyTorch
You can train your Torch model with AMP by:
1. Adding ``train.torch.accelerate(amp=True)`` to the top of your training function.
2. Wrapping your optimizer with ``train.torch.prepare_optimizer``.
3. Replacing your backward call with ``train.torch.backward``.
.. code-block:: diff
def train_func():
+ train.torch.accelerate(amp=True)
model = NeuralNetwork()
model = train.torch.prepare_model(model)
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
data_loader = train.torch.prepare_data_loader(data_loader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
+ optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(90):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
- loss.backward()
+ train.torch.backward(loss)
optimizer.step()
...
.. note:: The performance of AMP varies based on GPU architecture, model type,
and data shape. For certain workflows, AMP may perform worse than
full-precision training.
.. _train-reproducibility:
Reproducibility
---------------
.. tabbed:: PyTorch
To limit sources of nondeterministic behavior, add
``train.torch.enable_reproducibility()`` to the top of your training
function.
.. code-block:: diff
def train_func():
+ train.torch.enable_reproducibility()
model = NeuralNetwork()
model = train.torch.prepare_model(model)
...
.. warning:: ``train.torch.enable_reproducibility`` can't guarantee
completely reproducible results across executions. To learn more, read
the `PyTorch notes on randomness <https://pytorch.org/docs/stable/notes/randomness.html>`_.
..
import ray
from ray import tune
def training_func(config):
dataloader = ray.train.get_dataset()\
.get_shard(torch.rank())\
.iter_torch_batches(batch_size=config["batch_size"])
for i in config["epochs"]:
ray.train.report(...) # use same intermediate reporting API
# Declare the specification for training.
trainer = Trainer(backend="torch", num_workers=12, use_gpu=True)
dataset = ray.dataset.window()
# Convert this to a trainable.
trainable = trainer.to_tune_trainable(training_func, dataset=dataset)
tuner = tune.Tuner(trainable,
param_space={"lr": tune.uniform(), "batch_size": tune.randint(1, 2, 3)},
tune_config=tune.TuneConfig(num_samples=12))
results = tuner.fit()
..
Advanced APIs
-------------
TODO
Training Run Iterator API
~~~~~~~~~~~~~~~~~~~~~~~~~
TODO
Stateful Class API
~~~~~~~~~~~~~~~~~~
TODO