ray/doc/source/ray-air/examples/tfx_tabular_train_to_serve.ipynb

982 lines
35 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "VaFMt6AIhYbK"
},
"source": [
"# Tabular data training and serving with Keras and Ray AIR\n",
"\n",
"This notebook is adapted from [a Keras tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/components_keras).\n",
"It uses [Chicago Taxi dataset](https://data.cityofchicago.org/Transportation/Taxi-Trips/wrvz-psew) and a DNN Keras model to predict whether a trip may generate a big tip.\n",
"\n",
"In this example, we showcase how to achieve the same tasks as the Keras Tutorial using [Ray AIR](https://docs.ray.io/en/latest/ray-air/getting-started.html), covering\n",
"every step from data ingestion to pushing a model to serving.\n",
"\n",
"1. Read a CSV into [Ray Dataset](https://docs.ray.io/en/latest/data/dataset.html).\n",
"2. Process the dataset by chaining [Ray AIR preprocessors](https://docs.ray.io/en/latest/ray-air/getting-started.html#preprocessors).\n",
"3. Train the model using the TensorflowTrainer from AIR.\n",
"4. Serve the model using Ray Serve and the above preprocessors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQbdfyWQhYbO"
},
"source": [
"Uncomment and run the following line in order to install all the necessary dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YajFzmkthYbO",
"outputId": "cd4f1959-4ef4-465e-9e9d-71dfc3de28ff"
},
"outputs": [],
"source": [
"# ! pip install \"tensorflow>=2.8.0\" \"ray[tune, data, serve]>=1.12.1\"\n",
"# ! pip install fastapi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pvSRaEHChYbP"
},
"source": [
"## Set up Ray"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LRdL3kWBhYbQ"
},
"source": [
"We will use `ray.init()` to initialize a local cluster. By default, this cluster will be composed of only the machine you are running this notebook on. If you wish to attach to an existing Ray cluster, you can do so through `ray.init(address=\"auto\")`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MOsHUjgdIrIW",
"outputId": "8a21ead5-bb2d-4a3d-ae41-17a313688b24"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-20 18:45:28,814\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
" <div style=\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\">\n",
" <h3 style=\"color: var(--jp-ui-font-color0)\">Ray</h3>\n",
" <svg version=\"1.1\" id=\"ray\" width=\"3em\" viewBox=\"0 0 144.5 144.6\" style=\"margin-left: 3em;margin-right: 3em\">\n",
" <g id=\"layer-1\">\n",
" <path fill=\"#00a2e9\" class=\"st0\" d=\"M97.3,77.2c-3.8-1.1-6.2,0.9-8.3,5.1c-3.5,6.8-9.9,9.9-17.4,9.6S58,88.1,54.8,81.2c-1.4-3-3-4-6.3-4.1\n",
" c-5.6-0.1-9.9,0.1-13.1,6.4c-3.8,7.6-13.6,10.2-21.8,7.6C5.2,88.4-0.4,80.5,0,71.7c0.1-8.4,5.7-15.8,13.8-18.2\n",
" c8.4-2.6,17.5,0.7,22.3,8c1.3,1.9,1.3,5.2,3.6,5.6c3.9,0.6,8,0.2,12,0.2c1.8,0,1.9-1.6,2.4-2.8c3.5-7.8,9.7-11.8,18-11.9\n",
" c8.2-0.1,14.4,3.9,17.8,11.4c1.3,2.8,2.9,3.6,5.7,3.3c1-0.1,2,0.1,3,0c2.8-0.5,6.4,1.7,8.1-2.7s-2.3-5.5-4.1-7.5\n",
" c-5.1-5.7-10.9-10.8-16.1-16.3C84,38,81.9,37.1,78,38.3C66.7,42,56.2,35.7,53,24.1C50.3,14,57.3,2.8,67.7,0.5\n",
" C78.4-2,89,4.7,91.5,15.3c0.1,0.3,0.1,0.5,0.2,0.8c0.7,3.4,0.7,6.9-0.8,9.8c-1.7,3.2-0.8,5,1.5,7.2c6.7,6.5,13.3,13,19.8,19.7\n",
" c1.8,1.8,3,2.1,5.5,1.2c9.1-3.4,17.9-0.6,23.4,7c4.8,6.9,4.6,16.1-0.4,22.9c-5.4,7.2-14.2,9.9-23.1,6.5c-2.3-0.9-3.5-0.6-5.1,1.1\n",
" c-6.7,6.9-13.6,13.7-20.5,20.4c-1.8,1.8-2.5,3.2-1.4,5.9c3.5,8.7,0.3,18.6-7.7,23.6c-7.9,5-18.2,3.8-24.8-2.9\n",
" c-6.4-6.4-7.4-16.2-2.5-24.3c4.9-7.8,14.5-11,23.1-7.8c3,1.1,4.7,0.5,6.9-1.7C91.7,98.4,98,92.3,104.2,86c1.6-1.6,4.1-2.7,2.6-6.2\n",
" c-1.4-3.3-3.8-2.5-6.2-2.6C99.8,77.2,98.9,77.2,97.3,77.2z M72.1,29.7c5.5,0.1,9.9-4.3,10-9.8c0-0.1,0-0.2,0-0.3\n",
" C81.8,14,77,9.8,71.5,10.2c-5,0.3-9,4.2-9.3,9.2c-0.2,5.5,4,10.1,9.5,10.3C71.8,29.7,72,29.7,72.1,29.7z M72.3,62.3\n",
" c-5.4-0.1-9.9,4.2-10.1,9.7c0,0.2,0,0.3,0,0.5c0.2,5.4,4.5,9.7,9.9,10c5.1,0.1,9.9-4.7,10.1-9.8c0.2-5.5-4-10-9.5-10.3\n",
" C72.6,62.3,72.4,62.3,72.3,62.3z M115,72.5c0.1,5.4,4.5,9.7,9.8,9.9c5.6-0.2,10-4.8,10-10.4c-0.2-5.4-4.6-9.7-10-9.7\n",
" c-5.3-0.1-9.8,4.2-9.9,9.5C115,72.1,115,72.3,115,72.5z M19.5,62.3c-5.4,0.1-9.8,4.4-10,9.8c-0.1,5.1,5.2,10.4,10.2,10.3\n",
" c5.6-0.2,10-4.9,9.8-10.5c-0.1-5.4-4.5-9.7-9.9-9.6C19.6,62.3,19.5,62.3,19.5,62.3z M71.8,134.6c5.9,0.2,10.3-3.9,10.4-9.6\n",
" c0.5-5.5-3.6-10.4-9.1-10.8c-5.5-0.5-10.4,3.6-10.8,9.1c0,0.5,0,0.9,0,1.4c-0.2,5.3,4,9.8,9.3,10\n",
" C71.6,134.6,71.7,134.6,71.8,134.6z\"/>\n",
" </g>\n",
" </svg>\n",
" <table>\n",
" <tr>\n",
" <td style=\"text-align: left\"><b>Python version:</b></td>\n",
" <td style=\"text-align: left\"><b>3.7.10</b></td>\n",
" </tr>\n",
" <tr>\n",
" <td style=\"text-align: left\"><b>Ray version:</b></td>\n",
" <td style=\"text-align: left\"><b> 3.0.0.dev0</b></td>\n",
" </tr>\n",
" <tr>\n",
" <td style=\"text-align: left\"><b>Dashboard:</b></td>\n",
" <td style=\"text-align: left\"><b><a href=\"http://127.0.0.1:8266\" target=\"_blank\">http://127.0.0.1:8266</a></b></td>\n",
"</tr>\n",
"\n",
" </table>\n",
" </div>\n",
"</div>\n"
],
"text/plain": [
"RayContext(dashboard_url='127.0.0.1:8266', python_version='3.7.10', ray_version='3.0.0.dev0', ray_commit='{{RAY_COMMIT_SHA}}', address_info={'node_ip_address': '127.0.0.1', 'raylet_ip_address': '127.0.0.1', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-07-20_18-45-26_127581_21006/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-07-20_18-45-26_127581_21006/sockets/raylet', 'webui_url': '127.0.0.1:8266', 'session_dir': '/tmp/ray/session_2022-07-20_18-45-26_127581_21006', 'metrics_export_port': 63884, 'gcs_address': '127.0.0.1:63685', 'address': '127.0.0.1:63685', 'dashboard_agent_listen_port': 52365, 'node_id': 'c21f810137e56bd967ab3f246c66aadc5262e00bdbe19c34c23456e7'})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pprint import pprint\n",
"import ray\n",
"\n",
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oJiSdWy2hYbR"
},
"source": [
"We can check the resources our cluster is composed of. If you are running this notebook on your local machine or Google Colab, you should see the number of CPU cores and GPUs available on the said machine."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KlMz0dt9hYbS",
"outputId": "e7234b52-08b4-49fc-e14c-72f283b893f2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'CPU': 16.0,\n",
" 'memory': 30436675994.0,\n",
" 'node:127.0.0.1': 1.0,\n",
" 'object_store_memory': 2147483648.0}\n"
]
}
],
"source": [
"pprint(ray.cluster_resources())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jAgvLbhT8nB0"
},
"source": [
"## Getting the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IXQb4--97_Cf"
},
"source": [
"Let's start with defining a helper function to get the data to work with. Some columns are dropped for simplicity."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "gAbhv9OqhYbT"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"INPUT = \"input\"\n",
"LABEL = \"is_big_tip\"\n",
"\n",
"def get_data() -> pd.DataFrame:\n",
" \"\"\"Fetch the taxi fare data to work on.\"\"\"\n",
" _data = pd.read_csv(\n",
" \"https://raw.githubusercontent.com/tensorflow/tfx/master/\"\n",
" \"tfx/examples/chicago_taxi_pipeline/data/simple/data.csv\"\n",
" )\n",
" _data[LABEL] = _data[\"tips\"] / _data[\"fare\"] > 0.2\n",
" # We drop some columns here for the sake of simplicity.\n",
" return _data.drop(\n",
" [\n",
" \"tips\",\n",
" \"fare\",\n",
" \"dropoff_latitude\",\n",
" \"dropoff_longitude\",\n",
" \"pickup_latitude\",\n",
" \"pickup_longitude\",\n",
" \"pickup_census_tract\",\n",
" ],\n",
" axis=1,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "FbeYf1aF8ISK"
},
"outputs": [],
"source": [
"data = get_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1WALC3kT8WgL"
},
"source": [
"Now let's take a look at the data. Notice that some values are missing. This is exactly where preprocessing comes into the picture. We will come back to this in the preprocessing session below."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "8tugpr5S8gPq",
"outputId": "3c57a348-12a7-4b6c-f9b2-fabdcb7a7c88"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>pickup_community_area</th>\n",
" <th>trip_start_month</th>\n",
" <th>trip_start_hour</th>\n",
" <th>trip_start_day</th>\n",
" <th>trip_start_timestamp</th>\n",
" <th>trip_miles</th>\n",
" <th>dropoff_census_tract</th>\n",
" <th>payment_type</th>\n",
" <th>company</th>\n",
" <th>trip_seconds</th>\n",
" <th>dropoff_community_area</th>\n",
" <th>is_big_tip</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NaN</td>\n",
" <td>5</td>\n",
" <td>19</td>\n",
" <td>6</td>\n",
" <td>1400269500</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>Credit Card</td>\n",
" <td>Chicago Elite Cab Corp. (Chicago Carriag</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NaN</td>\n",
" <td>3</td>\n",
" <td>19</td>\n",
" <td>5</td>\n",
" <td>1362683700</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>Unknown</td>\n",
" <td>Chicago Elite Cab Corp.</td>\n",
" <td>300.0</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>60.0</td>\n",
" <td>10</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>1380593700</td>\n",
" <td>12.6</td>\n",
" <td>NaN</td>\n",
" <td>Cash</td>\n",
" <td>Taxi Affiliation Services</td>\n",
" <td>1380.0</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>10.0</td>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>1382319000</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>Cash</td>\n",
" <td>Taxi Affiliation Services</td>\n",
" <td>180.0</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>14.0</td>\n",
" <td>5</td>\n",
" <td>7</td>\n",
" <td>5</td>\n",
" <td>1369897200</td>\n",
" <td>0.0</td>\n",
" <td>NaN</td>\n",
" <td>Cash</td>\n",
" <td>Dispatch Taxi Affiliation</td>\n",
" <td>1080.0</td>\n",
" <td>NaN</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" pickup_community_area trip_start_month trip_start_hour trip_start_day \\\n",
"0 NaN 5 19 6 \n",
"1 NaN 3 19 5 \n",
"2 60.0 10 2 3 \n",
"3 10.0 10 1 2 \n",
"4 14.0 5 7 5 \n",
"\n",
" trip_start_timestamp trip_miles dropoff_census_tract payment_type \\\n",
"0 1400269500 0.0 NaN Credit Card \n",
"1 1362683700 0.0 NaN Unknown \n",
"2 1380593700 12.6 NaN Cash \n",
"3 1382319000 0.0 NaN Cash \n",
"4 1369897200 0.0 NaN Cash \n",
"\n",
" company trip_seconds \\\n",
"0 Chicago Elite Cab Corp. (Chicago Carriag 0.0 \n",
"1 Chicago Elite Cab Corp. 300.0 \n",
"2 Taxi Affiliation Services 1380.0 \n",
"3 Taxi Affiliation Services 180.0 \n",
"4 Dispatch Taxi Affiliation 1080.0 \n",
"\n",
" dropoff_community_area is_big_tip \n",
"0 NaN False \n",
"1 NaN False \n",
"2 NaN False \n",
"3 NaN False \n",
"4 NaN False "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xzNQKJMA9YV-"
},
"source": [
"We continue to split the data into training and test data.\n",
"For the test data, we separate out the features to run serving on as well as labels to compare serving results with."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "YSLvrBMC9aRv"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from typing import Tuple\n",
"\n",
"\n",
"def split_data(data: pd.DataFrame) -> Tuple[ray.data.Dataset, pd.DataFrame, np.array]:\n",
" \"\"\"Split the data in a stratified way.\n",
"\n",
" Returns:\n",
" A tuple containing train dataset, test data and test label.\n",
" \"\"\"\n",
" # There is a native offering in Ray Dataset for split as well.\n",
" # However, supporting stratification is a TODO there. So use\n",
" # scikit-learn equivalent here.\n",
" train_data, test_data = train_test_split(\n",
" data, stratify=data[[LABEL]], random_state=1113\n",
" )\n",
" _train_ds = ray.data.from_pandas(train_data)\n",
" _test_label = test_data[LABEL].values\n",
" _test_df = test_data.drop([LABEL], axis=1)\n",
" return _train_ds, _test_df, _test_label\n",
"\n",
"train_ds, test_df, test_label = split_data(data)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xfhRl7eO981w",
"outputId": "f80d90ff-fc8a-4a7d-b544-31633823d596"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 11251 samples for training and 3751 samples for testing.\n"
]
}
],
"source": [
"print(f\"There are {train_ds.count()} samples for training and {test_df.shape[0]} samples for testing.\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N7tiwqdP-zVS"
},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4RRkXuteIrIh"
},
"source": [
"Let's focus on preprocessing first.\n",
"Usually, input data needs to go through some preprocessing before being\n",
"fed into model. It is a good idea to package preprocessing logic into\n",
"a modularized component so that the same logic can be applied to both\n",
"training data as well as data for online serving or offline batch prediction.\n",
"\n",
"In AIR, this component is a [`Preprocessor`](https://docs.ray.io/en/latest/ray-air/getting-started.html#preprocessors).\n",
"It is constructed in a way that allows easy composition.\n",
"\n",
"Now let's construct a chained preprocessor composed of simple preprocessors, including\n",
"1. Imputer for filling missing features;\n",
"2. OneHotEncoder for encoding categorical features;\n",
"3. BatchMapper where arbitrary user-defined function can be applied to batches of records;\n",
"and so on. Take a look at [`Preprocessor`](https://docs.ray.io/en/latest/ray-air/getting-started.html#preprocessors).\n",
"The output of the preprocessing step goes into model for training."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "zVvslsfMIrIh"
},
"outputs": [],
"source": [
"from ray.data.preprocessors import (\n",
" BatchMapper,\n",
" Chain,\n",
" OneHotEncoder,\n",
" SimpleImputer,\n",
")\n",
"\n",
"def get_preprocessor():\n",
" \"\"\"Construct a chain of preprocessors.\"\"\"\n",
" imputer1 = SimpleImputer(\n",
" [\"dropoff_census_tract\"], strategy=\"most_frequent\"\n",
" )\n",
" imputer2 = SimpleImputer(\n",
" [\"pickup_community_area\", \"dropoff_community_area\"],\n",
" strategy=\"most_frequent\",\n",
" )\n",
" imputer3 = SimpleImputer([\"payment_type\"], strategy=\"most_frequent\")\n",
" imputer4 = SimpleImputer(\n",
" [\"company\"], strategy=\"most_frequent\")\n",
" imputer5 = SimpleImputer(\n",
" [\"trip_start_timestamp\", \"trip_miles\", \"trip_seconds\"], strategy=\"mean\"\n",
" )\n",
"\n",
" ohe = OneHotEncoder(\n",
" columns=[\n",
" \"trip_start_hour\",\n",
" \"trip_start_day\",\n",
" \"trip_start_month\",\n",
" \"dropoff_census_tract\",\n",
" \"pickup_community_area\",\n",
" \"dropoff_community_area\",\n",
" \"payment_type\",\n",
" \"company\",\n",
" ],\n",
" max_categories={\n",
" \"dropoff_census_tract\": 25,\n",
" \"pickup_community_area\": 20,\n",
" \"dropoff_community_area\": 20,\n",
" \"payment_type\": 2,\n",
" \"company\": 7,\n",
" },\n",
" )\n",
"\n",
" def batch_mapper_fn(df):\n",
" df[\"trip_start_year\"] = pd.to_datetime(df[\"trip_start_timestamp\"], unit=\"s\").dt.year\n",
" df = df.drop([\"trip_start_timestamp\"], axis=1)\n",
" return df\n",
"\n",
" def concat_for_tensor(dataframe):\n",
" from ray.data.extensions import TensorArray\n",
" result = {}\n",
" feature_cols = [col for col in dataframe.columns if col != LABEL]\n",
" result[INPUT] = TensorArray(dataframe[feature_cols].to_numpy(dtype=np.float32))\n",
" if LABEL in dataframe.columns:\n",
" result[LABEL] = dataframe[LABEL]\n",
" return pd.DataFrame(result)\n",
"\n",
" chained_pp = Chain(\n",
" imputer1,\n",
" imputer2,\n",
" imputer3,\n",
" imputer4,\n",
" imputer5,\n",
" ohe,\n",
" BatchMapper(batch_mapper_fn),\n",
" BatchMapper(concat_for_tensor)\n",
" )\n",
" return chained_pp\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V2BIiegi_brE"
},
"source": [
"Now let's define some constants for clarity."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "ejGVU-uN_dVP"
},
"outputs": [],
"source": [
"# Note that `INPUT_SIZE` here is corresponding to the output dimension\n",
"# of the previously defined processing steps.\n",
"# This is used to specify the input shape of Keras model as well as\n",
"# when converting from training data from `ray.data.Dataset` to `tf.Tensor`.\n",
"INPUT_SIZE = 120\n",
"# The training batch size. Based on `NUM_WORKERS`, each worker\n",
"# will get its own share of this batch size. For example, if\n",
"# `NUM_WORKERS = 2`, each worker will work on 4 samples per batch.\n",
"BATCH_SIZE = 8\n",
"# Number of epoch. Adjust it based on how quickly you want the run to be.\n",
"EPOCH = 1\n",
"# Number of training workers.\n",
"# Adjust this accordingly based on the resources you have!\n",
"NUM_WORKERS = 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "whPRbBNbIrIl"
},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7QYTpxXIrIl"
},
"source": [
"Let's starting with defining a simple Keras model for the classification task."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "MwhAeEOuhYbV"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"def build_model():\n",
" model = tf.keras.models.Sequential()\n",
" model.add(tf.keras.Input(shape=(INPUT_SIZE,)))\n",
" model.add(tf.keras.layers.Dense(50, activation=\"relu\"))\n",
" model.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVVji2YKADrh"
},
"source": [
"Now let's define the training loop. This code will be run on each training\n",
"worker in a distributed fashion. See more details [here](https://docs.ray.io/en/latest/train/train.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U5pdjIzoAGRd"
},
"outputs": [],
"source": [
"from ray.air import session, Checkpoint\n",
"from ray.train.tensorflow import prepare_dataset_shard\n",
"\n",
"def train_loop_per_worker():\n",
" dataset_shard = session.get_dataset_shard(\"train\")\n",
"\n",
" strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()\n",
" with strategy.scope():\n",
" model = build_model()\n",
" model.compile(\n",
" loss=\"binary_crossentropy\",\n",
" optimizer=\"adam\",\n",
" metrics=[\"accuracy\"],\n",
" )\n",
"\n",
" def to_tf_dataset(dataset, batch_size):\n",
" def to_tensor_iterator():\n",
" for batch in dataset.iter_tf_batches(\n",
" batch_size=batch_size, dtypes=tf.float32, drop_last=True,\n",
" ):\n",
" yield batch[INPUT], batch[LABEL]\n",
"\n",
" output_signature = (\n",
" tf.TensorSpec(shape=(BATCH_SIZE, INPUT_SIZE), dtype=tf.float32),\n",
" tf.TensorSpec(shape=(BATCH_SIZE,), dtype=tf.int64),\n",
" )\n",
" tf_dataset = tf.data.Dataset.from_generator(\n",
" to_tensor_iterator, output_signature=output_signature\n",
" )\n",
" return prepare_dataset_shard(tf_dataset)\n",
"\n",
" for epoch in range(EPOCH): \n",
" # This will make sure that the training workers will get their own\n",
" # share of batch to work on.\n",
" # See `ray.train.tensorflow.prepare_dataset_shard` for more information.\n",
" tf_dataset = to_tf_dataset(\n",
" dataset=dataset_shard,\n",
" batch_size=BATCH_SIZE,\n",
" )\n",
"\n",
" model.fit(tf_dataset, verbose=0)\n",
" # This saves checkpoint in a way that can be used by Ray Serve coherently.\n",
" session.report(\n",
" {},\n",
" checkpoint=Checkpoint.from_dict(\n",
" dict(epoch=epoch, model=model.get_weights())\n",
" ),\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzfPtOMoIrIu"
},
"source": [
"Now let's define a trainer that takes in the training loop,\n",
"the training dataset as well the preprocessor that we just defined.\n",
"\n",
"And run it!\n",
"\n",
"Notice that you can tune how long you want the run to be by changing ``EPOCH``."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
},
"id": "fzpWK7nuTJmN",
"outputId": "10020cb8-35ec-4f81-a528-0c99f7bdffea"
},
"outputs": [],
"source": [
"from ray.train.tensorflow import TensorflowTrainer\n",
"from ray.air.config import ScalingConfig\n",
"\n",
"trainer = TensorflowTrainer(\n",
" train_loop_per_worker=train_loop_per_worker,\n",
" scaling_config=ScalingConfig(num_workers=NUM_WORKERS),\n",
" datasets={\"train\": train_ds},\n",
" preprocessor=get_preprocessor(),\n",
")\n",
"result = trainer.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nb0HkOV2R4uL"
},
"source": [
"## Moving on to Serve"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OlzjlW8QR_q6"
},
"source": [
"We will use Ray Serve to serve the trained model. A core concept of Ray Serve is [Deployment](https://docs.ray.io/en/latest/serve/core-apis.html). It allows you to define and update your business logic or models that will handle incoming requests as well as how this is exposed over HTTP or in Python.\n",
"\n",
"In the case of serving model, `ray.serve.air_integrations.Predictor` and `ray.serve.air_integrations.PredictorDeployment` wrap a `ray.air.checkpoint.Checkpoint` into a Ray Serve deployment that can readily serve HTTP requests.\n",
"Note, ``Checkpoint`` captures both model and preprocessing steps in a way compatible with Ray Serve and ensures that ml workload can transition seamlessly between training and\n",
"serving.\n",
"\n",
"This removes the boilerplate code and minimizes the effort to bring your model to production!\n",
"\n",
"Generally speaking the http request can either send in json or data.\n",
"Upon receiving this payload, Ray Serve would need some \"adapter\" to convert the request payload into some shape or form that can be accepted as input to the preprocessing steps. In this case, we send in a json request and converts it to a pandas DataFrame through `dataframe_adapter`, defined below."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "BBbcMwc9Rz66"
},
"outputs": [],
"source": [
"from fastapi import Request\n",
"\n",
"async def dataframe_adapter(request: Request):\n",
" \"\"\"Serve HTTP Adapter that reads JSON and converts to pandas DataFrame.\"\"\"\n",
" content = await request.json()\n",
" return pd.DataFrame.from_dict(content)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SOnl90IuRywD"
},
"source": [
"Now let's wrap everything in a serve endpoint that exposes a URL to where requests can be sent to."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "ujmwT8ZhScq1"
},
"outputs": [],
"source": [
"from ray import serve\n",
"from ray.air.checkpoint import Checkpoint\n",
"from ray.train.tensorflow import TensorflowPredictor\n",
"from ray.serve import PredictorDeployment\n",
"\n",
"\n",
"def serve_model(checkpoint: Checkpoint, model_definition, adapter, name=\"Model\") -> str:\n",
" \"\"\"Expose a serve endpoint.\n",
"\n",
" Returns:\n",
" serve URL.\n",
" \"\"\"\n",
" serve.run(\n",
" PredictorDeployment.options(name=name).bind(\n",
" TensorflowPredictor,\n",
" checkpoint,\n",
" batching_params=dict(max_batch_size=2, batch_wait_timeout_s=5),\n",
" model_definition=model_definition,\n",
" http_adapter=adapter,\n",
" )\n",
" )\n",
" return f\"http://localhost:8000/\""
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "uRe9a8947pl9"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-20 18:46:11,759\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n",
"\u001b[2m\u001b[36m(ServeController pid=21308)\u001b[0m INFO 2022-07-20 18:46:15,348 controller 21308 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.\n",
"\u001b[2m\u001b[36m(ServeController pid=21308)\u001b[0m INFO 2022-07-20 18:46:15,350 controller 21308 http_state.py:126 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-58fb3ee046cdce5c602369291de78f60c65dcbd7c5c5a8af57ec3a26' on node '58fb3ee046cdce5c602369291de78f60c65dcbd7c5c5a8af57ec3a26' listening on '127.0.0.1:8000'\n",
"\u001b[2m\u001b[36m(HTTPProxyActor pid=21311)\u001b[0m INFO: Started server process [21311]\n",
"/Users/jiaodong/anaconda3/envs/ray3.7/lib/python3.7/site-packages/ipykernel_launcher.py:23: UserWarning: From /var/folders/1s/wy6f3ytn3q726p5hl8fw8d780000gn/T/ipykernel_21006/609683685.py:23: deploy (from ray.serve.deployment) is deprecated and will be removed in a future version Please see https://docs.ray.io/en/latest/serve/index.html\n",
"\u001b[2m\u001b[36m(ServeController pid=21308)\u001b[0m INFO 2022-07-20 18:46:17,658 controller 21308 deployment_state.py:1281 - Adding 1 replicas to deployment 'Model'.\n",
"\u001b[2m\u001b[36m(ServeReplica:Model pid=21314)\u001b[0m 2022-07-20 18:46:23,199\tWARNING compression.py:18 -- lz4 not available, disabling sample compression. This will significantly impact RLlib performance. To install lz4, run `pip install lz4`.\n"
]
}
],
"source": [
"import ray\n",
"# Generally speaking, training and serving are done in totally different ray clusters.\n",
"# To simulate that, let's shutdown the old ray cluster in preparation for serving.\n",
"ray.shutdown()\n",
"\n",
"endpoint_uri = serve_model(result.checkpoint, build_model, dataframe_adapter)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rzHSwa2bSyee"
},
"source": [
"Let's write a helper function to send requests to this endpoint and compare the results with labels."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "E9m80HDmSz66"
},
"outputs": [],
"source": [
"import requests\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"NUM_SERVE_REQUESTS = 10\n",
"\n",
"def send_requests(df: pd.DataFrame, label: np.array):\n",
" for i in range(NUM_SERVE_REQUESTS):\n",
" one_row = df.iloc[[i]].to_dict()\n",
" serve_result = requests.post(endpoint_uri, json=one_row).json()\n",
" print(\n",
" f\"request{i} prediction: {serve_result[0]['predictions']} \"\n",
" f\"- label: {str(label[i])}\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "GFPwKc5JTgnI"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"request0 prediction: 0.004963837098330259 - label: True\n",
"request1 prediction: 6.652726733591408e-05 - label: False\n",
"request2 prediction: 0.00018405025184620172 - label: False\n",
"request3 prediction: 0.00016512417641934007 - label: False\n",
"request4 prediction: 0.00015515758423134685 - label: False\n",
"request5 prediction: 5.948602483840659e-05 - label: False\n",
"request6 prediction: 9.51739348238334e-05 - label: False\n",
"request7 prediction: 3.4787988170137396e-06 - label: False\n",
"request8 prediction: 0.00010751552326837555 - label: False\n",
"request9 prediction: 0.060329731553792953 - label: True\n"
]
}
],
"source": [
"send_requests(test_df, test_label)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"jAgvLbhT8nB0"
],
"name": "tfx (1) (1) (1).ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"vscode": {
"interpreter": {
"hash": "99d89bfe98f3aa2d7facda0d08d31ff2a0af9559e5330d719288ce64a1966273"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}