ray/doc/source/ray-overview/doc_test/ray_serve.py
Max Pumperla d6bff736f3
[docs] test ray.io snippets (#22822)
Tests all snippets we have on ray.io. There were some minor issues, which I'll fix upstream.

Signed-off-by: Max Pumperla <max.pumperla@googlemail.com>
2022-03-08 15:50:57 +00:00

43 lines
1.1 KiB
Python

# TODO: actually use this to predict something
import ray
from ray import serve
from fastapi import FastAPI
from transformers import pipeline
app = FastAPI()
# Define our deployment.
@serve.deployment(num_replicas=2)
class GPT2:
def __init__(self):
self.nlp_model = pipeline("text-generation", model="gpt2")
async def predict(self, query: str):
return self.nlp_model(query, max_length=50)
async def __call__(self, request):
return self.predict(await request.body())
@app.on_event("startup") # Code to be run when the server starts.
async def startup_event():
ray.init(address="auto") # Connect to the running Ray cluster.
serve.start(http_host=None) # Start the Ray Serve instance.
# Deploy our GPT2 Deployment.
GPT2.deploy()
@app.get("/generate")
async def generate(query: str):
# Get a handle to our deployment so we can query it in Python.
handle = GPT2.get_handle()
return await handle.predict.remote(query)
@app.on_event("shutdown") # Code to be run when the server shuts down.
async def shutdown_event():
serve.shutdown() # Shut down Ray Serve.