mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
27 lines
599 B
Python
27 lines
599 B
Python
import ray
|
|
from ray import serve
|
|
from ray.serve.dag import InputNode
|
|
from ray.serve.drivers import DAGDriver
|
|
|
|
|
|
@serve.deployment
|
|
def preprocess(inp: int) -> int:
|
|
return inp + 1
|
|
|
|
|
|
@serve.deployment
|
|
class Model:
|
|
def __init__(self, increment: int):
|
|
self.increment = increment
|
|
|
|
def predict(self, inp: int) -> int:
|
|
return inp + self.increment
|
|
|
|
|
|
with InputNode() as inp:
|
|
model = Model.bind(increment=2)
|
|
output = model.predict.bind(preprocess.bind(inp))
|
|
serve_dag = DAGDriver.bind(output)
|
|
|
|
handle = serve.run(serve_dag)
|
|
assert ray.get(handle.predict.remote(1)) == 4
|