mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[ml/train] Training Interfaces [2/4]: Update interface for Trainer
(#22986)
This commit is contained in:
parent
f673acb0ad
commit
86b79b68be
3 changed files with 210 additions and 4 deletions
0
python/ray/ml/train/examples/__init__.py
Normal file
0
python/ray/ml/train/examples/__init__.py
Normal file
59
python/ray/ml/train/examples/custom_trainer.py
Normal file
59
python/ray/ml/train/examples/custom_trainer.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# flake8: noqa
|
||||||
|
# TODO(amog): Add this to CI once Trainer has been implemented.
|
||||||
|
# TODO(rliaw): Include this in the docs.
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __custom_trainer_begin__
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ray.ml.trainer import Trainer
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
|
||||||
|
class MyPytorchTrainer(Trainer):
|
||||||
|
def setup(self):
|
||||||
|
self.model = torch.nn.Linear(1, 1)
|
||||||
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
|
||||||
|
|
||||||
|
def training_loop(self):
|
||||||
|
# You can access any Trainer attributes directly in this method.
|
||||||
|
# self.train_dataset has already been preprocessed by self.preprocessor
|
||||||
|
dataset = self.train_dataset
|
||||||
|
|
||||||
|
torch_ds = dataset.to_torch()
|
||||||
|
|
||||||
|
for epoch_idx in range(10):
|
||||||
|
loss = 0
|
||||||
|
num_batches = 0
|
||||||
|
for X, y in iter(torch_ds):
|
||||||
|
# Compute prediction error
|
||||||
|
pred = self.model(X)
|
||||||
|
batch_loss = torch.nn.MSELoss(pred, y)
|
||||||
|
|
||||||
|
# Backpropagation
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
batch_loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
loss += batch_loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
loss /= num_batches
|
||||||
|
|
||||||
|
# Use Tune functions to report intermediate
|
||||||
|
# results.
|
||||||
|
tune.report(loss=loss, epoch=epoch_idx)
|
||||||
|
|
||||||
|
|
||||||
|
# __custom_trainer_end__
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __custom_trainer_usage_begin__
|
||||||
|
import ray
|
||||||
|
|
||||||
|
train_dataset = ray.data.from_items([1, 2, 3])
|
||||||
|
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
|
||||||
|
result = my_trainer.fit()
|
||||||
|
# __custom_trainer_usage_end__
|
||||||
|
# fmt: on
|
|
@ -8,6 +8,7 @@ from ray.ml.result import Result
|
||||||
from ray.ml.config import RunConfig, ScalingConfig
|
from ray.ml.config import RunConfig, ScalingConfig
|
||||||
from ray.tune import Trainable
|
from ray.tune import Trainable
|
||||||
from ray.util import PublicAPI
|
from ray.util import PublicAPI
|
||||||
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data import Dataset
|
from ray.data import Dataset
|
||||||
|
@ -27,17 +28,103 @@ class TrainingFailedError(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI(stability="alpha")
|
@DeveloperAPI
|
||||||
class Trainer(abc.ABC):
|
class Trainer(abc.ABC):
|
||||||
"""Defines interface for distributed training on Ray.
|
"""Defines interface for distributed training on Ray.
|
||||||
|
|
||||||
|
Note: The base ``Trainer`` class cannot be instantiated directly. Only
|
||||||
|
one of its subclasses can be used.
|
||||||
|
|
||||||
|
How does a trainer work?
|
||||||
|
- First, initialize the Trainer. The initialization runs locally,
|
||||||
|
so heavyweight setup should not be done in __init__.
|
||||||
|
- Then, when you call ``trainer.fit()``, the Trainer is serialized
|
||||||
|
and copied to a remote Ray actor. The following methods are then
|
||||||
|
called in sequence on the remote actor.
|
||||||
|
- ``trainer.setup()``: Any heavyweight Trainer setup should be
|
||||||
|
specified here.
|
||||||
|
- ``trainer.preprocess_datasets()``: The provided
|
||||||
|
ray.data.Dataset are preprocessed with the provided
|
||||||
|
ray.ml.preprocessor.
|
||||||
|
- ``trainer.train_loop()``: Executes the main training logic.
|
||||||
|
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
|
||||||
|
object where you can access metrics from your training run, as well
|
||||||
|
as any checkpoints that may have been saved.
|
||||||
|
|
||||||
|
How do I create a new ``Trainer``?
|
||||||
|
|
||||||
|
Subclass ``ray.train.Trainer``, and override the ``training_loop``
|
||||||
|
method, and optionally ``setup``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ray.ml.trainer import Trainer
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
|
||||||
|
class MyPytorchTrainer(Trainer):
|
||||||
|
def setup(self):
|
||||||
|
self.model = torch.nn.Linear(1, 1)
|
||||||
|
self.optimizer = torch.optim.SGD(
|
||||||
|
self.model.parameters(), lr=0.1)
|
||||||
|
|
||||||
|
def training_loop(self):
|
||||||
|
# You can access any Trainer attributes directly in this method.
|
||||||
|
# self.train_dataset has already been preprocessed by
|
||||||
|
# self.preprocessor
|
||||||
|
dataset = self.train_dataset
|
||||||
|
|
||||||
|
torch_ds = dataset.to_torch()
|
||||||
|
|
||||||
|
for epoch_idx in range(10):
|
||||||
|
loss = 0
|
||||||
|
num_batches = 0
|
||||||
|
for X, y in iter(torch_ds):
|
||||||
|
# Compute prediction error
|
||||||
|
pred = self.model(X)
|
||||||
|
batch_loss = torch.nn.MSELoss(pred, y)
|
||||||
|
|
||||||
|
# Backpropagation
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
batch_loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
loss += batch_loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
loss /= num_batches
|
||||||
|
|
||||||
|
# Use Tune functions to report intermediate
|
||||||
|
# results.
|
||||||
|
tune.report(loss=loss, epoch=epoch_idx)
|
||||||
|
|
||||||
|
How do I use an existing ``Trainer`` or one of my custom Trainers?
|
||||||
|
|
||||||
|
Initialize the Trainer, and call Trainer.fit()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
train_dataset = ray.data.from_items([1, 2, 3])
|
||||||
|
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
|
||||||
|
result = my_trainer.fit()
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scaling_config: Configuration for how to scale training.
|
scaling_config: Configuration for how to scale training.
|
||||||
run_config: Configuration for the execution of the training run.
|
run_config: Configuration for the execution of the training run.
|
||||||
train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>`
|
train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>`
|
||||||
or a Callable that returns a Dataset to use for training.
|
or a Callable that returns a Dataset, to use for training. If a
|
||||||
|
``preprocessor`` is also provided, it will be fit on this
|
||||||
|
dataset and this dataset will be transformed.
|
||||||
extra_datasets: Any extra Datasets (such as validation or test
|
extra_datasets: Any extra Datasets (such as validation or test
|
||||||
datasets) to use for training.
|
datasets) to use for training. If a ``preprocessor`` is
|
||||||
|
provided, the datasets specified here will only be transformed,
|
||||||
|
and not fit on.
|
||||||
preprocessor: A preprocessor to preprocess the provided datasets.
|
preprocessor: A preprocessor to preprocess the provided datasets.
|
||||||
resume_from_checkpoint: A checkpoint to resume training from.
|
resume_from_checkpoint: A checkpoint to resume training from.
|
||||||
"""
|
"""
|
||||||
|
@ -54,6 +141,67 @@ class Trainer(abc.ABC):
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def setup(self) -> None:
|
||||||
|
"""Called during fit() to perform initial setup on the Trainer.
|
||||||
|
|
||||||
|
Note: this method is run on a remote process.
|
||||||
|
|
||||||
|
This method will not be called on the driver, so any expensive setup
|
||||||
|
operations should be placed here and not in ``__init__``.
|
||||||
|
|
||||||
|
This method is called prior to ``preprocess_datasets`` and
|
||||||
|
``training_loop``.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def preprocess_datasets(self) -> None:
|
||||||
|
"""Called during fit() to preprocess dataset attributes with preprocessor.
|
||||||
|
|
||||||
|
Note: This method is run on a remote process.
|
||||||
|
|
||||||
|
This method is called prior to entering the training_loop.
|
||||||
|
|
||||||
|
If the ``Trainer`` has both a train_dataset and
|
||||||
|
preprocessor, and the preprocessor has not yet been fit, then it
|
||||||
|
will be fit on the train_dataset.
|
||||||
|
|
||||||
|
Then, the Trainer's train_dataset and any extra_datasets
|
||||||
|
will be transformed by its preprocessor.
|
||||||
|
|
||||||
|
The transformed datasets will be set back in the
|
||||||
|
``self.train_dataset`` and ``self.extra_datasets`` attributes to be
|
||||||
|
used when overriding ``training_loop``.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def training_loop(self) -> None:
|
||||||
|
"""Loop called by fit() to run training and report results to Tune.
|
||||||
|
|
||||||
|
Note: this method runs on a remote process.
|
||||||
|
|
||||||
|
`self.train_dataset` and the Dataset values in `self.extra_datasets`
|
||||||
|
have already been preprocessed by `self.preprocessor`.'
|
||||||
|
|
||||||
|
You can use the :ref:`Tune Function API functions <tune-function-docstring>`
|
||||||
|
(``tune.report()`` and ``tune.save_checkpoint()``) inside
|
||||||
|
this training loop.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block: python
|
||||||
|
|
||||||
|
from ray.ml.trainer import Trainer
|
||||||
|
|
||||||
|
class MyTrainer(Trainer):
|
||||||
|
def training_loop(self):
|
||||||
|
for epoch_idx in range(5):
|
||||||
|
...
|
||||||
|
tune.report(epoch=epoch_idx)
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@PublicAPI(stability="alpha")
|
||||||
def fit(self) -> Result:
|
def fit(self) -> Result:
|
||||||
"""Runs training.
|
"""Runs training.
|
||||||
|
|
||||||
|
@ -66,7 +214,6 @@ class Trainer(abc.ABC):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def as_trainable(self) -> Type[Trainable]:
|
def as_trainable(self) -> Type[Trainable]:
|
||||||
"""Convert self to a ``tune.Trainable`` class."""
|
"""Convert self to a ``tune.Trainable`` class."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
Loading…
Add table
Reference in a new issue