mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[7/X][Pipeline] pipeline user facing build function (#22934)
This commit is contained in:
parent
34ffc7e5cf
commit
3546aabefd
23 changed files with 665 additions and 206 deletions
|
@ -643,7 +643,7 @@ class ActorClass:
|
|||
**cls_options,
|
||||
)
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
"""
|
||||
**Experimental**
|
||||
|
||||
|
@ -1041,7 +1041,7 @@ class ActorClass:
|
|||
|
||||
return actor_handle
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
"""
|
||||
**Experimental**
|
||||
|
||||
|
|
|
@ -93,7 +93,18 @@ class ClassNode(DAGNode):
|
|||
|
||||
def to_json(self, encoder_cls) -> Dict[str, Any]:
|
||||
json_dict = super().to_json_base(encoder_cls, ClassNode.__name__)
|
||||
json_dict["import_path"] = self.get_import_path()
|
||||
import_path = self.get_import_path()
|
||||
error_message = (
|
||||
"Class used in DAG should not be in-line defined when exporting"
|
||||
"import path for deployment. Please ensure it has fully "
|
||||
"qualified name with valid __module__ and __qualname__ for "
|
||||
"import path, with no __main__ or <locals>. \n"
|
||||
f"Current import path: {import_path}"
|
||||
)
|
||||
assert "__main__" not in import_path, error_message
|
||||
assert "<locals>" not in import_path, error_message
|
||||
|
||||
json_dict["import_path"] = import_path
|
||||
return json_dict
|
||||
|
||||
@classmethod
|
||||
|
@ -117,7 +128,7 @@ class _UnboundClassMethodNode(object):
|
|||
self._method_name = method_name
|
||||
self._options = {}
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
other_args_to_resolve = {
|
||||
PARENT_CLASS_NODE_KEY: self._actor,
|
||||
PREV_CLASS_METHOD_CALL_KEY: self._actor._last_call,
|
||||
|
|
|
@ -84,7 +84,7 @@ class DAGNode:
|
|||
|
||||
def execute(self, *args, **kwargs) -> Union[ray.ObjectRef, ray.actor.ActorHandle]:
|
||||
"""Execute this DAG using the Ray default executor."""
|
||||
return self._apply_recursive(lambda node: node._execute_impl(*args, **kwargs))
|
||||
return self.apply_recursive(lambda node: node._execute_impl(*args, **kwargs))
|
||||
|
||||
def _get_toplevel_child_nodes(self) -> Set["DAGNode"]:
|
||||
"""Return the set of nodes specified as top-level args.
|
||||
|
@ -135,7 +135,7 @@ class DAGNode:
|
|||
"""Apply and replace all immediate child nodes using a given function.
|
||||
|
||||
This is a shallow replacement only. To recursively transform nodes in
|
||||
the DAG, use ``_apply_recursive()``.
|
||||
the DAG, use ``apply_recursive()``.
|
||||
|
||||
Args:
|
||||
fn: Callable that will be applied once to each child of this node.
|
||||
|
@ -168,7 +168,7 @@ class DAGNode:
|
|||
new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
|
||||
)
|
||||
|
||||
def _apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
|
||||
def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
|
||||
"""Apply callable on each node in this DAG in a bottom-up tree walk.
|
||||
|
||||
Args:
|
||||
|
@ -203,11 +203,11 @@ class DAGNode:
|
|||
|
||||
return fn(
|
||||
self._apply_and_replace_all_child_nodes(
|
||||
lambda node: node._apply_recursive(fn)
|
||||
lambda node: node.apply_recursive(fn)
|
||||
)
|
||||
)
|
||||
|
||||
def _apply_functional(
|
||||
def apply_functional(
|
||||
self,
|
||||
source_input_list: Any,
|
||||
predictate_fn: Callable,
|
||||
|
|
|
@ -63,7 +63,17 @@ class FunctionNode(DAGNode):
|
|||
|
||||
def to_json(self, encoder_cls) -> Dict[str, Any]:
|
||||
json_dict = super().to_json_base(encoder_cls, FunctionNode.__name__)
|
||||
json_dict["import_path"] = self.get_import_path()
|
||||
import_path = self.get_import_path()
|
||||
error_message = (
|
||||
"Function used in DAG should not be in-line defined when exporting"
|
||||
"import path for deployment. Please ensure it has fully "
|
||||
"qualified name with valid __module__ and __qualname__ for "
|
||||
"import path, with no __main__ or <locals>. \n"
|
||||
f"Current import path: {import_path}"
|
||||
)
|
||||
assert "__main__" not in import_path, error_message
|
||||
assert "<locals>" not in import_path, error_message
|
||||
json_dict["import_path"] = import_path
|
||||
return json_dict
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -231,6 +231,22 @@ class InputAtrributeNode(DAGNode):
|
|||
def __str__(self) -> str:
|
||||
return get_dag_node_str(self, f'["{self._key}"]')
|
||||
|
||||
def to_json(self, encoder_cls) -> Dict[str, Any]:
|
||||
json_dict = super().to_json_base(encoder_cls, InputAtrributeNode.__name__)
|
||||
return json_dict
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, input_json, object_hook=None):
|
||||
assert input_json[DAGNODE_TYPE_KEY] == InputAtrributeNode.__name__
|
||||
args_dict = super().from_json_base(input_json, object_hook=object_hook)
|
||||
node = cls(
|
||||
args_dict["other_args_to_resolve"]["dag_input_node"],
|
||||
args_dict["other_args_to_resolve"]["key"],
|
||||
args_dict["other_args_to_resolve"]["accessor_method"],
|
||||
)
|
||||
node._stable_uuid = input_json["uuid"]
|
||||
return node
|
||||
|
||||
|
||||
class DAGInputData:
|
||||
"""If user passed multiple args and kwargs directly to dag.execute(), we
|
||||
|
|
|
@ -44,16 +44,16 @@ def test_basic_actor_dag(shared_ray_instance):
|
|||
def combine(x, y):
|
||||
return x + y
|
||||
|
||||
a1 = Actor._bind(10)
|
||||
res = a1.get._bind()
|
||||
a1 = Actor.bind(10)
|
||||
res = a1.get.bind()
|
||||
print(res)
|
||||
assert ray.get(res.execute()) == 10
|
||||
|
||||
a2 = Actor._bind(10)
|
||||
a1.inc._bind(2)
|
||||
a1.inc._bind(4)
|
||||
a2.inc._bind(6)
|
||||
dag = combine._bind(a1.get._bind(), a2.get._bind())
|
||||
a2 = Actor.bind(10)
|
||||
a1.inc.bind(2)
|
||||
a1.inc.bind(4)
|
||||
a2.inc.bind(6)
|
||||
dag = combine.bind(a1.get.bind(), a2.get.bind())
|
||||
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 32
|
||||
|
@ -71,9 +71,9 @@ def test_class_as_class_constructor_arg(shared_ray_instance):
|
|||
def get(self):
|
||||
return ray.get(self.inner_actor.get.remote())
|
||||
|
||||
outer = OuterActor._bind(Actor._bind(10))
|
||||
outer.inc._bind(2)
|
||||
dag = outer.get._bind()
|
||||
outer = OuterActor.bind(Actor.bind(10))
|
||||
outer.inc.bind(2)
|
||||
dag = outer.get.bind()
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 12
|
||||
|
||||
|
@ -83,19 +83,19 @@ def test_class_as_function_constructor_arg(shared_ray_instance):
|
|||
def f(actor_handle):
|
||||
return ray.get(actor_handle.get.remote())
|
||||
|
||||
dag = f._bind(Actor._bind(10))
|
||||
dag = f.bind(Actor.bind(10))
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 10
|
||||
|
||||
|
||||
def test_basic_actor_dag_constructor_options(shared_ray_instance):
|
||||
a1 = Actor._bind(10)
|
||||
dag = a1.get._bind()
|
||||
a1 = Actor.bind(10)
|
||||
dag = a1.get.bind()
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 10
|
||||
|
||||
a1 = Actor.options(name="Actor", namespace="test", max_pending_calls=10)._bind(10)
|
||||
dag = a1.get._bind()
|
||||
a1 = Actor.options(name="Actor", namespace="test", max_pending_calls=10).bind(10)
|
||||
dag = a1.get.bind()
|
||||
print(dag)
|
||||
# Ensure execution result is identical with .options() in init()
|
||||
assert ray.get(dag.execute()) == 10
|
||||
|
@ -106,16 +106,16 @@ def test_basic_actor_dag_constructor_options(shared_ray_instance):
|
|||
|
||||
|
||||
def test_actor_method_options(shared_ray_instance):
|
||||
a1 = Actor._bind(10)
|
||||
dag = a1.get.options(name="actor_method_options")._bind()
|
||||
a1 = Actor.bind(10)
|
||||
dag = a1.get.options(name="actor_method_options").bind()
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 10
|
||||
assert dag.get_options().get("name") == "actor_method_options"
|
||||
|
||||
|
||||
def test_basic_actor_dag_constructor_invalid_options(shared_ray_instance):
|
||||
a1 = Actor.options(num_cpus=-1)._bind(10)
|
||||
invalid_dag = a1.get._bind()
|
||||
a1 = Actor.options(num_cpus=-1).bind(10)
|
||||
invalid_dag = a1.get.bind()
|
||||
with pytest.raises(ValueError, match=".*Resource quantities may not be negative.*"):
|
||||
ray.get(invalid_dag.execute())
|
||||
|
||||
|
@ -130,24 +130,24 @@ def test_actor_options_complicated(shared_ray_instance):
|
|||
def combine(x, y):
|
||||
return x + y
|
||||
|
||||
a1 = Actor.options(name="a1_v0")._bind(10)
|
||||
res = a1.get.options(name="v1")._bind()
|
||||
a1 = Actor.options(name="a1_v0").bind(10)
|
||||
res = a1.get.options(name="v1").bind()
|
||||
print(res)
|
||||
assert ray.get(res.execute()) == 10
|
||||
assert a1.get_options().get("name") == "a1_v0"
|
||||
assert res.get_options().get("name") == "v1"
|
||||
|
||||
a1 = Actor.options(name="a1_v1")._bind(10) # Cannot
|
||||
a2 = Actor.options(name="a2_v0")._bind(10)
|
||||
a1.inc.options(name="v1")._bind(2)
|
||||
a1.inc.options(name="v2")._bind(4)
|
||||
a2.inc.options(name="v3")._bind(6)
|
||||
dag = combine.options(name="v4")._bind(a1.get._bind(), a2.get._bind())
|
||||
a1 = Actor.options(name="a1_v1").bind(10) # Cannot
|
||||
a2 = Actor.options(name="a2_v0").bind(10)
|
||||
a1.inc.options(name="v1").bind(2)
|
||||
a1.inc.options(name="v2").bind(4)
|
||||
a2.inc.options(name="v3").bind(6)
|
||||
dag = combine.options(name="v4").bind(a1.get.bind(), a2.get.bind())
|
||||
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == 32
|
||||
test_a1 = dag.get_args()[0] # call graph for a1.get._bind()
|
||||
test_a2 = dag.get_args()[1] # call graph for a2.get._bind()
|
||||
test_a1 = dag.get_args()[0] # call graph for a1.get.bind()
|
||||
test_a2 = dag.get_args()[1] # call graph for a2.get.bind()
|
||||
assert test_a2.get_options() == {} # No .options() at outer call
|
||||
# refer to a2 constructor .options() call
|
||||
assert (
|
||||
|
@ -198,8 +198,8 @@ def test_pass_actor_handle(shared_ray_instance):
|
|||
assert isinstance(handle, ray.actor.ActorHandle), handle
|
||||
return ray.get(handle.ping.remote())
|
||||
|
||||
a1 = Actor._bind()
|
||||
dag = caller._bind(a1)
|
||||
a1 = Actor.bind()
|
||||
dag = caller.bind(a1)
|
||||
print(dag)
|
||||
assert ray.get(dag.execute()) == "hello"
|
||||
|
||||
|
@ -227,15 +227,15 @@ def test_dynamic_pipeline(shared_ray_instance):
|
|||
result = m2.forward.remote(x)
|
||||
return ray.get(result)
|
||||
|
||||
m1 = Model._bind("Even: ")
|
||||
m2 = Model._bind("Odd: ")
|
||||
selection = ModelSelection._bind()
|
||||
m1 = Model.bind("Even: ")
|
||||
m2 = Model.bind("Odd: ")
|
||||
selection = ModelSelection.bind()
|
||||
|
||||
even_input = pipeline._bind(20, m1, m2, selection)
|
||||
even_input = pipeline.bind(20, m1, m2, selection)
|
||||
print(even_input)
|
||||
assert ray.get(even_input.execute()) == "Even: 20"
|
||||
|
||||
odd_input = pipeline._bind(21, m1, m2, selection)
|
||||
odd_input = pipeline.bind(21, m1, m2, selection)
|
||||
print(odd_input)
|
||||
assert ray.get(odd_input.execute()) == "Odd: 21"
|
||||
|
||||
|
|
|
@ -38,13 +38,13 @@ def test_basic_task_dag(shared_ray_instance):
|
|||
ray.get(ct.inc.remote())
|
||||
return x + y
|
||||
|
||||
a_ref = a._bind()
|
||||
b_ref = b._bind(a_ref)
|
||||
c_ref = c._bind(a_ref)
|
||||
d_ref = d._bind(b_ref, c_ref)
|
||||
d1_ref = d._bind(d_ref, d_ref)
|
||||
d2_ref = d._bind(d1_ref, d_ref)
|
||||
dag = d._bind(d2_ref, d_ref)
|
||||
a_ref = a.bind()
|
||||
b_ref = b.bind(a_ref)
|
||||
c_ref = c.bind(a_ref)
|
||||
d_ref = d.bind(b_ref, c_ref)
|
||||
d1_ref = d.bind(d_ref, d_ref)
|
||||
d2_ref = d.bind(d1_ref, d_ref)
|
||||
dag = d.bind(d2_ref, d_ref)
|
||||
print(dag)
|
||||
|
||||
assert ray.get(dag.execute()) == 28
|
||||
|
@ -74,10 +74,10 @@ def test_basic_task_dag_with_options(shared_ray_instance):
|
|||
ray.get(ct.inc.remote())
|
||||
return x + y
|
||||
|
||||
a_ref = a._bind()
|
||||
b_ref = b.options(name="b", num_returns=1)._bind(a_ref)
|
||||
c_ref = c.options(name="c", max_retries=3)._bind(a_ref)
|
||||
dag = d.options(name="d", num_cpus=2)._bind(b_ref, c_ref)
|
||||
a_ref = a.bind()
|
||||
b_ref = b.options(name="b", num_returns=1).bind(a_ref)
|
||||
c_ref = c.options(name="c", max_retries=3).bind(a_ref)
|
||||
dag = d.options(name="d", num_cpus=2).bind(b_ref, c_ref)
|
||||
|
||||
print(dag)
|
||||
|
||||
|
@ -106,12 +106,12 @@ def test_invalid_task_options(shared_ray_instance):
|
|||
def b(x):
|
||||
return x * 2
|
||||
|
||||
a_ref = a._bind()
|
||||
dag = b._bind(a_ref)
|
||||
a_ref = a.bind()
|
||||
dag = b.bind(a_ref)
|
||||
|
||||
# Ensure current DAG is executable
|
||||
assert ray.get(dag.execute()) == 4
|
||||
invalid_dag = b.options(num_cpus=-1)._bind(a_ref)
|
||||
invalid_dag = b.options(num_cpus=-1).bind(a_ref)
|
||||
with pytest.raises(ValueError, match=".*Resource quantities may not be negative.*"):
|
||||
ray.get(invalid_dag.execute())
|
||||
|
||||
|
@ -121,17 +121,17 @@ def test_node_accessors(shared_ray_instance):
|
|||
def a(*a, **kw):
|
||||
pass
|
||||
|
||||
tmp1 = a._bind()
|
||||
tmp2 = a._bind()
|
||||
tmp3 = a._bind()
|
||||
node = a._bind(1, tmp1, x=tmp2, y={"foo": tmp3})
|
||||
tmp1 = a.bind()
|
||||
tmp2 = a.bind()
|
||||
tmp3 = a.bind()
|
||||
node = a.bind(1, tmp1, x=tmp2, y={"foo": tmp3})
|
||||
assert node.get_args() == (1, tmp1)
|
||||
assert node.get_kwargs() == {"x": tmp2, "y": {"foo": tmp3}}
|
||||
assert node._get_toplevel_child_nodes() == {tmp1, tmp2}
|
||||
assert node._get_all_child_nodes() == {tmp1, tmp2, tmp3}
|
||||
|
||||
tmp4 = a._bind()
|
||||
tmp5 = a._bind()
|
||||
tmp4 = a.bind()
|
||||
tmp5 = a.bind()
|
||||
replace = {tmp1: tmp4, tmp2: tmp4, tmp3: tmp5}
|
||||
n2 = node._apply_and_replace_all_child_nodes(lambda x: replace[x])
|
||||
assert n2._get_all_child_nodes() == {tmp4, tmp5}
|
||||
|
@ -160,10 +160,10 @@ def test_nested_args(shared_ray_instance):
|
|||
ray.get(ct.inc.remote())
|
||||
return ray.get(nested["x"]) + ray.get(nested["y"])
|
||||
|
||||
a_ref = a._bind()
|
||||
b_ref = b._bind(x=a_ref)
|
||||
c_ref = c._bind(x=a_ref)
|
||||
dag = d._bind({"x": b_ref, "y": c_ref})
|
||||
a_ref = a.bind()
|
||||
b_ref = b.bind(x=a_ref)
|
||||
c_ref = c.bind(x=a_ref)
|
||||
dag = d.bind({"x": b_ref, "y": c_ref})
|
||||
print(dag)
|
||||
|
||||
assert ray.get(dag.execute()) == 7
|
||||
|
|
|
@ -21,13 +21,13 @@ def test_no_args_to_input_node(shared_ray_instance):
|
|||
ValueError, match="InputNode should not take any args or kwargs"
|
||||
):
|
||||
with InputNode(0) as dag_input:
|
||||
f._bind(dag_input)
|
||||
f.bind(dag_input)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="InputNode should not take any args or kwargs",
|
||||
):
|
||||
with InputNode(key=1) as dag_input:
|
||||
f._bind(dag_input)
|
||||
f.bind(dag_input)
|
||||
|
||||
|
||||
def test_simple_func(shared_ray_instance):
|
||||
|
@ -42,8 +42,8 @@ def test_simple_func(shared_ray_instance):
|
|||
|
||||
# input -> a - > b -> ouput
|
||||
with InputNode() as dag_input:
|
||||
a_node = a._bind(dag_input)
|
||||
dag = b._bind(a_node)
|
||||
a_node = a.bind(dag_input)
|
||||
dag = b.bind(a_node)
|
||||
|
||||
assert ray.get(dag.execute("input")) == "input -> a -> b"
|
||||
assert ray.get(dag.execute("test")) == "test -> a -> b"
|
||||
|
@ -67,13 +67,13 @@ def test_func_dag(shared_ray_instance):
|
|||
return x + y
|
||||
|
||||
with InputNode() as dag_input:
|
||||
a_ref = a._bind(dag_input)
|
||||
b_ref = b._bind(a_ref)
|
||||
c_ref = c._bind(a_ref)
|
||||
d_ref = d._bind(b_ref, c_ref)
|
||||
d1_ref = d._bind(d_ref, d_ref)
|
||||
d2_ref = d._bind(d1_ref, d_ref)
|
||||
dag = d._bind(d2_ref, d_ref)
|
||||
a_ref = a.bind(dag_input)
|
||||
b_ref = b.bind(a_ref)
|
||||
c_ref = c.bind(a_ref)
|
||||
d_ref = d.bind(b_ref, c_ref)
|
||||
d1_ref = d.bind(d_ref, d_ref)
|
||||
d2_ref = d.bind(d1_ref, d_ref)
|
||||
dag = d.bind(d2_ref, d_ref)
|
||||
|
||||
# [(2*2 + 2+1) + (2*2 + 2+1)] + [(2*2 + 2+1) + (2*2 + 2+1)]
|
||||
assert ray.get(dag.execute(2)) == 28
|
||||
|
@ -95,9 +95,9 @@ def test_multi_input_func_dag(shared_ray_instance):
|
|||
return x + y
|
||||
|
||||
with InputNode() as dag_input:
|
||||
a_ref = a._bind(dag_input)
|
||||
b_ref = b._bind(dag_input)
|
||||
dag = c._bind(a_ref, b_ref)
|
||||
a_ref = a.bind(dag_input)
|
||||
b_ref = b.bind(dag_input)
|
||||
dag = c.bind(a_ref, b_ref)
|
||||
|
||||
# (2*2) + (2*1)
|
||||
assert ray.get(dag.execute(2)) == 7
|
||||
|
@ -124,7 +124,7 @@ def test_invalid_input_node_as_class_constructor(shared_ray_instance):
|
|||
),
|
||||
):
|
||||
with InputNode() as dag_input:
|
||||
Actor._bind(dag_input)
|
||||
Actor.bind(dag_input)
|
||||
|
||||
|
||||
def test_class_method_input(shared_ray_instance):
|
||||
|
@ -145,10 +145,10 @@ def test_class_method_input(shared_ray_instance):
|
|||
return input * self.scale
|
||||
|
||||
with InputNode() as dag_input:
|
||||
preprocess = FeatureProcessor._bind(0.5)
|
||||
feature = preprocess.process._bind(dag_input)
|
||||
model = Model._bind(4)
|
||||
dag = model.forward._bind(feature)
|
||||
preprocess = FeatureProcessor.bind(0.5)
|
||||
feature = preprocess.process.bind(dag_input)
|
||||
model = Model.bind(4)
|
||||
dag = model.forward.bind(feature)
|
||||
|
||||
# 2 * 0.5 * 4
|
||||
assert ray.get(dag.execute(2)) == 4
|
||||
|
@ -174,13 +174,13 @@ def test_multi_class_method_input(shared_ray_instance):
|
|||
return m1 + m2
|
||||
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(2)
|
||||
m2 = Model._bind(3)
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
|
||||
m1_output = m1.forward._bind(dag_input)
|
||||
m2_output = m2.forward._bind(dag_input)
|
||||
m1_output = m1.forward.bind(dag_input)
|
||||
m2_output = m2.forward.bind(dag_input)
|
||||
|
||||
dag = combine._bind(m1_output, m2_output)
|
||||
dag = combine.bind(m1_output, m2_output)
|
||||
|
||||
# 1*2 + 1*3
|
||||
assert ray.get(dag.execute(1)) == 5
|
||||
|
@ -211,11 +211,11 @@ def test_func_class_mixed_input(shared_ray_instance):
|
|||
return m1 + m2
|
||||
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(3)
|
||||
m1_output = m1.forward._bind(dag_input)
|
||||
m2_output = model_func._bind(dag_input)
|
||||
m1 = Model.bind(3)
|
||||
m1_output = m1.forward.bind(dag_input)
|
||||
m2_output = model_func.bind(dag_input)
|
||||
|
||||
dag = combine._bind(m1_output, m2_output)
|
||||
dag = combine.bind(m1_output, m2_output)
|
||||
# 2*3 + 2*2
|
||||
assert ray.get(dag.execute(2)) == 10
|
||||
# 3*3 + 3*2
|
||||
|
@ -240,11 +240,11 @@ def test_input_attr_partial_access(shared_ray_instance):
|
|||
|
||||
# 1) Test default wrapping of args and kwargs into internal python object
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(1)
|
||||
m2 = Model._bind(2)
|
||||
m1_output = m1.forward._bind(dag_input[0])
|
||||
m2_output = m2.forward._bind(dag_input[1])
|
||||
dag = combine._bind(m1_output, m2_output, dag_input.m3, dag_input.m4)
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
m1_output = m1.forward.bind(dag_input[0])
|
||||
m2_output = m2.forward.bind(dag_input[1])
|
||||
dag = combine.bind(m1_output, m2_output, dag_input.m3, dag_input.m4)
|
||||
# 1*1 + 2*2 + 3 + 4 = 12
|
||||
assert ray.get(dag.execute(1, 2, m3=3, m4={"deep": {"nested": 4}})) == 12
|
||||
|
||||
|
@ -262,32 +262,32 @@ def test_input_attr_partial_access(shared_ray_instance):
|
|||
self.field_3 = field_3
|
||||
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(1)
|
||||
m2 = Model._bind(2)
|
||||
m1_output = m1.forward._bind(dag_input.user_object_field_0)
|
||||
m2_output = m2.forward._bind(dag_input.user_object_field_1)
|
||||
dag = combine._bind(m1_output, m2_output, dag_input.field_3)
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
m1_output = m1.forward.bind(dag_input.user_object_field_0)
|
||||
m2_output = m2.forward.bind(dag_input.user_object_field_1)
|
||||
dag = combine.bind(m1_output, m2_output, dag_input.field_3)
|
||||
|
||||
# 1*1 + 2*2 + 3
|
||||
assert ray.get(dag.execute(UserDataObj(1, 2, 3))) == 8
|
||||
|
||||
# 3) Test user passed only one list object with regular list index accessor
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(1)
|
||||
m2 = Model._bind(2)
|
||||
m1_output = m1.forward._bind(dag_input[0])
|
||||
m2_output = m2.forward._bind(dag_input[1])
|
||||
dag = combine._bind(m1_output, m2_output, dag_input[2])
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
m1_output = m1.forward.bind(dag_input[0])
|
||||
m2_output = m2.forward.bind(dag_input[1])
|
||||
dag = combine.bind(m1_output, m2_output, dag_input[2])
|
||||
# 1*1 + 2*2 + 3 + 4 = 12
|
||||
assert ray.get(dag.execute([1, 2, 3])) == 8
|
||||
|
||||
# 4) Test user passed only one dict object with key str accessor
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(1)
|
||||
m2 = Model._bind(2)
|
||||
m1_output = m1.forward._bind(dag_input["m1"])
|
||||
m2_output = m2.forward._bind(dag_input["m2"])
|
||||
dag = combine._bind(m1_output, m2_output, dag_input["m3"])
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
m1_output = m1.forward.bind(dag_input["m1"])
|
||||
m2_output = m2.forward.bind(dag_input["m2"])
|
||||
dag = combine.bind(m1_output, m2_output, dag_input["m3"])
|
||||
# 1*1 + 2*2 + 3 + 4 = 12
|
||||
assert ray.get(dag.execute({"m1": 1, "m2": 2, "m3": 3})) == 8
|
||||
|
||||
|
@ -296,8 +296,8 @@ def test_input_attr_partial_access(shared_ray_instance):
|
|||
match="Please only use int index or str as first-level key",
|
||||
):
|
||||
with InputNode() as dag_input:
|
||||
m1 = Model._bind(1)
|
||||
dag = m1.forward._bind(dag_input[(1, 2)])
|
||||
m1 = Model.bind(1)
|
||||
dag = m1.forward.bind(dag_input[(1, 2)])
|
||||
|
||||
|
||||
def test_ensure_in_context_manager(shared_ray_instance):
|
||||
|
@ -317,7 +317,7 @@ def test_ensure_in_context_manager(shared_ray_instance):
|
|||
return input
|
||||
|
||||
# No enforcement on creation given __enter__ executes after __init__
|
||||
dag = f._bind(InputNode())
|
||||
dag = f.bind(InputNode())
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=(
|
||||
|
@ -338,10 +338,10 @@ def test_ensure_input_node_singleton(shared_ray_instance):
|
|||
return a + b
|
||||
|
||||
with InputNode() as input_1:
|
||||
a = f._bind(input_1)
|
||||
a = f.bind(input_1)
|
||||
with InputNode() as input_2:
|
||||
b = f._bind(input_2)
|
||||
dag = combine._bind(a, b)
|
||||
b = f.bind(input_2)
|
||||
dag = combine.bind(a, b)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError, match="Each DAG should only have one unique InputNode"
|
||||
|
|
|
@ -248,7 +248,7 @@ class RemoteFunction:
|
|||
def remote(self, *args, **kwargs):
|
||||
return func_cls._remote(args=args, kwargs=kwargs, **options)
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
"""
|
||||
**Experimental**
|
||||
|
||||
|
@ -460,7 +460,7 @@ class RemoteFunction:
|
|||
|
||||
return invocation(args, kwargs)
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
"""
|
||||
**Experimental**
|
||||
|
||||
|
|
|
@ -460,3 +460,11 @@ py_test(
|
|||
tags = ["exclusive", "team:serve"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_build",
|
||||
size = "medium",
|
||||
srcs = pipeline_tests_srcs,
|
||||
tags = ["exclusive", "team:serve"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
|
|
@ -1167,7 +1167,7 @@ class Deployment:
|
|||
_internal=True,
|
||||
)
|
||||
|
||||
def _bind(self, *args, **kwargs):
|
||||
def bind(self, *args, **kwargs):
|
||||
raise AttributeError(
|
||||
"DAG building API should only be used for @ray.remote decorated "
|
||||
"class or function, not in serve deployment or library "
|
||||
|
|
74
python/ray/serve/pipeline/api.py
Normal file
74
python/ray/serve/pipeline/api.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
from ray.experimental.dag import DAGNode
|
||||
from ray.serve.pipeline.generate import (
|
||||
transform_ray_dag_to_serve_dag,
|
||||
extract_deployments_from_serve_dag,
|
||||
get_pipeline_input_node,
|
||||
get_ingress_deployment,
|
||||
)
|
||||
|
||||
|
||||
def build(ray_dag_root_node: DAGNode):
|
||||
"""Do all the DAG transformation, extraction and generation needed to
|
||||
produce a runnable and deployable serve pipeline application from a valid
|
||||
DAG authored with Ray DAG API.
|
||||
|
||||
This should be the only user facing API that user interacts with.
|
||||
|
||||
Assumptions:
|
||||
Following enforcements are only applied at generating and applying
|
||||
pipeline artifact, but not blockers for local development and testing.
|
||||
|
||||
- ALL args and kwargs used in DAG building should be JSON serializable.
|
||||
This means in order to ensure your pipeline application can run on
|
||||
a remote cluster potentially with different runtime environment,
|
||||
among all options listed:
|
||||
|
||||
1) binding in-memory objects
|
||||
2) Rely on pickling
|
||||
3) Enforce JSON serialibility on all args used
|
||||
|
||||
We believe both 1) & 2) rely on unstable in-memory objects or
|
||||
cross version pickling / closure capture, where JSON serialization
|
||||
provides the right contract needed for proper deployment.
|
||||
|
||||
- ALL classes and methods used should be visible on top of the file and
|
||||
importable via a fully qualified name. Thus no inline class or
|
||||
function definitions should be used.
|
||||
|
||||
Args:
|
||||
ray_dag_root_node: DAGNode acting as root of a Ray authored DAG. It
|
||||
should be executable via `ray_dag_root_node.execute(user_input)`
|
||||
and should have `PipelineInputNode` in it.
|
||||
|
||||
Returns:
|
||||
app: The Ray Serve application object that wraps all deployments needed
|
||||
along with ingress deployment for an e2e runnable serve pipeline,
|
||||
accessible via python .remote() call and HTTP.
|
||||
|
||||
Examples:
|
||||
>>> with ServeInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
... m1 = Model.bind(1)
|
||||
... m2 = Model.bind(2)
|
||||
... m1_output = m1.forward.bind(dag_input[0])
|
||||
... m2_output = m2.forward.bind(dag_input[1])
|
||||
... ray_dag = ensemble.bind(m1_output, m2_output)
|
||||
|
||||
Assuming we have non-JSON serializable or inline defined class or
|
||||
function in local pipeline development.
|
||||
|
||||
>>> app = serve.pipeline.build(ray_dag) # This works
|
||||
>>> handle = app.deploy()
|
||||
>>> # This also works, we're simply executing the transformed serve_dag.
|
||||
>>> ray.get(handle.remote(data)
|
||||
>>> # This will fail where enforcements are applied.
|
||||
>>> deployment_yaml = app.to_yaml()
|
||||
"""
|
||||
serve_root_dag = ray_dag_root_node.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
pipeline_input_node = get_pipeline_input_node(serve_root_dag)
|
||||
ingress_deployment = get_ingress_deployment(serve_root_dag, pipeline_input_node)
|
||||
deployments.insert(0, ingress_deployment)
|
||||
|
||||
# TODO (jiaodong): Call into Application once Shreyas' PR is merged
|
||||
# TODO (jiaodong): Apply enforcements at serve app to_yaml level
|
||||
return deployments
|
|
@ -50,7 +50,7 @@ class DeploymentNode(DAGNode):
|
|||
(
|
||||
replaced_deployment_init_args,
|
||||
replaced_deployment_init_kwargs,
|
||||
) = self._apply_functional(
|
||||
) = self.apply_functional(
|
||||
[deployment_init_args, deployment_init_kwargs],
|
||||
predictate_fn=lambda node: isinstance(
|
||||
node, (DeploymentNode, DeploymentMethodNode)
|
||||
|
@ -169,7 +169,19 @@ class DeploymentNode(DAGNode):
|
|||
def to_json(self, encoder_cls) -> Dict[str, Any]:
|
||||
json_dict = super().to_json_base(encoder_cls, DeploymentNode.__name__)
|
||||
json_dict["deployment_name"] = self.get_deployment_name()
|
||||
json_dict["import_path"] = self.get_import_path()
|
||||
import_path = self.get_import_path()
|
||||
|
||||
error_message = (
|
||||
"Class used in DAG should not be in-line defined when exporting"
|
||||
"import path for deployment. Please ensure it has fully "
|
||||
"qualified name with valid __module__ and __qualname__ for "
|
||||
"import path, with no __main__ or <locals>. \n"
|
||||
f"Current import path: {import_path}"
|
||||
)
|
||||
assert "__main__" not in import_path, error_message
|
||||
assert "<locals>" not in import_path, error_message
|
||||
|
||||
json_dict["import_path"] = import_path
|
||||
|
||||
return json_dict
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from ray.experimental.dag import (
|
|||
ClassMethodNode,
|
||||
PARENT_CLASS_NODE_KEY,
|
||||
)
|
||||
from ray.experimental.dag.input_node import InputNode
|
||||
from ray.serve.api import Deployment
|
||||
from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode
|
||||
from ray.serve.pipeline.deployment_node import DeploymentNode
|
||||
|
@ -132,11 +133,44 @@ def extract_deployments_from_serve_dag(
|
|||
deployments[deployment.name] = deployment
|
||||
return dag_node
|
||||
|
||||
serve_dag_root._apply_recursive(extractor)
|
||||
serve_dag_root.apply_recursive(extractor)
|
||||
|
||||
return list(deployments.values())
|
||||
|
||||
|
||||
def get_pipeline_input_node(serve_dag_root_node: DAGNode):
|
||||
"""Return the PipelineInputNode singleton node from serve dag, and throw
|
||||
exceptions if we didn't find any, or found more than one.
|
||||
|
||||
Args:
|
||||
ray_dag_root_node: DAGNode acting as root of a Ray authored DAG. It
|
||||
should be executable via `ray_dag_root_node.execute(user_input)`
|
||||
and should have `PipelineInputNode` in it.
|
||||
Returns
|
||||
pipeline_input_node: Singleton input node for the serve pipeline.
|
||||
"""
|
||||
|
||||
input_nodes = []
|
||||
|
||||
def extractor(dag_node):
|
||||
if isinstance(dag_node, PipelineInputNode):
|
||||
input_nodes.append(dag_node)
|
||||
elif isinstance(dag_node, InputNode):
|
||||
raise ValueError(
|
||||
"Please change Ray DAG InputNode to PipelineInputNode in order "
|
||||
"to build serve application. See docstring of "
|
||||
"PipelineInputNode for examples."
|
||||
)
|
||||
|
||||
serve_dag_root_node.apply_recursive(extractor)
|
||||
assert len(input_nodes) == 1, (
|
||||
"There should be one and only one PipelineInputNode in the DAG. "
|
||||
f"Found {len(input_nodes)} PipelineInputNode(s) instead."
|
||||
)
|
||||
|
||||
return input_nodes[0]
|
||||
|
||||
|
||||
def get_ingress_deployment(
|
||||
serve_dag_root_node: DAGNode, pipeline_input_node: PipelineInputNode
|
||||
) -> Deployment:
|
||||
|
|
|
@ -9,6 +9,7 @@ from ray.experimental.dag import (
|
|||
ClassMethodNode,
|
||||
FunctionNode,
|
||||
InputNode,
|
||||
InputAtrributeNode,
|
||||
DAGNODE_TYPE_KEY,
|
||||
)
|
||||
from ray.serve.pipeline.deployment_node import DeploymentNode
|
||||
|
@ -90,6 +91,8 @@ def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
|
|||
# Deserialize DAGNode type
|
||||
elif input_json[DAGNODE_TYPE_KEY] == InputNode.__name__:
|
||||
return InputNode.from_json(input_json, object_hook=dagnode_from_json)
|
||||
elif input_json[DAGNODE_TYPE_KEY] == InputAtrributeNode.__name__:
|
||||
return InputAtrributeNode.from_json(input_json, object_hook=dagnode_from_json)
|
||||
elif input_json[DAGNODE_TYPE_KEY] == PipelineInputNode.__name__:
|
||||
return PipelineInputNode.from_json(input_json, object_hook=dagnode_from_json)
|
||||
elif input_json[DAGNODE_TYPE_KEY] == ClassMethodNode.__name__:
|
||||
|
@ -100,6 +103,7 @@ def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:
|
|||
return DeploymentMethodNode.from_json(input_json, object_hook=dagnode_from_json)
|
||||
else:
|
||||
# Class and Function nodes require original module as body.
|
||||
print(f"import_path: {input_json['import_path']}")
|
||||
module_name, attr_name = parse_import_path(input_json["import_path"])
|
||||
module = getattr(import_module(module_name), attr_name)
|
||||
if input_json[DAGNODE_TYPE_KEY] == FunctionNode.__name__:
|
||||
|
|
|
@ -24,8 +24,8 @@ class PipelineInputNode(InputNode):
|
|||
>>> with PipelineInputNode(
|
||||
... preprocessor=request_to_data_int
|
||||
... ) as dag_input:
|
||||
... model = Model._bind(2, ratio=0.3)
|
||||
... ray_dag = model.forward._bind(dag_input)
|
||||
... model = Model.bind(2, ratio=0.3)
|
||||
... ray_dag = model.forward.bind(dag_input)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -81,6 +81,17 @@ class PipelineInputNode(InputNode):
|
|||
|
||||
def to_json(self, encoder_cls) -> Dict[str, Any]:
|
||||
json_dict = super().to_json_base(encoder_cls, PipelineInputNode.__name__)
|
||||
preprocessor_import_path = self.get_preprocessor_import_path()
|
||||
error_message = (
|
||||
"Preprocessor used in DAG should not be in-line defined when "
|
||||
"exporting import path for deployment. Please ensure it has fully "
|
||||
"qualified name with valid __module__ and __qualname__ for "
|
||||
"import path, with no __main__ or <locals>. \n"
|
||||
f"Current import path: {preprocessor_import_path}"
|
||||
)
|
||||
assert "__main__" not in preprocessor_import_path, error_message
|
||||
assert "<locals>" not in preprocessor_import_path, error_message
|
||||
|
||||
return json_dict
|
||||
|
||||
@classmethod
|
||||
|
|
0
python/ray/serve/pipeline/tests/resources/__init__.py
Normal file
0
python/ray/serve/pipeline/tests/resources/__init__.py
Normal file
64
python/ray/serve/pipeline/tests/resources/test_dags.py
Normal file
64
python/ray/serve/pipeline/tests/resources/test_dags.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from ray.serve.pipeline.tests.resources.test_modules import (
|
||||
Model,
|
||||
Combine,
|
||||
combine,
|
||||
NESTED_HANDLE_KEY,
|
||||
request_to_data_int,
|
||||
request_to_data_obj,
|
||||
)
|
||||
from ray.serve.pipeline.pipeline_input_node import PipelineInputNode
|
||||
|
||||
|
||||
def get_simple_func_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_obj) as dag_input:
|
||||
ray_dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
|
||||
|
||||
return ray_dag, dag_input
|
||||
|
||||
|
||||
def get_simple_class_with_class_method_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
model = Model.bind(2, ratio=0.3)
|
||||
ray_dag = model.forward.bind(dag_input)
|
||||
|
||||
return ray_dag, dag_input
|
||||
|
||||
|
||||
def get_func_class_with_class_method_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_obj) as dag_input:
|
||||
m1 = Model.bind(1)
|
||||
m2 = Model.bind(2)
|
||||
m1_output = m1.forward.bind(dag_input[0])
|
||||
m2_output = m2.forward.bind(dag_input[1])
|
||||
ray_dag = combine.bind(m1_output, m2_output, kwargs_output=dag_input[2])
|
||||
|
||||
return ray_dag, dag_input
|
||||
|
||||
|
||||
def get_multi_instantiation_class_deployment_in_init_args_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2=m2)
|
||||
ray_dag = combine.__call__.bind(dag_input)
|
||||
|
||||
return ray_dag, dag_input
|
||||
|
||||
|
||||
def get_shared_deployment_handle_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m = Model.bind(2)
|
||||
combine = Combine.bind(m, m2=m)
|
||||
ray_dag = combine.__call__.bind(dag_input)
|
||||
|
||||
return ray_dag, dag_input
|
||||
|
||||
|
||||
def get_multi_instantiation_class_nested_deployment_arg_dag():
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
ray_dag = combine.__call__.bind(dag_input)
|
||||
|
||||
return ray_dag, dag_input
|
|
@ -4,6 +4,7 @@ fully qualified name as import_path to test DAG building, artifact generation
|
|||
and structured deployment.
|
||||
"""
|
||||
import starlette
|
||||
import json
|
||||
from typing import TypeVar
|
||||
|
||||
import ray
|
||||
|
@ -74,6 +75,22 @@ def combine(m1_output, m2_output, kwargs_output=0):
|
|||
return m1_output + m2_output + kwargs_output
|
||||
|
||||
|
||||
def class_factory():
|
||||
class MyInlineClass:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def get(self):
|
||||
return self.val
|
||||
|
||||
return MyInlineClass
|
||||
|
||||
|
||||
async def request_to_data_int(request: starlette.requests.Request):
|
||||
data = await request.body()
|
||||
return int(data)
|
||||
|
||||
|
||||
async def request_to_data_obj(request: starlette.requests.Request):
|
||||
data = await request.body()
|
||||
return json.loads(data)
|
102
python/ray/serve/pipeline/tests/test_build.py
Normal file
102
python/ray/serve/pipeline/tests/test_build.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
from ray.serve.pipeline.api import build
|
||||
|
||||
|
||||
from ray.serve.pipeline.tests.resources.test_dags import (
|
||||
get_simple_func_dag,
|
||||
get_simple_class_with_class_method_dag,
|
||||
get_func_class_with_class_method_dag,
|
||||
get_multi_instantiation_class_deployment_in_init_args_dag,
|
||||
get_shared_deployment_handle_dag,
|
||||
get_multi_instantiation_class_nested_deployment_arg_dag,
|
||||
)
|
||||
|
||||
|
||||
def test_build_simple_func_dag(serve_instance):
|
||||
ray_dag, _ = get_simple_func_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data=json.dumps([1, 2]))
|
||||
assert resp.text == "4"
|
||||
|
||||
|
||||
def test_build_simple_class_with_class_method_dag(serve_instance):
|
||||
ray_dag, _ = get_simple_class_with_class_method_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "0.6"
|
||||
|
||||
|
||||
def test_build_func_class_with_class_method_dag(serve_instance):
|
||||
ray_dag, _ = get_func_class_with_class_method_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data=json.dumps([1, 2, 3]))
|
||||
assert resp.text == "8"
|
||||
|
||||
|
||||
def test_build_multi_instantiation_class_deployment_in_init_args(
|
||||
serve_instance,
|
||||
):
|
||||
"""
|
||||
Test we can pass deployments as init_arg or init_kwarg, instantiated
|
||||
multiple times for the same class, and we can still correctly replace
|
||||
args with deployment handle and parse correct deployment instances.
|
||||
"""
|
||||
ray_dag, _ = get_multi_instantiation_class_deployment_in_init_args_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "5"
|
||||
|
||||
|
||||
def test_build_shared_deployment_handle(serve_instance):
|
||||
"""
|
||||
Test we can re-use the same deployment handle multiple times or in
|
||||
multiple places, without incorrectly parsing duplicated deployments.
|
||||
"""
|
||||
ray_dag, _ = get_shared_deployment_handle_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "4"
|
||||
|
||||
|
||||
def test_build_multi_instantiation_class_nested_deployment_arg(serve_instance):
|
||||
"""
|
||||
Test we can pass deployments with **nested** init_arg or init_kwarg,
|
||||
instantiated multiple times for the same class, and we can still correctly
|
||||
replace args with deployment handle and parse correct deployment instances.
|
||||
"""
|
||||
ray_dag, _ = get_multi_instantiation_class_nested_deployment_arg_dag()
|
||||
|
||||
deployments = build(ray_dag)
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "5"
|
|
@ -51,7 +51,7 @@ def test_disallow_binding_deployments():
|
|||
AttributeError,
|
||||
match="DAG building API should only be used for @ray.remote decorated",
|
||||
):
|
||||
_ = ServeActor._bind(10)
|
||||
_ = ServeActor.bind(10)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -1,20 +1,30 @@
|
|||
import pytest
|
||||
import requests
|
||||
import json
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.handle import RayServeSyncHandle
|
||||
from ray.experimental.dag import InputNode
|
||||
from ray.serve.pipeline.generate import (
|
||||
transform_ray_dag_to_serve_dag,
|
||||
extract_deployments_from_serve_dag,
|
||||
get_pipeline_input_node,
|
||||
get_ingress_deployment,
|
||||
)
|
||||
from ray.serve.pipeline.tests.test_modules import (
|
||||
from ray.serve.pipeline.tests.resources.test_modules import (
|
||||
Model,
|
||||
Combine,
|
||||
NESTED_HANDLE_KEY,
|
||||
combine,
|
||||
request_to_data_int,
|
||||
)
|
||||
from ray.serve.pipeline.tests.resources.test_dags import (
|
||||
get_simple_class_with_class_method_dag,
|
||||
get_func_class_with_class_method_dag,
|
||||
get_multi_instantiation_class_deployment_in_init_args_dag,
|
||||
get_shared_deployment_handle_dag,
|
||||
get_multi_instantiation_class_nested_deployment_arg_dag,
|
||||
)
|
||||
from ray.serve.pipeline.pipeline_input_node import PipelineInputNode
|
||||
|
||||
|
||||
|
@ -34,12 +44,9 @@ def _validate_consistent_python_output(
|
|||
|
||||
|
||||
def test_simple_single_class(serve_instance):
|
||||
# Assert converting both arg and kwarg
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
model = Model._bind(2, ratio=0.3)
|
||||
ray_dag = model.forward._bind(dag_input)
|
||||
ray_dag, dag_input = get_simple_class_with_class_method_dag()
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
ingress_deployment = get_ingress_deployment(serve_root_dag, dag_input)
|
||||
assert len(deployments) == 1
|
||||
|
@ -50,18 +57,16 @@ def test_simple_single_class(serve_instance):
|
|||
)
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get(
|
||||
f"http://127.0.0.1:8000/{ingress_deployment.name}", data="1"
|
||||
)
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "0.6"
|
||||
|
||||
|
||||
def test_single_class_with_valid_ray_options(serve_instance):
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
model = Model.options(num_cpus=1, memory=1000)._bind(2, ratio=0.3)
|
||||
ray_dag = model.forward._bind(dag_input)
|
||||
model = Model.options(num_cpus=1, memory=1000).bind(2, ratio=0.3)
|
||||
ray_dag = model.forward.bind(dag_input)
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
assert len(deployments) == 1
|
||||
deployments[0].deploy()
|
||||
|
@ -77,10 +82,10 @@ def test_single_class_with_valid_ray_options(serve_instance):
|
|||
|
||||
def test_single_class_with_invalid_deployment_options(serve_instance):
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
model = Model.options(name="my_deployment")._bind(2, ratio=0.3)
|
||||
ray_dag = model.forward._bind(dag_input)
|
||||
model = Model.options(name="my_deployment").bind(2, ratio=0.3)
|
||||
ray_dag = model.forward.bind(dag_input)
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
assert len(deployments) == 1
|
||||
with pytest.raises(
|
||||
|
@ -89,20 +94,33 @@ def test_single_class_with_invalid_deployment_options(serve_instance):
|
|||
deployments[0].deploy()
|
||||
|
||||
|
||||
def test_func_class_with_class_method_dag(serve_instance):
|
||||
ray_dag, dag_input = get_func_class_with_class_method_dag()
|
||||
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
ingress_deployment = get_ingress_deployment(serve_root_dag, dag_input)
|
||||
assert len(deployments) == 2
|
||||
for deployment in deployments:
|
||||
deployment.deploy()
|
||||
ingress_deployment.deploy()
|
||||
|
||||
assert ray.get(ray_dag.execute(1, 2, 3)) == 8
|
||||
assert ray.get(serve_root_dag.execute(1, 2, 3)) == 8
|
||||
for _ in range(5):
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data=json.dumps([1, 2, 3]))
|
||||
assert resp.text == "8"
|
||||
|
||||
|
||||
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
||||
"""
|
||||
Test we can pass deployments as init_arg or init_kwarg, instantiated
|
||||
multiple times for the same class, and we can still correctly replace
|
||||
args with deployment handle and parse correct deployment instances.
|
||||
"""
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model._bind(2)
|
||||
m2 = Model._bind(3)
|
||||
combine = Combine._bind(m1, m2=m2)
|
||||
ray_dag = combine.__call__._bind(dag_input)
|
||||
print(f"Ray DAG: \n{ray_dag}")
|
||||
ray_dag, dag_input = get_multi_instantiation_class_deployment_in_init_args_dag()
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
print(f"Serve DAG: \n{serve_root_dag}")
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
assert len(deployments) == 3
|
||||
|
@ -115,10 +133,9 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
|
|||
_validate_consistent_python_output(
|
||||
deployments[2], ray_dag, "Combine", input=1, output=5
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get(
|
||||
f"http://127.0.0.1:8000/{ingress_deployment.name}", data="1"
|
||||
)
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "5"
|
||||
|
||||
|
||||
|
@ -127,13 +144,9 @@ def test_shared_deployment_handle(serve_instance):
|
|||
Test we can re-use the same deployment handle multiple times or in
|
||||
multiple places, without incorrectly parsing duplicated deployments.
|
||||
"""
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m = Model._bind(2)
|
||||
combine = Combine._bind(m, m2=m)
|
||||
ray_dag = combine.__call__._bind(dag_input)
|
||||
print(f"Ray DAG: \n{ray_dag}")
|
||||
ray_dag, dag_input = get_shared_deployment_handle_dag()
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
print(f"Serve DAG: \n{serve_root_dag}")
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
assert len(deployments) == 2
|
||||
|
@ -147,9 +160,7 @@ def test_shared_deployment_handle(serve_instance):
|
|||
deployments[1], ray_dag, "Combine", input=1, output=4
|
||||
)
|
||||
for _ in range(5):
|
||||
resp = requests.get(
|
||||
f"http://127.0.0.1:8000/{ingress_deployment.name}", data="1"
|
||||
)
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "4"
|
||||
|
||||
|
||||
|
@ -159,14 +170,9 @@ def test_multi_instantiation_class_nested_deployment_arg(serve_instance):
|
|||
instantiated multiple times for the same class, and we can still correctly
|
||||
replace args with deployment handle and parse correct deployment instances.
|
||||
"""
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model._bind(2)
|
||||
m2 = Model._bind(3)
|
||||
combine = Combine._bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
ray_dag = combine.__call__._bind(dag_input)
|
||||
print(f"Ray DAG: \n{ray_dag}")
|
||||
ray_dag, dag_input = get_multi_instantiation_class_nested_deployment_arg_dag()
|
||||
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
print(f"Serve DAG: \n{serve_root_dag}")
|
||||
deployments = extract_deployments_from_serve_dag(serve_root_dag)
|
||||
assert len(deployments) == 3
|
||||
|
@ -189,13 +195,43 @@ def test_multi_instantiation_class_nested_deployment_arg(serve_instance):
|
|||
_validate_consistent_python_output(
|
||||
deployments[2], ray_dag, "Combine", input=1, output=5
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
resp = requests.get(
|
||||
f"http://127.0.0.1:8000/{ingress_deployment.name}", data="1"
|
||||
)
|
||||
resp = requests.get("http://127.0.0.1:8000/ingress", data="1")
|
||||
assert resp.text == "5"
|
||||
|
||||
|
||||
def test_get_pipeline_input_node():
|
||||
# 1) No PipelineInputNode found
|
||||
ray_dag = combine.bind(1, 2)
|
||||
serve_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
with pytest.raises(
|
||||
AssertionError, match="There should be one and only one PipelineInputNode"
|
||||
):
|
||||
get_pipeline_input_node(serve_dag)
|
||||
|
||||
# 2) More than one PipelineInputNode found
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
a = combine.bind(dag_input[0], dag_input[1])
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input_2:
|
||||
b = combine.bind(dag_input_2[0], dag_input_2[1])
|
||||
ray_dag = combine.bind(a, b)
|
||||
serve_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
with pytest.raises(
|
||||
AssertionError, match="There should be one and only one PipelineInputNode"
|
||||
):
|
||||
get_pipeline_input_node(serve_dag)
|
||||
|
||||
# 3) User forgot to change InputNode to PipelineInputNode
|
||||
with InputNode() as dag_input:
|
||||
ray_dag = combine.bind(dag_input[0], dag_input[1])
|
||||
serve_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
with pytest.raises(
|
||||
ValueError, match="Please change Ray DAG InputNode to PipelineInputNode"
|
||||
):
|
||||
get_pipeline_input_node(serve_dag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -9,12 +9,13 @@ from ray.serve.pipeline.json_serde import (
|
|||
dagnode_from_json,
|
||||
DAGNODE_TYPE_KEY,
|
||||
)
|
||||
from ray.serve.pipeline.tests.test_modules import (
|
||||
from ray.serve.pipeline.tests.resources.test_modules import (
|
||||
Model,
|
||||
combine,
|
||||
Counter,
|
||||
ClassHello,
|
||||
fn_hello,
|
||||
class_factory,
|
||||
Combine,
|
||||
request_to_data_int,
|
||||
NESTED_HANDLE_KEY,
|
||||
|
@ -90,7 +91,7 @@ def test_non_json_serializable_args():
|
|||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
ray_dag = combine._bind(MyNonJSONClass(1), MyNonJSONClass(2))
|
||||
ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2))
|
||||
# General context
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
|
@ -117,6 +118,65 @@ def test_non_json_serializable_args():
|
|||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
|
||||
|
||||
def test_no_inline_class_or_func(serve_instance):
|
||||
# 1) Inline function
|
||||
@ray.remote
|
||||
def inline_func(val):
|
||||
return val
|
||||
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
ray_dag = inline_func.bind(dag_input)
|
||||
|
||||
assert ray.get(ray_dag.execute(1)) == 1
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Function used in DAG should not be in-line defined",
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
|
||||
# 2) Inline class
|
||||
@ray.remote
|
||||
class InlineClass:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def get(self, input):
|
||||
return self.val + input
|
||||
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
node = InlineClass.bind(1)
|
||||
ray_dag = node.get.bind(dag_input)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Class used in DAG should not be in-line defined",
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
|
||||
# 3) Inline preprocessor fn
|
||||
def inline_preprocessor_fn(input):
|
||||
return input
|
||||
|
||||
with PipelineInputNode(preprocessor=inline_preprocessor_fn) as dag_input:
|
||||
ray_dag = combine.bind(dag_input[0], 2)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Preprocessor used in DAG should not be in-line defined",
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
|
||||
# 4) Class factory that function returns class object
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
instance = ray.remote(class_factory()).bind()
|
||||
ray_dag = instance.get.bind()
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Class used in DAG should not be in-line defined",
|
||||
):
|
||||
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
|
||||
|
||||
|
||||
def test_simple_function_node_json_serde(serve_instance):
|
||||
"""
|
||||
Test the following behavior
|
||||
|
@ -129,13 +189,13 @@ def test_simple_function_node_json_serde(serve_instance):
|
|||
- Simple function with only args, all primitive types
|
||||
- Simple function with args + kwargs, all primitive types
|
||||
"""
|
||||
original_dag_node = combine._bind(1, 2)
|
||||
original_dag_node = combine.bind(1, 2)
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_function_node,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "FunctionNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.combine",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.combine",
|
||||
"args": "[1, 2]",
|
||||
"kwargs": "{}",
|
||||
"options": "{}",
|
||||
|
@ -144,13 +204,13 @@ def test_simple_function_node_json_serde(serve_instance):
|
|||
},
|
||||
)
|
||||
|
||||
original_dag_node = combine._bind(1, 2, kwargs_output=3)
|
||||
original_dag_node = combine.bind(1, 2, kwargs_output=3)
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_function_node,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "FunctionNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.combine",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.combine",
|
||||
"args": "[1, 2]",
|
||||
"kwargs": '{"kwargs_output": 3}',
|
||||
"options": "{}",
|
||||
|
@ -159,13 +219,13 @@ def test_simple_function_node_json_serde(serve_instance):
|
|||
},
|
||||
)
|
||||
|
||||
original_dag_node = fn_hello._bind()
|
||||
original_dag_node = fn_hello.bind()
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_function_node,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "FunctionNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.fn_hello",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.fn_hello",
|
||||
"args": "[]",
|
||||
"kwargs": "{}",
|
||||
"options": "{}",
|
||||
|
@ -189,13 +249,13 @@ def test_simple_class_node_json_serde(serve_instance):
|
|||
- Simple class with args + kwargs, all primitive types
|
||||
- Simple chain of class method calls, all primitive types
|
||||
"""
|
||||
original_dag_node = ClassHello._bind()
|
||||
original_dag_node = ClassHello.bind()
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_class_node_ClassHello,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "ClassNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.ClassHello",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.ClassHello",
|
||||
"args": "[]",
|
||||
"kwargs": "{}",
|
||||
"options": "{}",
|
||||
|
@ -204,13 +264,13 @@ def test_simple_class_node_json_serde(serve_instance):
|
|||
},
|
||||
)
|
||||
|
||||
original_dag_node = Model._bind(1)
|
||||
original_dag_node = Model.bind(1)
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_class_node_Model,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "ClassNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.Model",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.Model",
|
||||
"args": "[1]",
|
||||
"kwargs": "{}",
|
||||
"options": "{}",
|
||||
|
@ -219,13 +279,13 @@ def test_simple_class_node_json_serde(serve_instance):
|
|||
},
|
||||
)
|
||||
|
||||
original_dag_node = Model._bind(1, ratio=0.5)
|
||||
original_dag_node = Model.bind(1, ratio=0.5)
|
||||
_test_json_serde_helper(
|
||||
original_dag_node,
|
||||
executor_fn=_test_execution_class_node_Model,
|
||||
expected_json_dict={
|
||||
DAGNODE_TYPE_KEY: "ClassNode",
|
||||
"import_path": "ray.serve.pipeline.tests.test_modules.Model",
|
||||
"import_path": "ray.serve.pipeline.tests.resources.test_modules.Model",
|
||||
"args": "[1]",
|
||||
"kwargs": '{"ratio": 0.5}',
|
||||
"options": "{}",
|
||||
|
@ -246,7 +306,7 @@ def _test_deployment_json_serde_helper(
|
|||
3) Deserialized serve dag can extract correct number and definition of
|
||||
serve deployments.
|
||||
"""
|
||||
serve_root_dag = ray_dag._apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
serve_root_dag = ray_dag.apply_recursive(transform_ray_dag_to_serve_dag)
|
||||
json_serialized = json.dumps(serve_root_dag, cls=DAGNodeEncoder)
|
||||
deserialized_serve_root_dag_node = json.loads(
|
||||
json_serialized, object_hook=dagnode_from_json
|
||||
|
@ -272,10 +332,10 @@ def test_simple_deployment_method_call_chain(serve_instance):
|
|||
ClassMethodNode to DeploymentMethodNode that acts on deployment handle
|
||||
that is uniquely identified by its name without dependency of uuid.
|
||||
"""
|
||||
counter = Counter._bind(0)
|
||||
counter.inc._bind(1)
|
||||
counter.inc._bind(2)
|
||||
ray_dag = counter.get._bind()
|
||||
counter = Counter.bind(0)
|
||||
counter.inc.bind(1)
|
||||
counter.inc.bind(2)
|
||||
ray_dag = counter.get.bind()
|
||||
assert ray.get(ray_dag.execute()) == 3
|
||||
(
|
||||
serve_root_dag,
|
||||
|
@ -291,10 +351,10 @@ def test_simple_deployment_method_call_chain(serve_instance):
|
|||
|
||||
def test_multi_instantiation_class_nested_deployment_arg(serve_instance):
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model._bind(2)
|
||||
m2 = Model._bind(3)
|
||||
combine = Combine._bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
ray_dag = combine.__call__._bind(dag_input)
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
|
||||
ray_dag = combine.__call__.bind(dag_input)
|
||||
|
||||
(
|
||||
serve_root_dag,
|
||||
|
@ -307,13 +367,13 @@ def test_multi_instantiation_class_nested_deployment_arg(serve_instance):
|
|||
|
||||
def test_nested_deployment_node_json_serde(serve_instance):
|
||||
with PipelineInputNode(preprocessor=request_to_data_int) as dag_input:
|
||||
m1 = Model._bind(2)
|
||||
m2 = Model._bind(3)
|
||||
m1 = Model.bind(2)
|
||||
m2 = Model.bind(3)
|
||||
|
||||
m1_output = m1.forward._bind(dag_input)
|
||||
m2_output = m2.forward._bind(dag_input)
|
||||
m1_output = m1.forward.bind(dag_input)
|
||||
m2_output = m2.forward.bind(dag_input)
|
||||
|
||||
ray_dag = combine._bind(m1_output, m2_output)
|
||||
ray_dag = combine.bind(m1_output, m2_output)
|
||||
(
|
||||
serve_root_dag,
|
||||
deserialized_serve_root_dag_node,
|
||||
|
|
Loading…
Add table
Reference in a new issue