mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

Enable checking of the ray core module, excluding serve, workflows, and tune, in ./ci/lint/check_api_annotations.py. This required moving many files to ray._private and associated fixes.
37 lines
862 B
Python
37 lines
862 B
Python
import ray
|
|
from ray import serve
|
|
from ray.serve.deployment_graph import InputNode
|
|
|
|
ray.init()
|
|
|
|
|
|
@serve.deployment
|
|
class Model:
|
|
def __init__(self, weight):
|
|
self.weight = weight
|
|
|
|
def forward(self, input):
|
|
return input + self.weight
|
|
|
|
|
|
@serve.deployment
|
|
def combine(output_1, output_2, kwargs_output=0):
|
|
return output_1 + output_2 + kwargs_output
|
|
|
|
|
|
with InputNode() as user_input:
|
|
m1 = Model.bind(1)
|
|
m2 = Model.bind(2)
|
|
m1_output = m1.forward.bind(user_input[0])
|
|
m2_output = m2.forward.bind(user_input[1])
|
|
dag = combine.bind(m1_output, m2_output, kwargs_output=user_input[2])
|
|
|
|
# Partial DAG visualization
|
|
graph = ray.dag.vis_utils._dag_to_dot(m1_output)
|
|
to_string = graph.to_string()
|
|
print(to_string)
|
|
|
|
# Entire DAG visualization
|
|
graph = ray.dag.vis_utils._dag_to_dot(dag)
|
|
to_string = graph.to_string()
|
|
print(to_string)
|