mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[Serve][docs] Add type annotations to code samples (#27795)
This commit is contained in:
parent
192d92bb77
commit
fa37ddc584
20 changed files with 63 additions and 202 deletions
|
@ -8,10 +8,10 @@ from ray.serve.deployment_graph import InputNode
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __init__(self, weight):
|
||||
def __init__(self, weight: int):
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: int) -> int:
|
||||
return input + self.weight
|
||||
|
||||
|
||||
|
|
|
@ -8,15 +8,15 @@ from ray.dag.input_node import InputNode
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __init__(self, weight):
|
||||
def __init__(self, weight: int):
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: int) -> int:
|
||||
return input + self.weight
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def combine(value1, value2, operation):
|
||||
def combine(value1: int, value2: int, operation: str) -> int:
|
||||
if operation == "sum":
|
||||
return sum([value1, value2])
|
||||
else:
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
from ray import serve
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("ray.serve")
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
async def __call__(self, request):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
|
||||
counter = Counter.bind()
|
|
@ -1,98 +0,0 @@
|
|||
# flake8: noqa
|
||||
# fmt: off
|
||||
#
|
||||
# __serve_example_begin__
|
||||
#
|
||||
# This brief example shows how to create, deploy, and expose access to
|
||||
# deployment models, using the simple Ray Serve deployment APIs.
|
||||
# Once deployed, you can access deployment via two methods:
|
||||
# ServerHandle API and HTTP
|
||||
#
|
||||
import os
|
||||
from random import random
|
||||
|
||||
import requests
|
||||
import starlette.requests
|
||||
from ray import serve
|
||||
|
||||
#
|
||||
# A simple example model stored in a pickled format at an accessible path
|
||||
# that can be reloaded and deserialized into a model instance. Once deployed
|
||||
# in Ray Serve, we can use it for prediction. The prediction is a fake condition,
|
||||
# based on threshold of weight greater than 0.5.
|
||||
#
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def predict(self, data: float) -> float:
|
||||
return random() + data if data > 0.5 else data
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Predictor:
|
||||
# Take in a path to load your desired model
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.model = Model(path)
|
||||
# Get the pid on which this deployment is running on
|
||||
self.pid = os.getpid()
|
||||
|
||||
# Deployments are callable. Here we simply return a prediction from
|
||||
# our request.
|
||||
async def predict(self, data: float) -> str:
|
||||
pred = self.model.predict(data)
|
||||
return (f"(pid: {self.pid}); path: {self.path}; "
|
||||
f"data: {float(data):.3f}; prediction: {pred:.3f}")
|
||||
|
||||
async def __call__(self, http_request: starlette.requests.Request) -> str:
|
||||
data = float(await http_request.query_params['data'])
|
||||
return await self.predict(data)
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class ServeHandleDemo:
|
||||
def __init__(self, predictor_1: Predictor, predictor_2: Predictor):
|
||||
self.predictor_1 = predictor_1
|
||||
self.predictor_2 = predictor_2
|
||||
|
||||
async def run(self):
|
||||
# Query each deployment twice to demonstrate that the requests
|
||||
# get forwarded to different replicas (below, we will set
|
||||
# num_replicas to 2 for each deployment).
|
||||
for _ in range(2):
|
||||
for predictor in [self.predictor_1, self.predictor_2]:
|
||||
# Call our deployments from Python using the ServeHandle API.
|
||||
random_prediction = await predictor.predict.remote(random())
|
||||
print(f"prediction: {random_prediction}")
|
||||
|
||||
async def __call__(self, http_request: starlette.requests.Request) -> str:
|
||||
return await self.run()
|
||||
|
||||
|
||||
predictor_1 = Predictor.options(num_replicas=2).bind("/model/model-1.pkl")
|
||||
predictor_2 = Predictor.options(num_replicas=2).bind("/model/model-2.pkl")
|
||||
|
||||
# Pass in our deployments as arguments. At runtime, these are resolved to ServeHandles.
|
||||
serve_handle_demo = ServeHandleDemo.bind(predictor_1, predictor_2)
|
||||
|
||||
# Start a local single-node Ray cluster and start Ray Serve. These will shut down upon
|
||||
# exiting this script.
|
||||
serve.run(serve_handle_demo)
|
||||
|
||||
print("ServeHandle API responses: " + "--" * 5)
|
||||
|
||||
url = "http://127.0.0.1:8000/"
|
||||
response = requests.get(url)
|
||||
prediction = response.text
|
||||
print(f"prediction : {prediction}")
|
||||
|
||||
# Output ("INFO" logs omitted for brevity):
|
||||
|
||||
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16059); path: /model/model-1.pkl; data: 0.166; prediction: 0.166
|
||||
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16061); path: /model/model-2.pkl; data: 0.820; prediction: 0.986
|
||||
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16058); path: /model/model-1.pkl; data: 0.691; prediction: 0.857
|
||||
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16060); path: /model/model-2.pkl; data: 0.948; prediction: 1.113
|
||||
# __serve_example_end__
|
|
@ -1,12 +1,14 @@
|
|||
import subprocess
|
||||
|
||||
# __deploy_in_single_file_1_start__
|
||||
from starlette.requests import Request
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def my_func(request):
|
||||
def my_func(request: Request) -> str:
|
||||
return "hello"
|
||||
|
||||
|
||||
|
@ -24,7 +26,7 @@ ray.init(address="auto", namespace="serve")
|
|||
|
||||
|
||||
@serve.deployment
|
||||
def my_func(request):
|
||||
def my_func(request: Request) -> str:
|
||||
return "hello"
|
||||
|
||||
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.drivers import DAGDriver
|
||||
from ray.dag.input_node import InputNode
|
||||
|
||||
|
||||
class ModelInputData(BaseModel):
|
||||
model_input1: int
|
||||
model_input2: str
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __init__(self, weight):
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, input: ModelInputData):
|
||||
return input.model_input1 + len(input.model_input2) + self.weight
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def combine(value_refs):
|
||||
return sum(ray.get(value_refs))
|
||||
|
||||
|
||||
with InputNode() as user_input:
|
||||
model1 = Model.bind(0)
|
||||
model2 = Model.bind(1)
|
||||
output1 = model1.forward.bind(user_input)
|
||||
output2 = model2.forward.bind(user_input)
|
||||
dag = combine.bind([output1, output2])
|
||||
serve_dag = DAGDriver.options(route_prefix="/my-dag").bind(
|
||||
dag, http_adapter=ModelInputData
|
||||
)
|
||||
|
||||
dag_handle = serve.run(serve_dag)
|
||||
|
||||
print(
|
||||
ray.get(
|
||||
dag_handle.predict.remote(ModelInputData(model_input1=1, model_input2="test"))
|
||||
)
|
||||
)
|
||||
print(
|
||||
requests.post(
|
||||
"http://127.0.0.1:8000/my-dag", json={"model_input1": 1, "model_input2": "test"}
|
||||
).text
|
||||
)
|
|
@ -1,6 +1,8 @@
|
|||
# flake8: noqa
|
||||
|
||||
# __import_start__
|
||||
from starlette.requests import Request
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
@ -25,7 +27,7 @@ class Translator:
|
|||
|
||||
return translation
|
||||
|
||||
async def __call__(self, http_request) -> str:
|
||||
async def __call__(self, http_request: Request) -> str:
|
||||
english_text: str = await http_request.json()
|
||||
return self.translate(english_text)
|
||||
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
# __deployment_full_start__
|
||||
# File name: serve_deployment.py
|
||||
from starlette.requests import Request
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
|
@ -22,7 +25,7 @@ class Translator:
|
|||
|
||||
return translation
|
||||
|
||||
async def __call__(self, http_request) -> str:
|
||||
async def __call__(self, http_request: Request) -> str:
|
||||
english_text: str = await http_request.json()
|
||||
return self.translate(english_text)
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
# __start_graph__
|
||||
# File name: graph.py
|
||||
from starlette.requests import Request
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
|
||||
|
@ -40,7 +42,7 @@ class Summarizer:
|
|||
|
||||
return summary
|
||||
|
||||
async def __call__(self, http_request) -> str:
|
||||
async def __call__(self, http_request: Request) -> str:
|
||||
english_text: str = await http_request.json()
|
||||
summary = self.summarize(english_text)
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# flake8: noqa
|
||||
|
||||
# __begin_sync_handle__
|
||||
from starlette.requests import Request
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.handle import RayServeSyncHandle
|
||||
|
@ -8,7 +10,7 @@ from ray.serve.handle import RayServeSyncHandle
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __call__(self):
|
||||
def __call__(self) -> str:
|
||||
return "hello"
|
||||
|
||||
|
||||
|
@ -27,7 +29,7 @@ from ray.serve.handle import RayServeDeploymentHandle, RayServeSyncHandle
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __call__(self):
|
||||
def __call__(self) -> str:
|
||||
return "hello"
|
||||
|
||||
|
||||
|
@ -104,10 +106,10 @@ from ray.serve.handle import RayServeSyncHandle
|
|||
|
||||
@serve.deployment
|
||||
class Deployment:
|
||||
def method1(self, arg):
|
||||
def method1(self, arg: str) -> str:
|
||||
return f"Method1: {arg}"
|
||||
|
||||
def __call__(self, arg):
|
||||
def __call__(self, arg: str) -> str:
|
||||
return f"__call__: {arg}"
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from ray.serve.drivers import DAGDriver
|
|||
|
||||
|
||||
@serve.deployment
|
||||
def preprocess(inp: int):
|
||||
def preprocess(inp: int) -> int:
|
||||
return inp + 1
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ class Model:
|
|||
def __init__(self, increment: int):
|
||||
self.increment = increment
|
||||
|
||||
def predict(self, inp: int):
|
||||
def predict(self, inp: int) -> int:
|
||||
return inp + self.increment
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from ray.serve.handle import RayServeDeploymentHandle
|
|||
from ray.serve.handle import RayServeSyncHandle
|
||||
|
||||
import requests
|
||||
import starlette
|
||||
from starlette.requests import Request
|
||||
|
||||
serve.start()
|
||||
|
||||
|
@ -16,7 +16,7 @@ serve.start()
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def forward(self, input):
|
||||
def forward(self, input) -> str:
|
||||
# do some inference work
|
||||
return "done"
|
||||
|
||||
|
@ -92,7 +92,7 @@ serve.start()
|
|||
# __customized_route_old_api_start__
|
||||
@serve.deployment(route_prefix="/my_model1")
|
||||
class Model:
|
||||
def __call__(self, req: starlette.requests.Request):
|
||||
def __call__(self, req: Request) -> str:
|
||||
# some inference work
|
||||
return "done"
|
||||
|
||||
|
@ -157,7 +157,7 @@ serve.shutdown()
|
|||
# __customized_route_old_api_1_start__
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __call__(self, req: starlette.requests.Request):
|
||||
def __call__(self, req: Request) -> str:
|
||||
# some inference work
|
||||
return "done"
|
||||
|
||||
|
@ -173,14 +173,14 @@ serve.shutdown()
|
|||
# __customized_route_old_api_2_start__
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __call__(self, req: starlette.requests.Request):
|
||||
def __call__(self, req: Request) -> str:
|
||||
# some inference work
|
||||
return "done"
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class Model2:
|
||||
def __call__(self, req: starlette.requests.Request):
|
||||
def __call__(self, req: Request) -> str:
|
||||
# some inference work
|
||||
return "done"
|
||||
|
||||
|
@ -197,7 +197,7 @@ serve.shutdown()
|
|||
# __graph_with_new_api_start__
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def forward(self, input):
|
||||
def forward(self, input) -> str:
|
||||
# do some inference work
|
||||
return "done"
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from ray import serve
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
# __batch_example_start__
|
||||
|
@ -9,14 +11,14 @@ class BatchingExample:
|
|||
self.count = 0
|
||||
|
||||
@serve.batch
|
||||
async def handle_batch(self, requests: List[Any]) -> List[Dict]:
|
||||
async def handle_batch(self, requests: List[Request]) -> List[Dict]:
|
||||
responses = []
|
||||
for request in requests:
|
||||
responses.append(request.json())
|
||||
|
||||
return responses
|
||||
|
||||
async def __call__(self, request) -> List[Dict]:
|
||||
async def __call__(self, request: Request) -> List[Dict]:
|
||||
return await self.handle_batch(request)
|
||||
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ from ray import serve
|
|||
from ray.serve.drivers import DAGDriver
|
||||
from ray.serve.deployment_graph import InputNode
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class AddCls:
|
||||
|
@ -15,7 +17,7 @@ class AddCls:
|
|||
def add(self, number: float) -> float:
|
||||
return number + self.addend
|
||||
|
||||
async def unpack_request(self, http_request) -> float:
|
||||
async def unpack_request(self, http_request: Request) -> float:
|
||||
return await http_request.json()
|
||||
|
||||
|
||||
|
@ -25,7 +27,7 @@ def subtract_one_fn(number: float) -> float:
|
|||
|
||||
|
||||
@serve.deployment
|
||||
async def unpack_request(http_request) -> float:
|
||||
async def unpack_request(http_request: Request) -> float:
|
||||
return await http_request.json()
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import requests
|
|||
|
||||
# __echo_class_start__
|
||||
# File name: echo.py
|
||||
from starlette.requests import Request
|
||||
|
||||
from ray import serve
|
||||
|
||||
|
||||
|
@ -12,7 +14,7 @@ class EchoClass:
|
|||
def __init__(self, echo_str: str):
|
||||
self.echo_str = echo_str
|
||||
|
||||
def __call__(self, request) -> str:
|
||||
def __call__(self, request: Request) -> str:
|
||||
return self.echo_str
|
||||
|
||||
|
||||
|
|
|
@ -7,15 +7,15 @@ from ray.dag.vis_utils import _dag_to_dot
|
|||
|
||||
@serve.deployment
|
||||
class Model:
|
||||
def __init__(self, weight):
|
||||
def __init__(self, weight: int):
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, input: int) -> int:
|
||||
return input + self.weight
|
||||
|
||||
|
||||
@serve.deployment
|
||||
def combine(output_1, output_2, kwargs_output=0):
|
||||
def combine(output_1: int, output_2: int, kwargs_output: int = 0) -> int:
|
||||
return output_1 + output_2 + kwargs_output
|
||||
|
||||
|
||||
|
|
|
@ -5,13 +5,14 @@
|
|||
|
||||
from ray import serve
|
||||
import logging
|
||||
from starlette.requests import Request
|
||||
|
||||
logger = logging.getLogger("ray.serve")
|
||||
|
||||
|
||||
@serve.deployment
|
||||
class SayHello:
|
||||
async def __call__(self, request):
|
||||
async def __call__(self, request: Request) -> str:
|
||||
logger.info("Hello world!")
|
||||
return "hi"
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import requests
|
||||
from starlette.requests import Request
|
||||
from typing import Dict
|
||||
|
||||
from ray import serve
|
||||
|
||||
|
||||
|
@ -9,7 +12,7 @@ class MyModelDeployment:
|
|||
# Initialize model state: could be very large neural net weights.
|
||||
self._msg = msg
|
||||
|
||||
def __call__(self, request):
|
||||
def __call__(self, request: Request) -> Dict:
|
||||
return {"result": self._msg}
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
# __serve_example_begin__
|
||||
import requests
|
||||
from starlette.requests import Request
|
||||
from typing import Dict
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
|
@ -22,7 +24,7 @@ class BoostingModel:
|
|||
self.model = model
|
||||
self.label_list = iris_dataset["target_names"].tolist()
|
||||
|
||||
async def __call__(self, request):
|
||||
async def __call__(self, request: Request) -> Dict:
|
||||
payload = (await request.json())["vector"]
|
||||
print(f"Received http request with data {payload}")
|
||||
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import requests
|
||||
from starlette.requests import Request
|
||||
from typing import Dict
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from ray import serve
|
||||
|
||||
|
||||
|
@ -9,7 +13,7 @@ class SentimentAnalysisDeployment:
|
|||
def __init__(self):
|
||||
self._model = pipeline("sentiment-analysis")
|
||||
|
||||
def __call__(self, request):
|
||||
def __call__(self, request: Request) -> Dict:
|
||||
return self._model(request.query_params["text"])[0]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue