mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
877 lines
No EOL
33 KiB
Text
877 lines
No EOL
33 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3fb2a049",
|
|
"metadata": {
|
|
"tags": [
|
|
"remove-cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "24af9556",
|
|
"metadata": {
|
|
"tags": [
|
|
"remove-cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# flake8: noqa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "70ac1fe8",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model selection and serving with Ray Tune and Ray Serve\n",
|
|
"\n",
|
|
"```{image} /images/serve.svg\n",
|
|
":align: center\n",
|
|
"```\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```\n",
|
|
"\n",
|
|
"This tutorial will show you an end-to-end example how to train a\n",
|
|
"model using Ray Tune on incrementally arriving data and deploy\n",
|
|
"the model using Ray Serve.\n",
|
|
"\n",
|
|
"A machine learning workflow can be quite simple: You decide on\n",
|
|
"the objective you're trying to solve, collect and annotate the\n",
|
|
"data, and build a model to hopefully solve your problem. But\n",
|
|
"usually the work is not over yet. First, you would likely continue\n",
|
|
"to do some hyperparameter optimization to obtain the best possible\n",
|
|
"model (called *model selection*). Second, your trained model\n",
|
|
"somehow has to be moved to production - in other words, users\n",
|
|
"or services should be enabled to use your model to actually make\n",
|
|
"predictions. This part is called *model serving*.\n",
|
|
"\n",
|
|
"Fortunately, Ray includes two libraries that help you with these\n",
|
|
"two steps: Ray Tune and Ray Serve. And even more, they compliment\n",
|
|
"each other nicely. Most notably, both are able to scale up your\n",
|
|
"workloads easily - so both your model training and serving benefit\n",
|
|
"from additional resources and can adapt to your environment. If you\n",
|
|
"need to train on more data or have more hyperparameters to tune,\n",
|
|
"Ray Tune can leverage your whole cluster for training. If you have\n",
|
|
"many users doing inference on your served models, Ray Serve can\n",
|
|
"automatically distribute the inference to multiple nodes.\n",
|
|
"\n",
|
|
"This tutorial will show you an end-to-end example how to train a MNIST\n",
|
|
"image classifier on incrementally arriving data and automatically\n",
|
|
"serve an updated model on a HTTP endpoint.\n",
|
|
"\n",
|
|
"By the end of this tutorial you will be able to\n",
|
|
"\n",
|
|
"1. Do hyperparameter optimization on a simple MNIST classifier\n",
|
|
"2. Continue to train this classifier from an existing model with\n",
|
|
" newly arriving data\n",
|
|
"3. Automatically create and serve data deployments with Ray Serve\n",
|
|
"\n",
|
|
"## Roadmap and desired functionality\n",
|
|
"\n",
|
|
"The general idea of this example is that we simulate newly arriving\n",
|
|
"data each day. So at day 0 we might have some initial data available\n",
|
|
"already, but at each day, new data arrives.\n",
|
|
"\n",
|
|
"Our approach here is that we offer two ways to train: From scratch and\n",
|
|
"from an existing model. Maybe you would like to train and select models\n",
|
|
"from scratch each week with all data available until then, e.g. each\n",
|
|
"Sunday, like this:\n",
|
|
"\n",
|
|
"```{code-block} bash\n",
|
|
"# Train with all data available at day 0\n",
|
|
"python tune-serve-integration-mnist.py --from_scratch --day 0\n",
|
|
"```\n",
|
|
"\n",
|
|
"During the other days you might want to improve your model, but\n",
|
|
"not train everything from scratch, saving some cluster resources.\n",
|
|
"\n",
|
|
"```{code-block} bash\n",
|
|
"# Train with data arriving between day 0 and day 1\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 1\n",
|
|
"# Train with incremental data on the other days, too\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 2\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 3\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 4\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 5\n",
|
|
"python tune-serve-integration-mnist.py --from_existing --day 6\n",
|
|
"# Retrain from scratch every 7th day:\n",
|
|
"python tune-serve-integration-mnist.py --from_scratch --day 7\n",
|
|
"```\n",
|
|
"\n",
|
|
"This example will support both modes. After each model selection run,\n",
|
|
"we will tell Ray Serve to serve an updated model. We also include a\n",
|
|
"small utility to query our served model to see if it works as it should.\n",
|
|
"\n",
|
|
"```{code-block} bash\n",
|
|
"$ python tune-serve-integration-mnist.py --query 6\n",
|
|
"Querying model with example #6. Label = 1, Response = 1, Correct = True\n",
|
|
"```\n",
|
|
"\n",
|
|
"\n",
|
|
"## Imports\n",
|
|
"\n",
|
|
"Let's start with our dependencies. Most of these should be familiar\n",
|
|
"if you worked with PyTorch before. The most notable import for Ray\n",
|
|
"is the ``from ray import tune, serve`` import statement - which\n",
|
|
"includes almost all the things we need from the Ray side."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0376c9c4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import argparse\n",
|
|
"import json\n",
|
|
"import os\n",
|
|
"import shutil\n",
|
|
"import sys\n",
|
|
"from functools import partial\n",
|
|
"from math import ceil\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"import ray\n",
|
|
"from ray import tune, serve\n",
|
|
"from ray.serve.exceptions import RayServeException\n",
|
|
"from ray.tune import CLIReporter\n",
|
|
"from ray.tune.schedulers import ASHAScheduler\n",
|
|
"\n",
|
|
"from torch.utils.data import random_split, Subset\n",
|
|
"from torchvision.datasets import MNIST\n",
|
|
"from torchvision.transforms import transforms"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "58eaafa1",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data interface\n",
|
|
"\n",
|
|
"Let's start with a simulated data interface. This class acts as the\n",
|
|
"interface between your training code and your database. We simulate\n",
|
|
"that new data arrives each day with a ``day`` parameter. So, calling\n",
|
|
"``get_data(day=3)`` would return all data we received until day 3.\n",
|
|
"We also implement an incremental data method, so calling\n",
|
|
"``get_incremental_data(day=3)`` would return all data collected\n",
|
|
"between day 2 and day 3."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fcf4de94",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MNISTDataInterface(object):\n",
|
|
" \"\"\"Data interface. Simulates that new data arrives every day.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, data_dir, max_days=10):\n",
|
|
" self.data_dir = data_dir\n",
|
|
" self.max_days = max_days\n",
|
|
"\n",
|
|
" transform = transforms.Compose(\n",
|
|
" [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
|
|
" )\n",
|
|
" self.dataset = MNIST(\n",
|
|
" self.data_dir, train=True, download=True, transform=transform\n",
|
|
" )\n",
|
|
"\n",
|
|
" def _get_day_slice(self, day=0):\n",
|
|
" if day < 0:\n",
|
|
" return 0\n",
|
|
" n = len(self.dataset)\n",
|
|
" # Start with 30% of the data, get more data each day\n",
|
|
" return min(n, ceil(n * (0.3 + 0.7 * day / self.max_days)))\n",
|
|
"\n",
|
|
" def get_data(self, day=0):\n",
|
|
" \"\"\"Get complete normalized train and validation data to date.\"\"\"\n",
|
|
" end = self._get_day_slice(day)\n",
|
|
"\n",
|
|
" available_data = Subset(self.dataset, list(range(end)))\n",
|
|
" train_n = int(0.8 * end) # 80% train data, 20% validation data\n",
|
|
"\n",
|
|
" return random_split(available_data, [train_n, end - train_n])\n",
|
|
"\n",
|
|
" def get_incremental_data(self, day=0):\n",
|
|
" \"\"\"Get next normalized train and validation data day slice.\"\"\"\n",
|
|
" start = self._get_day_slice(day - 1)\n",
|
|
" end = self._get_day_slice(day)\n",
|
|
"\n",
|
|
" available_data = Subset(self.dataset, list(range(start, end)))\n",
|
|
" train_n = int(0.8 * (end - start)) # 80% train data, 20% validation data\n",
|
|
"\n",
|
|
" return random_split(available_data, [train_n, end - start - train_n])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "13612bb2",
|
|
"metadata": {},
|
|
"source": [
|
|
"## PyTorch neural network classifier\n",
|
|
"\n",
|
|
"Next, we will introduce our PyTorch neural network model and the\n",
|
|
"train and test function. These are adapted directly from\n",
|
|
"our {doc}`PyTorch MNIST example </tune/examples/includes/mnist_pytorch>`.\n",
|
|
"We only introduced an additional neural network layer with a configurable\n",
|
|
"layer size. This is not strictly needed for learning good performance on\n",
|
|
"MNIST, but it is useful to demonstrate scenarios where your hyperparameter\n",
|
|
"search space affects the model complexity."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c2c21aa0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class ConvNet(nn.Module):\n",
|
|
" def __init__(self, layer_size=192):\n",
|
|
" super(ConvNet, self).__init__()\n",
|
|
" self.layer_size = layer_size\n",
|
|
" self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n",
|
|
" self.fc = nn.Linear(192, self.layer_size)\n",
|
|
" self.out = nn.Linear(self.layer_size, 10)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = F.relu(F.max_pool2d(self.conv1(x), 3))\n",
|
|
" x = x.view(-1, 192)\n",
|
|
" x = self.fc(x)\n",
|
|
" x = self.out(x)\n",
|
|
" return F.log_softmax(x, dim=1)\n",
|
|
"\n",
|
|
"\n",
|
|
"def train(model, optimizer, train_loader, device=None):\n",
|
|
" device = device or torch.device(\"cpu\")\n",
|
|
" model.train()\n",
|
|
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
|
" data, target = data.to(device), target.to(device)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" output = model(data)\n",
|
|
" loss = F.nll_loss(output, target)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
"\n",
|
|
"def test(model, data_loader, device=None):\n",
|
|
" device = device or torch.device(\"cpu\")\n",
|
|
" model.eval()\n",
|
|
" correct = 0\n",
|
|
" total = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for batch_idx, (data, target) in enumerate(data_loader):\n",
|
|
" data, target = data.to(device), target.to(device)\n",
|
|
" outputs = model(data)\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += target.size(0)\n",
|
|
" correct += (predicted == target).sum().item()\n",
|
|
"\n",
|
|
" return correct / total"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "677ded46",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tune trainable for model selection\n",
|
|
"\n",
|
|
"We'll now define our Tune trainable function. This function takes\n",
|
|
"a ``config`` parameter containing the hyperparameters we should train\n",
|
|
"the model on, and will start a full training run. This means it\n",
|
|
"will take care of creating the model and optimizer and repeatedly\n",
|
|
"call the ``train`` function to train the model. Also, this function\n",
|
|
"will report the training progress back to Tune."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4c29de4c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_mnist(\n",
|
|
" config,\n",
|
|
" start_model=None,\n",
|
|
" checkpoint_dir=None,\n",
|
|
" num_epochs=10,\n",
|
|
" use_gpus=False,\n",
|
|
" data_fn=None,\n",
|
|
" day=0,\n",
|
|
"):\n",
|
|
" # Create model\n",
|
|
" use_cuda = use_gpus and torch.cuda.is_available()\n",
|
|
" device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
|
|
" model = ConvNet(layer_size=config[\"layer_size\"]).to(device)\n",
|
|
"\n",
|
|
" # Create optimizer\n",
|
|
" optimizer = optim.SGD(\n",
|
|
" model.parameters(), lr=config[\"lr\"], momentum=config[\"momentum\"]\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Load checkpoint, or load start model if no checkpoint has been\n",
|
|
" # passed and a start model is specified\n",
|
|
" load_dir = None\n",
|
|
" if checkpoint_dir:\n",
|
|
" load_dir = checkpoint_dir\n",
|
|
" elif start_model:\n",
|
|
" load_dir = start_model\n",
|
|
"\n",
|
|
" if load_dir:\n",
|
|
" model_state, optimizer_state = torch.load(os.path.join(load_dir, \"checkpoint\"))\n",
|
|
" model.load_state_dict(model_state)\n",
|
|
" optimizer.load_state_dict(optimizer_state)\n",
|
|
"\n",
|
|
" # Get full training datasets\n",
|
|
" train_dataset, validation_dataset = data_fn(day=day)\n",
|
|
"\n",
|
|
" train_loader = torch.utils.data.DataLoader(\n",
|
|
" train_dataset, batch_size=config[\"batch_size\"], shuffle=True\n",
|
|
" )\n",
|
|
"\n",
|
|
" validation_loader = torch.utils.data.DataLoader(\n",
|
|
" validation_dataset, batch_size=config[\"batch_size\"], shuffle=True\n",
|
|
" )\n",
|
|
"\n",
|
|
" for i in range(num_epochs):\n",
|
|
" train(model, optimizer, train_loader, device)\n",
|
|
" acc = test(model, validation_loader, device)\n",
|
|
" if i == num_epochs - 1:\n",
|
|
" with tune.checkpoint_dir(step=i) as checkpoint_dir:\n",
|
|
" torch.save(\n",
|
|
" (model.state_dict(), optimizer.state_dict()),\n",
|
|
" os.path.join(checkpoint_dir, \"checkpoint\"),\n",
|
|
" )\n",
|
|
" tune.report(mean_accuracy=acc, done=True)\n",
|
|
" else:\n",
|
|
" tune.report(mean_accuracy=acc)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "513f8db0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Configuring the search space and starting Ray Tune\n",
|
|
"\n",
|
|
"We would like to support two modes of training the model: Training\n",
|
|
"a model from scratch, and continuing to train a model from an\n",
|
|
"existing one.\n",
|
|
"\n",
|
|
"This is our function to train a number of models with different\n",
|
|
"hyperparameters from scratch, i.e. from all data that is available\n",
|
|
"until the given day. Our search space can thus also contain parameters\n",
|
|
"that affect the model complexity (such as the layer size), since it\n",
|
|
"does not have to be compatible to an existing model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "82fcbf6e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tune_from_scratch(num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0):\n",
|
|
" data_interface = MNISTDataInterface(\"~/data\", max_days=10)\n",
|
|
" num_examples = data_interface._get_day_slice(day)\n",
|
|
"\n",
|
|
" config = {\n",
|
|
" \"batch_size\": tune.choice([16, 32, 64]),\n",
|
|
" \"layer_size\": tune.choice([32, 64, 128, 192]),\n",
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
|
|
" \"momentum\": tune.uniform(0.1, 0.9),\n",
|
|
" }\n",
|
|
"\n",
|
|
" scheduler = ASHAScheduler(\n",
|
|
" metric=\"mean_accuracy\",\n",
|
|
" mode=\"max\",\n",
|
|
" max_t=num_epochs,\n",
|
|
" grace_period=1,\n",
|
|
" reduction_factor=2,\n",
|
|
" )\n",
|
|
"\n",
|
|
" reporter = CLIReporter(\n",
|
|
" parameter_columns=[\"layer_size\", \"lr\", \"momentum\", \"batch_size\"],\n",
|
|
" metric_columns=[\"mean_accuracy\", \"training_iteration\"],\n",
|
|
" )\n",
|
|
"\n",
|
|
" analysis = tune.run(\n",
|
|
" partial(\n",
|
|
" train_mnist,\n",
|
|
" start_model=None,\n",
|
|
" data_fn=data_interface.get_data,\n",
|
|
" num_epochs=num_epochs,\n",
|
|
" use_gpus=True if gpus_per_trial > 0 else False,\n",
|
|
" day=day,\n",
|
|
" ),\n",
|
|
" resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
|
|
" config=config,\n",
|
|
" num_samples=num_samples,\n",
|
|
" scheduler=scheduler,\n",
|
|
" progress_reporter=reporter,\n",
|
|
" verbose=0,\n",
|
|
" name=\"tune_serve_mnist_fromscratch\",\n",
|
|
" )\n",
|
|
"\n",
|
|
" best_trial = analysis.get_best_trial(\"mean_accuracy\", \"max\", \"last\")\n",
|
|
" best_accuracy = best_trial.metric_analysis[\"mean_accuracy\"][\"last\"]\n",
|
|
" best_trial_config = best_trial.config\n",
|
|
" best_checkpoint = best_trial.checkpoint.value\n",
|
|
"\n",
|
|
" return best_accuracy, best_trial_config, best_checkpoint, num_examples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f051b634",
|
|
"metadata": {},
|
|
"source": [
|
|
"To continue training from an existing model, we can use this function\n",
|
|
"instead. It takes a starting model (a checkpoint) as a parameter and\n",
|
|
"the old config.\n",
|
|
"\n",
|
|
"Note that this time the search space does _not_ contain the\n",
|
|
"layer size parameter. Since we continue to train an existing model,\n",
|
|
"we cannot change the layer size mid training, so we just continue\n",
|
|
"to use the existing one."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "56b26451",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def tune_from_existing(\n",
|
|
" start_model, start_config, num_samples=10, num_epochs=10, gpus_per_trial=0.0, day=0\n",
|
|
"):\n",
|
|
" data_interface = MNISTDataInterface(\"/tmp/mnist_data\", max_days=10)\n",
|
|
" num_examples = data_interface._get_day_slice(day) - data_interface._get_day_slice(\n",
|
|
" day - 1\n",
|
|
" )\n",
|
|
"\n",
|
|
" config = start_config.copy()\n",
|
|
" config.update(\n",
|
|
" {\n",
|
|
" \"batch_size\": tune.choice([16, 32, 64]),\n",
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
|
|
" \"momentum\": tune.uniform(0.1, 0.9),\n",
|
|
" }\n",
|
|
" )\n",
|
|
"\n",
|
|
" scheduler = ASHAScheduler(\n",
|
|
" metric=\"mean_accuracy\",\n",
|
|
" mode=\"max\",\n",
|
|
" max_t=num_epochs,\n",
|
|
" grace_period=1,\n",
|
|
" reduction_factor=2,\n",
|
|
" )\n",
|
|
"\n",
|
|
" reporter = CLIReporter(\n",
|
|
" parameter_columns=[\"lr\", \"momentum\", \"batch_size\"],\n",
|
|
" metric_columns=[\"mean_accuracy\", \"training_iteration\"],\n",
|
|
" )\n",
|
|
"\n",
|
|
" analysis = tune.run(\n",
|
|
" partial(\n",
|
|
" train_mnist,\n",
|
|
" start_model=start_model,\n",
|
|
" data_fn=data_interface.get_incremental_data,\n",
|
|
" num_epochs=num_epochs,\n",
|
|
" use_gpus=True if gpus_per_trial > 0 else False,\n",
|
|
" day=day,\n",
|
|
" ),\n",
|
|
" resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
|
|
" config=config,\n",
|
|
" num_samples=num_samples,\n",
|
|
" scheduler=scheduler,\n",
|
|
" progress_reporter=reporter,\n",
|
|
" verbose=0,\n",
|
|
" name=\"tune_serve_mnist_fromsexisting\",\n",
|
|
" )\n",
|
|
"\n",
|
|
" best_trial = analysis.get_best_trial(\"mean_accuracy\", \"max\", \"last\")\n",
|
|
" best_accuracy = best_trial.metric_analysis[\"mean_accuracy\"][\"last\"]\n",
|
|
" best_trial_config = best_trial.config\n",
|
|
" best_checkpoint = best_trial.checkpoint.value\n",
|
|
"\n",
|
|
" return best_accuracy, best_trial_config, best_checkpoint, num_examples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a25629c1",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Serving tuned models with Ray Serve\n",
|
|
"\n",
|
|
"Let's now turn to the model serving part with Ray Serve. Serve allows\n",
|
|
"you to deploy your models as multiple _deployments_. Broadly speaking,\n",
|
|
"a deployment handles incoming requests and replies with a result. For\n",
|
|
"instance, our MNIST deployment takes an image as input and outputs the\n",
|
|
"digit it recognized from it. This deployment can be exposed over HTTP.\n",
|
|
"\n",
|
|
"First, we will define our deployment. This loads our PyTorch\n",
|
|
"MNIST model from a checkpoint, takes an image as an input and\n",
|
|
"outputs our digit prediction according to our trained model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a0d6a4ca",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@serve.deployment(name=\"mnist\", route_prefix=\"/mnist\")\n",
|
|
"class MNISTDeployment:\n",
|
|
" def __init__(self, checkpoint_dir, config, metrics, use_gpu=False):\n",
|
|
" self.checkpoint_dir = checkpoint_dir\n",
|
|
" self.config = config\n",
|
|
" self.metrics = metrics\n",
|
|
"\n",
|
|
" use_cuda = use_gpu and torch.cuda.is_available()\n",
|
|
" self.device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
|
|
" model = ConvNet(layer_size=self.config[\"layer_size\"]).to(self.device)\n",
|
|
"\n",
|
|
" model_state, optimizer_state = torch.load(\n",
|
|
" os.path.join(self.checkpoint_dir, \"checkpoint\"), map_location=self.device\n",
|
|
" )\n",
|
|
" model.load_state_dict(model_state)\n",
|
|
"\n",
|
|
" self.model = model\n",
|
|
"\n",
|
|
" def __call__(self, flask_request):\n",
|
|
" images = torch.tensor(flask_request.json[\"images\"])\n",
|
|
" images = images.to(self.device)\n",
|
|
" outputs = self.model(images)\n",
|
|
" predicted = torch.max(outputs.data, 1)[1]\n",
|
|
" return {\"result\": predicted.numpy().tolist()}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2ba14c4a",
|
|
"metadata": {},
|
|
"source": [
|
|
"We would like to have a fixed location where we store the currently\n",
|
|
"active model. We call this directory ``model_dir``. Every time we\n",
|
|
"would like to update our model, we copy the checkpoint of the new\n",
|
|
"model to this directory. We then update the deployment to the new version."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bba77923",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def serve_new_model(model_dir, checkpoint, config, metrics, day, use_gpu=False):\n",
|
|
" print(\"Serving checkpoint: {}\".format(checkpoint))\n",
|
|
"\n",
|
|
" checkpoint_path = _move_checkpoint_to_model_dir(\n",
|
|
" model_dir, checkpoint, config, metrics\n",
|
|
" )\n",
|
|
"\n",
|
|
" serve.start(detached=True)\n",
|
|
" MNISTDeployment.deploy(checkpoint_path, config, metrics, use_gpu)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _move_checkpoint_to_model_dir(model_dir, checkpoint, config, metrics):\n",
|
|
" \"\"\"Move backend checkpoint to a central `model_dir` on the head node.\n",
|
|
" If you would like to run Serve on multiple nodes, you might want to\n",
|
|
" move the checkpoint to a shared storage, like Amazon S3, instead.\"\"\"\n",
|
|
" os.makedirs(model_dir, 0o755, exist_ok=True)\n",
|
|
"\n",
|
|
" checkpoint_path = os.path.join(model_dir, \"checkpoint\")\n",
|
|
" meta_path = os.path.join(model_dir, \"meta.json\")\n",
|
|
"\n",
|
|
" if os.path.exists(checkpoint_path):\n",
|
|
" shutil.rmtree(checkpoint_path)\n",
|
|
"\n",
|
|
" shutil.copytree(checkpoint, checkpoint_path)\n",
|
|
"\n",
|
|
" with open(meta_path, \"wt\") as fp:\n",
|
|
" json.dump(dict(config=config, metrics=metrics), fp)\n",
|
|
"\n",
|
|
" return checkpoint_path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f779c7bc",
|
|
"metadata": {},
|
|
"source": [
|
|
"Since we would like to continue training from the current existing\n",
|
|
"model, we introduce an utility function that fetches the currently\n",
|
|
"served checkpoint as well as the hyperparameter config and achieved\n",
|
|
"accuracy."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "005f2787",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_current_model(model_dir):\n",
|
|
" checkpoint_path = os.path.join(model_dir, \"checkpoint\")\n",
|
|
" meta_path = os.path.join(model_dir, \"meta.json\")\n",
|
|
"\n",
|
|
" if not os.path.exists(checkpoint_path) or not os.path.exists(meta_path):\n",
|
|
" return None, None, None\n",
|
|
"\n",
|
|
" with open(meta_path, \"rt\") as fp:\n",
|
|
" meta = json.load(fp)\n",
|
|
"\n",
|
|
" return checkpoint_path, meta[\"config\"], meta[\"metrics\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5c55a5d3",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Putting everything together\n",
|
|
"\n",
|
|
"Now we only need to glue this code together. This is the main\n",
|
|
"entrypoint of the script, and we will define three methods:\n",
|
|
"\n",
|
|
"1. Train new model from scratch with all data\n",
|
|
"2. Continue training from existing model with new data only\n",
|
|
"3. Query the model with test data\n",
|
|
"\n",
|
|
"Internally, this will just call the ``tune_from_scratch`` and\n",
|
|
"``tune_from_existing()`` functions.\n",
|
|
"Both training functions will then call ``serve_new_model()`` to serve\n",
|
|
"the newly trained or updated model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "053bbbfe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# The query function will send a HTTP request to Serve with some\n",
|
|
"# test data obtained from the MNIST dataset.\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" \"\"\"\n",
|
|
" This script offers training a new model from scratch with all\n",
|
|
" available data, or continuing to train an existing model\n",
|
|
" with newly available data.\n",
|
|
"\n",
|
|
" For instance, we might get new data every day. Every Sunday, we\n",
|
|
" would like to train a new model from scratch.\n",
|
|
"\n",
|
|
" Naturally, we would like to use hyperparameter optimization to\n",
|
|
" find the best model for out data.\n",
|
|
"\n",
|
|
" First, we might train a model with all data available at this day:\n",
|
|
"\n",
|
|
" ```{code-block} bash\n",
|
|
" python tune-serve-integration-mnist.py --from_scratch --day 0\n",
|
|
" ```\n",
|
|
"\n",
|
|
" On the coming days, we want to continue to train this model with\n",
|
|
" newly available data:\n",
|
|
"\n",
|
|
" ```{code-block} bash\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 1\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 2\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 3\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 4\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 5\n",
|
|
" python tune-serve-integration-mnist.py --from_existing --day 6\n",
|
|
" # Retrain from scratch every 7th day:\n",
|
|
" python tune-serve-integration-mnist.py --from_scratch --day 7\n",
|
|
" ```\n",
|
|
"\n",
|
|
" We can also use this script to query our served model\n",
|
|
" with some test data:\n",
|
|
"\n",
|
|
" ```{code-block} bash\n",
|
|
" python tune-serve-integration-mnist.py --query 6\n",
|
|
" Querying model with example #6. Label = 1, Response = 1, Correct = T\n",
|
|
" python tune-serve-integration-mnist.py --query 28\n",
|
|
" Querying model with example #28. Label = 2, Response = 7, Correct = F\n",
|
|
" ```\n",
|
|
"\n",
|
|
" \"\"\"\n",
|
|
" parser = argparse.ArgumentParser(description=\"MNIST Tune/Serve example\")\n",
|
|
" parser.add_argument(\"--model_dir\", type=str, default=\"~/mnist_tune_serve\")\n",
|
|
"\n",
|
|
" parser.add_argument(\n",
|
|
" \"--from_scratch\",\n",
|
|
" action=\"store_true\",\n",
|
|
" help=\"Train and select best model from scratch\",\n",
|
|
" default=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" parser.add_argument(\n",
|
|
" \"--from_existing\",\n",
|
|
" action=\"store_true\",\n",
|
|
" help=\"Train and select best model from existing model\",\n",
|
|
" default=False,\n",
|
|
" )\n",
|
|
"\n",
|
|
" parser.add_argument(\n",
|
|
" \"--day\",\n",
|
|
" help=\"Indicate the day to simulate the amount of data available to us\",\n",
|
|
" type=int,\n",
|
|
" default=0,\n",
|
|
" )\n",
|
|
"\n",
|
|
" parser.add_argument(\n",
|
|
" \"--query\", help=\"Query endpoint with example\", type=int, default=-1\n",
|
|
" )\n",
|
|
"\n",
|
|
" parser.add_argument(\n",
|
|
" \"--smoke-test\",\n",
|
|
" action=\"store_true\",\n",
|
|
" help=\"Finish quickly for testing\",\n",
|
|
" default=True,\n",
|
|
" )\n",
|
|
"\n",
|
|
" args = parser.parse_args()\n",
|
|
"\n",
|
|
" if args.smoke_test:\n",
|
|
" ray.init(num_cpus=3, namespace=\"tune-serve-integration\")\n",
|
|
" else:\n",
|
|
" ray.init(namespace=\"tune-serve-integration\")\n",
|
|
"\n",
|
|
" model_dir = os.path.expanduser(args.model_dir)\n",
|
|
"\n",
|
|
" if args.query >= 0:\n",
|
|
" import requests\n",
|
|
"\n",
|
|
" dataset = MNISTDataInterface(\"/tmp/mnist_data\", max_days=0).dataset\n",
|
|
" data = dataset[args.query]\n",
|
|
" label = data[1]\n",
|
|
"\n",
|
|
" # Query our model\n",
|
|
" response = requests.post(\n",
|
|
" \"http://localhost:8000/mnist\", json={\"images\": [data[0].numpy().tolist()]}\n",
|
|
" )\n",
|
|
"\n",
|
|
" try:\n",
|
|
" pred = response.json()[\"result\"][0]\n",
|
|
" except: # noqa: E722\n",
|
|
" pred = -1\n",
|
|
"\n",
|
|
" print(\n",
|
|
" \"Querying model with example #{}. \"\n",
|
|
" \"Label = {}, Response = {}, Correct = {}\".format(\n",
|
|
" args.query, label, pred, label == pred\n",
|
|
" )\n",
|
|
" )\n",
|
|
" sys.exit(0)\n",
|
|
"\n",
|
|
" gpus_per_trial = 0.5 if not args.smoke_test else 0.0\n",
|
|
" serve_gpu = True if gpus_per_trial > 0 else False\n",
|
|
" num_samples = 8 if not args.smoke_test else 1\n",
|
|
" num_epochs = 10 if not args.smoke_test else 1\n",
|
|
"\n",
|
|
" if args.from_scratch: # train everyday from scratch\n",
|
|
" print(\"Start training job from scratch on day {}.\".format(args.day))\n",
|
|
" acc, config, best_checkpoint, num_examples = tune_from_scratch(\n",
|
|
" num_samples, num_epochs, gpus_per_trial, day=args.day\n",
|
|
" )\n",
|
|
" print(\n",
|
|
" \"Trained day {} from scratch on {} samples. \"\n",
|
|
" \"Best accuracy: {:.4f}. Best config: {}\".format(\n",
|
|
" args.day, num_examples, acc, config\n",
|
|
" )\n",
|
|
" )\n",
|
|
" serve_new_model(\n",
|
|
" model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu\n",
|
|
" )\n",
|
|
"\n",
|
|
" if args.from_existing:\n",
|
|
" old_checkpoint, old_config, old_acc = get_current_model(model_dir)\n",
|
|
" if not old_checkpoint or not old_config or not old_acc:\n",
|
|
" print(\"No existing model found. Train one with --from_scratch \" \"first.\")\n",
|
|
" sys.exit(1)\n",
|
|
" acc, config, best_checkpoint, num_examples = tune_from_existing(\n",
|
|
" old_checkpoint,\n",
|
|
" old_config,\n",
|
|
" num_samples,\n",
|
|
" num_epochs,\n",
|
|
" gpus_per_trial,\n",
|
|
" day=args.day,\n",
|
|
" )\n",
|
|
" print(\n",
|
|
" \"Trained day {} from existing on {} samples. \"\n",
|
|
" \"Best accuracy: {:.4f}. Best config: {}\".format(\n",
|
|
" args.day, num_examples, acc, config\n",
|
|
" )\n",
|
|
" )\n",
|
|
" serve_new_model(\n",
|
|
" model_dir, best_checkpoint, config, acc, args.day, use_gpu=serve_gpu\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7c8be26a",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's it! We now have an end-to-end workflow to train and update a\n",
|
|
"model every day with newly arrived data. Every week we might retrain\n",
|
|
"the whole model. At every point in time we make sure to serve the\n",
|
|
"model that achieved the best validation set accuracy.\n",
|
|
"\n",
|
|
"There are some ways we might extend this example. For instance, right\n",
|
|
"now we only serve the latest trained model. We could also choose to\n",
|
|
"route only a certain percentage of users to the new model, maybe to\n",
|
|
"see if the new model really does it's job right. These kind of\n",
|
|
"deployments are called canary deployments.\n",
|
|
"These kind of deployments would also require us to keep more than one\n",
|
|
"model in our ``model_dir`` - which should be quite easy: We could just\n",
|
|
"create subdirectories for each training day.\n",
|
|
"\n",
|
|
"Still, this example should show you how easy it is to integrate the\n",
|
|
"Ray libraries Ray Tune and Ray Serve in your workflow. While both tools\n",
|
|
"also work independently of each other, they complement each other\n",
|
|
"nicely and support a large number of use cases."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |