mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] pytorch lightning template and walkthrough (#9151)
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
1e79730eee
commit
e0b6984dce
7 changed files with 563 additions and 0 deletions
|
@ -112,6 +112,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_lightning.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \
|
||||
--smoke-test
|
||||
|
|
BIN
doc/source/images/pytorch_lightning_full.png
Normal file
BIN
doc/source/images/pytorch_lightning_full.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.2 KiB |
BIN
doc/source/images/pytorch_lightning_small.png
Normal file
BIN
doc/source/images/pytorch_lightning_small.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.6 KiB |
|
@ -39,6 +39,7 @@ Take a look at any of the below tutorials to get started with Tune.
|
|||
|
||||
tune-60-seconds.rst
|
||||
tune-tutorial.rst
|
||||
tune-pytorch-lightning.rst
|
||||
tune-xgboost.rst
|
||||
|
||||
|
||||
|
@ -66,6 +67,12 @@ These pages will demonstrate the various features and configurations of Tune.
|
|||
:figure: /images/tune.png
|
||||
:description: :doc:`A guide to distributed hyperparameter tuning <tune-distributed>`
|
||||
|
||||
.. customgalleryitem::
|
||||
:tooltip: Tuning PyTorch Lightning modules
|
||||
:figure: /images/pytorch_lightning_small.png
|
||||
:description: :doc:`Tuning PyTorch Lightning modules <tune-pytorch-lightning>`
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
|
297
doc/source/tune/_tutorials/tune-pytorch-lightning.rst
Normal file
297
doc/source/tune/_tutorials/tune-pytorch-lightning.rst
Normal file
|
@ -0,0 +1,297 @@
|
|||
.. _tune-pytorch-lightning:
|
||||
|
||||
Using PyTorch Lightning with Tune
|
||||
=================================
|
||||
|
||||
PyTorch Lightning is a framework which brings structure into training PyTorch models. It
|
||||
aims to avoid boilerplate code, so you don't have to write the same training
|
||||
loops all over again when building a new model.
|
||||
|
||||
.. image:: /images/pytorch_lightning_full.png
|
||||
|
||||
The main abstraction of PyTorch Lightning is the ``LightningModule`` class, which
|
||||
should be extended by your application. There is `a great post on how to transfer
|
||||
your models from vanilla PyTorch to Lightning <https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09>`_.
|
||||
|
||||
The class structure of PyTorch Lightning makes it very easy to define and tune model
|
||||
parameters. This tutorial will show you how to use Tune to find the best set of
|
||||
parameters for your application on the example of training a MNIST classifier. Notably,
|
||||
the ``LightningModule`` does not have to be altered at all for this - so you can
|
||||
use it plug and play for your existing models, assuming their parameters are configurable!
|
||||
|
||||
.. note::
|
||||
|
||||
To run this example, you will need to install the following:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install ray torch torchvision pytorch-lightning
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
:backlinks: none
|
||||
|
||||
PyTorch Lightning classifier for MNIST
|
||||
--------------------------------------
|
||||
Let's first start with the basic PyTorch Lightning implementation of an MNIST classifier.
|
||||
This classifier does not include any tuning code at this point.
|
||||
|
||||
Our example builds on the MNIST example from the `blog post we talked about
|
||||
earlier <https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09>`_.
|
||||
|
||||
First, we run some imports:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __import_lightning_begin__
|
||||
:end-before: __import_lightning_end__
|
||||
|
||||
And then there is the Lightning model adapted from the blog post.
|
||||
Note that we left out the test set validation and made the model parameters
|
||||
configurable through a ``config`` dict that is passed on initialization.
|
||||
Also, we specify a ``data_dir`` where the MNIST data will be stored.
|
||||
Lastly, we added a new metric, the validation accuracy, to the logs.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __lightning_begin__
|
||||
:end-before: __lightning_end__
|
||||
|
||||
And that's it! You can now run ``train_mnist(config)`` to train the classifier, e.g.
|
||||
like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
config = {
|
||||
"layer_1_size": 128,
|
||||
"layer_2_size": 256,
|
||||
"lr": 1e-3,
|
||||
"batch_size": 64
|
||||
}
|
||||
train_mnist(config)
|
||||
|
||||
Tuning the model parameters
|
||||
---------------------------
|
||||
The parameters above should give you a good accuracy of over 90% already. However,
|
||||
we might improve on this simply by changing some of the hyperparameters. For instance,
|
||||
maybe we get an even higher accuracy if we used a larger batch size.
|
||||
|
||||
Instead of guessing the parameter values, let's use Tune to systematically try out
|
||||
parameter combinations and find the best performing set.
|
||||
|
||||
First, we need some additional imports:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __import_tune_begin__
|
||||
:end-before: __import_tune_end__
|
||||
|
||||
Talking to Tune with a PyTorch Lightning callback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
PyTorch Lightning introduced `Callbacks <https://pytorch-lightning.readthedocs.io/en/latest/callbacks.html>`_
|
||||
that can be used to plug custom functions into the training loop. This way the original
|
||||
``LightningModule`` does not have to be altered at all. Also, we could use the same
|
||||
callback for multiple modules.
|
||||
|
||||
The callback just reports some metrics back to Tune after each validation epoch:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_callback_begin__
|
||||
:end-before: __tune_callback_end__
|
||||
|
||||
Adding the Tune training function
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Then we specify our training function. Note that we added the ``data_dir`` as a config
|
||||
parameter here, even though it should not be tuned. We just need to specify it to avoid
|
||||
that each training run downloads the full MNIST dataset. Instead, we want to access
|
||||
a shared data location.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_train_begin__
|
||||
:end-before: __tune_train_end__
|
||||
|
||||
Sharing the data
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
All our trials are using the MNIST data. To avoid that each training instance downloads
|
||||
their own MNIST dataset, we download it once and share the ``data_dir`` between runs.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
:lines: 2-3
|
||||
:dedent: 4
|
||||
|
||||
We also delete this data after training to avoid filling up our disk or memory space.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
:lines: 27
|
||||
:dedent: 4
|
||||
|
||||
Configuring the search space
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Now we configure the parameter search space. We would like to choose between three
|
||||
different layer and batch sizes. The learning rate should be sampled uniformly between
|
||||
``0.0001`` and ``0.1``. The ``tune.loguniform()`` function is syntactic sugar to make
|
||||
sampling between these different orders of magnitude easier, specifically
|
||||
we are able to also sample small values.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
:lines: 4-10
|
||||
:dedent: 4
|
||||
|
||||
Selecting a scheduler
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In this example, we use an `Asynchronous Hyperband <https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/>`_
|
||||
scheduler. This scheduler decides at each iteration which trials are likely to perform
|
||||
badly, and stops these trials. This way we don't waste any resources on bad hyperparameter
|
||||
configurations.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
:lines: 11-16
|
||||
:dedent: 4
|
||||
|
||||
|
||||
Changing the CLI output
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
We instantiate a ``CLIReporter`` to specify which metrics we would like to see in our
|
||||
output tables in the command line. If we didn't specify this, Tune would print all
|
||||
hyperparameters by default, but since ``data_dir`` is not a real hyperparameter, we
|
||||
can avoid printing it by omitting it in the ``parameter_columns`` parameter.
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
:lines: 17-19
|
||||
:dedent: 4
|
||||
|
||||
Putting it together
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Lastly, we need to start Tune with ``tune.run()``.
|
||||
|
||||
The full code looks like this:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_asha_begin__
|
||||
:end-before: __tune_asha_end__
|
||||
|
||||
|
||||
In the example above, Tune runs 10 trials with different hyperparameter configurations.
|
||||
An example output could look like so:
|
||||
|
||||
.. code-block::
|
||||
:emphasize-lines: 12
|
||||
|
||||
+------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+
|
||||
| Trial name | status | loc | layer_1_size | layer_2_size | lr | batch_size | loss | mean_accuracy | training_iteration |
|
||||
|------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------|
|
||||
| train_mnist_tune_63ecc_00000 | TERMINATED | | 128 | 64 | 0.00121197 | 128 | 0.120173 | 0.972461 | 10 |
|
||||
| train_mnist_tune_63ecc_00001 | TERMINATED | | 64 | 128 | 0.0301395 | 128 | 0.454836 | 0.868164 | 4 |
|
||||
| train_mnist_tune_63ecc_00002 | TERMINATED | | 64 | 128 | 0.0432097 | 128 | 0.718396 | 0.718359 | 1 |
|
||||
| train_mnist_tune_63ecc_00003 | TERMINATED | | 32 | 128 | 0.000294669 | 32 | 0.111475 | 0.965764 | 10 |
|
||||
| train_mnist_tune_63ecc_00004 | TERMINATED | | 32 | 256 | 0.000386664 | 64 | 0.133538 | 0.960839 | 8 |
|
||||
| train_mnist_tune_63ecc_00005 | TERMINATED | | 128 | 128 | 0.0837395 | 32 | 2.32628 | 0.0991242 | 1 |
|
||||
| train_mnist_tune_63ecc_00006 | TERMINATED | | 64 | 128 | 0.000158761 | 128 | 0.134595 | 0.959766 | 10 |
|
||||
| train_mnist_tune_63ecc_00007 | TERMINATED | | 64 | 64 | 0.000672126 | 64 | 0.118182 | 0.972903 | 10 |
|
||||
| train_mnist_tune_63ecc_00008 | TERMINATED | | 128 | 64 | 0.000502428 | 32 | 0.11082 | 0.975518 | 10 |
|
||||
| train_mnist_tune_63ecc_00009 | TERMINATED | | 64 | 256 | 0.00112894 | 32 | 0.13472 | 0.971935 | 8 |
|
||||
+------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+
|
||||
|
||||
As you can see in the ``training_iteration`` column, trials with a high loss
|
||||
(and low accuracy) have been terminated early. The best performing trial used
|
||||
``layer_1_size=128``, ``layer_2_size=64``, ``lr=0.000502428`` and
|
||||
``batch_size=32``.
|
||||
|
||||
Using Population Based Training to find the best parameters
|
||||
-----------------------------------------------------------
|
||||
The ``ASHAScheduler`` terminates those trials early that show bad performance.
|
||||
Sometimes, this stops trials that would get better after more training steps,
|
||||
and which might eventually even show better performance than other configurations.
|
||||
|
||||
Another popular method for hyperparameter tuning, called
|
||||
`Population Based Training <https://deepmind.com/blog/article/population-based-training-neural-networks>`_,
|
||||
instead perturbs hyperparameters during the training run. Tune implements PBT, and
|
||||
we only need to make some slight adjustments to our code.
|
||||
|
||||
Adding checkpoints to the PyTorch Lightning module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
First, we need to introduce
|
||||
another callback to save model checkpoints:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_checkpoint_callback_begin__
|
||||
:end-before: __tune_checkpoint_callback_end__
|
||||
|
||||
We also include checkpoint loading in our training function:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_train_checkpoint_begin__
|
||||
:end-before: __tune_train_checkpoint_end__
|
||||
|
||||
|
||||
Configuring and running Population Based Training
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
We need to call Tune slightly differently:
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
|
||||
:language: python
|
||||
:start-after: __tune_pbt_begin__
|
||||
:end-before: __tune_pbt_end__
|
||||
|
||||
Instead of passing tune parameters to the ``config`` dict, we start
|
||||
with fixed values, though we are also able to sample some of them, like the
|
||||
layer sizes. Additionally, we have to tell PBT how to perturb the hyperparameters.
|
||||
Note that the layer sizes are not tuned right here. This is because we cannot simply
|
||||
change layer sizes during a training run - which is what would happen in PBT.
|
||||
|
||||
An example output could look like this:
|
||||
|
||||
.. code-block::
|
||||
|
||||
+-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+
|
||||
| Trial name | status | loc | layer_1_size | layer_2_size | lr | batch_size | loss | mean_accuracy | training_iteration |
|
||||
|-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------|
|
||||
| train_mnist_tune_checkpoint_85489_00000 | TERMINATED | | 128 | 128 | 0.001 | 64 | 0.108734 | 0.973101 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00001 | TERMINATED | | 128 | 128 | 0.001 | 64 | 0.093577 | 0.978639 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00002 | TERMINATED | | 128 | 256 | 0.0008 | 32 | 0.0922348 | 0.979299 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00003 | TERMINATED | | 64 | 256 | 0.001 | 64 | 0.124648 | 0.973892 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00004 | TERMINATED | | 128 | 64 | 0.001 | 64 | 0.101717 | 0.975079 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00005 | TERMINATED | | 64 | 64 | 0.001 | 64 | 0.121467 | 0.969146 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00006 | TERMINATED | | 128 | 256 | 0.00064 | 32 | 0.053446 | 0.987062 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00007 | TERMINATED | | 128 | 256 | 0.001 | 64 | 0.129804 | 0.973497 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00008 | TERMINATED | | 64 | 256 | 0.0285125 | 128 | 0.363236 | 0.913867 | 10 |
|
||||
| train_mnist_tune_checkpoint_85489_00009 | TERMINATED | | 32 | 256 | 0.001 | 64 | 0.150946 | 0.964201 | 10 |
|
||||
+-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+
|
||||
|
||||
As you can see, each sample ran the full number of 10 iterations.
|
||||
All trials ended with quite good parameter combinations and showed relatively good performances.
|
||||
In some runs, the parameters have been perturbed. And the best configuration even reached a
|
||||
mean validation accuracy of ``0.987062``!
|
||||
|
||||
In summary, PyTorch Lightning Modules are easy to extend to use with Tune. It just took
|
||||
us writing one or two callbacks and a small wrapper function to get great performing
|
||||
parameter configurations.
|
|
@ -18,6 +18,7 @@ opencv-python-headless
|
|||
pandas
|
||||
pytest-remotedata>=0.3.1
|
||||
pytest-timeout
|
||||
pytorch-lightning
|
||||
scikit-learn==0.22.2
|
||||
scikit-optimize
|
||||
sigopt
|
||||
|
|
254
python/ray/tune/examples/mnist_pytorch_lightning.py
Normal file
254
python/ray/tune/examples/mnist_pytorch_lightning.py
Normal file
|
@ -0,0 +1,254 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
|
||||
# __import_lightning_begin__
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.nn import functional as F
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
import os
|
||||
# __import_lightning_end__
|
||||
|
||||
# __import_tune_begin__
|
||||
import shutil
|
||||
from tempfile import mkdtemp
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
||||
# __import_tune_end__
|
||||
|
||||
|
||||
# __lightning_begin__
|
||||
class LightningMNISTClassifier(pl.LightningModule):
|
||||
"""
|
||||
This has been adapted from
|
||||
https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09
|
||||
"""
|
||||
|
||||
def __init__(self, config, data_dir=None):
|
||||
super(LightningMNISTClassifier, self).__init__()
|
||||
|
||||
self.data_dir = data_dir or os.getcwd()
|
||||
|
||||
self.layer_1_size = config["layer_1_size"]
|
||||
self.layer_2_size = config["layer_2_size"]
|
||||
self.lr = config["lr"]
|
||||
self.batch_size = config["batch_size"]
|
||||
|
||||
# mnist images are (1, 28, 28) (channels, width, height)
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
|
||||
self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
|
||||
self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
x = self.layer_1(x)
|
||||
x = torch.relu(x)
|
||||
|
||||
x = self.layer_2(x)
|
||||
x = torch.relu(x)
|
||||
|
||||
x = self.layer_3(x)
|
||||
x = torch.log_softmax(x, dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def cross_entropy_loss(self, logits, labels):
|
||||
return F.nll_loss(logits, labels)
|
||||
|
||||
def accuracy(self, logits, labels):
|
||||
_, predicted = torch.max(logits.data, 1)
|
||||
correct = (predicted == labels).sum().item()
|
||||
accuracy = correct / len(labels)
|
||||
return torch.tensor(accuracy)
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
x, y = train_batch
|
||||
logits = self.forward(x)
|
||||
loss = self.cross_entropy_loss(logits, y)
|
||||
accuracy = self.accuracy(logits, y)
|
||||
|
||||
logs = {"train_loss": loss, "train_accuracy": accuracy}
|
||||
return {"loss": loss, "log": logs}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
x, y = val_batch
|
||||
logits = self.forward(x)
|
||||
loss = self.cross_entropy_loss(logits, y)
|
||||
accuracy = self.accuracy(logits, y)
|
||||
|
||||
return {"val_loss": loss, "val_accuracy": accuracy}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
|
||||
tensorboard_logs = {"val_loss": avg_loss, "val_accuracy": avg_acc}
|
||||
|
||||
return {
|
||||
"avg_val_loss": avg_loss,
|
||||
"avg_val_accuracy": avg_acc,
|
||||
"log": tensorboard_logs
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def download_data(data_dir):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])
|
||||
return MNIST(data_dir, train=True, download=True, transform=transform)
|
||||
|
||||
def prepare_data(self):
|
||||
mnist_train = self.download_data(self.data_dir)
|
||||
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_train, [55000, 5000])
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.mnist_train, batch_size=int(self.batch_size))
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.mnist_val, batch_size=int(self.batch_size))
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def train_mnist(config):
|
||||
model = LightningMNISTClassifier(config)
|
||||
trainer = pl.Trainer(max_epochs=10, show_progress_bar=False)
|
||||
|
||||
trainer.fit(model)
|
||||
# __lightning_end__
|
||||
|
||||
|
||||
# __tune_callback_begin__
|
||||
class TuneReportCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
tune.report(
|
||||
loss=trainer.callback_metrics["avg_val_loss"],
|
||||
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"])
|
||||
# __tune_callback_end__
|
||||
|
||||
|
||||
# __tune_train_begin__
|
||||
def train_mnist_tune(config):
|
||||
model = LightningMNISTClassifier(config, config["data_dir"])
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=10,
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[TuneReportCallback()])
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_end__
|
||||
|
||||
|
||||
# __tune_checkpoint_callback_begin__
|
||||
class CheckpointCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
path = tune.make_checkpoint_dir(trainer.global_step)
|
||||
trainer.save_checkpoint(os.path.join(path, "checkpoint"))
|
||||
tune.save_checkpoint(path)
|
||||
# __tune_checkpoint_callback_end__
|
||||
|
||||
|
||||
# __tune_train_checkpoint_begin__
|
||||
def train_mnist_tune_checkpoint(config, checkpoint=None):
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=10,
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[CheckpointCallback(),
|
||||
TuneReportCallback()])
|
||||
if checkpoint:
|
||||
# Currently, this leads to errors:
|
||||
# model = LightningMNISTClassifier.load_from_checkpoint(
|
||||
# os.path.join(checkpoint, "checkpoint"))
|
||||
# Workaround:
|
||||
ckpt = pl_load(
|
||||
os.path.join(checkpoint, "checkpoint"),
|
||||
map_location=lambda storage, loc: storage)
|
||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
|
||||
trainer.current_epoch = ckpt["epoch"]
|
||||
else:
|
||||
model = LightningMNISTClassifier(
|
||||
config=config, data_dir=config["data_dir"])
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_checkpoint_end__
|
||||
|
||||
|
||||
# __tune_asha_begin__
|
||||
def tune_mnist_asha():
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
config = {
|
||||
"layer_1_size": tune.choice([32, 64, 128]),
|
||||
"layer_2_size": tune.choice([64, 128, 256]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"data_dir": data_dir
|
||||
}
|
||||
scheduler = ASHAScheduler(
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_t=10,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
tune.run(
|
||||
train_mnist_tune,
|
||||
resources_per_trial={"cpu": 1},
|
||||
config=config,
|
||||
num_samples=10,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter)
|
||||
shutil.rmtree(data_dir)
|
||||
# __tune_asha_end__
|
||||
|
||||
|
||||
# __tune_pbt_begin__
|
||||
def tune_mnist_pbt():
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
config = {
|
||||
"layer_1_size": tune.choice([32, 64, 128]),
|
||||
"layer_2_size": tune.choice([64, 128, 256]),
|
||||
"lr": 1e-3,
|
||||
"batch_size": 64,
|
||||
"data_dir": data_dir
|
||||
}
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="loss",
|
||||
mode="min",
|
||||
perturbation_interval=4,
|
||||
hyperparam_mutations={
|
||||
"lr": lambda: tune.loguniform(1e-4, 1e-1).func(None),
|
||||
"batch_size": [32, 64, 128]
|
||||
})
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
tune.run(
|
||||
train_mnist_tune_checkpoint,
|
||||
resources_per_trial={"cpu": 1},
|
||||
config=config,
|
||||
num_samples=10,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter)
|
||||
shutil.rmtree(data_dir)
|
||||
# __tune_pbt_end__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# tune_mnist_asha() # ASHA scheduler
|
||||
tune_mnist_pbt() # population based training
|
Loading…
Add table
Reference in a new issue