{ "cells": [ { "cell_type": "markdown", "id": "0c5705f2", "metadata": {}, "source": [ "(gradio-serve-tutorial)=\n", "\n", "# Building a Gradio demo with Ray Serve\n", "\n", "In this example, we will show you how to wrap a machine learning model served\n", "by Ray Serve in a [Gradio demo](https://gradio.app/).\n", "\n", "Specifically, we're going to download a GPT-2 model from the `transformer` library,\n", "define a Ray Serve deployment with it, and then define and launch a Gradio `Interface`.\n", "Let's take a look." ] }, { "cell_type": "code", "execution_count": null, "id": "c017f8c4", "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# Install all dependencies for this example.\n", "! pip install ray gradio transformers requests" ] }, { "cell_type": "markdown", "id": "6245b4c3", "metadata": {}, "source": [ "## Deploying a model with Ray Serve\n", "\n", "To start off, we import Ray Serve, Gradio, the `transformers` and `requests` libraries,\n", "and then simply start Ray Serve:" ] }, { "cell_type": "code", "execution_count": null, "id": "79d354ae", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "from ray import serve\n", "from transformers import pipeline\n", "import requests\n", "\n", "\n", "serve.start()" ] }, { "cell_type": "markdown", "id": "85b1eba9", "metadata": {}, "source": [ "Next, we define a Ray Serve deployment with a GPT-2 model, by using the `@serve.deployment` decorator on a `model`\n", "function that takes a `request` argument.\n", "In this function we define a GPT-2 model with a call to `pipeline` and return the result of querying the model." ] }, { "cell_type": "code", "execution_count": null, "id": "6ef8e2c7", "metadata": {}, "outputs": [], "source": [ "@serve.deployment\n", "def model(request):\n", " language_model = pipeline(\"text-generation\", model=\"gpt2\")\n", " query = request.query_params[\"query\"]\n", " return language_model(query, max_length=100)" ] }, { "cell_type": "markdown", "id": "ba7be609", "metadata": {}, "source": [ "This `model` can now easily be deployed using a `model.deploy()` call.\n", "To test this deployment we use a simple `example` query to get a `response` from the model running\n", "on `localhost:8000/model`.\n", "The first time you use this endpoint, the model will be downloaded first, which can take a while to complete.\n", "Subsequent calls will be faster." ] }, { "cell_type": "code", "execution_count": null, "id": "c278dfb7", "metadata": {}, "outputs": [], "source": [ "model.deploy()\n", "example = \"What's the meaning of life?\"\n", "response = requests.get(f\"http://localhost:8000/model?query={example}\")\n", "print(response.text)" ] }, { "cell_type": "markdown", "id": "0b11e675", "metadata": {}, "source": [ "## Defining and launching a Gradio interface\n", "\n", "Defining a Gradio interface is now straightforward.\n", "All we need is a function that Gradio can call to get the response from the model.\n", "That's just a thin wrapper around our previous `requests` call:" ] }, { "cell_type": "code", "execution_count": null, "id": "61c3ab00", "metadata": {}, "outputs": [], "source": [ "def gpt2(query):\n", " response = requests.get(f\"http://localhost:8000/model?query={query}\")\n", " return response.json()[0][\"generated_text\"]" ] }, { "cell_type": "markdown", "id": "53b4a5ef", "metadata": {}, "source": [ "Apart from our `gpt2` function, the only other thing that we need to define a Gradio interface is\n", "a description of the model inputs and outputs that Gradio understands.\n", "Since our model takes text as input and output, this turns out to be pretty simple:" ] }, { "cell_type": "code", "execution_count": null, "id": "115fb25f", "metadata": {}, "outputs": [], "source": [ "iface = gr.Interface(\n", " fn=gpt2,\n", " inputs=[gr.inputs.Textbox(\n", " default=example, label=\"Input prompt\"\n", " )],\n", " outputs=[gr.outputs.Textbox(label=\"Model output\")]\n", ")" ] }, { "cell_type": "markdown", "id": "2e998109", "metadata": {}, "source": [ "For more complex models served with Ray, you might need multiple `gr.inputs`\n", "and `gr.outputs` of different types.\n", "\n", "```{margin}\n", "The [Gradio documentation](https://gradio.app/docs/) covers all viable input and output components in detail.\n", "```\n", "\n", "Finally, we can launch the interface using `iface.launch()`:" ] }, { "cell_type": "code", "execution_count": null, "id": "203ce70e", "metadata": {}, "outputs": [], "source": [ "iface.launch()" ] }, { "cell_type": "markdown", "id": "5e5638a9", "metadata": {}, "source": [ "This should launch an interface that you can interact with that looks like this:\n", "\n", "```{image} https://raw.githubusercontent.com/ray-project/images/master/docs/serve/gradio_serve_gpt.png\n", "```\n", "\n", "You can run this examples directly in the browser, for instance by launching this notebook directly\n", "into Google Colab or Binder, by clicking on the _rocket icon_ at the top right of this page.\n", "If you run this code locally in Python, this Gradio app will be served on `http://127.0.0.1:7861/`.\n", "\n", "## Building a Gradio app from a Scikit-Learn model\n", "\n", "Let's take a look at another example, so that you can see the slight differences to the first example\n", "in direct comparison." ] }, { "cell_type": "code", "execution_count": null, "id": "0fdc6b92", "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# Install all dependencies for this example.\n", "! pip install ray gradio requests scikit-learn" ] }, { "cell_type": "markdown", "id": "257744c8", "metadata": {}, "source": [ "This time we're going to use a [Scikit-Learn](https://scikit-learn.org/) model that we quickly train\n", "ourselves on the famous Iris dataset.\n", "To do this, we'll download the Iris dataset using the built-in `load_iris` function from the `sklearn` library,\n", "and we used the `GradientBoostingClassifier` from the `sklearn.ensemble` module for training.\n", "\n", "This time we'll use the `@serve.deployment` decorator on a _class_ called `BoostingModel`, which has an\n", "asynchronous `__call__` method that Ray Serve needs to define your deployment.\n", "All else remains the same as in the first example." ] }, { "cell_type": "code", "execution_count": null, "id": "cb92f167", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import requests\n", "from sklearn.datasets import load_iris\n", "from sklearn.ensemble import GradientBoostingClassifier\n", "\n", "from ray import serve\n", "\n", "# Train your model.\n", "iris_dataset = load_iris()\n", "model = GradientBoostingClassifier()\n", "model.fit(iris_dataset[\"data\"], iris_dataset[\"target\"])\n", "\n", "# Start Ray Serve.\n", "serve.start()\n", "\n", "# Define your deployment.\n", "@serve.deployment(route_prefix=\"/iris\")\n", "class BoostingModel:\n", " def __init__(self, model):\n", " self.model = model\n", " self.label_list = iris_dataset[\"target_names\"].tolist()\n", "\n", " async def __call__(self, request):\n", " payload = (await request.json())[\"vector\"]\n", " print(f\"Received http request with data {payload}\")\n", "\n", " prediction = self.model.predict([payload])[0]\n", " human_name = self.label_list[prediction]\n", " return {\"result\": human_name}\n", "\n", "\n", "# Deploy your model.\n", "BoostingModel.deploy(model)" ] }, { "cell_type": "markdown", "id": "30c3ef21", "metadata": {}, "source": [ "Equipped with our `BoostingModel` class, we can now define and launch a Gradio interface as follows.\n", "The Iris dataset has a total of four features, namely the four numeric values _sepal length_, _sepal width_,\n", "_petal length_, and _petal width_.\n", "We use this fact to define an `iris` function that takes these four features and returns the predicted class,\n", "using our deployed model.\n", "This time, the Gradio interface takes four input `Number`s, and returns the predicted class as `text`.\n", "Go ahead and try it out in the browser yourself." ] }, { "cell_type": "code", "execution_count": null, "id": "733fb4f5", "metadata": {}, "outputs": [], "source": [ "# Define gradio function\n", "def iris(sl, sw, pl, pw):\n", " request_input = {\"vector\": [sl, sw, pl, pw]}\n", " response = requests.get(\n", " \"http://localhost:8000/iris\", json=request_input)\n", " return response.json()[0][\"result\"]\n", "\n", "\n", "# Define gradio interface\n", "iface = gr.Interface(\n", " fn=iris,\n", " inputs=[\n", " gr.inputs.Number(default=1.0, label=\"sepal length (cm)\"),\n", " gr.inputs.Number(default=1.0, label=\"sepal width (cm)\"),\n", " gr.inputs.Number(default=1.0, label=\"petal length (cm)\"),\n", " gr.inputs.Number(default=1.0, label=\"petal width (cm)\"),\n", " ],\n", " outputs=\"text\")\n", "\n", "# Launch the gradio interface\n", "iface.launch()" ] }, { "cell_type": "markdown", "id": "a3e47ff7", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Launching this interface, you should see an interactive interface that looks like this:\n", "\n", "```{image} https://raw.githubusercontent.com/ray-project/images/master/docs/serve/gradio_serve_iris.png\n", "```\n", "\n", "## Conclusion\n", "\n", "To summarize, it's easy to build Gradio apps from Ray Serve deployments.\n", "You only need to properly encode your model's inputs and outputs in a Gradio interface, and you're good to go!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }