[ml/train] Training Interfaces [2/4]: Update interface for Trainer (#22986)

This commit is contained in:
Amog Kamsetty 2022-03-13 18:09:50 -07:00 committed by GitHub
parent f673acb0ad
commit 86b79b68be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 210 additions and 4 deletions

View file

View file

@ -0,0 +1,59 @@
# flake8: noqa
# TODO(amog): Add this to CI once Trainer has been implemented.
# TODO(rliaw): Include this in the docs.
# fmt: off
# __custom_trainer_begin__
import torch
from ray.ml.trainer import Trainer
from ray import tune
class MyPytorchTrainer(Trainer):
def setup(self):
self.model = torch.nn.Linear(1, 1)
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
def training_loop(self):
# You can access any Trainer attributes directly in this method.
# self.train_dataset has already been preprocessed by self.preprocessor
dataset = self.train_dataset
torch_ds = dataset.to_torch()
for epoch_idx in range(10):
loss = 0
num_batches = 0
for X, y in iter(torch_ds):
# Compute prediction error
pred = self.model(X)
batch_loss = torch.nn.MSELoss(pred, y)
# Backpropagation
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
loss += batch_loss.item()
num_batches += 1
loss /= num_batches
# Use Tune functions to report intermediate
# results.
tune.report(loss=loss, epoch=epoch_idx)
# __custom_trainer_end__
# fmt: on
# fmt: off
# __custom_trainer_usage_begin__
import ray
train_dataset = ray.data.from_items([1, 2, 3])
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
result = my_trainer.fit()
# __custom_trainer_usage_end__
# fmt: on

View file

@ -8,6 +8,7 @@ from ray.ml.result import Result
from ray.ml.config import RunConfig, ScalingConfig
from ray.tune import Trainable
from ray.util import PublicAPI
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.data import Dataset
@ -27,17 +28,103 @@ class TrainingFailedError(RuntimeError):
pass
@PublicAPI(stability="alpha")
@DeveloperAPI
class Trainer(abc.ABC):
"""Defines interface for distributed training on Ray.
Note: The base ``Trainer`` class cannot be instantiated directly. Only
one of its subclasses can be used.
How does a trainer work?
- First, initialize the Trainer. The initialization runs locally,
so heavyweight setup should not be done in __init__.
- Then, when you call ``trainer.fit()``, the Trainer is serialized
and copied to a remote Ray actor. The following methods are then
called in sequence on the remote actor.
- ``trainer.setup()``: Any heavyweight Trainer setup should be
specified here.
- ``trainer.preprocess_datasets()``: The provided
ray.data.Dataset are preprocessed with the provided
ray.ml.preprocessor.
- ``trainer.train_loop()``: Executes the main training logic.
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
object where you can access metrics from your training run, as well
as any checkpoints that may have been saved.
How do I create a new ``Trainer``?
Subclass ``ray.train.Trainer``, and override the ``training_loop``
method, and optionally ``setup``.
Example:
.. code-block:: python
import torch
from ray.ml.trainer import Trainer
from ray import tune
class MyPytorchTrainer(Trainer):
def setup(self):
self.model = torch.nn.Linear(1, 1)
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=0.1)
def training_loop(self):
# You can access any Trainer attributes directly in this method.
# self.train_dataset has already been preprocessed by
# self.preprocessor
dataset = self.train_dataset
torch_ds = dataset.to_torch()
for epoch_idx in range(10):
loss = 0
num_batches = 0
for X, y in iter(torch_ds):
# Compute prediction error
pred = self.model(X)
batch_loss = torch.nn.MSELoss(pred, y)
# Backpropagation
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
loss += batch_loss.item()
num_batches += 1
loss /= num_batches
# Use Tune functions to report intermediate
# results.
tune.report(loss=loss, epoch=epoch_idx)
How do I use an existing ``Trainer`` or one of my custom Trainers?
Initialize the Trainer, and call Trainer.fit()
Example:
.. code-block:: python
import ray
train_dataset = ray.data.from_items([1, 2, 3])
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
result = my_trainer.fit()
Args:
scaling_config: Configuration for how to scale training.
run_config: Configuration for the execution of the training run.
train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>`
or a Callable that returns a Dataset to use for training.
or a Callable that returns a Dataset, to use for training. If a
``preprocessor`` is also provided, it will be fit on this
dataset and this dataset will be transformed.
extra_datasets: Any extra Datasets (such as validation or test
datasets) to use for training.
datasets) to use for training. If a ``preprocessor`` is
provided, the datasets specified here will only be transformed,
and not fit on.
preprocessor: A preprocessor to preprocess the provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
"""
@ -54,6 +141,67 @@ class Trainer(abc.ABC):
raise NotImplementedError
def setup(self) -> None:
"""Called during fit() to perform initial setup on the Trainer.
Note: this method is run on a remote process.
This method will not be called on the driver, so any expensive setup
operations should be placed here and not in ``__init__``.
This method is called prior to ``preprocess_datasets`` and
``training_loop``.
"""
raise NotImplementedError
def preprocess_datasets(self) -> None:
"""Called during fit() to preprocess dataset attributes with preprocessor.
Note: This method is run on a remote process.
This method is called prior to entering the training_loop.
If the ``Trainer`` has both a train_dataset and
preprocessor, and the preprocessor has not yet been fit, then it
will be fit on the train_dataset.
Then, the Trainer's train_dataset and any extra_datasets
will be transformed by its preprocessor.
The transformed datasets will be set back in the
``self.train_dataset`` and ``self.extra_datasets`` attributes to be
used when overriding ``training_loop``.
"""
raise NotImplementedError
@abc.abstractmethod
def training_loop(self) -> None:
"""Loop called by fit() to run training and report results to Tune.
Note: this method runs on a remote process.
`self.train_dataset` and the Dataset values in `self.extra_datasets`
have already been preprocessed by `self.preprocessor`.'
You can use the :ref:`Tune Function API functions <tune-function-docstring>`
(``tune.report()`` and ``tune.save_checkpoint()``) inside
this training loop.
Example:
.. code-block: python
from ray.ml.trainer import Trainer
class MyTrainer(Trainer):
def training_loop(self):
for epoch_idx in range(5):
...
tune.report(epoch=epoch_idx)
"""
raise NotImplementedError
@PublicAPI(stability="alpha")
def fit(self) -> Result:
"""Runs training.
@ -66,7 +214,6 @@ class Trainer(abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def as_trainable(self) -> Type[Trainable]:
"""Convert self to a ``tune.Trainable`` class."""
raise NotImplementedError