mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
75 lines
1.9 KiB
Python
75 lines
1.9 KiB
Python
# flake8: noqa
|
|
|
|
# __start_graph__
|
|
# File name: graph.py
|
|
import ray
|
|
from ray import serve
|
|
|
|
from transformers import pipeline
|
|
|
|
|
|
@serve.deployment
|
|
class Translator:
|
|
def __init__(self):
|
|
# Load model
|
|
self.model = pipeline("translation_en_to_fr", model="t5-small")
|
|
|
|
def translate(self, text: str) -> str:
|
|
# Run inference
|
|
model_output = self.model(text)
|
|
|
|
# Post-process output to return only the translation text
|
|
translation = model_output[0]["translation_text"]
|
|
|
|
return translation
|
|
|
|
|
|
@serve.deployment
|
|
class Summarizer:
|
|
def __init__(self, translator):
|
|
# Load model
|
|
self.model = pipeline("summarization", model="t5-small")
|
|
self.translator = translator
|
|
|
|
def summarize(self, text: str) -> str:
|
|
# Run inference
|
|
model_output = self.model(text, min_length=5, max_length=15)
|
|
|
|
# Post-process output to return only the summary text
|
|
summary = model_output[0]["summary_text"]
|
|
|
|
return summary
|
|
|
|
async def __call__(self, http_request) -> str:
|
|
english_text: str = await http_request.json()
|
|
summary = self.summarize(english_text)
|
|
|
|
translation_ref = self.translator.translate.remote(summary)
|
|
translation = ray.get(translation_ref)
|
|
|
|
return translation
|
|
|
|
|
|
deployment_graph = Summarizer.bind(Translator.bind())
|
|
# __end_graph__
|
|
|
|
serve.run(deployment_graph)
|
|
|
|
# __start_client__
|
|
# File name: graph_client.py
|
|
import requests
|
|
|
|
english_text = (
|
|
"It was the best of times, it was the worst of times, it was the age "
|
|
"of wisdom, it was the age of foolishness, it was the epoch of belief"
|
|
)
|
|
response = requests.post("http://127.0.0.1:8000/", json=english_text)
|
|
french_text = response.text
|
|
|
|
print(french_text)
|
|
# __end_client__
|
|
|
|
assert french_text == "c'était le meilleur des temps, c'était le pire des temps ."
|
|
|
|
serve.shutdown()
|
|
ray.shutdown()
|