[Serve][docs] Add type annotations to code samples (#27795)

This commit is contained in:
zcin 2022-08-12 13:41:08 -07:00 committed by GitHub
parent 192d92bb77
commit fa37ddc584
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 63 additions and 202 deletions

View file

@ -8,10 +8,10 @@ from ray.serve.deployment_graph import InputNode
@serve.deployment
class Model:
def __init__(self, weight):
def __init__(self, weight: int):
self.weight = weight
def forward(self, input):
def forward(self, input: int) -> int:
return input + self.weight

View file

@ -8,15 +8,15 @@ from ray.dag.input_node import InputNode
@serve.deployment
class Model:
def __init__(self, weight):
def __init__(self, weight: int):
self.weight = weight
def forward(self, input):
def forward(self, input: int) -> int:
return input + self.weight
@serve.deployment
def combine(value1, value2, operation):
def combine(value1: int, value2: int, operation: str) -> int:
if operation == "sum":
return sum([value1, value2])
else:

View file

@ -1,18 +0,0 @@
from ray import serve
import logging
logger = logging.getLogger("ray.serve")
@serve.deployment
class Counter:
def __init__(self):
self.count = 0
async def __call__(self, request):
self.count += 1
return self.count
counter = Counter.bind()

View file

@ -1,98 +0,0 @@
# flake8: noqa
# fmt: off
#
# __serve_example_begin__
#
# This brief example shows how to create, deploy, and expose access to
# deployment models, using the simple Ray Serve deployment APIs.
# Once deployed, you can access deployment via two methods:
# ServerHandle API and HTTP
#
import os
from random import random
import requests
import starlette.requests
from ray import serve
#
# A simple example model stored in a pickled format at an accessible path
# that can be reloaded and deserialized into a model instance. Once deployed
# in Ray Serve, we can use it for prediction. The prediction is a fake condition,
# based on threshold of weight greater than 0.5.
#
class Model:
def __init__(self, path):
self.path = path
def predict(self, data: float) -> float:
return random() + data if data > 0.5 else data
@serve.deployment
class Predictor:
# Take in a path to load your desired model
def __init__(self, path: str) -> None:
self.path = path
self.model = Model(path)
# Get the pid on which this deployment is running on
self.pid = os.getpid()
# Deployments are callable. Here we simply return a prediction from
# our request.
async def predict(self, data: float) -> str:
pred = self.model.predict(data)
return (f"(pid: {self.pid}); path: {self.path}; "
f"data: {float(data):.3f}; prediction: {pred:.3f}")
async def __call__(self, http_request: starlette.requests.Request) -> str:
data = float(await http_request.query_params['data'])
return await self.predict(data)
@serve.deployment
class ServeHandleDemo:
def __init__(self, predictor_1: Predictor, predictor_2: Predictor):
self.predictor_1 = predictor_1
self.predictor_2 = predictor_2
async def run(self):
# Query each deployment twice to demonstrate that the requests
# get forwarded to different replicas (below, we will set
# num_replicas to 2 for each deployment).
for _ in range(2):
for predictor in [self.predictor_1, self.predictor_2]:
# Call our deployments from Python using the ServeHandle API.
random_prediction = await predictor.predict.remote(random())
print(f"prediction: {random_prediction}")
async def __call__(self, http_request: starlette.requests.Request) -> str:
return await self.run()
predictor_1 = Predictor.options(num_replicas=2).bind("/model/model-1.pkl")
predictor_2 = Predictor.options(num_replicas=2).bind("/model/model-2.pkl")
# Pass in our deployments as arguments. At runtime, these are resolved to ServeHandles.
serve_handle_demo = ServeHandleDemo.bind(predictor_1, predictor_2)
# Start a local single-node Ray cluster and start Ray Serve. These will shut down upon
# exiting this script.
serve.run(serve_handle_demo)
print("ServeHandle API responses: " + "--" * 5)
url = "http://127.0.0.1:8000/"
response = requests.get(url)
prediction = response.text
print(f"prediction : {prediction}")
# Output ("INFO" logs omitted for brevity):
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16059); path: /model/model-1.pkl; data: 0.166; prediction: 0.166
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16061); path: /model/model-2.pkl; data: 0.820; prediction: 0.986
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16058); path: /model/model-1.pkl; data: 0.691; prediction: 0.857
# (ServeReplica:ServeHandleDemo pid=16062) prediction: (pid: 16060); path: /model/model-2.pkl; data: 0.948; prediction: 1.113
# __serve_example_end__

View file

@ -1,12 +1,14 @@
import subprocess
# __deploy_in_single_file_1_start__
from starlette.requests import Request
import ray
from ray import serve
@serve.deployment
def my_func(request):
def my_func(request: Request) -> str:
return "hello"
@ -24,7 +26,7 @@ ray.init(address="auto", namespace="serve")
@serve.deployment
def my_func(request):
def my_func(request: Request) -> str:
return "hello"

View file

@ -1,50 +0,0 @@
import requests
from pydantic import BaseModel
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.dag.input_node import InputNode
class ModelInputData(BaseModel):
model_input1: int
model_input2: str
@serve.deployment
class Model:
def __init__(self, weight):
self.weight = weight
def forward(self, input: ModelInputData):
return input.model_input1 + len(input.model_input2) + self.weight
@serve.deployment
def combine(value_refs):
return sum(ray.get(value_refs))
with InputNode() as user_input:
model1 = Model.bind(0)
model2 = Model.bind(1)
output1 = model1.forward.bind(user_input)
output2 = model2.forward.bind(user_input)
dag = combine.bind([output1, output2])
serve_dag = DAGDriver.options(route_prefix="/my-dag").bind(
dag, http_adapter=ModelInputData
)
dag_handle = serve.run(serve_dag)
print(
ray.get(
dag_handle.predict.remote(ModelInputData(model_input1=1, model_input2="test"))
)
)
print(
requests.post(
"http://127.0.0.1:8000/my-dag", json={"model_input1": 1, "model_input2": "test"}
).text
)

View file

@ -1,6 +1,8 @@
# flake8: noqa
# __import_start__
from starlette.requests import Request
import ray
from ray import serve
@ -25,7 +27,7 @@ class Translator:
return translation
async def __call__(self, http_request) -> str:
async def __call__(self, http_request: Request) -> str:
english_text: str = await http_request.json()
return self.translate(english_text)

View file

@ -2,8 +2,11 @@
# __deployment_full_start__
# File name: serve_deployment.py
from starlette.requests import Request
import ray
from ray import serve
from transformers import pipeline
@ -22,7 +25,7 @@ class Translator:
return translation
async def __call__(self, http_request) -> str:
async def __call__(self, http_request: Request) -> str:
english_text: str = await http_request.json()
return self.translate(english_text)

View file

@ -2,6 +2,8 @@
# __start_graph__
# File name: graph.py
from starlette.requests import Request
import ray
from ray import serve
@ -40,7 +42,7 @@ class Summarizer:
return summary
async def __call__(self, http_request) -> str:
async def __call__(self, http_request: Request) -> str:
english_text: str = await http_request.json()
summary = self.summarize(english_text)

View file

@ -1,6 +1,8 @@
# flake8: noqa
# __begin_sync_handle__
from starlette.requests import Request
import ray
from ray import serve
from ray.serve.handle import RayServeSyncHandle
@ -8,7 +10,7 @@ from ray.serve.handle import RayServeSyncHandle
@serve.deployment
class Model:
def __call__(self):
def __call__(self) -> str:
return "hello"
@ -27,7 +29,7 @@ from ray.serve.handle import RayServeDeploymentHandle, RayServeSyncHandle
@serve.deployment
class Model:
def __call__(self):
def __call__(self) -> str:
return "hello"
@ -104,10 +106,10 @@ from ray.serve.handle import RayServeSyncHandle
@serve.deployment
class Deployment:
def method1(self, arg):
def method1(self, arg: str) -> str:
return f"Method1: {arg}"
def __call__(self, arg):
def __call__(self, arg: str) -> str:
return f"__call__: {arg}"

View file

@ -5,7 +5,7 @@ from ray.serve.drivers import DAGDriver
@serve.deployment
def preprocess(inp: int):
def preprocess(inp: int) -> int:
return inp + 1
@ -14,7 +14,7 @@ class Model:
def __init__(self, increment: int):
self.increment = increment
def predict(self, inp: int):
def predict(self, inp: int) -> int:
return inp + self.increment

View file

@ -6,7 +6,7 @@ from ray.serve.handle import RayServeDeploymentHandle
from ray.serve.handle import RayServeSyncHandle
import requests
import starlette
from starlette.requests import Request
serve.start()
@ -16,7 +16,7 @@ serve.start()
@serve.deployment
class Model:
def forward(self, input):
def forward(self, input) -> str:
# do some inference work
return "done"
@ -92,7 +92,7 @@ serve.start()
# __customized_route_old_api_start__
@serve.deployment(route_prefix="/my_model1")
class Model:
def __call__(self, req: starlette.requests.Request):
def __call__(self, req: Request) -> str:
# some inference work
return "done"
@ -157,7 +157,7 @@ serve.shutdown()
# __customized_route_old_api_1_start__
@serve.deployment
class Model:
def __call__(self, req: starlette.requests.Request):
def __call__(self, req: Request) -> str:
# some inference work
return "done"
@ -173,14 +173,14 @@ serve.shutdown()
# __customized_route_old_api_2_start__
@serve.deployment
class Model:
def __call__(self, req: starlette.requests.Request):
def __call__(self, req: Request) -> str:
# some inference work
return "done"
@serve.deployment
class Model2:
def __call__(self, req: starlette.requests.Request):
def __call__(self, req: Request) -> str:
# some inference work
return "done"
@ -197,7 +197,7 @@ serve.shutdown()
# __graph_with_new_api_start__
@serve.deployment
class Model:
def forward(self, input):
def forward(self, input) -> str:
# do some inference work
return "done"

View file

@ -1,5 +1,7 @@
from ray import serve
from typing import List, Dict, Any
from typing import List, Dict
from starlette.requests import Request
# __batch_example_start__
@ -9,14 +11,14 @@ class BatchingExample:
self.count = 0
@serve.batch
async def handle_batch(self, requests: List[Any]) -> List[Dict]:
async def handle_batch(self, requests: List[Request]) -> List[Dict]:
responses = []
for request in requests:
responses.append(request.json())
return responses
async def __call__(self, request) -> List[Dict]:
async def __call__(self, request: Request) -> List[Dict]:
return await self.handle_batch(request)

View file

@ -6,6 +6,8 @@ from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.deployment_graph import InputNode
from starlette.requests import Request
@serve.deployment
class AddCls:
@ -15,7 +17,7 @@ class AddCls:
def add(self, number: float) -> float:
return number + self.addend
async def unpack_request(self, http_request) -> float:
async def unpack_request(self, http_request: Request) -> float:
return await http_request.json()
@ -25,7 +27,7 @@ def subtract_one_fn(number: float) -> float:
@serve.deployment
async def unpack_request(http_request) -> float:
async def unpack_request(http_request: Request) -> float:
return await http_request.json()

View file

@ -4,6 +4,8 @@ import requests
# __echo_class_start__
# File name: echo.py
from starlette.requests import Request
from ray import serve
@ -12,7 +14,7 @@ class EchoClass:
def __init__(self, echo_str: str):
self.echo_str = echo_str
def __call__(self, request) -> str:
def __call__(self, request: Request) -> str:
return self.echo_str

View file

@ -7,15 +7,15 @@ from ray.dag.vis_utils import _dag_to_dot
@serve.deployment
class Model:
def __init__(self, weight):
def __init__(self, weight: int):
self.weight = weight
def forward(self, input):
def forward(self, input: int) -> int:
return input + self.weight
@serve.deployment
def combine(output_1, output_2, kwargs_output=0):
def combine(output_1: int, output_2: int, kwargs_output: int = 0) -> int:
return output_1 + output_2 + kwargs_output

View file

@ -5,13 +5,14 @@
from ray import serve
import logging
from starlette.requests import Request
logger = logging.getLogger("ray.serve")
@serve.deployment
class SayHello:
async def __call__(self, request):
async def __call__(self, request: Request) -> str:
logger.info("Hello world!")
return "hi"

View file

@ -1,4 +1,7 @@
import requests
from starlette.requests import Request
from typing import Dict
from ray import serve
@ -9,7 +12,7 @@ class MyModelDeployment:
# Initialize model state: could be very large neural net weights.
self._msg = msg
def __call__(self, request):
def __call__(self, request: Request) -> Dict:
return {"result": self._msg}

View file

@ -3,6 +3,8 @@
# __serve_example_begin__
import requests
from starlette.requests import Request
from typing import Dict
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
@ -22,7 +24,7 @@ class BoostingModel:
self.model = model
self.label_list = iris_dataset["target_names"].tolist()
async def __call__(self, request):
async def __call__(self, request: Request) -> Dict:
payload = (await request.json())["vector"]
print(f"Received http request with data {payload}")

View file

@ -1,5 +1,9 @@
import requests
from starlette.requests import Request
from typing import Dict
from transformers import pipeline
from ray import serve
@ -9,7 +13,7 @@ class SentimentAnalysisDeployment:
def __init__(self):
self._model = pipeline("sentiment-analysis")
def __call__(self, request):
def __call__(self, request: Request) -> Dict:
return self._model(request.query_params["text"])[0]