ray/doc/source/serve/doc_code/tutorial_batch.py

34 lines
909 B
Python

# fmt: off
# __doc_import_begin__
from typing import List
from starlette.requests import Request
from transformers import pipeline, Pipeline
from ray import serve
# __doc_import_end__
# fmt: on
# __doc_define_servable_begin__
@serve.deployment
class BatchTextGenerator:
def __init__(self, model: Pipeline):
self.model = model
@serve.batch(max_batch_size=4)
async def handle_batch(self, inputs: List[str]) -> List[str]:
print("Our input array has length:", len(inputs))
results = self.model(inputs)
return [result[0]["generated_text"] for result in results]
async def __call__(self, request: Request) -> List[str]:
return await self.handle_batch(request.query_params["text"])
# __doc_define_servable_end__
# __doc_deploy_begin__
model = pipeline("text-generation", "gpt2")
generator = BatchTextGenerator.bind(model)
# __doc_deploy_end__