[ml/train] Training Interfaces [2/4]: Update interface for Trainer (#22986)

This commit is contained in:
Amog Kamsetty 2022-03-13 18:09:50 -07:00 committed by GitHub
parent f673acb0ad
commit 86b79b68be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 210 additions and 4 deletions

View file

View file

@ -0,0 +1,59 @@
# flake8: noqa
# TODO(amog): Add this to CI once Trainer has been implemented.
# TODO(rliaw): Include this in the docs.
# fmt: off
# __custom_trainer_begin__
import torch
from ray.ml.trainer import Trainer
from ray import tune
class MyPytorchTrainer(Trainer):
def setup(self):
self.model = torch.nn.Linear(1, 1)
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
def training_loop(self):
# You can access any Trainer attributes directly in this method.
# self.train_dataset has already been preprocessed by self.preprocessor
dataset = self.train_dataset
torch_ds = dataset.to_torch()
for epoch_idx in range(10):
loss = 0
num_batches = 0
for X, y in iter(torch_ds):
# Compute prediction error
pred = self.model(X)
batch_loss = torch.nn.MSELoss(pred, y)
# Backpropagation
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
loss += batch_loss.item()
num_batches += 1
loss /= num_batches
# Use Tune functions to report intermediate
# results.
tune.report(loss=loss, epoch=epoch_idx)
# __custom_trainer_end__
# fmt: on
# fmt: off
# __custom_trainer_usage_begin__
import ray
train_dataset = ray.data.from_items([1, 2, 3])
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
result = my_trainer.fit()
# __custom_trainer_usage_end__
# fmt: on

View file

@ -8,6 +8,7 @@ from ray.ml.result import Result
from ray.ml.config import RunConfig, ScalingConfig from ray.ml.config import RunConfig, ScalingConfig
from ray.tune import Trainable from ray.tune import Trainable
from ray.util import PublicAPI from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.data import Dataset from ray.data import Dataset
@ -27,17 +28,103 @@ class TrainingFailedError(RuntimeError):
pass pass
@PublicAPI(stability="alpha") @DeveloperAPI
class Trainer(abc.ABC): class Trainer(abc.ABC):
"""Defines interface for distributed training on Ray. """Defines interface for distributed training on Ray.
Note: The base ``Trainer`` class cannot be instantiated directly. Only
one of its subclasses can be used.
How does a trainer work?
- First, initialize the Trainer. The initialization runs locally,
so heavyweight setup should not be done in __init__.
- Then, when you call ``trainer.fit()``, the Trainer is serialized
and copied to a remote Ray actor. The following methods are then
called in sequence on the remote actor.
- ``trainer.setup()``: Any heavyweight Trainer setup should be
specified here.
- ``trainer.preprocess_datasets()``: The provided
ray.data.Dataset are preprocessed with the provided
ray.ml.preprocessor.
- ``trainer.train_loop()``: Executes the main training logic.
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
object where you can access metrics from your training run, as well
as any checkpoints that may have been saved.
How do I create a new ``Trainer``?
Subclass ``ray.train.Trainer``, and override the ``training_loop``
method, and optionally ``setup``.
Example:
.. code-block:: python
import torch
from ray.ml.trainer import Trainer
from ray import tune
class MyPytorchTrainer(Trainer):
def setup(self):
self.model = torch.nn.Linear(1, 1)
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=0.1)
def training_loop(self):
# You can access any Trainer attributes directly in this method.
# self.train_dataset has already been preprocessed by
# self.preprocessor
dataset = self.train_dataset
torch_ds = dataset.to_torch()
for epoch_idx in range(10):
loss = 0
num_batches = 0
for X, y in iter(torch_ds):
# Compute prediction error
pred = self.model(X)
batch_loss = torch.nn.MSELoss(pred, y)
# Backpropagation
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
loss += batch_loss.item()
num_batches += 1
loss /= num_batches
# Use Tune functions to report intermediate
# results.
tune.report(loss=loss, epoch=epoch_idx)
How do I use an existing ``Trainer`` or one of my custom Trainers?
Initialize the Trainer, and call Trainer.fit()
Example:
.. code-block:: python
import ray
train_dataset = ray.data.from_items([1, 2, 3])
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
result = my_trainer.fit()
Args: Args:
scaling_config: Configuration for how to scale training. scaling_config: Configuration for how to scale training.
run_config: Configuration for the execution of the training run. run_config: Configuration for the execution of the training run.
train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>` train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>`
or a Callable that returns a Dataset to use for training. or a Callable that returns a Dataset, to use for training. If a
``preprocessor`` is also provided, it will be fit on this
dataset and this dataset will be transformed.
extra_datasets: Any extra Datasets (such as validation or test extra_datasets: Any extra Datasets (such as validation or test
datasets) to use for training. datasets) to use for training. If a ``preprocessor`` is
provided, the datasets specified here will only be transformed,
and not fit on.
preprocessor: A preprocessor to preprocess the provided datasets. preprocessor: A preprocessor to preprocess the provided datasets.
resume_from_checkpoint: A checkpoint to resume training from. resume_from_checkpoint: A checkpoint to resume training from.
""" """
@ -54,6 +141,67 @@ class Trainer(abc.ABC):
raise NotImplementedError raise NotImplementedError
def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.
Note: this method is run on a remote process.
This method will not be called on the driver, so any expensive setup
operations should be placed here and not in ``__init__``.
This method is called prior to ``preprocess_datasets`` and
``training_loop``.
"""
raise NotImplementedError
def preprocess_datasets(self) -> None:
"""Called during fit() to preprocess dataset attributes with preprocessor.
Note: This method is run on a remote process.
This method is called prior to entering the training_loop.
If the ``Trainer`` has both a train_dataset and
preprocessor, and the preprocessor has not yet been fit, then it
will be fit on the train_dataset.
Then, the Trainer's train_dataset and any extra_datasets
will be transformed by its preprocessor.
The transformed datasets will be set back in the
``self.train_dataset`` and ``self.extra_datasets`` attributes to be
used when overriding ``training_loop``.
"""
raise NotImplementedError
@abc.abstractmethod
def training_loop(self) -> None:
"""Loop called by fit() to run training and report results to Tune.
Note: this method runs on a remote process.
`self.train_dataset` and the Dataset values in `self.extra_datasets`
have already been preprocessed by `self.preprocessor`.'
You can use the :ref:`Tune Function API functions <tune-function-docstring>`
(``tune.report()`` and ``tune.save_checkpoint()``) inside
this training loop.
Example:
.. code-block: python
from ray.ml.trainer import Trainer
class MyTrainer(Trainer):
def training_loop(self):
for epoch_idx in range(5):
...
tune.report(epoch=epoch_idx)
"""
raise NotImplementedError
@PublicAPI(stability="alpha")
def fit(self) -> Result: def fit(self) -> Result:
"""Runs training. """Runs training.
@ -66,7 +214,6 @@ class Trainer(abc.ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def as_trainable(self) -> Type[Trainable]: def as_trainable(self) -> Type[Trainable]:
"""Convert self to a ``tune.Trainable`` class.""" """Convert self to a ``tune.Trainable`` class."""
raise NotImplementedError raise NotImplementedError