ray/doc/source/ray-overview/doc_test/ray_serve.py
clarng ef866d1e49
exclude doc_code from import sorting (#25772)
Skip sorting the imports in doc_code.
2022-06-15 11:34:45 -07:00

42 lines
1.1 KiB
Python

from fastapi import FastAPI
from transformers import pipeline
# TODO: actually use this to predict something
import ray
from ray import serve
app = FastAPI()
# Define our deployment.
@serve.deployment(num_replicas=2)
class GPT2:
def __init__(self):
self.nlp_model = pipeline("text-generation", model="gpt2")
async def predict(self, query: str):
return self.nlp_model(query, max_length=50)
async def __call__(self, request):
return self.predict(await request.body())
@app.on_event("startup") # Code to be run when the server starts.
async def startup_event():
ray.init(address="auto") # Connect to the running Ray cluster.
serve.start(http_host=None) # Start the Ray Serve instance.
# Deploy our GPT2 Deployment.
GPT2.deploy()
@app.get("/generate")
async def generate(query: str):
# Get a handle to our deployment so we can query it in Python.
handle = GPT2.get_handle()
return await handle.predict.remote(query)
@app.on_event("shutdown") # Code to be run when the server shuts down.
async def shutdown_event():
serve.shutdown() # Shut down Ray Serve.