mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
34 lines
909 B
Python
34 lines
909 B
Python
# fmt: off
|
|
# __doc_import_begin__
|
|
from typing import List
|
|
|
|
from starlette.requests import Request
|
|
from transformers import pipeline, Pipeline
|
|
|
|
from ray import serve
|
|
# __doc_import_end__
|
|
# fmt: on
|
|
|
|
|
|
# __doc_define_servable_begin__
|
|
@serve.deployment
|
|
class BatchTextGenerator:
|
|
def __init__(self, model: Pipeline):
|
|
self.model = model
|
|
|
|
@serve.batch(max_batch_size=4)
|
|
async def handle_batch(self, inputs: List[str]) -> List[str]:
|
|
print("Our input array has length:", len(inputs))
|
|
|
|
results = self.model(inputs)
|
|
return [result[0]["generated_text"] for result in results]
|
|
|
|
async def __call__(self, request: Request) -> List[str]:
|
|
return await self.handle_batch(request.query_params["text"])
|
|
# __doc_define_servable_end__
|
|
|
|
|
|
# __doc_deploy_begin__
|
|
model = pipeline("text-generation", "gpt2")
|
|
generator = BatchTextGenerator.bind(model)
|
|
# __doc_deploy_end__
|