mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
507 lines
20 KiB
Text
507 lines
20 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "586737af",
|
|
"metadata": {},
|
|
"source": [
|
|
"# How to use Tune with PyTorch\n",
|
|
"\n",
|
|
"(tune-pytorch-cifar-ref)=\n",
|
|
"\n",
|
|
"In this walkthrough, we will show you how to integrate Tune into your PyTorch\n",
|
|
"training workflow. We will follow [this tutorial from the PyTorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)\n",
|
|
"for training a CIFAR10 image classifier.\n",
|
|
"\n",
|
|
"```{image} /images/pytorch_logo.png\n",
|
|
":align: center\n",
|
|
"```\n",
|
|
"\n",
|
|
"Hyperparameter tuning can make the difference between an average model and a highly\n",
|
|
"accurate one. Often simple things like choosing a different learning rate or changing\n",
|
|
"a network layer size can have a dramatic impact on your model performance. Fortunately,\n",
|
|
"Tune makes exploring these optimal parameter combinations easy - and works nicely\n",
|
|
"together with PyTorch.\n",
|
|
"\n",
|
|
"As you will see, we only need to add some slight modifications. In particular, we\n",
|
|
"need to\n",
|
|
"\n",
|
|
"1. wrap data loading and training in functions,\n",
|
|
"2. make some network parameters configurable,\n",
|
|
"3. add checkpointing (optional),\n",
|
|
"4. and define the search space for the model tuning\n",
|
|
"\n",
|
|
":::{note}\n",
|
|
"To run this example, you will need to install the following:\n",
|
|
"\n",
|
|
"```bash\n",
|
|
"$ pip install ray torch torchvision\n",
|
|
"```\n",
|
|
":::\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7e8650d1",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup / Imports\n",
|
|
"\n",
|
|
"Let's start with the imports:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "55529285",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"from filelock import FileLock\n",
|
|
"from torch.utils.data import random_split\n",
|
|
"import torchvision\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"import ray\n",
|
|
"from ray import tune\n",
|
|
"from ray.tune.schedulers import ASHAScheduler"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f59e551d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Most of the imports are needed for building the PyTorch model. Only the last three\n",
|
|
"imports are for Ray Tune.\n",
|
|
"\n",
|
|
"## Data loaders\n",
|
|
"\n",
|
|
"We wrap the data loaders in their own function and pass a global data directory.\n",
|
|
"This way we can share a data directory between different trials."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "01471556",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_data(data_dir=\"./data\"):\n",
|
|
" transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
|
|
" ])\n",
|
|
"\n",
|
|
" # We add FileLock here because multiple workers will want to\n",
|
|
" # download data, and this may cause overwrites since\n",
|
|
" # DataLoader is not threadsafe.\n",
|
|
" with FileLock(os.path.expanduser(\"~/.data.lock\")):\n",
|
|
" trainset = torchvision.datasets.CIFAR10(\n",
|
|
" root=data_dir, train=True, download=True, transform=transform)\n",
|
|
"\n",
|
|
" testset = torchvision.datasets.CIFAR10(\n",
|
|
" root=data_dir, train=False, download=True, transform=transform)\n",
|
|
"\n",
|
|
" return trainset, testset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "80958cf3",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Configurable neural network\n",
|
|
"\n",
|
|
"We can only tune those parameters that are configurable. In this example, we can specify\n",
|
|
"the layer sizes of the fully connected layers:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fff6bd0d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Net(nn.Module):\n",
|
|
" def __init__(self, l1=120, l2=84):\n",
|
|
" super(Net, self).__init__()\n",
|
|
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
|
" self.pool = nn.MaxPool2d(2, 2)\n",
|
|
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
|
" self.fc1 = nn.Linear(16 * 5 * 5, l1)\n",
|
|
" self.fc2 = nn.Linear(l1, l2)\n",
|
|
" self.fc3 = nn.Linear(l2, 10)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.pool(F.relu(self.conv1(x)))\n",
|
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
|
" x = x.view(-1, 16 * 5 * 5)\n",
|
|
" x = F.relu(self.fc1(x))\n",
|
|
" x = F.relu(self.fc2(x))\n",
|
|
" x = self.fc3(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fb619875",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The train function\n",
|
|
"\n",
|
|
"Now it gets interesting, because we introduce some changes to the example [from the PyTorch\n",
|
|
"documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).\n",
|
|
"\n",
|
|
"(communicating-with-ray-tune)=\n",
|
|
"\n",
|
|
"The full code example looks like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fa0bdae0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_cifar(config, checkpoint_dir=None):\n",
|
|
" net = Net(config[\"l1\"], config[\"l2\"])\n",
|
|
"\n",
|
|
" device = \"cpu\"\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" device = \"cuda:0\"\n",
|
|
" if torch.cuda.device_count() > 1:\n",
|
|
" net = nn.DataParallel(net)\n",
|
|
" net.to(device)\n",
|
|
"\n",
|
|
" criterion = nn.CrossEntropyLoss()\n",
|
|
" optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n",
|
|
"\n",
|
|
" # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint\n",
|
|
" # should be restored.\n",
|
|
" if checkpoint_dir:\n",
|
|
" checkpoint = os.path.join(checkpoint_dir, \"checkpoint\")\n",
|
|
" model_state, optimizer_state = torch.load(checkpoint)\n",
|
|
" net.load_state_dict(model_state)\n",
|
|
" optimizer.load_state_dict(optimizer_state)\n",
|
|
"\n",
|
|
" data_dir = os.path.abspath(\"./data\")\n",
|
|
" trainset, testset = load_data(data_dir)\n",
|
|
"\n",
|
|
" test_abs = int(len(trainset) * 0.8)\n",
|
|
" train_subset, val_subset = random_split(\n",
|
|
" trainset, [test_abs, len(trainset) - test_abs])\n",
|
|
"\n",
|
|
" trainloader = torch.utils.data.DataLoader(\n",
|
|
" train_subset,\n",
|
|
" batch_size=int(config[\"batch_size\"]),\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=8)\n",
|
|
" valloader = torch.utils.data.DataLoader(\n",
|
|
" val_subset,\n",
|
|
" batch_size=int(config[\"batch_size\"]),\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=8)\n",
|
|
"\n",
|
|
" for epoch in range(10): # loop over the dataset multiple times\n",
|
|
" running_loss = 0.0\n",
|
|
" epoch_steps = 0\n",
|
|
" for i, data in enumerate(trainloader, 0):\n",
|
|
" # get the inputs; data is a list of [inputs, labels]\n",
|
|
" inputs, labels = data\n",
|
|
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
"\n",
|
|
" # zero the parameter gradients\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" # forward + backward + optimize\n",
|
|
" outputs = net(inputs)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" # print statistics\n",
|
|
" running_loss += loss.item()\n",
|
|
" epoch_steps += 1\n",
|
|
" if i % 2000 == 1999: # print every 2000 mini-batches\n",
|
|
" print(\"[%d, %5d] loss: %.3f\" % (epoch + 1, i + 1,\n",
|
|
" running_loss / epoch_steps))\n",
|
|
" running_loss = 0.0\n",
|
|
"\n",
|
|
" # Validation loss\n",
|
|
" val_loss = 0.0\n",
|
|
" val_steps = 0\n",
|
|
" total = 0\n",
|
|
" correct = 0\n",
|
|
" for i, data in enumerate(valloader, 0):\n",
|
|
" with torch.no_grad():\n",
|
|
" inputs, labels = data\n",
|
|
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
"\n",
|
|
" outputs = net(inputs)\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += labels.size(0)\n",
|
|
" correct += (predicted == labels).sum().item()\n",
|
|
"\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" val_loss += loss.cpu().numpy()\n",
|
|
" val_steps += 1\n",
|
|
"\n",
|
|
" # Here we save a checkpoint. It is automatically registered with\n",
|
|
" # Ray Tune and will potentially be passed as the `checkpoint_dir`\n",
|
|
" # parameter in future iterations.\n",
|
|
" with tune.checkpoint_dir(step=epoch) as checkpoint_dir:\n",
|
|
" path = os.path.join(checkpoint_dir, \"checkpoint\")\n",
|
|
" torch.save(\n",
|
|
" (net.state_dict(), optimizer.state_dict()), path)\n",
|
|
"\n",
|
|
" tune.report(loss=(val_loss / val_steps), accuracy=correct / total)\n",
|
|
" print(\"Finished Training\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "918d8baf",
|
|
"metadata": {},
|
|
"source": [
|
|
"As you can see, most of the code is adapted directly from the example.\n",
|
|
"\n",
|
|
"## Test set accuracy\n",
|
|
"\n",
|
|
"Commonly the performance of a machine learning model is tested on a hold-out test\n",
|
|
"set with data that has not been used for training the model. We also wrap this in a\n",
|
|
"function:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "93b5b4af",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def test_best_model(best_trial):\n",
|
|
" best_trained_model = Net(best_trial.config[\"l1\"], best_trial.config[\"l2\"])\n",
|
|
" device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
" best_trained_model.to(device)\n",
|
|
"\n",
|
|
" checkpoint_path = os.path.join(best_trial.checkpoint.value, \"checkpoint\")\n",
|
|
"\n",
|
|
" model_state, optimizer_state = torch.load(checkpoint_path)\n",
|
|
" best_trained_model.load_state_dict(model_state)\n",
|
|
"\n",
|
|
" trainset, testset = load_data()\n",
|
|
"\n",
|
|
" testloader = torch.utils.data.DataLoader(\n",
|
|
" testset, batch_size=4, shuffle=False, num_workers=2)\n",
|
|
"\n",
|
|
" correct = 0\n",
|
|
" total = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for data in testloader:\n",
|
|
" images, labels = data\n",
|
|
" images, labels = images.to(device), labels.to(device)\n",
|
|
" outputs = best_trained_model(images)\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += labels.size(0)\n",
|
|
" correct += (predicted == labels).sum().item()\n",
|
|
"\n",
|
|
"\n",
|
|
" print(\"Best trial test set accuracy: {}\".format(correct / total))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "85f8230e",
|
|
"metadata": {},
|
|
"source": [
|
|
"As you can see, the function also expects a `device` parameter, so we can do the\n",
|
|
"test set validation on a GPU.\n",
|
|
"\n",
|
|
"## Configuring the search space\n",
|
|
"\n",
|
|
"Lastly, we need to define Tune's search space. Here is an example:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5416cece",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = {\n",
|
|
" \"l1\": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),\n",
|
|
" \"l2\": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),\n",
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
|
|
" \"batch_size\": tune.choice([2, 4, 8, 16]),\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20af95cc",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `tune.sample_from()` function makes it possible to define your own sample\n",
|
|
"methods to obtain hyperparameters. In this example, the `l1` and `l2` parameters\n",
|
|
"should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.\n",
|
|
"The `lr` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,\n",
|
|
"the batch size is a choice between 2, 4, 8, and 16.\n",
|
|
"\n",
|
|
"At each trial, Tune will now randomly sample a combination of parameters from these\n",
|
|
"search spaces. It will then train a number of models in parallel and find the best\n",
|
|
"performing one among these. We also use the `ASHAScheduler` which will terminate bad\n",
|
|
"performing trials early.\n",
|
|
"\n",
|
|
"You can specify the number of CPUs, which are then available e.g.\n",
|
|
"to increase the `num_workers` of the PyTorch `DataLoader` instances. The selected\n",
|
|
"number of GPUs are made visible to PyTorch in each trial. Trials do not have access to\n",
|
|
"GPUs that haven't been requested for them - so you don't have to care about two trials\n",
|
|
"using the same set of resources.\n",
|
|
"\n",
|
|
"Here we can also specify fractional GPUs, so something like `gpus_per_trial=0.5` is\n",
|
|
"completely valid. The trials will then share GPUs among each other.\n",
|
|
"You just have to make sure that the models still fit in the GPU memory.\n",
|
|
"\n",
|
|
"After training the models, we will find the best performing one and load the trained\n",
|
|
"network from the checkpoint file. We then obtain the test set accuracy and report\n",
|
|
"everything by printing.\n",
|
|
"\n",
|
|
"The full main function looks like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "91d83380",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):\n",
|
|
" config = {\n",
|
|
" \"l1\": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),\n",
|
|
" \"l2\": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),\n",
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
|
|
" \"batch_size\": tune.choice([2, 4, 8, 16])\n",
|
|
" }\n",
|
|
" scheduler = ASHAScheduler(\n",
|
|
" max_t=max_num_epochs,\n",
|
|
" grace_period=1,\n",
|
|
" reduction_factor=2)\n",
|
|
" result = tune.run(\n",
|
|
" tune.with_parameters(train_cifar),\n",
|
|
" resources_per_trial={\"cpu\": 2, \"gpu\": gpus_per_trial},\n",
|
|
" config=config,\n",
|
|
" metric=\"loss\",\n",
|
|
" mode=\"min\",\n",
|
|
" num_samples=num_samples,\n",
|
|
" scheduler=scheduler\n",
|
|
" )\n",
|
|
"\n",
|
|
" best_trial = result.get_best_trial(\"loss\", \"min\", \"last\")\n",
|
|
" print(\"Best trial config: {}\".format(best_trial.config))\n",
|
|
" print(\"Best trial final validation loss: {}\".format(\n",
|
|
" best_trial.last_result[\"loss\"]))\n",
|
|
" print(\"Best trial final validation accuracy: {}\".format(\n",
|
|
" best_trial.last_result[\"accuracy\"]))\n",
|
|
"\n",
|
|
" if ray.util.client.ray.is_connected():\n",
|
|
" # If using Ray Client, we want to make sure checkpoint access\n",
|
|
" # happens on the server. So we wrap `test_best_model` in a Ray task.\n",
|
|
" # We have to make sure it gets executed on the same node that\n",
|
|
" # ``tune.run`` is called on.\n",
|
|
" from ray.util.ml_utils.node import force_on_current_node\n",
|
|
" remote_fn = force_on_current_node(ray.remote(test_best_model))\n",
|
|
" ray.get(remote_fn.remote(best_trial))\n",
|
|
" else:\n",
|
|
" test_best_model(best_trial)\n",
|
|
"\n",
|
|
"main(num_samples=2, max_num_epochs=2, gpus_per_trial=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b702b4ce",
|
|
"metadata": {},
|
|
"source": [
|
|
"If you run the code, an example output could look like this:\n",
|
|
"\n",
|
|
"```{code-block} bash\n",
|
|
":emphasize-lines: 7\n",
|
|
"\n",
|
|
" Number of trials: 10 (10 TERMINATED)\n",
|
|
" +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+\n",
|
|
" | Trial name | status | loc | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |\n",
|
|
" |-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------|\n",
|
|
" | train_cifar_87d1f_00000 | TERMINATED | | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |\n",
|
|
" | train_cifar_87d1f_00001 | TERMINATED | | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |\n",
|
|
" | train_cifar_87d1f_00002 | TERMINATED | | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |\n",
|
|
" | train_cifar_87d1f_00003 | TERMINATED | | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |\n",
|
|
" | train_cifar_87d1f_00004 | TERMINATED | | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |\n",
|
|
" | train_cifar_87d1f_00005 | TERMINATED | | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |\n",
|
|
" | train_cifar_87d1f_00006 | TERMINATED | | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |\n",
|
|
" | train_cifar_87d1f_00007 | TERMINATED | | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |\n",
|
|
" | train_cifar_87d1f_00008 | TERMINATED | | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |\n",
|
|
" | train_cifar_87d1f_00009 | TERMINATED | | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |\n",
|
|
" +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+\n",
|
|
"\n",
|
|
"\n",
|
|
" Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'}\n",
|
|
" Best trial final validation loss: 1.1815014744281769\n",
|
|
" Best trial final validation accuracy: 0.5836\n",
|
|
" Best trial test set accuracy: 0.5806\n",
|
|
"```\n",
|
|
"\n",
|
|
"As you can see, most trials have been stopped early in order to avoid wasting resources.\n",
|
|
"The best performing trial achieved a validation accuracy of about 58%, which could\n",
|
|
"be confirmed on the test set.\n",
|
|
"\n",
|
|
"So that's it! You can now tune the parameters of your PyTorch models.\n",
|
|
"\n",
|
|
"## See More PyTorch Examples\n",
|
|
"\n",
|
|
"- {doc}`/tune/examples/includes/mnist_pytorch`: Converts the PyTorch MNIST example to use Tune with the function-based API.\n",
|
|
" Also shows how to easily convert something relying on argparse to use Tune.\n",
|
|
"- {doc}`/tune/examples/includes/pbt_convnet_function_example`: Example training a ConvNet with checkpointing in function API.\n",
|
|
"- {doc}`/tune/examples/includes/mnist_pytorch_trainable`: Converts the PyTorch MNIST example to use Tune with Trainable API.\n",
|
|
" Also uses the HyperBandScheduler and checkpoints the model at the end."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.6"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|