2022-03-08 16:50:57 +01:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from transformers import pipeline
|
|
|
|
|
2022-06-15 11:34:45 -07:00
|
|
|
# TODO: actually use this to predict something
|
|
|
|
import ray
|
|
|
|
from ray import serve
|
2022-03-08 16:50:57 +01:00
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
# Define our deployment.
|
|
|
|
@serve.deployment(num_replicas=2)
|
|
|
|
class GPT2:
|
|
|
|
def __init__(self):
|
|
|
|
self.nlp_model = pipeline("text-generation", model="gpt2")
|
|
|
|
|
|
|
|
async def predict(self, query: str):
|
|
|
|
return self.nlp_model(query, max_length=50)
|
|
|
|
|
|
|
|
async def __call__(self, request):
|
|
|
|
return self.predict(await request.body())
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") # Code to be run when the server starts.
|
|
|
|
async def startup_event():
|
|
|
|
ray.init(address="auto") # Connect to the running Ray cluster.
|
|
|
|
serve.start(http_host=None) # Start the Ray Serve instance.
|
|
|
|
|
|
|
|
# Deploy our GPT2 Deployment.
|
|
|
|
GPT2.deploy()
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/generate")
|
|
|
|
async def generate(query: str):
|
|
|
|
# Get a handle to our deployment so we can query it in Python.
|
|
|
|
handle = GPT2.get_handle()
|
|
|
|
return await handle.predict.remote(query)
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown") # Code to be run when the server shuts down.
|
|
|
|
async def shutdown_event():
|
|
|
|
serve.shutdown() # Shut down Ray Serve.
|