mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
56 lines
1.1 KiB
Python
56 lines
1.1 KiB
Python
# flake8: noqa
|
|
|
|
# __import_start__
|
|
import ray
|
|
from ray import serve
|
|
|
|
# __import_end__
|
|
|
|
# __model_start__
|
|
from transformers import pipeline
|
|
|
|
|
|
@serve.deployment
|
|
class Translator:
|
|
def __init__(self):
|
|
# Load model
|
|
self.model = pipeline("translation_en_to_fr", model="t5-small")
|
|
|
|
def translate(self, text: str) -> str:
|
|
# Run inference
|
|
model_output = self.model(text)
|
|
|
|
# Post-process output to return only the translation text
|
|
translation = model_output[0]["translation_text"]
|
|
|
|
return translation
|
|
|
|
async def __call__(self, http_request) -> str:
|
|
english_text: str = await http_request.json()
|
|
return self.translate(english_text)
|
|
|
|
|
|
# __model_end__
|
|
|
|
# __model_deploy_start__
|
|
translator = Translator.bind()
|
|
# __model_deploy_end__
|
|
|
|
serve.run(translator)
|
|
|
|
# __client_function_start__
|
|
# File name: model_client.py
|
|
import requests
|
|
|
|
english_text = "Hello world!"
|
|
|
|
response = requests.post("http://127.0.0.1:8000/", json=english_text)
|
|
french_text = response.text
|
|
|
|
print(french_text)
|
|
# __client_function_end__
|
|
|
|
assert french_text == "Bonjour monde!"
|
|
|
|
serve.shutdown()
|
|
ray.shutdown()
|