mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[workflow] Deprecate "workflow.step" [Part 2 - most nested workflows] (#23728)
* remove workflow.step * convert examples
This commit is contained in:
parent
c0e38e335c
commit
46465abd6d
18 changed files with 174 additions and 154 deletions
|
@ -4,14 +4,14 @@ import ray
|
|||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def extract() -> dict:
|
||||
data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
|
||||
order_data_dict = json.loads(data_string)
|
||||
return order_data_dict
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def transform(order_data_dict: dict) -> dict:
|
||||
total_order_value = 0
|
||||
for value in order_data_dict.values():
|
||||
|
@ -19,7 +19,7 @@ def transform(order_data_dict: dict) -> dict:
|
|||
return {"total_order_value": ray.put(total_order_value)}
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def load(data_dict: dict) -> str:
|
||||
total_order_value = ray.get(data_dict["total_order_value"])
|
||||
return f"Total order value is: {total_order_value:.2f}"
|
||||
|
@ -27,7 +27,7 @@ def load(data_dict: dict) -> str:
|
|||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
order_data = extract.step()
|
||||
order_summary = transform.step(order_data)
|
||||
etl = load.step(order_summary)
|
||||
print(etl.run())
|
||||
order_data = extract.bind()
|
||||
order_summary = transform.bind(order_data)
|
||||
etl = load.bind(order_summary)
|
||||
print(workflow.create(etl).run())
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def handle_heads() -> str:
|
||||
return "It was heads"
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def handle_tails() -> str:
|
||||
return "It was tails"
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def flip_coin() -> str:
|
||||
import random
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def decide(heads: bool) -> str:
|
||||
if heads:
|
||||
return handle_heads.step()
|
||||
else:
|
||||
return handle_tails.step()
|
||||
return workflow.continuation(
|
||||
handle_heads.bind() if heads else handle_tails.bind()
|
||||
)
|
||||
|
||||
return decide.step(random.random() > 0.5)
|
||||
return workflow.continuation(decide.bind(random.random() > 0.5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
print(flip_coin.step().run())
|
||||
print(workflow.create(flip_coin.bind()).run())
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Tuple, Optional
|
||||
|
||||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
|
@ -8,17 +9,17 @@ def intentional_fail() -> str:
|
|||
raise RuntimeError("oops")
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def cry(error: Exception) -> None:
|
||||
print("Sadly", error)
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def celebrate(result: str) -> None:
|
||||
print("Success!", result)
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def send_email(result: str) -> None:
|
||||
print("Sending email", result)
|
||||
|
||||
|
@ -26,17 +27,17 @@ def send_email(result: str) -> None:
|
|||
@workflow.step
|
||||
def exit_handler(res: Tuple[Optional[str], Optional[Exception]]) -> None:
|
||||
result, error = res
|
||||
email = send_email.step("Raw result: {}, {}".format(result, error))
|
||||
email = send_email.bind(f"Raw result: {result}, {error}")
|
||||
if error:
|
||||
handler = cry.step(error)
|
||||
handler = cry.bind(error)
|
||||
else:
|
||||
handler = celebrate.step(result)
|
||||
return wait_all.step(handler, email)
|
||||
handler = celebrate.bind(result)
|
||||
return workflow.continuation(wait_all.bind(handler, email))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def wait_all(*deps):
|
||||
pass
|
||||
return "done"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
# TODO(ekl) should support something like runtime_env={"pip": ["whalesay"]}
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def hello(msg: str) -> None:
|
||||
print(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
hello.step("hello world").run()
|
||||
workflow.create(hello.bind("hello world")).run()
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def hello(msg: str) -> None:
|
||||
print(msg)
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def wait_all(*args) -> None:
|
||||
pass
|
||||
|
||||
|
@ -15,5 +16,5 @@ if __name__ == "__main__":
|
|||
workflow.init()
|
||||
children = []
|
||||
for msg in ["hello world", "goodbye world"]:
|
||||
children.append(hello.step(msg))
|
||||
wait_all.step(*children).run()
|
||||
children.append(hello.bind(msg))
|
||||
workflow.create(wait_all.bind(*children)).run()
|
||||
|
|
|
@ -1,31 +1,32 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def handle_heads() -> str:
|
||||
return "It was heads"
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def handle_tails() -> str:
|
||||
print("It was tails, retrying")
|
||||
return flip_coin.step()
|
||||
return workflow.continuation(flip_coin.bind())
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def flip_coin() -> str:
|
||||
import random
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def decide(heads: bool) -> str:
|
||||
if heads:
|
||||
return handle_heads.step()
|
||||
return workflow.continuation(handle_heads.bind())
|
||||
else:
|
||||
return handle_tails.step()
|
||||
return workflow.continuation(handle_tails.bind())
|
||||
|
||||
return decide.step(random.random() > 0.5)
|
||||
return workflow.continuation(decide.bind(random.random() > 0.5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
print(flip_coin.step().run())
|
||||
print(workflow.create(flip_coin.bind()).run())
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def compose_greeting(greeting: str, name: str) -> str:
|
||||
return greeting + ": " + name
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def main_workflow(name: str) -> str:
|
||||
return compose_greeting.step("Hello", name)
|
||||
return workflow.continuation(compose_greeting.bind("Hello", name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
wf = main_workflow.step("Alice")
|
||||
wf = workflow.create(main_workflow.bind("Alice"))
|
||||
print(wf.run())
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Tuple, Optional
|
||||
|
||||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
|
@ -16,24 +17,19 @@ def generate_request_id():
|
|||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
@workflow.step
|
||||
def cancel(request_id: str) -> None:
|
||||
make_request("cancel", request_id)
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def book_car(request_id: str) -> str:
|
||||
car_reservation_id = make_request("book_car", request_id)
|
||||
return car_reservation_id
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def book_hotel(request_id: str, *deps) -> str:
|
||||
hotel_reservation_id = make_request("book_hotel", request_id)
|
||||
return hotel_reservation_id
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def book_flight(request_id: str, *deps) -> str:
|
||||
flight_reservation_id = make_request("book_flight", request_id)
|
||||
return flight_reservation_id
|
||||
|
@ -41,15 +37,15 @@ def book_flight(request_id: str, *deps) -> str:
|
|||
|
||||
@workflow.step
|
||||
def book_all(car_req_id: str, hotel_req_id: str, flight_req_id: str) -> str:
|
||||
car_res_id = book_car.step(car_req_id)
|
||||
hotel_res_id = book_hotel.step(hotel_req_id, car_res_id)
|
||||
flight_res_id = book_flight.step(hotel_req_id, hotel_res_id)
|
||||
car_res_id = book_car.bind(car_req_id)
|
||||
hotel_res_id = book_hotel.bind(hotel_req_id, car_res_id)
|
||||
flight_res_id = book_flight.bind(hotel_req_id, hotel_res_id)
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def concat(*ids: List[str]) -> str:
|
||||
return ", ".join(ids)
|
||||
|
||||
return concat.step(car_res_id, hotel_res_id, flight_res_id)
|
||||
return workflow.continuation(concat.bind(car_res_id, hotel_res_id, flight_res_id))
|
||||
|
||||
|
||||
@workflow.step
|
||||
|
@ -61,15 +57,21 @@ def handle_errors(
|
|||
) -> str:
|
||||
result, error = final_result
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def wait_all(*deps) -> None:
|
||||
pass
|
||||
|
||||
@ray.remote
|
||||
def cancel(request_id: str) -> None:
|
||||
make_request("cancel", request_id)
|
||||
|
||||
if error:
|
||||
return wait_all.step(
|
||||
cancel.step(car_req_id),
|
||||
cancel.step(hotel_req_id),
|
||||
cancel.step(flight_req_id),
|
||||
return workflow.continuation(
|
||||
wait_all.bind(
|
||||
cancel.bind(car_req_id),
|
||||
cancel.bind(hotel_req_id),
|
||||
cancel.bind(flight_req_id),
|
||||
)
|
||||
)
|
||||
else:
|
||||
return result
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
from typing import List
|
||||
|
||||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def iterate(array: List[str], result: str, i: int) -> str:
|
||||
if i >= len(array):
|
||||
return result
|
||||
return iterate.step(array, result + array[i], i + 1)
|
||||
return workflow.continuation(iterate.bind(array, result + array[i], i + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
print(iterate.step(["foo", "ba", "r"], "", 0).run())
|
||||
print(workflow.create(iterate.bind(["foo", "ba", "r"], "", 0)).run())
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
|
@ -6,36 +7,36 @@ def make_request(url: str) -> str:
|
|||
return "42"
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def get_size() -> int:
|
||||
return int(make_request("https://www.example.com/callA"))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def small(result: int) -> str:
|
||||
return make_request("https://www.example.com/SmallFunc")
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def medium(result: int) -> str:
|
||||
return make_request("https://www.example.com/MediumFunc")
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def large(result: int) -> str:
|
||||
return make_request("https://www.example.com/LargeFunc")
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def decide(result: int) -> str:
|
||||
if result < 10:
|
||||
return small.step(result)
|
||||
return workflow.continuation(small.bind(result))
|
||||
elif result < 100:
|
||||
return medium.step(result)
|
||||
return workflow.continuation(medium.bind(result))
|
||||
else:
|
||||
return large.step(result)
|
||||
return workflow.continuation(large.bind(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
print(decide.step(get_size.step()).run())
|
||||
print(workflow.create(decide.bind(get_size.bind())).run())
|
||||
|
|
|
@ -1,23 +1,24 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def hello(name: str) -> str:
|
||||
return format_name.step(name)
|
||||
return workflow.continuation(format_name.bind(name))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def format_name(name: str) -> str:
|
||||
return "hello, {}".format(name)
|
||||
return f"hello, {name}"
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def report(msg: str) -> None:
|
||||
print(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
r1 = hello.step("Kristof")
|
||||
r2 = report.step(r1)
|
||||
r2.run()
|
||||
r1 = hello.bind("Kristof")
|
||||
r2 = report.bind(r1)
|
||||
workflow.create(r2).run()
|
||||
|
|
|
@ -1,27 +1,26 @@
|
|||
from typing import List
|
||||
|
||||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def start():
|
||||
titles = ["Stranger Things", "House of Cards", "Narcos"]
|
||||
children = []
|
||||
for t in titles:
|
||||
children.append(a.step(t))
|
||||
return end.step(children)
|
||||
children = [a.bind(t) for t in titles]
|
||||
return workflow.continuation(end.bind(children))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def a(title: str) -> str:
|
||||
return "{} processed".format(title)
|
||||
return f"{title} processed"
|
||||
|
||||
|
||||
@workflow.step
|
||||
def end(results: List[str]) -> str:
|
||||
return "\n".join(results)
|
||||
@ray.remote
|
||||
def end(results: "List[ray.ObjectRef[str]]") -> str:
|
||||
return "\n".join(ray.get(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
start.step().run()
|
||||
workflow.create(start.bind()).run()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import ray
|
||||
from ray import workflow
|
||||
import requests
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def compute_large_fib(M: int, n: int = 1, fib: int = 1):
|
||||
next_fib = requests.post(
|
||||
"https://nemo.api.stdlib.com/fibonacci@0.0.1/", data={"nth": n}
|
||||
|
@ -10,9 +11,9 @@ def compute_large_fib(M: int, n: int = 1, fib: int = 1):
|
|||
if next_fib > M:
|
||||
return fib
|
||||
else:
|
||||
return compute_large_fib.step(M, n + 1, next_fib)
|
||||
return workflow.continuation(compute_large_fib.bind(M, n + 1, next_fib))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
workflow.init()
|
||||
assert compute_large_fib.step(100).run() == 89
|
||||
assert workflow.create(compute_large_fib.bind(100)).run() == 89
|
||||
|
|
|
@ -64,16 +64,16 @@ def test_basic_workflows(workflow_start_regular_shared):
|
|||
z = append2.bind(x)
|
||||
return workflow.continuation(join.bind(y, z))
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def factorial(n):
|
||||
if n == 1:
|
||||
return 1
|
||||
else:
|
||||
return mul.step(n, factorial.step(n - 1))
|
||||
return workflow.continuation(mul.bind(n, factorial.bind(n - 1)))
|
||||
|
||||
# This test also shows different "style" of running workflows.
|
||||
assert (
|
||||
|
@ -92,7 +92,7 @@ def test_basic_workflows(workflow_start_regular_shared):
|
|||
wf = fork_join.bind()
|
||||
assert workflow.create(wf).run() == "join([source1][append1], [source1][append2])"
|
||||
|
||||
assert factorial.step(10).run() == 3628800
|
||||
assert workflow.create(factorial.bind(10)).run() == 3628800
|
||||
|
||||
|
||||
def test_async_execution(workflow_start_regular_shared):
|
||||
|
@ -269,13 +269,13 @@ def test_step_failure_decorator(workflow_start_regular_shared, tmp_path):
|
|||
|
||||
|
||||
def test_nested_catch_exception(workflow_start_regular_shared, tmp_path):
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def f2():
|
||||
return 10
|
||||
|
||||
@workflow.step
|
||||
def f1():
|
||||
return f2.step()
|
||||
return workflow.continuation(f2.bind())
|
||||
|
||||
assert (10, None) == f1.options(catch_exceptions=True).step().run()
|
||||
|
||||
|
|
|
@ -210,15 +210,15 @@ def test_get_named_step_output_error(workflow_start_regular, tmp_path):
|
|||
|
||||
|
||||
def test_get_named_step_default(workflow_start_regular, tmp_path):
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def factorial(n, r=1):
|
||||
if n == 1:
|
||||
return r
|
||||
return factorial.step(n - 1, r * n)
|
||||
return workflow.continuation(factorial.bind(n - 1, r * n))
|
||||
|
||||
import math
|
||||
|
||||
assert math.factorial(5) == factorial.step(5).run("factorial")
|
||||
assert math.factorial(5) == workflow.create(factorial.bind(5)).run("factorial")
|
||||
for i in range(5):
|
||||
step_name = (
|
||||
"test_basic_workflows_2.test_get_named_step_default.locals.factorial"
|
||||
|
|
|
@ -3,30 +3,31 @@ import pytest
|
|||
from ray.tests.conftest import * # noqa
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray import workflow
|
||||
|
||||
|
||||
def test_simple_large_intermediate(workflow_start_regular_shared):
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def large_input():
|
||||
return np.arange(2 ** 24)
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def average(x):
|
||||
return np.mean(x)
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def simple_large_intermediate():
|
||||
x = large_input.step()
|
||||
y = identity.step(x)
|
||||
return average.step(y)
|
||||
x = large_input.bind()
|
||||
y = identity.bind(x)
|
||||
return workflow.continuation(average.bind(y))
|
||||
|
||||
start = time.time()
|
||||
outputs = simple_large_intermediate.step().run()
|
||||
outputs = workflow.create(simple_large_intermediate.bind()).run()
|
||||
print(f"duration = {time.time() - start}")
|
||||
assert np.isclose(outputs, 8388607.5)
|
||||
|
||||
|
|
|
@ -12,30 +12,34 @@ from ray import workflow
|
|||
|
||||
|
||||
def test_objectref_inputs(workflow_start_regular_shared):
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def nested_workflow(n: int):
|
||||
if n <= 0:
|
||||
return "nested"
|
||||
else:
|
||||
return nested_workflow.step(n - 1)
|
||||
return workflow.continuation(nested_workflow.bind(n - 1))
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def deref_check(u: int, x: str, y: List[str], z: List[Dict[str, str]]):
|
||||
try:
|
||||
return (
|
||||
u == 42
|
||||
and x == "nested"
|
||||
and y[0] == "nested"
|
||||
and z[0]["output"] == "nested"
|
||||
and isinstance(y[0], ray.ObjectRef)
|
||||
and ray.get(y) == ["nested"]
|
||||
and isinstance(z[0]["output"], ray.ObjectRef)
|
||||
and ray.get(z[0]["output"]) == "nested"
|
||||
), f"{u}, {x}, {y}, {z}"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
output, s = deref_check.step(
|
||||
output, s = workflow.create(
|
||||
deref_check.bind(
|
||||
ray.put(42),
|
||||
nested_workflow.step(10),
|
||||
[nested_workflow.step(9)],
|
||||
[{"output": nested_workflow.step(7)}],
|
||||
nested_workflow.bind(10),
|
||||
[nested_workflow.bind(9)],
|
||||
[{"output": nested_workflow.bind(7)}],
|
||||
)
|
||||
).run()
|
||||
assert output is True, s
|
||||
|
||||
|
@ -45,29 +49,38 @@ def test_objectref_outputs(workflow_start_regular_shared):
|
|||
def nested_ref():
|
||||
return ray.put(42)
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def nested_ref_workflow():
|
||||
return nested_ref.remote()
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def return_objectrefs() -> List[ObjectRef]:
|
||||
return [ray.put(x) for x in range(5)]
|
||||
|
||||
single = nested_ref_workflow.step().run()
|
||||
single = workflow.create(nested_ref_workflow.bind()).run()
|
||||
assert ray.get(ray.get(single)) == 42
|
||||
|
||||
multi = return_objectrefs.step().run()
|
||||
multi = workflow.create(return_objectrefs.bind()).run()
|
||||
assert ray.get(multi) == list(range(5))
|
||||
|
||||
|
||||
def test_object_deref(workflow_start_regular_shared):
|
||||
def test_object_input_dedup(workflow_start_regular_shared):
|
||||
@workflow.step
|
||||
def empty_list():
|
||||
return [1]
|
||||
|
||||
@workflow.step
|
||||
def deref_shared(x, y):
|
||||
# x and y should share the same variable.
|
||||
x.append(2)
|
||||
return y == [1, 2]
|
||||
|
||||
@workflow.step
|
||||
x = empty_list.step()
|
||||
assert deref_shared.step(x, x).run()
|
||||
|
||||
|
||||
def test_object_deref(workflow_start_regular_shared):
|
||||
@ray.remote
|
||||
def empty_list():
|
||||
return [1]
|
||||
|
||||
|
@ -77,22 +90,18 @@ def test_object_deref(workflow_start_regular_shared):
|
|||
|
||||
@ray.remote
|
||||
def return_workflow():
|
||||
return empty_list.step()
|
||||
return workflow.create(empty_list.bind())
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def return_data() -> ray.ObjectRef:
|
||||
obj = ray.put(np.ones(4096))
|
||||
return obj
|
||||
return ray.put(np.ones(4096))
|
||||
|
||||
@workflow.step
|
||||
def receive_data(data: np.ndarray):
|
||||
return data
|
||||
|
||||
x = empty_list.step()
|
||||
assert deref_shared.step(x, x).run()
|
||||
@ray.remote
|
||||
def receive_data(data: "ray.ObjectRef[np.ndarray]"):
|
||||
return ray.get(data)
|
||||
|
||||
# test we are forbidden from directly passing workflow to Ray.
|
||||
x = empty_list.step()
|
||||
x = workflow.create(empty_list.bind())
|
||||
with pytest.raises(ValueError):
|
||||
ray.put(x)
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -101,8 +110,8 @@ def test_object_deref(workflow_start_regular_shared):
|
|||
ray.get(return_workflow.remote())
|
||||
|
||||
# test return object ref
|
||||
obj = return_data.step()
|
||||
arr: np.ndarray = receive_data.step(obj).run()
|
||||
obj = return_data.bind()
|
||||
arr: np.ndarray = workflow.create(receive_data.bind(obj)).run()
|
||||
assert np.array_equal(arr, np.ones(4096))
|
||||
|
||||
|
||||
|
|
|
@ -12,45 +12,45 @@ from ray.workflow.storage.filesystem import FilesystemStorageImpl
|
|||
from ray.workflow.tests.utils import _alter_storage
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def pass_1(x: str, y: str):
|
||||
return sha1((x + y + "1").encode()).hexdigest()
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def pass_2(x: str, y: str):
|
||||
if sha1((x + y + "_2").encode()).hexdigest() > x:
|
||||
return sha1((x + y + "2").encode()).hexdigest()
|
||||
return pass_1.step(x, y)
|
||||
return workflow.continuation(pass_1.bind(x, y))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def pass_3(x: str, y: str):
|
||||
if sha1((x + y + "_3").encode()).hexdigest() > x:
|
||||
return sha1((x + y + "3").encode()).hexdigest()
|
||||
return pass_2.step(x, y)
|
||||
return workflow.continuation(pass_2.bind(x, y))
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def merge(x0: str, x1: str, x2: str) -> str:
|
||||
return sha1((x0 + x1 + x2).encode()).hexdigest()
|
||||
|
||||
|
||||
@workflow.step
|
||||
@ray.remote
|
||||
def scan(x0: str, x1: str, x2: str):
|
||||
x0 = sha1((x0 + x2).encode()).hexdigest()
|
||||
x1 = sha1((x1 + x2).encode()).hexdigest()
|
||||
x2 = sha1((x0 + x1 + x2).encode()).hexdigest()
|
||||
y0, y1, y2 = pass_1.step(x0, x1), pass_2.step(x1, x2), pass_3.step(x2, x0)
|
||||
return merge.step(y0, y1, y2)
|
||||
y0, y1, y2 = pass_1.bind(x0, x1), pass_2.bind(x1, x2), pass_3.bind(x2, x0)
|
||||
return workflow.continuation(merge.bind(y0, y1, y2))
|
||||
|
||||
|
||||
def construct_workflow(length: int):
|
||||
results = ["a", "b"]
|
||||
for i in range(length):
|
||||
x0, x1, x2 = results[-2], results[-1], str(i)
|
||||
results.append(scan.step(x0, x1, x2))
|
||||
return results[-1]
|
||||
results.append(scan.bind(x0, x1, x2))
|
||||
return workflow.create(results[-1])
|
||||
|
||||
|
||||
def _locate_initial_commit(debug_store: DebugStorage) -> int:
|
||||
|
|
Loading…
Add table
Reference in a new issue