[workflow] Ensure that DAGs are dereferenced like ObjectRefs in Ray tasks (#23320)

This commit is contained in:
Siyuan (Ryans) Zhuang 2022-03-18 17:02:15 -07:00 committed by GitHub
parent 9b38b6de47
commit 65cc877ad8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 12 deletions

View file

@ -103,12 +103,19 @@ class WorkflowStaticRef:
step_id: StepID
# The ObjectRef of the output.
ref: ObjectRef
# This tag indicates we should resolve the workflow like an ObjectRef, when
# included in the arguments of another workflow.
_resolve_like_object_ref_in_args: bool = False
def __hash__(self):
return hash(self.step_id + self.ref.hex())
def __reduce__(self):
return WorkflowStaticRef, (self.step_id, _RefBypass(self.ref))
return WorkflowStaticRef, (
self.step_id,
_RefBypass(self.ref),
self._resolve_like_object_ref_in_args,
)
@PublicAPI(stability="beta")

View file

@ -20,8 +20,14 @@ def transform_ray_dag_to_workflow(dag_node: DAGNode, input_context: DAGInputData
def _node_visitor(node: Any) -> Any:
if isinstance(node, FunctionNode):
workflow_step = workflow.step(node._body).options(**node._bound_options)
return workflow_step.step(*node._bound_args, **node._bound_kwargs)
# "_resolve_like_object_ref_in_args" indicates we should resolve the
# workflow like an ObjectRef, when included in the arguments of
# another workflow.
workflow_step = workflow.step(node._body).options(
**node._bound_options, _resolve_like_object_ref_in_args=True
)
wf = workflow_step.step(*node._bound_args, **node._bound_kwargs)
return wf
if isinstance(node, InputAtrributeNode):
return node._execute_impl() # get data from input node
if isinstance(node, InputNode):

View file

@ -35,10 +35,10 @@ class WorkflowNotResumableError(Exception):
@WorkflowStepFunction
def _recover_workflow_step(
args: List[Any],
kwargs: Dict[str, Any],
input_workflows: List[Any],
input_workflow_refs: List[WorkflowRef],
*args,
**kwargs,
):
"""A workflow step that recovers the output of an unfinished step.
@ -151,8 +151,13 @@ def _construct_resume_workflow_from_step(
workflow_refs = list(map(WorkflowRef, result.workflow_refs))
args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs)
# Note: we must uppack args and kwargs, so the refs in the args/kwargs can get
# resolved consistently like in Ray.
recovery_workflow: Workflow = _recover_workflow_step.step(
args, kwargs, input_workflows, workflow_refs
input_workflows,
workflow_refs,
*args,
**kwargs,
)
recovery_workflow._step_id = step_id
# override step_options

View file

@ -149,10 +149,17 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
for w in inputs.workflows:
static_ref = w.ref
if static_ref is None:
extra_options = w.data.step_options.ray_options
# The input workflow is not a reference to an executed
# workflow .
# workflow.
output = execute_workflow(w).persisted_output
static_ref = WorkflowStaticRef(step_id=w.step_id, ref=output)
static_ref = WorkflowStaticRef(
step_id=w.step_id,
ref=output,
_resolve_like_object_ref_in_args=extra_options.get(
"_resolve_like_object_ref_in_args", False
),
)
workflow_outputs.append(static_ref)
baked_inputs = _BakedWorkflowInputs(
@ -187,9 +194,10 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
# tasks.
executor = _workflow_wait_executor_remote.options(num_cpus=0).remote
else:
executor = _workflow_step_executor_remote.options(
**step_options.ray_options
).remote
ray_options = step_options.ray_options.copy()
# cleanup the "_resolve_like_object_ref_in_args" option, it is not for Ray.
ray_options.pop("_resolve_like_object_ref_in_args", None)
executor = _workflow_step_executor_remote.options(**ray_options).remote
# Stage 3: execution
persisted_output, volatile_output = executor(
@ -627,7 +635,10 @@ class _BakedWorkflowInputs:
"""
objects_mapping = []
for obj_ref in self.workflow_outputs:
obj, ref = _resolve_object_ref(obj_ref.ref)
if obj_ref._resolve_like_object_ref_in_args:
obj = obj_ref.ref
else:
obj, ref = _resolve_object_ref(obj_ref.ref)
objects_mapping.append(obj)
workflow_ref_mapping = _resolve_dynamic_workflow_refs(self.workflow_refs)

View file

@ -137,6 +137,26 @@ def test_dereference_object_refs(workflow_start_regular_shared):
ray.get(dag.execute())
def test_dereference_dags(workflow_start_regular_shared):
"""Ensure that DAGs are dereferenced like ObjectRefs in ray tasks."""
@ray.remote
def g(x, y):
assert x == 314
assert isinstance(y[0], ray.ObjectRef)
assert ray.get(y) == [2022]
@ray.remote
def h(x):
return x
dag = g.bind(x=h.bind(314), y=[h.bind(2022)])
# Run with workflow and normal Ray engine.
workflow.create(dag).run()
ray.get(dag.execute())
def test_workflow_continuation(workflow_start_regular_shared):
"""Test unified behavior of returning continuation inside
workflow and default Ray execution engine."""