mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
1202 lines
No EOL
43 KiB
Text
1202 lines
No EOL
43 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "548805e2",
|
|
"metadata": {
|
|
"tags": [
|
|
"remove-cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# flake8: noqa\n",
|
|
"import warnings\n",
|
|
"import os\n",
|
|
"\n",
|
|
"# Suppress noisy requests warnings.\n",
|
|
"warnings.filterwarnings(\"ignore\")\n",
|
|
"os.environ[\"PYTHONWARNINGS\"] = \"ignore\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "af627a74",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Processing NYC taxi data using Ray Datasets\n",
|
|
"\n",
|
|
"The [NYC Taxi dataset](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page) is a popular tabular dataset. In this example, we demonstrate some basic data processing on this dataset using Ray Datasets.\n",
|
|
"\n",
|
|
"## Overview\n",
|
|
"\n",
|
|
"This tutorial will cover:\n",
|
|
" - Reading Parquet data\n",
|
|
" - Inspecting the metadata and first few rows of a large Ray {class}`Dataset <ray.data.Dataset>`\n",
|
|
" - Calculating some common global and grouped statistics on the dataset\n",
|
|
" - Dropping columns and rows\n",
|
|
" - Adding a derived column\n",
|
|
" - Shuffling the dataset\n",
|
|
" - Sharding the dataset and feeding it to parallel consumers (trainers)\n",
|
|
" - Applying batch (offline) inference to the data\n",
|
|
"\n",
|
|
"## Walkthrough\n",
|
|
"\n",
|
|
"Let's start by importing Ray and initializing a local Ray cluster."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "be863f26",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": []
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
" <div style=\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\">\n",
|
|
" <h3 style=\"color: var(--jp-ui-font-color0)\">Ray</h3>\n",
|
|
" <svg version=\"1.1\" id=\"ray\" width=\"3em\" viewBox=\"0 0 144.5 144.6\" style=\"margin-left: 3em;margin-right: 3em\">\n",
|
|
" <g id=\"layer-1\">\n",
|
|
" <path fill=\"#00a2e9\" class=\"st0\" d=\"M97.3,77.2c-3.8-1.1-6.2,0.9-8.3,5.1c-3.5,6.8-9.9,9.9-17.4,9.6S58,88.1,54.8,81.2c-1.4-3-3-4-6.3-4.1\n",
|
|
" c-5.6-0.1-9.9,0.1-13.1,6.4c-3.8,7.6-13.6,10.2-21.8,7.6C5.2,88.4-0.4,80.5,0,71.7c0.1-8.4,5.7-15.8,13.8-18.2\n",
|
|
" c8.4-2.6,17.5,0.7,22.3,8c1.3,1.9,1.3,5.2,3.6,5.6c3.9,0.6,8,0.2,12,0.2c1.8,0,1.9-1.6,2.4-2.8c3.5-7.8,9.7-11.8,18-11.9\n",
|
|
" c8.2-0.1,14.4,3.9,17.8,11.4c1.3,2.8,2.9,3.6,5.7,3.3c1-0.1,2,0.1,3,0c2.8-0.5,6.4,1.7,8.1-2.7s-2.3-5.5-4.1-7.5\n",
|
|
" c-5.1-5.7-10.9-10.8-16.1-16.3C84,38,81.9,37.1,78,38.3C66.7,42,56.2,35.7,53,24.1C50.3,14,57.3,2.8,67.7,0.5\n",
|
|
" C78.4-2,89,4.7,91.5,15.3c0.1,0.3,0.1,0.5,0.2,0.8c0.7,3.4,0.7,6.9-0.8,9.8c-1.7,3.2-0.8,5,1.5,7.2c6.7,6.5,13.3,13,19.8,19.7\n",
|
|
" c1.8,1.8,3,2.1,5.5,1.2c9.1-3.4,17.9-0.6,23.4,7c4.8,6.9,4.6,16.1-0.4,22.9c-5.4,7.2-14.2,9.9-23.1,6.5c-2.3-0.9-3.5-0.6-5.1,1.1\n",
|
|
" c-6.7,6.9-13.6,13.7-20.5,20.4c-1.8,1.8-2.5,3.2-1.4,5.9c3.5,8.7,0.3,18.6-7.7,23.6c-7.9,5-18.2,3.8-24.8-2.9\n",
|
|
" c-6.4-6.4-7.4-16.2-2.5-24.3c4.9-7.8,14.5-11,23.1-7.8c3,1.1,4.7,0.5,6.9-1.7C91.7,98.4,98,92.3,104.2,86c1.6-1.6,4.1-2.7,2.6-6.2\n",
|
|
" c-1.4-3.3-3.8-2.5-6.2-2.6C99.8,77.2,98.9,77.2,97.3,77.2z M72.1,29.7c5.5,0.1,9.9-4.3,10-9.8c0-0.1,0-0.2,0-0.3\n",
|
|
" C81.8,14,77,9.8,71.5,10.2c-5,0.3-9,4.2-9.3,9.2c-0.2,5.5,4,10.1,9.5,10.3C71.8,29.7,72,29.7,72.1,29.7z M72.3,62.3\n",
|
|
" c-5.4-0.1-9.9,4.2-10.1,9.7c0,0.2,0,0.3,0,0.5c0.2,5.4,4.5,9.7,9.9,10c5.1,0.1,9.9-4.7,10.1-9.8c0.2-5.5-4-10-9.5-10.3\n",
|
|
" C72.6,62.3,72.4,62.3,72.3,62.3z M115,72.5c0.1,5.4,4.5,9.7,9.8,9.9c5.6-0.2,10-4.8,10-10.4c-0.2-5.4-4.6-9.7-10-9.7\n",
|
|
" c-5.3-0.1-9.8,4.2-9.9,9.5C115,72.1,115,72.3,115,72.5z M19.5,62.3c-5.4,0.1-9.8,4.4-10,9.8c-0.1,5.1,5.2,10.4,10.2,10.3\n",
|
|
" c5.6-0.2,10-4.9,9.8-10.5c-0.1-5.4-4.5-9.7-9.9-9.6C19.6,62.3,19.5,62.3,19.5,62.3z M71.8,134.6c5.9,0.2,10.3-3.9,10.4-9.6\n",
|
|
" c0.5-5.5-3.6-10.4-9.1-10.8c-5.5-0.5-10.4,3.6-10.8,9.1c0,0.5,0,0.9,0,1.4c-0.2,5.3,4,9.8,9.3,10\n",
|
|
" C71.6,134.6,71.7,134.6,71.8,134.6z\"/>\n",
|
|
" </g>\n",
|
|
" </svg>\n",
|
|
" <table>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Python version:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b>3.8.5</b></td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Ray version:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b> 3.0.0.dev0</b></td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Dashboard:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b><a href=\"http://127.0.0.1:8265\" target=\"_blank\">http://127.0.0.1:8265</a></b></td>\n",
|
|
"</tr>\n",
|
|
"\n",
|
|
" </table>\n",
|
|
" </div>\n",
|
|
"</div>\n"
|
|
],
|
|
"text/plain": [
|
|
"RayContext(dashboard_url='127.0.0.1:8265', python_version='3.8.5', ray_version='3.0.0.dev0', ray_commit='c01bb831d4fcf6066c8bd60f73999115b315148a', address_info={'node_ip_address': '172.31.95.92', 'raylet_ip_address': '172.31.95.92', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-07-25_13-46-24_910980_201/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-07-25_13-46-24_910980_201/sockets/raylet', 'webui_url': '127.0.0.1:8265', 'session_dir': '/tmp/ray/session_2022-07-25_13-46-24_910980_201', 'metrics_export_port': 43248, 'gcs_address': '172.31.95.92:9031', 'address': '172.31.95.92:9031', 'dashboard_agent_listen_port': 52365, 'node_id': '723fbaf30ca70b1ba739386bf2fae31a4f620c113fdca47729204709'})"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Import ray and initialize a local Ray cluster.\n",
|
|
"import ray\n",
|
|
"ray.init()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f1f7ea00",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Reading and Inspecting the Data\n",
|
|
"\n",
|
|
"Next, we read a few of the files from the dataset. This read is semi-lazy, where reading of the first file is eagerly executed, but reading of all other files is delayed until the underlying data is needed by downstream operations (e.g. consuming the data with {meth}`ds.take() <ray.data.Dataset.take>`, or transforming the data with {meth}`ds.map_batches() <ray.data.Dataset.map_batches>`).\n",
|
|
"\n",
|
|
"We could process the entire Dataset in a streaming fashion using {ref}`pipelining <dataset_pipeline_concept>` or all of it in parallel using a multi-node Ray cluster, but we'll save that for our large-scale examples. :)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ee6fe392",
|
|
"metadata": {
|
|
"tags": [
|
|
"remove-output"
|
|
]
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"⚠️ The number of blocks in this dataset (2) limits its parallelism to 2 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read two Parquet files in parallel.\n",
|
|
"ds = ray.data.read_parquet([\n",
|
|
" \"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_01_data.parquet\",\n",
|
|
" \"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_02_data.parquet\"\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a4d4769c",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can easily inspect the schema of this dataset. For Parquet files, we don't even have to read the actual data to get the schema; we can read it from the lightweight Parquet metadata!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "8df10660",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"vendor_id: string\n",
|
|
"pickup_at: timestamp[us]\n",
|
|
"dropoff_at: timestamp[us]\n",
|
|
"passenger_count: int8\n",
|
|
"trip_distance: float\n",
|
|
"pickup_longitude: float\n",
|
|
"pickup_latitude: float\n",
|
|
"rate_code_id: null\n",
|
|
"store_and_fwd_flag: string\n",
|
|
"dropoff_longitude: float\n",
|
|
"dropoff_latitude: float\n",
|
|
"payment_type: string\n",
|
|
"fare_amount: float\n",
|
|
"extra: float\n",
|
|
"mta_tax: float\n",
|
|
"tip_amount: float\n",
|
|
"tolls_amount: float\n",
|
|
"total_amount: float\n",
|
|
"-- schema metadata --\n",
|
|
"pandas: '{\"index_columns\": [{\"kind\": \"range\", \"name\": null, \"start\": 0, \"' + 2524"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Fetch the schema from the underlying Parquet metadata.\n",
|
|
"ds.schema()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d918b6bd",
|
|
"metadata": {},
|
|
"source": [
|
|
"Parquet even stores the number of rows per file in the Parquet metadata, so we can get the number of rows in ``ds`` without triggering a full data read."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "80549d69",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"2749936"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.count()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "87fd9a17",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can get a nice, cheap summary of the ``Dataset`` by leveraging it's informative repr:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "96eceee0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Dataset(num_blocks=2, num_rows=2749936, schema={vendor_id: string, pickup_at: timestamp[us], dropoff_at: timestamp[us], passenger_count: int8, trip_distance: float, pickup_longitude: float, pickup_latitude: float, rate_code_id: null, store_and_fwd_flag: string, dropoff_longitude: float, dropoff_latitude: float, payment_type: string, fare_amount: float, extra: float, mta_tax: float, tip_amount: float, tolls_amount: float, total_amount: float})"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Display some metadata about the dataset.\n",
|
|
"ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2c96f0cf",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also poke at the actual data, taking a peek at a single row. Since this is only returning a row from the first file, reading of the second file is **not** triggered yet."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "e6b6eb72",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[ArrowRow({'vendor_id': 'VTS',\n",
|
|
" 'pickup_at': datetime.datetime(2009, 1, 21, 14, 58),\n",
|
|
" 'dropoff_at': datetime.datetime(2009, 1, 21, 15, 3),\n",
|
|
" 'passenger_count': 1,\n",
|
|
" 'trip_distance': 0.5299999713897705,\n",
|
|
" 'pickup_longitude': -73.99270629882812,\n",
|
|
" 'pickup_latitude': 40.7529411315918,\n",
|
|
" 'rate_code_id': None,\n",
|
|
" 'store_and_fwd_flag': None,\n",
|
|
" 'dropoff_longitude': -73.98814392089844,\n",
|
|
" 'dropoff_latitude': 40.75956344604492,\n",
|
|
" 'payment_type': 'CASH',\n",
|
|
" 'fare_amount': 4.5,\n",
|
|
" 'extra': 0.0,\n",
|
|
" 'mta_tax': None,\n",
|
|
" 'tip_amount': 0.0,\n",
|
|
" 'tolls_amount': 0.0,\n",
|
|
" 'total_amount': 4.5})]"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.take(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a3fb551b",
|
|
"metadata": {},
|
|
"source": [
|
|
"To get a better sense of the data size, we can calculate the size in bytes of the full dataset. Note that for Parquet files, this size-in-bytes will be pulled from the Parquet metadata (not triggering a data read) and will therefore be the on-disk size of the data; this might be significantly smaller than the in-memory size!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "3da22d56",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"237029648"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.size_bytes()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cb4515bf",
|
|
"metadata": {},
|
|
"source": [
|
|
"In order to get the in-memory size, we can trigger full reading of the dataset and inspect the size in bytes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "a7971d4e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Read progress: 100%|██████████| 2/2 [00:00<00:00, 2.50it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"226524489"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.fully_executed().size_bytes()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "934afde3",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Advanced Aside - Reading Partitioned Parquet Datasets\n",
|
|
"\n",
|
|
"In addition to being able to read lists of individual files, {func}`ray.data.read_parquet() <ray.data.read_parquet>` (as well as other ``ray.data.read_*()`` APIs) can read directories containing multiple Parquet files. For Parquet in particular, reading Parquet datasets partitioned by a particular column is supported, allowing for path-based (zero-read) partition filtering and (optionally) including the partition column value specified in the file paths directly in the read table data.\n",
|
|
"\n",
|
|
"For the NYC taxi dataset, instead of reading individual per-month Parquet files, we can read the entire 2009 directory.\n",
|
|
"\n",
|
|
"```{warning}\n",
|
|
"This could be a lot of data (downsampled with 0.01 ratio leads to ~50.2 MB on disk, ~147 MB in memory), so be careful triggering full reads on a limited-memory machine! This is one place where Datasets' semi-lazy reading comes in handy: Datasets will only read one file eagerly, which allows us to inspect a subset of the data without having to read the entire dataset.\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "4a1fa8ab",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": []
|
|
}
|
|
],
|
|
"source": [
|
|
"# Read all Parquet data for the year 2009.\n",
|
|
"year_ds = ray.data.read_parquet(\"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_full_year_data.parquet\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6616a15d",
|
|
"metadata": {},
|
|
"source": [
|
|
"The metadata that Datasets prints in its repr is guaranteed to not trigger reads of all files; data such as the row count and the schema is pulled directly from the Parquet metadata."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "0b2239a1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"146863542"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"year_ds.size_bytes()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e61dd6d7",
|
|
"metadata": {},
|
|
"source": [
|
|
"That's a lot of rows! Since we're not going to use this full-year dataset, let's now delete this dataset to free up some memory in our Ray cluster."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "d62e74ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"del year_ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b864efab",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Data Exploration and Cleaning\n",
|
|
"\n",
|
|
"Let's calculate some stats to get a better picture of our data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "174f28a1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Shuffle Map: 100%|██████████| 2/2 [00:00<00:00, 50.69it/s]\n",
|
|
"Shuffle Reduce: 100%|██████████| 1/1 [00:00<00:00, 114.04it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"ArrowRow({'max(trip_distance)': 50.0,\n",
|
|
" 'max(tip_amount)': 100.0,\n",
|
|
" 'max(passenger_count)': 6})"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# What's the longets trip distance, largest tip amount, and most number of passengers?\n",
|
|
"ds.max([\"trip_distance\", \"tip_amount\", \"passenger_count\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "31789322",
|
|
"metadata": {},
|
|
"source": [
|
|
"Whoa, there was a trip with 113 people in the taxi!? Let's check out these kind of many-passenger records by filtering to just these records using our {meth}`ds.map_batches() <ray.data.Dataset.map_batches>` batch mapping API.\n",
|
|
"\n",
|
|
":::{note}\n",
|
|
"Our filtering UDF receives a Pandas DataFrame, which is the default batch format for tabular data, and returns a Pandas DataFrame, which keeps the Dataset in a tabular format.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "1d4d5cce",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map_Batches: 100%|██████████| 2/2 [00:01<00:00, 1.96it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[]"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Whoa, 113 passengers? I need to see this record and other ones with lots of passengers.\n",
|
|
"ds.map_batches(lambda df: df[df[\"passenger_count\"] > 10]).take()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ba7323a3",
|
|
"metadata": {},
|
|
"source": [
|
|
"That seems weird, probably bad data, or at least data points that I'm not interested in. We should filter these out!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "ff56e5ee",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map_Batches: 100%|██████████| 2/2 [00:04<00:00, 2.14s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Filter out all records with over 10 passengers.\n",
|
|
"ds = ds.map_batches(lambda df: df[df[\"passenger_count\"] <= 10])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8c935738",
|
|
"metadata": {},
|
|
"source": [
|
|
"We don't have any use for the ``store_and_fwd_flag`` or ``mta_tax`` columns, so let's drop those."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "4037d398",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map_Batches: 100%|██████████| 2/2 [00:03<00:00, 1.59s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Drop some columns.\n",
|
|
"ds = ds.drop_columns([\"store_and_fwd_flag\", \"mta_tax\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "936d3a3e",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's say we want to know how many trips there are for each passenger count."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "66854349",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Sort Sample: 100%|██████████| 2/2 [00:00<00:00, 5.01it/s]\n",
|
|
"Shuffle Map: 100%|██████████| 2/2 [03:21<00:00, 100.61s/it]\n",
|
|
"Shuffle Reduce: 0%| | 0/2 [00:00<?, ?it/s](map pid=9272) E0725 14:35:16.665638988 9301 chttp2_transport.cc:1103] Received a GOAWAY with error code ENHANCE_YOUR_CALM and debug data equal to \"too_many_pings\"\n",
|
|
"Shuffle Reduce: 100%|██████████| 2/2 [00:01<00:00, 1.97it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[PandasRow({'passenger_count': -48, 'count()': 3}),\n",
|
|
" PandasRow({'passenger_count': 0, 'count()': 91}),\n",
|
|
" PandasRow({'passenger_count': 1, 'count()': 1865548}),\n",
|
|
" PandasRow({'passenger_count': 2, 'count()': 451452}),\n",
|
|
" PandasRow({'passenger_count': 3, 'count()': 119406}),\n",
|
|
" PandasRow({'passenger_count': 4, 'count()': 55547}),\n",
|
|
" PandasRow({'passenger_count': 5, 'count()': 245332}),\n",
|
|
" PandasRow({'passenger_count': 6, 'count()': 12557})]"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.groupby(\"passenger_count\").count().take()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "36493c33",
|
|
"metadata": {},
|
|
"source": [
|
|
"Again, it looks like there are some more nonsensical passenger counts, i.e. the negative ones. Let's filter those out too."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "730687c6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map_Batches: 100%|██████████| 2/2 [00:03<00:00, 1.60s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Filter our records with negative passenger counts.\n",
|
|
"ds = ds.map_batches(lambda df: df[df[\"passenger_count\"] > 0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0d1e2106",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Advanced Aside - Projection and Filter Pushdown\n",
|
|
"\n",
|
|
"Note that Ray Datasets' Parquet reader supports projection (column selection) and row filter pushdown, where we can push the above column selection and the row-based filter to the Parquet read. If we specify column selection at Parquet read time, the unselected columns won't even be read from disk!\n",
|
|
"\n",
|
|
"The row-based filter is specified via\n",
|
|
"[Arrow's dataset field expressions](https://arrow.apache.org/docs/6.0/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression). See the {ref}`feature guide for reading Parquet data <dataset_supported_file_formats>` for more information."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "8dda3095",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"⚠️ The number of blocks in this dataset (2) limits its parallelism to 2 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
|
|
"Read progress: 100%|██████████| 2/2 [00:00<00:00, 9.19it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Dataset(num_blocks=2, num_rows=2749842, schema={passenger_count: int8, trip_distance: float})"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Only read the passenger_count and trip_distance columns.\n",
|
|
"import pyarrow as pa\n",
|
|
"filter_expr = (\n",
|
|
" (pa.dataset.field(\"passenger_count\") <= 10)\n",
|
|
" & (pa.dataset.field(\"passenger_count\") > 0)\n",
|
|
")\n",
|
|
"\n",
|
|
"pushdown_ds = ray.data.read_parquet(\n",
|
|
" [\n",
|
|
" \"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_01_data.parquet\",\n",
|
|
" \"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_02_data.parquet\",\n",
|
|
" ],\n",
|
|
" columns=[\"passenger_count\", \"trip_distance\"],\n",
|
|
" filter=filter_expr,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Force full execution of both of the file reads.\n",
|
|
"pushdown_ds = pushdown_ds.fully_executed()\n",
|
|
"pushdown_ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "72ad8acc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Delete the pushdown dataset. Deleting the Dataset object\n",
|
|
"# will release the underlying memory in the cluster.\n",
|
|
"del pushdown_ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8f687553",
|
|
"metadata": {},
|
|
"source": [
|
|
"Do the passenger counts influences the typical trip distance?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "17d2904b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Sort Sample: 100%|██████████| 2/2 [00:00<00:00, 4.57it/s]\n",
|
|
"Shuffle Map: 100%|██████████| 2/2 [03:23<00:00, 101.59s/it]\n",
|
|
"Shuffle Reduce: 100%|██████████| 2/2 [00:00<00:00, 178.79it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[PandasRow({'passenger_count': 1, 'mean(trip_distance)': 2.543288084787955}),\n",
|
|
" PandasRow({'passenger_count': 2, 'mean(trip_distance)': 2.7043459216040686}),\n",
|
|
" PandasRow({'passenger_count': 3, 'mean(trip_distance)': 2.6233412684454716}),\n",
|
|
" PandasRow({'passenger_count': 4, 'mean(trip_distance)': 2.642096445352584}),\n",
|
|
" PandasRow({'passenger_count': 5, 'mean(trip_distance)': 2.6286944833939314}),\n",
|
|
" PandasRow({'passenger_count': 6, 'mean(trip_distance)': 2.5848625579855855})]"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Mean trip distance grouped by passenger count.\n",
|
|
"ds.groupby(\"passenger_count\").mean(\"trip_distance\").take()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0ade2a72",
|
|
"metadata": {},
|
|
"source": [
|
|
"See the feature guides for {ref}`transforming data <transforming_datasets>` and {ref}`ML preprocessing <datasets-ml-preprocessing>` for more information on how we can process our data with Ray Datasets."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "08e10163",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Ingesting into Model Trainers\n",
|
|
"\n",
|
|
"Now that we've learned more about our data and we have cleaned up our dataset a bit, we now look at how we can feed this dataset into some dummy model trainers.\n",
|
|
"\n",
|
|
"First, let's do a full global random shuffle of the dataset to decorrelate these samples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "a850863d",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Shuffle Map: 100%|██████████| 2/2 [00:01<00:00, 1.34it/s]\n",
|
|
"Shuffle Reduce: 100%|██████████| 2/2 [00:01<00:00, 1.09it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"ds = ds.random_shuffle()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ff05b6ea",
|
|
"metadata": {},
|
|
"source": [
|
|
"We define a dummy ``Trainer`` actor, where each trainer will consume a dataset shard in batches and simulate model training.\n",
|
|
"\n",
|
|
":::{note}\n",
|
|
"In a real training workflow, we would feed ``ds`` to {ref}`Ray Train <train-docs>`, which would do this sharding and creation of training actors for us, under the hood.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "16d34523",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[Actor(Trainer, 9326d43345699213608f324003000000),\n",
|
|
" Actor(Trainer, f0ce2ce44528fbf748c9c1a103000000),\n",
|
|
" Actor(Trainer, 7ba39c8f82ebd78c68e92ec903000000),\n",
|
|
" Actor(Trainer, b95fe3494b7bc2d8f42abbba03000000)]"
|
|
]
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"@ray.remote\n",
|
|
"class Trainer:\n",
|
|
" def __init__(self, rank: int):\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def train(self, shard: ray.data.Dataset) -> int:\n",
|
|
" for batch in shard.iter_batches(batch_size=256):\n",
|
|
" pass\n",
|
|
" return shard.count()\n",
|
|
"\n",
|
|
"trainers = [Trainer.remote(i) for i in range(4)]\n",
|
|
"trainers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9a1afb70",
|
|
"metadata": {},
|
|
"source": [
|
|
"Next, we split the dataset into ``len(trainers)`` shards, ensuring that the shards are of equal size."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "2594a815",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[Dataset(num_blocks=1, num_rows=687460, schema={vendor_id: object, pickup_at: datetime64[ns], dropoff_at: datetime64[ns], passenger_count: int8, trip_distance: float32, pickup_longitude: float32, pickup_latitude: float32, rate_code_id: object, dropoff_longitude: float32, dropoff_latitude: float32, payment_type: object, fare_amount: float32, extra: float32, tip_amount: float32, tolls_amount: float32, total_amount: float32}),\n",
|
|
" Dataset(num_blocks=1, num_rows=687460, schema={vendor_id: object, pickup_at: datetime64[ns], dropoff_at: datetime64[ns], passenger_count: int8, trip_distance: float32, pickup_longitude: float32, pickup_latitude: float32, rate_code_id: object, dropoff_longitude: float32, dropoff_latitude: float32, payment_type: object, fare_amount: float32, extra: float32, tip_amount: float32, tolls_amount: float32, total_amount: float32}),\n",
|
|
" Dataset(num_blocks=2, num_rows=687460, schema={vendor_id: object, pickup_at: datetime64[ns], dropoff_at: datetime64[ns], passenger_count: int8, trip_distance: float32, pickup_longitude: float32, pickup_latitude: float32, rate_code_id: object, dropoff_longitude: float32, dropoff_latitude: float32, payment_type: object, fare_amount: float32, extra: float32, tip_amount: float32, tolls_amount: float32, total_amount: float32}),\n",
|
|
" Dataset(num_blocks=1, num_rows=687460, schema={vendor_id: object, pickup_at: datetime64[ns], dropoff_at: datetime64[ns], passenger_count: int8, trip_distance: float32, pickup_longitude: float32, pickup_latitude: float32, rate_code_id: object, dropoff_longitude: float32, dropoff_latitude: float32, payment_type: object, fare_amount: float32, extra: float32, tip_amount: float32, tolls_amount: float32, total_amount: float32})]"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"shards = ds.split(n=len(trainers), equal=True)\n",
|
|
"shards"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3ef20846",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we simulate training, passing each shard to the corresponding trainer. The number of rows per shard is returned."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "fbfe7da9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[687460, 687460, 687460, 687460]"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ray.get([w.train.remote(s) for w, s in zip(trainers, shards)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "53fb2190",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Delete trainer actor handle references, which should terminate the actors.\n",
|
|
"del trainers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "75ccd122",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Parallel Batch Inference\n",
|
|
"\n",
|
|
"After we've trained a model, we may want to perform batch (offline) inference on such a tabular dataset. With Ray Datasets, this is as easy as a {meth}`ds.map_batches() <ray.data.Dataset.map_batches>` call!\n",
|
|
"\n",
|
|
"First, we define a callable class that will cache the loading of the model in its constructor."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "c29c52fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"def load_model():\n",
|
|
" # A dummy model.\n",
|
|
" def model(batch: pd.DataFrame) -> pd.DataFrame:\n",
|
|
" return pd.DataFrame({\"score\": batch[\"passenger_count\"] % 2 == 0})\n",
|
|
" \n",
|
|
" return model\n",
|
|
"\n",
|
|
"class BatchInferModel:\n",
|
|
" def __init__(self):\n",
|
|
" self.model = load_model()\n",
|
|
" def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:\n",
|
|
" return self.model(batch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0c1ba955",
|
|
"metadata": {},
|
|
"source": [
|
|
"``BatchInferModel``'s constructor will only be called once per actor worker when using the actor pool compute strategy in {meth}`ds.map_batches() <ray.data.Dataset.map_batches>`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "4729f147",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (2 actors 1 pending): 100%|██████████| 2/2 [00:05<00:00, 2.57s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False})]"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.map_batches(BatchInferModel, batch_size=2048, compute=\"actors\").take()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b154f758",
|
|
"metadata": {},
|
|
"source": [
|
|
"If wanting to perform batch inference on GPUs, simply specify the number of GPUs you wish to provision for each batch inference worker.\n",
|
|
"\n",
|
|
":::{warning}\n",
|
|
"This will only run successfully if your cluster has nodes with GPUs!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "b99ff2f7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (15 actors 4 pending): 100%|██████████| 2/2 [00:21<00:00, 10.67s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False})]"
|
|
]
|
|
},
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ds.map_batches(\n",
|
|
" BatchInferModel,\n",
|
|
" batch_size=256,\n",
|
|
" #num_gpus=1, # Uncomment this to run this on GPUs!\n",
|
|
" compute=\"actors\",\n",
|
|
").take()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "061ca8b4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also configure the autoscaling actor pool that this inference stage uses, setting upper and lower bounds on the actor pool size, and even tweak the batch prefetching vs. inference task queueing tradeoff."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "b8f6920e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Map Progress (8 actors 0 pending): 100%|██████████| 2/2 [00:21<00:00, 10.71s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': False}),\n",
|
|
" PandasRow({'score': True}),\n",
|
|
" PandasRow({'score': False})]"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from ray.data import ActorPoolStrategy\n",
|
|
"\n",
|
|
"# The actor pool will have at least 2 workers and at most 8 workers.\n",
|
|
"strategy = ActorPoolStrategy(min_size=2, max_size=8)\n",
|
|
"\n",
|
|
"ds.map_batches(\n",
|
|
" BatchInferModel,\n",
|
|
" batch_size=256,\n",
|
|
" #num_gpus=1, # Uncomment this to run this on GPUs!\n",
|
|
" compute=strategy,\n",
|
|
").take()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"celltoolbar": "Tags",
|
|
"kernelspec": {
|
|
"display_name": "Python 3.7.10 ('ray3.7')",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.10"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "99d89bfe98f3aa2d7facda0d08d31ff2a0af9559e5330d719288ce64a1966273"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |