ray/doc/source/ray-air/examples/torch_incremental_learning.ipynb

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1712 lines
109 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "TsniIjjg2Pym"
},
"source": [
"*This example is adapted from Continual AI Avalanche quick start https://avalanche.continualai.org/*"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VsUrzVm1W-h"
},
"source": [
"# Incremental Learning with Ray AIR\n",
"\n",
"In this example, we show how to use Ray AIR to incrementally train a simple image classification PyTorch model\n",
"on a stream of incoming tasks.\n",
"\n",
"Each task is a random permutation of the MNIST Dataset, which is a common benchmark\n",
"used for continual training. After training on all the\n",
"tasks, the model is expected to be able to make predictions on data from any task.\n",
"\n",
"In this example, we use just a naive finetuning strategy, where the model is trained\n",
"on each task, without any special methods to prevent [catastrophic forgetting](\n",
"https://en.wikipedia.org/wiki/Catastrophic_interference). Model performance is\n",
"expected to be poor.\n",
"\n",
"More precisely, this example showcases domain incremental training, in which during\n",
"prediction/testing\n",
"time, the model is asked to predict on data from tasks trained on so far with the\n",
"task ID not provided. This is opposed to task incremental training, where the task ID is\n",
"provided during prediction/testing time.\n",
"\n",
"For more information on the 3 different categories for incremental/continual\n",
"learning, please see [\"Three scenarios for continual learning\" by van de Ven and Tolias](https://arxiv.org/pdf/1904.07734.pdf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q3oGiuqYfj9_"
},
"source": [
"This example will cover the following:\n",
"1. Loading a PyTorch Dataset to Ray Datasets\n",
"2. Create an `Iterator[ray.data.Datasets]` abstraction to represent a stream of data to train on for incremental training.\n",
"3. Implement a custom Ray AIR preprocessor to preprocess the Dataset.\n",
"4. Incrementally train a model using data parallel training.\n",
"5. Use our trained model to perform batch prediction on test data.\n",
"6. Incrementally deploying our trained model with Ray Serve and performing online prediction queries."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z52Y8O4q1bIk"
},
"source": [
"# Step 1: Installations and Initializing Ray\n",
"\n",
"To get started, let's first install the necessary packages: Ray AIR, torch, and torchvision. Uncomment the below lines and run the cell to install the necessary packages."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kWr6BRMk1Y1j",
"outputId": "dad49a31-a602-4e44-b5fe-932de603925e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: ray[data,serve,tune] in /usr/local/lib/python3.7/dist-packages (2.0.0.dev0)\n",
"Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.21.6)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (3.7.0)\n",
"Requirement already satisfied: grpcio!=1.44.0,>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.43.0)\n",
"Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.0.3)\n",
"Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (3.17.3)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.2.0)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.3.0)\n",
"Requirement already satisfied: virtualenv in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (20.14.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (3.13)\n",
"Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (21.4.0)\n",
"Requirement already satisfied: click<=8.0.4,>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (7.1.2)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (2.23.0)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (4.3.3)\n",
"Requirement already satisfied: pyarrow<7.0.0,>=6.0.1 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (6.0.1)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (2022.5.0)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.3.5)\n",
"Requirement already satisfied: tensorboardX>=1.9 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (2.5)\n",
"Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.8.9)\n",
"Requirement already satisfied: aiorwlock in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.3.0)\n",
"Requirement already satisfied: starlette in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.19.1)\n",
"Requirement already satisfied: prometheus-client<0.14.0,>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.13.1)\n",
"Requirement already satisfied: py-spy>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.3.12)\n",
"Requirement already satisfied: smart-open in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (6.0.0)\n",
"Requirement already satisfied: gpustat>=1.0.0b1 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (1.0.0b1)\n",
"Requirement already satisfied: colorful in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.5.4)\n",
"Requirement already satisfied: aiohttp>=3.7 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (3.8.1)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.78.0)\n",
"Requirement already satisfied: aiohttp-cors in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.7.0)\n",
"Requirement already satisfied: uvicorn==0.16.0 in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.16.0)\n",
"Requirement already satisfied: opencensus in /usr/local/lib/python3.7/dist-packages (from ray[data,serve,tune]) (0.9.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from uvicorn==0.16.0->ray[data,serve,tune]) (4.2.0)\n",
"Requirement already satisfied: asgiref>=3.4.0 in /usr/local/lib/python3.7/dist-packages (from uvicorn==0.16.0->ray[data,serve,tune]) (3.5.2)\n",
"Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.7/dist-packages (from uvicorn==0.16.0->ray[data,serve,tune]) (0.13.0)\n",
"Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp>=3.7->ray[data,serve,tune]) (0.13.0)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp>=3.7->ray[data,serve,tune]) (2.0.12)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp>=3.7->ray[data,serve,tune]) (6.0.2)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp>=3.7->ray[data,serve,tune]) (4.0.2)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp>=3.7->ray[data,serve,tune]) (1.7.2)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from gpustat>=1.0.0b1->ray[data,serve,tune]) (5.4.8)\n",
"Requirement already satisfied: six>=1.7 in /usr/local/lib/python3.7/dist-packages (from gpustat>=1.0.0b1->ray[data,serve,tune]) (1.15.0)\n",
"Requirement already satisfied: blessed>=1.17.1 in /usr/local/lib/python3.7/dist-packages (from gpustat>=1.0.0b1->ray[data,serve,tune]) (1.19.1)\n",
"Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.7/dist-packages (from gpustat>=1.0.0b1->ray[data,serve,tune]) (7.352.0)\n",
"Requirement already satisfied: wcwidth>=0.1.4 in /usr/local/lib/python3.7/dist-packages (from blessed>=1.17.1->gpustat>=1.0.0b1->ray[data,serve,tune]) (0.2.5)\n",
"Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.7/dist-packages (from yarl<2.0,>=1.0->aiohttp>=3.7->ray[data,serve,tune]) (2.10)\n",
"Requirement already satisfied: pydantic!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0,>=1.6.2 in /usr/local/lib/python3.7/dist-packages (from fastapi->ray[data,serve,tune]) (1.9.1)\n",
"Requirement already satisfied: anyio<5,>=3.4.0 in /usr/local/lib/python3.7/dist-packages (from starlette->ray[data,serve,tune]) (3.6.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.7/dist-packages (from anyio<5,>=3.4.0->starlette->ray[data,serve,tune]) (1.2.0)\n",
"Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[data,serve,tune]) (5.7.1)\n",
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[data,serve,tune]) (0.18.1)\n",
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray[data,serve,tune]) (4.11.3)\n",
"Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema->ray[data,serve,tune]) (3.8.0)\n",
"Requirement already satisfied: google-api-core<3.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from opencensus->ray[data,serve,tune]) (1.31.5)\n",
"Requirement already satisfied: opencensus-context>=0.1.2 in /usr/local/lib/python3.7/dist-packages (from opencensus->ray[data,serve,tune]) (0.1.2)\n",
"Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (2022.1)\n",
"Requirement already satisfied: google-auth<2.0dev,>=1.25.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (1.35.0)\n",
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (1.56.1)\n",
"Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (57.4.0)\n",
"Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (21.3)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.25.0->google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (0.2.8)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.25.0->google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (4.8)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2.0dev,>=1.25.0->google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (4.2.4)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (3.0.9)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2.0dev,>=1.25.0->google-api-core<3.0.0,>=1.0.0->opencensus->ray[data,serve,tune]) (0.4.8)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray[data,serve,tune]) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray[data,serve,tune]) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray[data,serve,tune]) (2022.5.18.1)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->ray[data,serve,tune]) (2.8.2)\n",
"Requirement already satisfied: distlib<1,>=0.3.1 in /usr/local/lib/python3.7/dist-packages (from virtualenv->ray[data,serve,tune]) (0.3.4)\n",
"Requirement already satisfied: platformdirs<3,>=2 in /usr/local/lib/python3.7/dist-packages (from virtualenv->ray[data,serve,tune]) (2.5.2)\n",
"Found existing installation: ray 2.0.0.dev0\n",
"Uninstalling ray-2.0.0.dev0:\n",
" Successfully uninstalled ray-2.0.0.dev0\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting ray==3.0.0.dev0\n",
" Downloading https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl (54.9 MB)\n",
"\u001b[K |████████████████████████████████| 54.9 MB 74.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: msgpack<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (1.0.3)\n",
"Requirement already satisfied: virtualenv in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (20.14.1)\n",
"Requirement already satisfied: protobuf>=3.15.3 in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (3.17.3)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (3.13)\n",
"Requirement already satisfied: click<=8.0.4,>=7.0 in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (7.1.2)\n",
"Requirement already satisfied: attrs in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (21.4.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (2.23.0)\n",
"Requirement already satisfied: frozenlist in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (1.3.0)\n",
"Requirement already satisfied: aiosignal in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (1.2.0)\n",
"Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (1.21.6)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (4.3.3)\n",
"Requirement already satisfied: grpcio<=1.43.0,>=1.28.1 in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (1.43.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from ray==3.0.0.dev0) (3.7.0)\n",
"Requirement already satisfied: six>=1.5.2 in /usr/local/lib/python3.7/dist-packages (from grpcio<=1.43.0,>=1.28.1->ray==3.0.0.dev0) (1.15.0)\n",
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==3.0.0.dev0) (0.18.1)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==3.0.0.dev0) (4.2.0)\n",
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==3.0.0.dev0) (4.11.3)\n",
"Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema->ray==3.0.0.dev0) (5.7.1)\n",
"Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from importlib-resources>=1.4.0->jsonschema->ray==3.0.0.dev0) (3.8.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->ray==3.0.0.dev0) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->ray==3.0.0.dev0) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->ray==3.0.0.dev0) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->ray==3.0.0.dev0) (2022.5.18.1)\n",
"Requirement already satisfied: platformdirs<3,>=2 in /usr/local/lib/python3.7/dist-packages (from virtualenv->ray==3.0.0.dev0) (2.5.2)\n",
"Requirement already satisfied: distlib<1,>=0.3.1 in /usr/local/lib/python3.7/dist-packages (from virtualenv->ray==3.0.0.dev0) (0.3.4)\n",
"Installing collected packages: ray\n",
"Successfully installed ray-3.0.0.dev0\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.11.0+cu113)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch) (4.2.0)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.12.0+cu113)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision) (2.23.0)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.1.2)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.21.6)\n",
"Requirement already satisfied: torch==1.11.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.11.0+cu113)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torchvision) (4.2.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision) (2022.5.18.1)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision) (1.24.3)\n"
]
}
],
"source": [
"# !pip install -q \"ray[air]\"\n",
"# !pip install -q torch\n",
"# !pip install -q torchvision"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RpD4STX3g1dq"
},
"source": [
"Then, let's initialize Ray! We can just import and call `ray.init()`. If you are running on a Ray cluster, then you can do `ray.init(\"auto\")` to connect to the cluster instead of initiailzing a new local Ray instance."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "72fEFqL4T7iA",
"outputId": "9cae25f2-c712-4baa-f66b-337049e1b565"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:25:31,150\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
]
},
{
"data": {
"text/plain": [
"RayContext(dashboard_url='127.0.0.1:8265', python_version='3.7.13', ray_version='3.0.0.dev0', ray_commit='ac620aeec0c0f68c92328ace0b2a5835f5b14b26', address_info={'node_ip_address': '172.28.0.2', 'raylet_ip_address': '172.28.0.2', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-05-25_22-25-28_641559_1518/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-05-25_22-25-28_641559_1518/sockets/raylet', 'webui_url': '127.0.0.1:8265', 'session_dir': '/tmp/ray/session_2022-05-25_22-25-28_641559_1518', 'metrics_export_port': 61030, 'gcs_address': '172.28.0.2:62940', 'address': '172.28.0.2:62940', 'node_id': '97455d0de12f3393126427ed2b1ef0a009f0bd3fb97177cb86b42d92'})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ray\n",
"ray.init()\n",
"# If runnning on a cluster, use the below line instead.\n",
"# ray.init(\"auto\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AedcxD_FClQL"
},
"source": [
"# Step 2: Define our PyTorch Model\n",
"\n",
"Now that we have the necessary installations, let's define our PyTorch model. For this example to classify MNIST images, we will use a simple multi-layer perceptron."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "3TVkSmFFCHhI"
},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class SimpleMLP(nn.Module):\n",
" def __init__(self, num_classes=10, input_size=28 * 28):\n",
" super(SimpleMLP, self).__init__()\n",
"\n",
" self.features = nn.Sequential(\n",
" nn.Linear(input_size, 512),\n",
" nn.ReLU(inplace=True),\n",
" nn.Dropout(),\n",
" )\n",
" self.classifier = nn.Linear(512, num_classes)\n",
" self._input_size = input_size\n",
"\n",
" def forward(self, x):\n",
" x = x.contiguous()\n",
" x = x.view(-1, self._input_size)\n",
" x = self.features(x)\n",
" x = self.classifier(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L2N1U22VC_N9"
},
"source": [
"# Step 3: Create the Stream of tasks\n",
"\n",
"We can now create a stream of tasks (where each task contains a dataset to train on). For this example, we will create an artificial stream of tasks consisting of\n",
"permuted variations of MNIST, which is a classic benchmark in continual learning\n",
"research.\n",
"\n",
"For real-world scenarios, this step is not necessary as fresh data will already be\n",
"arriving as a stream of tasks. It does not need to be artificially created."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3SVSrkqrDJuc"
},
"source": [
"## 3a: Load MNIST Dataset to a Ray Dataset\n",
"\n",
"Let's first define a simple function that will return the original MNIST Dataset as a distributed Ray Dataset. Ray Datasets are the standard way to load and exchange data in Ray libraries and applications, read more about them [here](https://docs.ray.io/en/latest/data/dataset.html)!\n",
"\n",
"The function in the below code snippet does the following:\n",
"1. Downloads the MNIST Dataset from torchvision in-memory\n",
"2. Loads the in-memory Torch Dataset into a Ray Dataset\n",
"3. Converts the Ray Dataset into a Pandas format. Instead of the Ray Dataset iterating over tuples, it will have 2 columns: \"image\" & \"label\". \n",
"<!-- TODO: Figure out when and how to use TensorArray extension -->\n",
"<!-- The image will be stored as a multi-dimensional tensor (via the [TensorArray format](https://docs.ray.io/en/latest/data/dataset-tensor-support.html) instead of a PIL image). -->\n",
"This will allow us to apply built-in preprocessors to the Ray Dataset and allow Ray Datasets to be used with Ray AIR Predictors.\n",
" <!-- and also means that any transformations done to the images can be done in a zero-copy fashion. -->\n",
"\n",
"For this example, since we are just working with MNIST dataset, which is small, we use the [`SimpleTorchDataSource`](https://docs.ray.io/en/master/data/package-ref.html?highlight=SimpleTorchDatasource#ray.data.datasource.SimpleTorchDatasource) which just loads the full MNIST dataset into memory.\n",
"\n",
"For loading larger datasets in a parallel fashion, you should use [Ray Dataset's additional read APIs](https://docs.ray.io/en/master/data/dataset.html#supported-input-formats) to load data from parquet, csv, image files, and more!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "0XKwJKrNCxg4"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"import torchvision\n",
"from torchvision.transforms import RandomCrop\n",
"\n",
"import ray\n",
"from ray.data.datasource.torch_datasource import SimpleTorchDatasource\n",
"\n",
"\n",
"def get_mnist_dataset(train: bool = True) -> ray.data.Dataset:\n",
" \"\"\"Returns MNIST Dataset as a ray.data.Dataset.\n",
" \n",
" Args:\n",
" train: Whether to return the train dataset or test dataset.\n",
" \"\"\"\n",
"\n",
" def mnist_dataset_factory():\n",
" if train:\n",
" # Only perform random cropping on the Train dataset.\n",
" transform = RandomCrop(28, padding=4)\n",
" else:\n",
" transform = None\n",
" return torchvision.datasets.MNIST(\"./data\", download=True, train=train, transform=transform)\n",
"\n",
" def convert_batch_to_pandas(batch):\n",
" images = [np.array(item[0]) for item in batch]\n",
" labels = [item[1] for item in batch]\n",
"\n",
" df = pd.DataFrame({\"image\": images, \"label\": labels})\n",
"\n",
" return df\n",
"\n",
" mnist_dataset = ray.data.read_datasource(\n",
" SimpleTorchDatasource(), dataset_factory=mnist_dataset_factory\n",
" )\n",
" mnist_dataset = mnist_dataset.map_batches(convert_batch_to_pandas)\n",
" return mnist_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vqrfgfl9YnVe"
},
"source": [
"## 3b: Create our Stream abstraction\n",
"\n",
"Now we can create our \"stream\" abstraction. This abstraction provides two\n",
"methods (`generate_train_stream` and `generate_test_stream`) that each returns an Iterator\n",
"over Ray Datasets. Each item in this iterator contains a unique permutation of\n",
"MNIST, and is one task that we want to train on.\n",
"\n",
"In this example, \"the stream of tasks\" is contrived since all the data for all tasks exist already in an offline setting. For true online continual learning, you would want to implement a custom dataset iterator that reads from some stream datasource to produce new tasks. The only abstraction that's needed is `Iterator[ray.data.Dataset]`.\n",
"\n",
"Note that the test dataset stream has the same permutations that are used for the training dataset stream. In general for continual learning, it is expected that the data distribution of the test/prediction data follows what the model was trained on. If you notice that the distribution of new prediction queries is changing compared to the distribution of the training data, then you should probably trigger training of a new task."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "f2EagMWCN3he"
},
"outputs": [],
"source": [
"from typing import Iterator, List\n",
"import random\n",
"import numpy as np\n",
"\n",
"from ray.data import ActorPoolStrategy\n",
"\n",
"\n",
"class PermutedMNISTStream:\n",
" \"\"\"Generates streams of permuted MNIST Datasets.\n",
" \n",
" Example:\n",
" \n",
" permuted_mnist = PermutedMNISTStream(n_tasks=3)\n",
" train_stream = permuted_mnist.generate_train_stream()\n",
" \n",
" # Iterate through the train_stream\n",
" for train_dataset in train_stream:\n",
" ...\n",
" \n",
" Args:\n",
" n_tasks: The number of tasks to generate.\n",
" \"\"\"\n",
"\n",
" def __init__(self, n_tasks: int = 3):\n",
" self.n_tasks = n_tasks\n",
" self.permutations = [\n",
" np.random.permutation(28 * 28) for _ in range(self.n_tasks)\n",
" ]\n",
"\n",
" self.train_mnist_dataset = get_mnist_dataset(train=True)\n",
" self.test_mnist_dataset = get_mnist_dataset(train=False)\n",
"\n",
" def random_permute_dataset(\n",
" self, dataset: ray.data.Dataset, permutation: np.ndarray\n",
" ):\n",
" \"\"\"Randomly permutes the pixels for each image in the dataset.\"\"\"\n",
"\n",
" class PixelsPermutation(object):\n",
" def __call__(self, batch):\n",
" batch[\"image\"] = batch[\"image\"].map(lambda image: image.reshape(-1)[permutation].reshape(28, 28))\n",
" return batch\n",
"\n",
" return dataset.map_batches(PixelsPermutation, compute=ActorPoolStrategy(), batch_format=\"pandas\")\n",
"\n",
" def generate_train_stream(self) -> Iterator[ray.data.Dataset]:\n",
" for permutation in self.permutations:\n",
" permuted_mnist_dataset = self.random_permute_dataset(\n",
" self.train_mnist_dataset, permutation\n",
" )\n",
" yield permuted_mnist_dataset\n",
"\n",
" def generate_test_stream(self) -> Iterator[ray.data.Dataset]:\n",
" for permutation in self.permutations:\n",
" mnist_dataset = get_mnist_dataset(train=False)\n",
" permuted_mnist_dataset = self.random_permute_dataset(\n",
" self.test_mnist_dataset, permutation\n",
" )\n",
" yield permuted_mnist_dataset\n",
"\n",
" def generate_test_samples(self, num_samples: int = 10) -> List[np.ndarray]:\n",
" \"\"\"Generates num_samples permuted MNIST images.\"\"\"\n",
" random_permutation = random.choice(self.permutations)\n",
" return self.random_permute_dataset(self.test_mnist_dataset.random_shuffle().limit(num_samples), random_permutation).to_pandas()[\"image\"].to_list()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HDGHgtb699kd"
},
"source": [
"# Step 4: Define the logic for Training and Inference/Prediction\n",
"\n",
"Now that we can get an Iterator over Ray Datasets, we can incrementally train our model in a data parallel fashion via Ray Train, while incrementally deploying our model via Ray Serve. Let's define some helper functions to allow us to do this!\n",
"\n",
"If you are not familiar with data parallel training, it is a form of distributed training strategies, where we have multiple model replicas, and each replica trains on a different batch of data. After each batch, the gradients are synchronized across the replicas. This effecitively allows us to train on more data in a shorter amount of time."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SBWxP1sP-G-o"
},
"source": [
"## 4a: Define our training logic for each Data Parallel worker\n",
"\n",
"The first thing we need to do is to define the training loop that will be run on each training worker. \n",
"\n",
"The training loop takes in a `config` Dict as an argument that we can use to pass in any configurations for training.\n",
"\n",
"This is just standard PyTorch training, with the difference being that we can leverage [Ray Train's utility functions](https://docs.ray.io/en/master/train/api.html#training-function-utilities) and [Ray AIR Sesssion](https://docs.ray.io/en/master/ray-air/package-ref.html#module-ray.air.session):\n",
"- `ray.train.torch.prepare_model(...)`: This will prepare the model for distributed training by wrapping it in PyTorch `DistributedDataParallel` and moving it to the correct accelerator device.\n",
"- `ray.air.session.get_dataset_shard(...)`: This will get the Ray Dataset shard for this particular Data Parallel worker.\n",
"- `ray.air.session.report({}, checkpoint=...)`: This will tell Ray Train to persist the provided `Checkpoint` object.\n",
"- `ray.air.session.get_checkpoint()`: Returns a checkpoint to resume from. This is useful for either fault tolerance purposes, or for our purposes, to continue training the same model on a new incoming dataset."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Y9IRDMec-GZ9"
},
"outputs": [],
"source": [
"from ray import train\n",
"from ray.air import session, Checkpoint\n",
"\n",
"from torch.optim import SGD\n",
"from torch.nn import CrossEntropyLoss\n",
"\n",
"from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present\n",
"\n",
"def train_loop_per_worker(config: dict):\n",
" num_epochs = config[\"num_epochs\"]\n",
" learning_rate = config[\"learning_rate\"]\n",
" momentum = config[\"momentum\"]\n",
" batch_size = config[\"batch_size\"]\n",
"\n",
" model = SimpleMLP(num_classes=10)\n",
"\n",
" # Load model from checkpoint if there is a checkpoint to load from.\n",
" checkpoint_to_load = session.get_checkpoint()\n",
" if checkpoint_to_load:\n",
" state_dict_to_resume_from = checkpoint_to_load.to_dict()[\"model\"]\n",
" model.load_state_dict(state_dict=state_dict_to_resume_from)\n",
"\n",
" model = train.torch.prepare_model(model)\n",
"\n",
" optimizer = SGD(model.parameters(), lr=learning_rate, momentum=momentum)\n",
" criterion = CrossEntropyLoss()\n",
"\n",
" # Get the Ray Dataset shard for this data parallel worker, and convert it to a PyTorch Dataset.\n",
" dataset_shard = session.get_dataset_shard(\"train\").to_torch(\n",
" label_column=\"label\",\n",
" batch_size=batch_size,\n",
" unsqueeze_feature_tensors=False,\n",
" unsqueeze_label_tensor=False,\n",
" )\n",
"\n",
" for epoch_idx in range(num_epochs):\n",
" running_loss = 0\n",
" for iteration, (train_mb_x, train_mb_y) in enumerate(dataset_shard):\n",
" optimizer.zero_grad()\n",
" train_mb_x = train_mb_x.to(train.torch.get_device())\n",
" train_mb_y = train_mb_y.to(train.torch.get_device())\n",
"\n",
" # Forward\n",
" logits = model(train_mb_x)\n",
" # Loss\n",
" loss = criterion(logits, train_mb_y)\n",
" # Backward\n",
" loss.backward()\n",
" # Update\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
" if session.get_world_rank() == 0 and iteration % 500 == 0:\n",
" print(f\"loss: {loss.item():>7f}, epoch: {epoch_idx}, iteration: {iteration}\")\n",
"\n",
" # Checkpoint model after every epoch.\n",
" state_dict = model.state_dict()\n",
" consume_prefix_in_state_dict_if_present(state_dict, \"module.\")\n",
" checkpoint = Checkpoint.from_dict(dict(model=state_dict))\n",
" session.report({\"loss\": running_loss}, checkpoint=checkpoint)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9HUciluylZbX"
},
"source": [
"## 4b: Define our Preprocessor\n",
"\n",
"Next, we define our `Preprocessor` to preprocess our data before training and prediction. Our preprocessor will normalize the MNIST Images by the mean and standard deviation of the MNIST training dataset. This is a common operation to do on MNIST to improve training: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "yHzQZTlAlY-9"
},
"outputs": [],
"source": [
"from ray.data.preprocessors import BatchMapper\n",
"\n",
"from torchvision import transforms\n",
"\n",
"def preprocess_images(df: pd.DataFrame) -> pd.DataFrame:\n",
" \"\"\"Preprocess images by scaling each channel in the image.\"\"\"\n",
"\n",
" torchvision_transforms = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
" )\n",
"\n",
" df[\"image\"] = df[\"image\"].map(torchvision_transforms)\n",
" return df\n",
"\n",
"mnist_normalize_preprocessor = BatchMapper(fn=preprocess_images)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uto3v90Hagni"
},
"source": [
"## 4c: Define logic for Batch/Offline Prediction.\n",
"\n",
"After training on each task, we want to use our trained model to do batch (i.e. offline) inference on a test dataset. \n",
"\n",
"To do this, we leverage the built-in `ray.air.BatchPredictor`. We define a `batch_predict` function that will take in a Checkpoint and a Test Dataset and outputs the accuracy our model achieves on the test dataset."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "DM2lFHzFa6uI"
},
"outputs": [],
"source": [
"from ray.train.batch_predictor import BatchPredictor\n",
"from ray.train.torch import TorchPredictor\n",
"\n",
"def batch_predict(checkpoint: ray.air.Checkpoint, test_dataset: ray.data.Dataset) -> float:\n",
" \"\"\"Perform batch prediction on the provided test dataset, and return accuracy results.\"\"\"\n",
"\n",
" batch_predictor = BatchPredictor.from_checkpoint(checkpoint, predictor_cls=TorchPredictor, model=SimpleMLP(num_classes=10))\n",
" model_output = batch_predictor.predict(\n",
" data=test_dataset, feature_columns=[\"image\"], keep_columns=[\"label\"]\n",
" )\n",
" \n",
" # Postprocess model outputs.\n",
" # Convert logits outputted from model into actual class predictions.\n",
" def convert_logits_to_classes(df):\n",
" best_class = df[\"predictions\"].map(lambda x: np.array(x).argmax())\n",
" df[\"predictions\"] = best_class\n",
" return df\n",
" \n",
" prediction_results = model_output.map_batches(convert_logits_to_classes, batch_format=\"pandas\")\n",
" \n",
" # Then, for each prediction output, see if it matches with the ground truth\n",
" # label.\n",
" def calculate_prediction_scores(df):\n",
" return pd.DataFrame({\"correct\": df[\"predictions\"] == df[\"label\"]})\n",
"\n",
" correct_dataset = prediction_results.map_batches(\n",
" calculate_prediction_scores, batch_format=\"pandas\"\n",
" )\n",
"\n",
" return correct_dataset.sum(on=\"correct\") / correct_dataset.count()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GWiTtsmVbIZP"
},
"source": [
"## 4d: Define logic for Deploying and Querying our model\n",
"\n",
"In addition to batch inference, we also want to deploy our model so that we can submit live queries to it for online inference. We use Ray Serve's `PredictorDeployment` utility to deploy our trained model. \n",
"\n",
"Once we deploy the model, we can send HTTP requests to our deployment."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "ZC3JCWz7bhR-"
},
"outputs": [],
"source": [
"from typing import List\n",
"import requests\n",
"from requests import Response\n",
"import numpy as np\n",
"\n",
"from ray.serve.http_adapters import NdArray\n",
"\n",
"\n",
"def deploy_model(checkpoint: ray.air.Checkpoint) -> str:\n",
" \"\"\"Deploys the model from the provided Checkpoint and returns the URL for the endpoint of the model deployment.\"\"\"\n",
" def json_to_pandas(payload: NdArray) -> pd.DataFrame:\n",
" \"\"\"Accepts an NdArray JSON from an HTTP body and converts it to a Pandas dataframe.\"\"\"\n",
" # Have to explicitly convert to float since np.array reads as a double.\n",
" arr = np.array(payload.array, dtype=np.float32)\n",
" # We have to specify an image column since our preprocessor requires it.\n",
" df = pd.DataFrame({\"image\": [arr]})\n",
" return df\n",
"\n",
" deployment = PredictorDeployment.options(name=\"mnist_model\", route_prefix=\"/mnist_predict\", version=f\"v{task_idx}\", num_replicas=2)\n",
" deployment.deploy(batching_params=False, http_adapter=json_to_pandas, predictor_cls=TorchPredictor, checkpoint=latest_checkpoint, model=SimpleMLP(num_classes=10))\n",
" return deployment.url\n",
"\n",
"# Function that queries our deployed model\n",
"def query_deployment(test_samples: List[np.ndarray], endpoint_uri: str) -> List[Response]:\n",
" \"\"\"Given a set of test samples, queries the model deployment at the provided endpoint and returns the results.\"\"\"\n",
" results = []\n",
" # Have to convert to Python List since Numpy arrays are not Json serializable.\n",
" for sample in test_samples:\n",
" results.append(requests.post(endpoint_uri, json={\"array\": sample.tolist()}))\n",
" # TODO: Figure out how Serve deals with Pandas DataFrame returned by Predictors.\n",
" return results"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-NQDj0rFVUX3"
},
"source": [
"# Step 5: Putting it all together\n",
"\n",
"Once we have defined our training logic and our preprocessor, we can put everything together!\n",
"\n",
"For each dataset in our stream, we do the following:\n",
"1. Train on the dataset in Data Parallel fashion. We create a `TorchTrainer`, specify the config for the training loop we defined above, the dataset to train on, and how much we want to scale. `TorchTrainer` also accepts a `checkpoint` arg to continue training from a previously saved checkpoint.\n",
"2. Get the saved checkpoint from the training run.\n",
"3. Test our trained model on a test set containing test data from all the tasks trained on so far.\n",
"3. After training on each task, we deploy our model so we can query it for predictions.\n",
"\n",
"In this example, the training and test data for each task is well-defined beforehand by the benchmark. For real-world scenarios, this probably will not be the case. It is very likely that the prediction requests after training on one task will become the training data for the next task. \n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "I_OrfQTqNYRk",
"outputId": "a89da8b8-1acf-4796-cc88-9ee889a32123"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_prepare_read pid=1772)\u001b[0m 2022-05-25 22:25:35,236\tWARNING torch_datasource.py:56 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n",
"Read->Map_Batches: 100%|██████████| 1/1 [00:05<00:00, 5.92s/it]\n",
"\u001b[2m\u001b[36m(_prepare_read pid=1772)\u001b[0m 2022-05-25 22:25:53,593\tWARNING torch_datasource.py:56 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n",
"Read->Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 2.51it/s]\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00, 2.72s/it]\n",
"\u001b[2m\u001b[36m(_prepare_read pid=1978)\u001b[0m 2022-05-25 22:25:58,761\tWARNING torch_datasource.py:56 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n",
"Read->Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 2.41it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.37s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training for task: 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-05-25 22:27:16 (running for 00:01:14.46)<br>Memory usage on this node: 4.7/12.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/7.31 GiB heap, 0.0/3.66 GiB objects (0.0/1.0 accelerator_type:T4)<br>Result logdir: /root/ray_results/TorchTrainer_2022-05-25_22-26-01<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>TorchTrainer_a8585_00000</td><td>TERMINATED</td><td>172.28.0.2:2126</td></tr>\n",
"</tbody>\n",
"</table><br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_map_block_nosplit pid=2159)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(_map_block_nosplit pid=2159)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m 2022-05-25 22:26:19,944\tINFO torch.py:347 -- Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m 2022-05-25 22:26:20,033\tINFO torch.py:98 -- Moving model to device: cuda:0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 2.315190, epoch: 0, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 1.464406, epoch: 0, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 1.279081, epoch: 0, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 1.052461, epoch: 0, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.816213, epoch: 1, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 1.019127, epoch: 1, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.525613, epoch: 1, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.570595, epoch: 1, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.572004, epoch: 2, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.543432, epoch: 2, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.350156, epoch: 2, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.443743, epoch: 2, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.438318, epoch: 3, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.342512, epoch: 3, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.302048, epoch: 3, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2197)\u001b[0m loss: 0.414025, epoch: 3, iteration: 1500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:27:16,013\tERROR checkpoint_manager.py:193 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial TorchTrainer_a8585_00000 completed. Last result: \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:27:16,138\tINFO tune.py:753 -- Total run time: 74.68 seconds (74.45 seconds for the tuning loop).\n",
"Map Progress (1 actors 1 pending): 0%| | 0/1 [00:01<?, ?it/s]\u001b[2m\u001b[36m(BlockWorker pid=2267)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(BlockWorker pid=2267)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:04<00:00, 4.18s/it]\n",
"Map_Batches: 100%|██████████| 1/1 [00:01<00:00, 1.63s/it]\n",
"Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 13.60it/s]\n",
"Shuffle Map: 100%|██████████| 1/1 [00:00<00:00, 24.76it/s]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 49.17it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for task 1: 0.946\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(ServeController pid=2382)\u001b[0m INFO 2022-05-25 22:27:23,467 controller 2382 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.\n",
"\u001b[2m\u001b[36m(ServeController pid=2382)\u001b[0m INFO 2022-05-25 22:27:23,470 controller 2382 http_state.py:115 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:XnXlnS:SERVE_PROXY_ACTOR-node:172.28.0.2-0' on node 'node:172.28.0.2-0' listening on '127.0.0.1:8000'\n",
"Shuffle Map: 0%| | 0/1 [00:00<?, ?it/s]\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO: Started server process [2415]\n",
"Shuffle Map: 100%|██████████| 1/1 [00:01<00:00, 1.40s/it]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 7.72it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.18s/it]\n",
"\u001b[2m\u001b[36m(ServeController pid=2382)\u001b[0m INFO 2022-05-25 22:27:28,825 controller 2382 deployment_state.py:1219 - Adding 2 replicas to deployment 'mnist_model'.\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:32,954 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 4.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:32,977 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 21.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:32,985 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 4.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:32,976 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 15.5ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:32,992 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 5.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:32,952 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:32,982 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:32,997 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 11.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,008 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 6.1ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,017 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.1ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,022 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,031 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,036 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,044 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,048 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,057 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.9ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,061 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,070 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,074 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,082 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,088 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,016 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,029 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,043 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,056 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,068 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,081 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,007 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,021 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,035 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,047 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,060 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,073 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,086 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,122 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 25.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,134 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.0ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=2415)\u001b[0m INFO 2022-05-25 22:27:33,142 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,117 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 14.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2573)\u001b[0m INFO 2022-05-25 22:27:33,141 mnist_model mnist_model#vDEhSp replica.py:483 - HANDLE __call__ OK 4.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=2575)\u001b[0m INFO 2022-05-25 22:27:33,133 mnist_model mnist_model#QdDxIB replica.py:483 - HANDLE __call__ OK 0.4ms\n",
"\u001b[2m\u001b[36m(ServeController pid=2382)\u001b[0m INFO 2022-05-25 22:27:33,225 controller 2382 deployment_state.py:1243 - Removing 2 replicas from deployment 'mnist_model'.\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00, 2.58s/it]\n",
"\u001b[2m\u001b[36m(_prepare_read pid=2726)\u001b[0m 2022-05-25 22:27:40,353\tWARNING torch_datasource.py:56 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n",
"Read->Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 2.20it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.41s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training for task: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-05-25 22:28:52 (running for 00:01:09.00)<br>Memory usage on this node: 5.0/12.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/7.31 GiB heap, 0.0/3.66 GiB objects (0.0/1.0 accelerator_type:T4)<br>Result logdir: /root/ray_results/TorchTrainer_2022-05-25_22-27-43<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>TorchTrainer_e4f66_00000</td><td>TERMINATED</td><td>172.28.0.2:2875</td></tr>\n",
"</tbody>\n",
"</table><br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_map_block_nosplit pid=2909)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(_map_block_nosplit pid=2909)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m 2022-05-25 22:28:01,917\tINFO torch.py:347 -- Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m 2022-05-25 22:28:02,063\tINFO torch.py:98 -- Moving model to device: cuda:0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 3.347775, epoch: 0, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 1.343975, epoch: 0, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.768560, epoch: 0, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.607410, epoch: 0, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.578952, epoch: 1, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.473788, epoch: 1, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.609530, epoch: 1, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.741895, epoch: 1, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.417272, epoch: 2, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.510404, epoch: 2, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.422137, epoch: 2, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.403623, epoch: 2, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.384720, epoch: 3, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.414567, epoch: 3, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.274302, epoch: 3, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=2948)\u001b[0m loss: 0.348169, epoch: 3, iteration: 1500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:28:52,221\tERROR checkpoint_manager.py:193 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial TorchTrainer_e4f66_00000 completed. Last result: \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:28:52,344\tINFO tune.py:753 -- Total run time: 69.20 seconds (68.99 seconds for the tuning loop).\n",
"Map Progress (1 actors 1 pending): 0%| | 0/2 [00:01<?, ?it/s]\u001b[2m\u001b[36m(BlockWorker pid=3027)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(BlockWorker pid=3027)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 2/2 [00:05<00:00, 2.64s/it]\n",
"Map_Batches: 100%|██████████| 2/2 [00:01<00:00, 1.07it/s]\n",
"Map_Batches: 100%|██████████| 2/2 [00:01<00:00, 1.55it/s]\n",
"Shuffle Map: 100%|██████████| 2/2 [00:00<00:00, 3.78it/s]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 72.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for task 2: 0.9261\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(ServeController pid=3209)\u001b[0m INFO 2022-05-25 22:29:02,797 controller 3209 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.\n",
"\u001b[2m\u001b[36m(ServeController pid=3209)\u001b[0m INFO 2022-05-25 22:29:02,802 controller 3209 http_state.py:115 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:lsPTvu:SERVE_PROXY_ACTOR-node:172.28.0.2-0' on node 'node:172.28.0.2-0' listening on '127.0.0.1:8000'\n",
"Shuffle Map: 0%| | 0/1 [00:00<?, ?it/s]\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO: Started server process [3241]\n",
"Shuffle Map: 100%|██████████| 1/1 [00:01<00:00, 1.54s/it]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 8.17it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.15s/it]\n",
"\u001b[2m\u001b[36m(ServeController pid=3209)\u001b[0m INFO 2022-05-25 22:29:08,327 controller 3209 deployment_state.py:1219 - Adding 2 replicas to deployment 'mnist_model'.\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,440 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 5.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,438 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,460 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 15.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,466 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 24.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,471 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,481 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.6ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,487 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.9ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,496 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.9ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,501 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,509 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,514 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.6ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,523 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,528 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,537 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,542 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.5ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,550 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 6.7ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,556 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.7ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,564 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.0ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,480 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 5.1ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,495 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.5ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,508 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,522 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,536 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.7ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,549 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,563 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.7ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,470 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,485 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,500 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,513 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,527 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,540 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,554 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,586 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 4.6ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,596 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 9.3ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,601 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.7ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=3241)\u001b[0m INFO 2022-05-25 22:29:12,610 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.0ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,594 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 6.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3401)\u001b[0m INFO 2022-05-25 22:29:12,609 mnist_model mnist_model#uumYOV replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,583 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=3402)\u001b[0m INFO 2022-05-25 22:29:12,600 mnist_model mnist_model#Egafuf replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(ServeController pid=3209)\u001b[0m INFO 2022-05-25 22:29:12,699 controller 3209 deployment_state.py:1243 - Removing 2 replicas from deployment 'mnist_model'.\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00, 2.56s/it]\n",
"\u001b[2m\u001b[36m(_prepare_read pid=3556)\u001b[0m 2022-05-25 22:29:19,825\tWARNING torch_datasource.py:56 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n",
"Read->Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 2.44it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.41s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting training for task: 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-05-25 22:30:31 (running for 00:01:09.12)<br>Memory usage on this node: 5.0/12.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/7.31 GiB heap, 0.0/3.66 GiB objects (0.0/1.0 accelerator_type:T4)<br>Result logdir: /root/ray_results/TorchTrainer_2022-05-25_22-29-22<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>TorchTrainer_2040e_00000</td><td>TERMINATED</td><td>172.28.0.2:3703</td></tr>\n",
"</tbody>\n",
"</table><br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_map_block_nosplit pid=3738)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(_map_block_nosplit pid=3738)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m 2022-05-25 22:29:41,392\tINFO torch.py:347 -- Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m 2022-05-25 22:29:41,549\tINFO torch.py:98 -- Moving model to device: cuda:0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 4.353125, epoch: 0, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 1.147782, epoch: 0, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.609233, epoch: 0, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.606812, epoch: 0, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.494777, epoch: 1, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.776362, epoch: 1, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.376833, epoch: 1, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.478181, epoch: 1, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.413856, epoch: 2, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.668218, epoch: 2, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.318078, epoch: 2, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.427121, epoch: 2, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.369263, epoch: 3, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.479945, epoch: 3, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.457482, epoch: 3, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=3778)\u001b[0m loss: 0.318416, epoch: 3, iteration: 1500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:30:31,831\tERROR checkpoint_manager.py:193 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial TorchTrainer_2040e_00000 completed. Last result: \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-25 22:30:31,953\tINFO tune.py:753 -- Total run time: 69.33 seconds (69.12 seconds for the tuning loop).\n",
"Map Progress (1 actors 1 pending): 0%| | 0/3 [00:01<?, ?it/s]\u001b[2m\u001b[36m(BlockWorker pid=3857)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(BlockWorker pid=3857)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Map Progress (2 actors 1 pending): 33%|███▎ | 1/3 [00:04<00:08, 4.24s/it]\u001b[2m\u001b[36m(BlockWorker pid=3886)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(BlockWorker pid=3886)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 3/3 [00:06<00:00, 2.16s/it]\n",
"Map_Batches: 100%|██████████| 3/3 [00:01<00:00, 1.53it/s]\n",
"Map_Batches: 100%|██████████| 3/3 [00:00<00:00, 19.25it/s]\n",
"Shuffle Map: 100%|██████████| 3/3 [00:00<00:00, 97.56it/s]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 64.24it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for task 3: 0.9001333333333333\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(ServeController pid=4011)\u001b[0m INFO 2022-05-25 22:30:43,081 controller 4011 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.\n",
"\u001b[2m\u001b[36m(ServeController pid=4011)\u001b[0m INFO 2022-05-25 22:30:43,084 controller 4011 http_state.py:115 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:viEsyL:SERVE_PROXY_ACTOR-node:172.28.0.2-0' on node 'node:172.28.0.2-0' listening on '127.0.0.1:8000'\n",
"Shuffle Map: 0%| | 0/1 [00:00<?, ?it/s]\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO: Started server process [4043]\n",
"Shuffle Map: 100%|██████████| 1/1 [00:01<00:00, 1.61s/it]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 7.16it/s]\n",
"Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:01<00:00, 1.36s/it]\n",
"\u001b[2m\u001b[36m(ServeController pid=4011)\u001b[0m INFO 2022-05-25 22:30:48,663 controller 4011 deployment_state.py:1219 - Adding 2 replicas to deployment 'mnist_model'.\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,754 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 5.0ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,771 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 15.8ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,777 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.1ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,788 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 9.0ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,794 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.5ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,803 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.0ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,808 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.5ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,817 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,822 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,770 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 11.5ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,787 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 6.1ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,802 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 4.8ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,815 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 4.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,752 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,776 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,793 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,807 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,821 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,848 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 24.9ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,853 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.6ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,869 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 13.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,847 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 8.4ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,867 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 6.6ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,852 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,984 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.5ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:52,995 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 9.0ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,001 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 3.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,011 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 8.1ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,016 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.7ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,025 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 7.4ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,030 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 307 2.5ms\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=4043)\u001b[0m INFO 2022-05-25 22:30:53,045 http_proxy 172.28.0.2 http_proxy.py:320 - POST /mnist_predict 200 11.9ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:52,993 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 5.9ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:53,010 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 5.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:53,024 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 4.9ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4199)\u001b[0m INFO 2022-05-25 22:30:53,043 mnist_model mnist_model#kzOVuE replica.py:483 - HANDLE __call__ OK 4.9ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,982 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:52,999 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:53,015 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.2ms\n",
"\u001b[2m\u001b[36m(mnist_model pid=4200)\u001b[0m INFO 2022-05-25 22:30:53,029 mnist_model mnist_model#QFllkk replica.py:483 - HANDLE __call__ OK 0.3ms\n",
"\u001b[2m\u001b[36m(ServeController pid=4011)\u001b[0m INFO 2022-05-25 22:30:53,125 controller 4011 deployment_state.py:1243 - Removing 2 replicas from deployment 'mnist_model'.\n"
]
}
],
"source": [
"from ray.train.torch import TorchTrainer\n",
"from ray.train.torch import TorchPredictor\n",
"from ray import serve\n",
"from ray.serve import PredictorDeployment\n",
"from ray.serve.http_adapters import json_to_ndarray\n",
"\n",
"# The number of tasks (i.e. datasets in our stream) that we want to use for this example.\n",
"n_tasks = 3\n",
"\n",
"# Number of epochs to train each task for.\n",
"num_epochs = 4\n",
"# Batch size.\n",
"batch_size = 32\n",
"# Optimizer args.\n",
"learning_rate = 0.001\n",
"momentum = 0.9\n",
"\n",
"# Number of data parallel workers to use for training.\n",
"num_workers = 1\n",
"# Whether to use GPU or not.\n",
"use_gpu = True\n",
"\n",
"permuted_mnist = PermutedMNISTStream(n_tasks=n_tasks)\n",
"train_stream = permuted_mnist.generate_train_stream()\n",
"test_stream = permuted_mnist.generate_test_stream()\n",
"\n",
"latest_checkpoint = None\n",
"\n",
"accuracy_for_all_tasks = []\n",
"task_idx = 0\n",
"all_test_datasets_seen_so_far = []\n",
"for train_dataset, test_dataset in zip(train_stream, test_stream):\n",
" print(f\"Starting training for task: {task_idx}\")\n",
" task_idx += 1\n",
"\n",
" # *********Training*****************\n",
"\n",
" trainer = TorchTrainer(\n",
" train_loop_per_worker=train_loop_per_worker,\n",
" train_loop_config={\n",
" \"num_epochs\": num_epochs,\n",
" \"learning_rate\": learning_rate,\n",
" \"momentum\": momentum,\n",
" \"batch_size\": batch_size,\n",
" },\n",
" # Have to specify trainer_resources as 0 so that the example works on Colab. \n",
" scaling_config={\"num_workers\": num_workers, \"use_gpu\": use_gpu, \"trainer_resources\": {\"CPU\": 0}},\n",
" datasets={\"train\": train_dataset},\n",
" preprocessor=BatchMapper(fn=preprocess_images),\n",
" resume_from_checkpoint=latest_checkpoint,\n",
" )\n",
" result = trainer.fit()\n",
" latest_checkpoint = result.checkpoint\n",
"\n",
" # **************Batch Prediction**************************\n",
"\n",
" # We can do batch prediction on the test data for the tasks seen so far.\n",
" # TODO: Fix type signature in Ray Datasets\n",
" # TODO: Fix dataset.union when used with empty list.\n",
" if len(all_test_datasets_seen_so_far) > 0:\n",
" full_test_dataset = test_dataset.union(*all_test_datasets_seen_so_far)\n",
" else:\n",
" full_test_dataset = test_dataset\n",
"\n",
" all_test_datasets_seen_so_far.append(test_dataset)\n",
"\n",
" accuracy_for_this_task = batch_predict(latest_checkpoint, full_test_dataset)\n",
" print(f\"Accuracy for task {task_idx}: {accuracy_for_this_task}\")\n",
" accuracy_for_all_tasks.append(accuracy_for_this_task)\n",
"\n",
" # *************Model Deployment & Online Inference***************************\n",
" \n",
" # We can also deploy our model to do online inference with Ray Serve.\n",
" # Start Ray Serve.\n",
" serve.start()\n",
" test_samples = permuted_mnist.generate_test_samples()\n",
" endpoint_uri = deploy_model(latest_checkpoint)\n",
" online_inference_results = query_deployment(test_samples, endpoint_uri)\n",
"\n",
" if ray.available_resources().get(\"CPU\", 0) < num_workers+1:\n",
" # If there are no more CPUs left, then shutdown the Serve replicas so we can continue training on the next task.\n",
" serve.shutdown()\n",
"\n",
" \n",
"serve.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ORWpRkPjcPbD"
},
"source": [
"Now that we have finished all of our training, let's see the accuracy of our model after training on each task. \n",
"\n",
"We should see the accuracy decrease over time. This is to be expected since we are using just a naive fine-tuning strategy so our model is prone to catastrophic forgetting.\n",
"\n",
"As we increase the number of tasks, the model performance on all the tasks trained on so far should decrease."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "thpeB0KGmr99",
"outputId": "59fdbb6d-eaf4-4c2a-d350-5ff6b48e96a3"
},
"outputs": [
{
"data": {
"text/plain": [
"[0.946, 0.9261, 0.9001333333333333]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_for_all_tasks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xLLAvsTk8LoV"
},
"source": [
"# [Optional] Step 6: Compare against full training.\n",
"\n",
"We have now incrementally trained our simple multi-layer perceptron. Let's compare the incrementally trained model via fine tuning against a model that is trained on all the tasks up front.\n",
"\n",
"Since we are using a naive fine-tuning strategy, we should expect that our incrementally trained model will perform worse than the the one that is fully trained! However, there's various other strategies that have been developed and are actively being researched to improve accuracy for incremental training. And overall, incremental/continual learning allows you to train in many real world settings where the entire dataset is not available up front, but new data is arriving at a relatively high rate."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RNHsEVBHc0p2"
},
"source": [
"Let's first combine all of our datasets for each task into a single, unified Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pU2fVH068lfF",
"outputId": "fd6a3b56-dda1-4fa6-cebd-d0ee8784e698"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:02<00:00, 2.93s/it]\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:03<00:00, 3.11s/it]\n",
"Map Progress: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[1m\u001b[36m(scheduler +8m58s)\u001b[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.\n",
"\u001b[2m\u001b[1m\u001b[33m(scheduler +8m58s)\u001b[0m Warning: The following resource request cannot be scheduled right now: {'CPU': 1.0}. This is likely due to all cluster resources being claimed by actors. Consider creating fewer actors or adding more nodes to this Ray cluster.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map Progress (2 actors 1 pending): 100%|██████████| 1/1 [00:03<00:00, 3.06s/it]\n",
"Shuffle Map: 100%|██████████| 3/3 [00:04<00:00, 1.64s/it]\n",
"Shuffle Reduce: 100%|██████████| 3/3 [00:02<00:00, 1.07it/s]\n"
]
}
],
"source": [
"train_stream = permuted_mnist.generate_train_stream()\n",
"\n",
"# Collect all datasets in the stream into a single dataset.\n",
"all_training_datasets = []\n",
"for train_dataset in train_stream:\n",
" all_training_datasets.append(train_dataset)\n",
"combined_training_dataset = all_training_datasets[0].union(*all_training_datasets[1:])\n",
"\n",
"\n",
"combined_training_dataset = combined_training_dataset.random_shuffle()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tJ6Oqdgvc5dn"
},
"source": [
"Then, we train a new model on the unified Dataset using the same configurations as before."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "PmH9c0-z9KME",
"outputId": "653b4dfc-ed47-4307-fa84-e4c4ea3ec354"
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Current time: 2022-05-18 23:52:49 (running for 00:03:27.40)<br>Memory usage on this node: 7.0/12.7 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/7.34 GiB heap, 0.0/3.67 GiB objects (0.0/1.0 accelerator_type:T4)<br>Result logdir: /root/ray_results/TorchTrainer_2022-05-18_23-49-22<br>Number of trials: 1/1 (1 TERMINATED)<br><table>\n",
"<thead>\n",
"<tr><th>Trial name </th><th>status </th><th>loc </th></tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr><td>TorchTrainer_24496_00000</td><td>TERMINATED</td><td>172.28.0.2:4630</td></tr>\n",
"</tbody>\n",
"</table><br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_map_block_nosplit pid=4666)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(_map_block_nosplit pid=4666)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m 2022-05-18 23:50:06,950\tINFO torch.py:347 -- Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m 2022-05-18 23:50:07,011\tINFO torch.py:98 -- Moving model to device: cuda:0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 2.373475, epoch: 0, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 1.699985, epoch: 0, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 1.636039, epoch: 0, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 1.334987, epoch: 0, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 1.152312, epoch: 0, iteration: 2000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.998297, epoch: 0, iteration: 2500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 1.434949, epoch: 0, iteration: 3000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.971171, epoch: 0, iteration: 3500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.796480, epoch: 0, iteration: 4000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.802282, epoch: 0, iteration: 4500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.731363, epoch: 0, iteration: 5000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.847772, epoch: 0, iteration: 5500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.879676, epoch: 1, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.564319, epoch: 1, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.714444, epoch: 1, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.565163, epoch: 1, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.739525, epoch: 1, iteration: 2000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.510878, epoch: 1, iteration: 2500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.814798, epoch: 1, iteration: 3000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.473765, epoch: 1, iteration: 3500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.557866, epoch: 1, iteration: 4000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.674371, epoch: 1, iteration: 4500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.532800, epoch: 1, iteration: 5000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.832442, epoch: 1, iteration: 5500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.557547, epoch: 2, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.355255, epoch: 2, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.426749, epoch: 2, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.484543, epoch: 2, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.360856, epoch: 2, iteration: 2000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.444718, epoch: 2, iteration: 2500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.596777, epoch: 2, iteration: 3000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.289816, epoch: 2, iteration: 3500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.407941, epoch: 2, iteration: 4000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.438239, epoch: 2, iteration: 4500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.379983, epoch: 2, iteration: 5000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.527786, epoch: 2, iteration: 5500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.598584, epoch: 3, iteration: 0\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.355202, epoch: 3, iteration: 500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.392683, epoch: 3, iteration: 1000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.415264, epoch: 3, iteration: 1500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.417230, epoch: 3, iteration: 2000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.289974, epoch: 3, iteration: 2500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.648514, epoch: 3, iteration: 3000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.369468, epoch: 3, iteration: 3500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.378548, epoch: 3, iteration: 4000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.392761, epoch: 3, iteration: 4500\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.555575, epoch: 3, iteration: 5000\n",
"\u001b[2m\u001b[36m(BaseWorkerMixin pid=4709)\u001b[0m loss: 0.394487, epoch: 3, iteration: 5500\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-18 23:52:49,915\tERROR checkpoint_manager.py:193 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial TorchTrainer_24496_00000 completed. Last result: \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-05-18 23:52:50,042\tINFO tune.py:753 -- Total run time: 207.53 seconds (207.39 seconds for the tuning loop).\n"
]
}
],
"source": [
"# Now we do training with the same configurations as before\n",
"trainer = TorchTrainer(\n",
" train_loop_per_worker=train_loop_per_worker,\n",
" train_loop_config={\n",
" \"num_epochs\": num_epochs,\n",
" \"learning_rate\": learning_rate,\n",
" \"momentum\": momentum,\n",
" \"batch_size\": batch_size,\n",
" },\n",
" # Have to specify trainer_resources as 0 so that the example works on Colab. \n",
" scaling_config={\"num_workers\": num_workers, \"use_gpu\": use_gpu, \"trainer_resources\": {\"CPU\": 0}},\n",
" datasets={\"train\": combined_training_dataset},\n",
" preprocessor=BatchMapper(fn=preprocess_images),\n",
" )\n",
"result = trainer.fit()\n",
"full_training_checkpoint = result.checkpoint"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jLaOcmBddRqB"
},
"source": [
"Then, let's test model that was trained on all the tasks up front."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WC7zV_Cw9TAi",
"outputId": "12a86f2b-be90-47b6-e252-25e3199689f9"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map Progress (1 actors 1 pending): 0%| | 0/3 [00:01<?, ?it/s]\u001b[2m\u001b[36m(BlockWorker pid=4840)\u001b[0m /usr/local/lib/python3.7/dist-packages/torchvision/transforms/functional.py:133: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)\n",
"\u001b[2m\u001b[36m(BlockWorker pid=4840)\u001b[0m img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Map Progress (2 actors 1 pending): 100%|██████████| 3/3 [00:06<00:00, 2.25s/it]\n",
"Map Progress: 100%|██████████| 3/3 [00:01<00:00, 1.51it/s]\n",
"Map Progress: 100%|██████████| 3/3 [00:01<00:00, 1.94it/s]\n",
"Shuffle Map: 100%|██████████| 3/3 [00:00<00:00, 5.53it/s]\n",
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 65.42it/s]\n"
]
}
],
"source": [
"# Then, we used the fully trained model and do batch prediction on the entire test set.\n",
"\n",
"# `full_test_dataset` should already contain the combined test datasets.\n",
"fully_trained_accuracy = batch_predict(full_training_checkpoint, full_test_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pn5LJ4CUdZgI"
},
"source": [
"Finally, let's compare the accuracies between the incrementally trained model and the fully trained model. We should see that the fully trained model performs better."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UFhRf_8e-vgA",
"outputId": "056ff06f-ff87-4f3a-d740-4cc556bde3dd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fully trained model accuracy: 0.9468\n",
"Incrementally trained model accuracy: 0.9207666666666666\n"
]
}
],
"source": [
"print(\"Fully trained model accuracy: \", fully_trained_accuracy)\n",
"print(\"Incrementally trained model accuracy: \", accuracy_for_all_tasks[-1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FuqKePrYe-Fz"
},
"source": [
"# Next Steps\n",
"\n",
"Once you've completed this notebook, you should be set to play around with scalable incremental training using Ray, either by trying more fancy algorithms for incremental learning other than naive fine-tuning, or attempting to scale out to larger datasets!\n",
"\n",
"If you run into any issues, or have any feature requests, please file an issue on the [Ray Github](https://github.com/ray-project/ray/issues).\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2GdLZD4od3oI"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "ray_air_incremental_learning (1).ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}