mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[ml/train] Training Interfaces [2/4]: Update interface for Trainer
(#22986)
This commit is contained in:
parent
f673acb0ad
commit
86b79b68be
3 changed files with 210 additions and 4 deletions
0
python/ray/ml/train/examples/__init__.py
Normal file
0
python/ray/ml/train/examples/__init__.py
Normal file
59
python/ray/ml/train/examples/custom_trainer.py
Normal file
59
python/ray/ml/train/examples/custom_trainer.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# flake8: noqa
|
||||
# TODO(amog): Add this to CI once Trainer has been implemented.
|
||||
# TODO(rliaw): Include this in the docs.
|
||||
|
||||
# fmt: off
|
||||
# __custom_trainer_begin__
|
||||
import torch
|
||||
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray import tune
|
||||
|
||||
|
||||
class MyPytorchTrainer(Trainer):
|
||||
def setup(self):
|
||||
self.model = torch.nn.Linear(1, 1)
|
||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
|
||||
|
||||
def training_loop(self):
|
||||
# You can access any Trainer attributes directly in this method.
|
||||
# self.train_dataset has already been preprocessed by self.preprocessor
|
||||
dataset = self.train_dataset
|
||||
|
||||
torch_ds = dataset.to_torch()
|
||||
|
||||
for epoch_idx in range(10):
|
||||
loss = 0
|
||||
num_batches = 0
|
||||
for X, y in iter(torch_ds):
|
||||
# Compute prediction error
|
||||
pred = self.model(X)
|
||||
batch_loss = torch.nn.MSELoss(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
self.optimizer.zero_grad()
|
||||
batch_loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
loss += batch_loss.item()
|
||||
num_batches += 1
|
||||
loss /= num_batches
|
||||
|
||||
# Use Tune functions to report intermediate
|
||||
# results.
|
||||
tune.report(loss=loss, epoch=epoch_idx)
|
||||
|
||||
|
||||
# __custom_trainer_end__
|
||||
# fmt: on
|
||||
|
||||
|
||||
# fmt: off
|
||||
# __custom_trainer_usage_begin__
|
||||
import ray
|
||||
|
||||
train_dataset = ray.data.from_items([1, 2, 3])
|
||||
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
|
||||
result = my_trainer.fit()
|
||||
# __custom_trainer_usage_end__
|
||||
# fmt: on
|
|
@ -8,6 +8,7 @@ from ray.ml.result import Result
|
|||
from ray.ml.config import RunConfig, ScalingConfig
|
||||
from ray.tune import Trainable
|
||||
from ray.util import PublicAPI
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset
|
||||
|
@ -27,17 +28,103 @@ class TrainingFailedError(RuntimeError):
|
|||
pass
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
@DeveloperAPI
|
||||
class Trainer(abc.ABC):
|
||||
"""Defines interface for distributed training on Ray.
|
||||
|
||||
Note: The base ``Trainer`` class cannot be instantiated directly. Only
|
||||
one of its subclasses can be used.
|
||||
|
||||
How does a trainer work?
|
||||
- First, initialize the Trainer. The initialization runs locally,
|
||||
so heavyweight setup should not be done in __init__.
|
||||
- Then, when you call ``trainer.fit()``, the Trainer is serialized
|
||||
and copied to a remote Ray actor. The following methods are then
|
||||
called in sequence on the remote actor.
|
||||
- ``trainer.setup()``: Any heavyweight Trainer setup should be
|
||||
specified here.
|
||||
- ``trainer.preprocess_datasets()``: The provided
|
||||
ray.data.Dataset are preprocessed with the provided
|
||||
ray.ml.preprocessor.
|
||||
- ``trainer.train_loop()``: Executes the main training logic.
|
||||
- Calling ``trainer.fit()`` will return a ``ray.result.Result``
|
||||
object where you can access metrics from your training run, as well
|
||||
as any checkpoints that may have been saved.
|
||||
|
||||
How do I create a new ``Trainer``?
|
||||
|
||||
Subclass ``ray.train.Trainer``, and override the ``training_loop``
|
||||
method, and optionally ``setup``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
|
||||
from ray.ml.trainer import Trainer
|
||||
from ray import tune
|
||||
|
||||
|
||||
class MyPytorchTrainer(Trainer):
|
||||
def setup(self):
|
||||
self.model = torch.nn.Linear(1, 1)
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.model.parameters(), lr=0.1)
|
||||
|
||||
def training_loop(self):
|
||||
# You can access any Trainer attributes directly in this method.
|
||||
# self.train_dataset has already been preprocessed by
|
||||
# self.preprocessor
|
||||
dataset = self.train_dataset
|
||||
|
||||
torch_ds = dataset.to_torch()
|
||||
|
||||
for epoch_idx in range(10):
|
||||
loss = 0
|
||||
num_batches = 0
|
||||
for X, y in iter(torch_ds):
|
||||
# Compute prediction error
|
||||
pred = self.model(X)
|
||||
batch_loss = torch.nn.MSELoss(pred, y)
|
||||
|
||||
# Backpropagation
|
||||
self.optimizer.zero_grad()
|
||||
batch_loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
loss += batch_loss.item()
|
||||
num_batches += 1
|
||||
loss /= num_batches
|
||||
|
||||
# Use Tune functions to report intermediate
|
||||
# results.
|
||||
tune.report(loss=loss, epoch=epoch_idx)
|
||||
|
||||
How do I use an existing ``Trainer`` or one of my custom Trainers?
|
||||
|
||||
Initialize the Trainer, and call Trainer.fit()
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
|
||||
train_dataset = ray.data.from_items([1, 2, 3])
|
||||
my_trainer = MyPytorchTrainer(train_dataset=train_dataset)
|
||||
result = my_trainer.fit()
|
||||
|
||||
Args:
|
||||
scaling_config: Configuration for how to scale training.
|
||||
run_config: Configuration for the execution of the training run.
|
||||
train_dataset: Either a distributed Ray :ref:`Dataset <dataset-api>`
|
||||
or a Callable that returns a Dataset to use for training.
|
||||
or a Callable that returns a Dataset, to use for training. If a
|
||||
``preprocessor`` is also provided, it will be fit on this
|
||||
dataset and this dataset will be transformed.
|
||||
extra_datasets: Any extra Datasets (such as validation or test
|
||||
datasets) to use for training.
|
||||
datasets) to use for training. If a ``preprocessor`` is
|
||||
provided, the datasets specified here will only be transformed,
|
||||
and not fit on.
|
||||
preprocessor: A preprocessor to preprocess the provided datasets.
|
||||
resume_from_checkpoint: A checkpoint to resume training from.
|
||||
"""
|
||||
|
@ -54,6 +141,67 @@ class Trainer(abc.ABC):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Called during fit() to perform initial setup on the Trainer.
|
||||
|
||||
Note: this method is run on a remote process.
|
||||
|
||||
This method will not be called on the driver, so any expensive setup
|
||||
operations should be placed here and not in ``__init__``.
|
||||
|
||||
This method is called prior to ``preprocess_datasets`` and
|
||||
``training_loop``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def preprocess_datasets(self) -> None:
|
||||
"""Called during fit() to preprocess dataset attributes with preprocessor.
|
||||
|
||||
Note: This method is run on a remote process.
|
||||
|
||||
This method is called prior to entering the training_loop.
|
||||
|
||||
If the ``Trainer`` has both a train_dataset and
|
||||
preprocessor, and the preprocessor has not yet been fit, then it
|
||||
will be fit on the train_dataset.
|
||||
|
||||
Then, the Trainer's train_dataset and any extra_datasets
|
||||
will be transformed by its preprocessor.
|
||||
|
||||
The transformed datasets will be set back in the
|
||||
``self.train_dataset`` and ``self.extra_datasets`` attributes to be
|
||||
used when overriding ``training_loop``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def training_loop(self) -> None:
|
||||
"""Loop called by fit() to run training and report results to Tune.
|
||||
|
||||
Note: this method runs on a remote process.
|
||||
|
||||
`self.train_dataset` and the Dataset values in `self.extra_datasets`
|
||||
have already been preprocessed by `self.preprocessor`.'
|
||||
|
||||
You can use the :ref:`Tune Function API functions <tune-function-docstring>`
|
||||
(``tune.report()`` and ``tune.save_checkpoint()``) inside
|
||||
this training loop.
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
from ray.ml.trainer import Trainer
|
||||
|
||||
class MyTrainer(Trainer):
|
||||
def training_loop(self):
|
||||
for epoch_idx in range(5):
|
||||
...
|
||||
tune.report(epoch=epoch_idx)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
def fit(self) -> Result:
|
||||
"""Runs training.
|
||||
|
||||
|
@ -66,7 +214,6 @@ class Trainer(abc.ABC):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def as_trainable(self) -> Type[Trainable]:
|
||||
"""Convert self to a ``tune.Trainable`` class."""
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Add table
Reference in a new issue