From 86b79b68be40e2494cf903bebb6f07251fead554 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Sun, 13 Mar 2022 18:09:50 -0700 Subject: [PATCH] [ml/train] Training Interfaces [2/4]: Update interface for `Trainer` (#22986) --- python/ray/ml/train/examples/__init__.py | 0 .../ray/ml/train/examples/custom_trainer.py | 59 +++++++ python/ray/ml/trainer.py | 155 +++++++++++++++++- 3 files changed, 210 insertions(+), 4 deletions(-) create mode 100644 python/ray/ml/train/examples/__init__.py create mode 100644 python/ray/ml/train/examples/custom_trainer.py diff --git a/python/ray/ml/train/examples/__init__.py b/python/ray/ml/train/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/ml/train/examples/custom_trainer.py b/python/ray/ml/train/examples/custom_trainer.py new file mode 100644 index 000000000..88cd95335 --- /dev/null +++ b/python/ray/ml/train/examples/custom_trainer.py @@ -0,0 +1,59 @@ +# flake8: noqa +# TODO(amog): Add this to CI once Trainer has been implemented. +# TODO(rliaw): Include this in the docs. + +# fmt: off +# __custom_trainer_begin__ +import torch + +from ray.ml.trainer import Trainer +from ray import tune + + +class MyPytorchTrainer(Trainer): + def setup(self): + self.model = torch.nn.Linear(1, 1) + self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) + + def training_loop(self): + # You can access any Trainer attributes directly in this method. + # self.train_dataset has already been preprocessed by self.preprocessor + dataset = self.train_dataset + + torch_ds = dataset.to_torch() + + for epoch_idx in range(10): + loss = 0 + num_batches = 0 + for X, y in iter(torch_ds): + # Compute prediction error + pred = self.model(X) + batch_loss = torch.nn.MSELoss(pred, y) + + # Backpropagation + self.optimizer.zero_grad() + batch_loss.backward() + self.optimizer.step() + + loss += batch_loss.item() + num_batches += 1 + loss /= num_batches + + # Use Tune functions to report intermediate + # results. + tune.report(loss=loss, epoch=epoch_idx) + + +# __custom_trainer_end__ +# fmt: on + + +# fmt: off +# __custom_trainer_usage_begin__ +import ray + +train_dataset = ray.data.from_items([1, 2, 3]) +my_trainer = MyPytorchTrainer(train_dataset=train_dataset) +result = my_trainer.fit() +# __custom_trainer_usage_end__ +# fmt: on diff --git a/python/ray/ml/trainer.py b/python/ray/ml/trainer.py index 522ae59af..27d0a7022 100644 --- a/python/ray/ml/trainer.py +++ b/python/ray/ml/trainer.py @@ -8,6 +8,7 @@ from ray.ml.result import Result from ray.ml.config import RunConfig, ScalingConfig from ray.tune import Trainable from ray.util import PublicAPI +from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: from ray.data import Dataset @@ -27,17 +28,103 @@ class TrainingFailedError(RuntimeError): pass -@PublicAPI(stability="alpha") +@DeveloperAPI class Trainer(abc.ABC): """Defines interface for distributed training on Ray. + Note: The base ``Trainer`` class cannot be instantiated directly. Only + one of its subclasses can be used. + + How does a trainer work? + - First, initialize the Trainer. The initialization runs locally, + so heavyweight setup should not be done in __init__. + - Then, when you call ``trainer.fit()``, the Trainer is serialized + and copied to a remote Ray actor. The following methods are then + called in sequence on the remote actor. + - ``trainer.setup()``: Any heavyweight Trainer setup should be + specified here. + - ``trainer.preprocess_datasets()``: The provided + ray.data.Dataset are preprocessed with the provided + ray.ml.preprocessor. + - ``trainer.train_loop()``: Executes the main training logic. + - Calling ``trainer.fit()`` will return a ``ray.result.Result`` + object where you can access metrics from your training run, as well + as any checkpoints that may have been saved. + + How do I create a new ``Trainer``? + + Subclass ``ray.train.Trainer``, and override the ``training_loop`` + method, and optionally ``setup``. + + Example: + + .. code-block:: python + + import torch + + from ray.ml.trainer import Trainer + from ray import tune + + + class MyPytorchTrainer(Trainer): + def setup(self): + self.model = torch.nn.Linear(1, 1) + self.optimizer = torch.optim.SGD( + self.model.parameters(), lr=0.1) + + def training_loop(self): + # You can access any Trainer attributes directly in this method. + # self.train_dataset has already been preprocessed by + # self.preprocessor + dataset = self.train_dataset + + torch_ds = dataset.to_torch() + + for epoch_idx in range(10): + loss = 0 + num_batches = 0 + for X, y in iter(torch_ds): + # Compute prediction error + pred = self.model(X) + batch_loss = torch.nn.MSELoss(pred, y) + + # Backpropagation + self.optimizer.zero_grad() + batch_loss.backward() + self.optimizer.step() + + loss += batch_loss.item() + num_batches += 1 + loss /= num_batches + + # Use Tune functions to report intermediate + # results. + tune.report(loss=loss, epoch=epoch_idx) + + How do I use an existing ``Trainer`` or one of my custom Trainers? + + Initialize the Trainer, and call Trainer.fit() + + Example: + .. code-block:: python + + import ray + + train_dataset = ray.data.from_items([1, 2, 3]) + my_trainer = MyPytorchTrainer(train_dataset=train_dataset) + result = my_trainer.fit() + Args: scaling_config: Configuration for how to scale training. run_config: Configuration for the execution of the training run. train_dataset: Either a distributed Ray :ref:`Dataset ` - or a Callable that returns a Dataset to use for training. + or a Callable that returns a Dataset, to use for training. If a + ``preprocessor`` is also provided, it will be fit on this + dataset and this dataset will be transformed. extra_datasets: Any extra Datasets (such as validation or test - datasets) to use for training. + datasets) to use for training. If a ``preprocessor`` is + provided, the datasets specified here will only be transformed, + and not fit on. preprocessor: A preprocessor to preprocess the provided datasets. resume_from_checkpoint: A checkpoint to resume training from. """ @@ -54,6 +141,67 @@ class Trainer(abc.ABC): raise NotImplementedError + def setup(self) -> None: + """Called during fit() to perform initial setup on the Trainer. + + Note: this method is run on a remote process. + + This method will not be called on the driver, so any expensive setup + operations should be placed here and not in ``__init__``. + + This method is called prior to ``preprocess_datasets`` and + ``training_loop``. + """ + raise NotImplementedError + + def preprocess_datasets(self) -> None: + """Called during fit() to preprocess dataset attributes with preprocessor. + + Note: This method is run on a remote process. + + This method is called prior to entering the training_loop. + + If the ``Trainer`` has both a train_dataset and + preprocessor, and the preprocessor has not yet been fit, then it + will be fit on the train_dataset. + + Then, the Trainer's train_dataset and any extra_datasets + will be transformed by its preprocessor. + + The transformed datasets will be set back in the + ``self.train_dataset`` and ``self.extra_datasets`` attributes to be + used when overriding ``training_loop``. + """ + raise NotImplementedError + + @abc.abstractmethod + def training_loop(self) -> None: + """Loop called by fit() to run training and report results to Tune. + + Note: this method runs on a remote process. + + `self.train_dataset` and the Dataset values in `self.extra_datasets` + have already been preprocessed by `self.preprocessor`.' + + You can use the :ref:`Tune Function API functions ` + (``tune.report()`` and ``tune.save_checkpoint()``) inside + this training loop. + + Example: + .. code-block: python + + from ray.ml.trainer import Trainer + + class MyTrainer(Trainer): + def training_loop(self): + for epoch_idx in range(5): + ... + tune.report(epoch=epoch_idx) + + """ + raise NotImplementedError + + @PublicAPI(stability="alpha") def fit(self) -> Result: """Runs training. @@ -66,7 +214,6 @@ class Trainer(abc.ABC): """ raise NotImplementedError - @abc.abstractmethod def as_trainable(self) -> Type[Trainable]: """Convert self to a ``tune.Trainable`` class.""" raise NotImplementedError