mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[workflow] Ensure that DAGs are dereferenced like ObjectRefs in Ray tasks (#23320)
This commit is contained in:
parent
9b38b6de47
commit
65cc877ad8
5 changed files with 61 additions and 12 deletions
|
@ -103,12 +103,19 @@ class WorkflowStaticRef:
|
|||
step_id: StepID
|
||||
# The ObjectRef of the output.
|
||||
ref: ObjectRef
|
||||
# This tag indicates we should resolve the workflow like an ObjectRef, when
|
||||
# included in the arguments of another workflow.
|
||||
_resolve_like_object_ref_in_args: bool = False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.step_id + self.ref.hex())
|
||||
|
||||
def __reduce__(self):
|
||||
return WorkflowStaticRef, (self.step_id, _RefBypass(self.ref))
|
||||
return WorkflowStaticRef, (
|
||||
self.step_id,
|
||||
_RefBypass(self.ref),
|
||||
self._resolve_like_object_ref_in_args,
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
|
|
@ -20,8 +20,14 @@ def transform_ray_dag_to_workflow(dag_node: DAGNode, input_context: DAGInputData
|
|||
|
||||
def _node_visitor(node: Any) -> Any:
|
||||
if isinstance(node, FunctionNode):
|
||||
workflow_step = workflow.step(node._body).options(**node._bound_options)
|
||||
return workflow_step.step(*node._bound_args, **node._bound_kwargs)
|
||||
# "_resolve_like_object_ref_in_args" indicates we should resolve the
|
||||
# workflow like an ObjectRef, when included in the arguments of
|
||||
# another workflow.
|
||||
workflow_step = workflow.step(node._body).options(
|
||||
**node._bound_options, _resolve_like_object_ref_in_args=True
|
||||
)
|
||||
wf = workflow_step.step(*node._bound_args, **node._bound_kwargs)
|
||||
return wf
|
||||
if isinstance(node, InputAtrributeNode):
|
||||
return node._execute_impl() # get data from input node
|
||||
if isinstance(node, InputNode):
|
||||
|
|
|
@ -35,10 +35,10 @@ class WorkflowNotResumableError(Exception):
|
|||
|
||||
@WorkflowStepFunction
|
||||
def _recover_workflow_step(
|
||||
args: List[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
input_workflows: List[Any],
|
||||
input_workflow_refs: List[WorkflowRef],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""A workflow step that recovers the output of an unfinished step.
|
||||
|
||||
|
@ -151,8 +151,13 @@ def _construct_resume_workflow_from_step(
|
|||
workflow_refs = list(map(WorkflowRef, result.workflow_refs))
|
||||
|
||||
args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs)
|
||||
# Note: we must uppack args and kwargs, so the refs in the args/kwargs can get
|
||||
# resolved consistently like in Ray.
|
||||
recovery_workflow: Workflow = _recover_workflow_step.step(
|
||||
args, kwargs, input_workflows, workflow_refs
|
||||
input_workflows,
|
||||
workflow_refs,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
recovery_workflow._step_id = step_id
|
||||
# override step_options
|
||||
|
|
|
@ -149,10 +149,17 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
for w in inputs.workflows:
|
||||
static_ref = w.ref
|
||||
if static_ref is None:
|
||||
extra_options = w.data.step_options.ray_options
|
||||
# The input workflow is not a reference to an executed
|
||||
# workflow .
|
||||
# workflow.
|
||||
output = execute_workflow(w).persisted_output
|
||||
static_ref = WorkflowStaticRef(step_id=w.step_id, ref=output)
|
||||
static_ref = WorkflowStaticRef(
|
||||
step_id=w.step_id,
|
||||
ref=output,
|
||||
_resolve_like_object_ref_in_args=extra_options.get(
|
||||
"_resolve_like_object_ref_in_args", False
|
||||
),
|
||||
)
|
||||
workflow_outputs.append(static_ref)
|
||||
|
||||
baked_inputs = _BakedWorkflowInputs(
|
||||
|
@ -187,9 +194,10 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
# tasks.
|
||||
executor = _workflow_wait_executor_remote.options(num_cpus=0).remote
|
||||
else:
|
||||
executor = _workflow_step_executor_remote.options(
|
||||
**step_options.ray_options
|
||||
).remote
|
||||
ray_options = step_options.ray_options.copy()
|
||||
# cleanup the "_resolve_like_object_ref_in_args" option, it is not for Ray.
|
||||
ray_options.pop("_resolve_like_object_ref_in_args", None)
|
||||
executor = _workflow_step_executor_remote.options(**ray_options).remote
|
||||
|
||||
# Stage 3: execution
|
||||
persisted_output, volatile_output = executor(
|
||||
|
@ -627,7 +635,10 @@ class _BakedWorkflowInputs:
|
|||
"""
|
||||
objects_mapping = []
|
||||
for obj_ref in self.workflow_outputs:
|
||||
obj, ref = _resolve_object_ref(obj_ref.ref)
|
||||
if obj_ref._resolve_like_object_ref_in_args:
|
||||
obj = obj_ref.ref
|
||||
else:
|
||||
obj, ref = _resolve_object_ref(obj_ref.ref)
|
||||
objects_mapping.append(obj)
|
||||
|
||||
workflow_ref_mapping = _resolve_dynamic_workflow_refs(self.workflow_refs)
|
||||
|
|
|
@ -137,6 +137,26 @@ def test_dereference_object_refs(workflow_start_regular_shared):
|
|||
ray.get(dag.execute())
|
||||
|
||||
|
||||
def test_dereference_dags(workflow_start_regular_shared):
|
||||
"""Ensure that DAGs are dereferenced like ObjectRefs in ray tasks."""
|
||||
|
||||
@ray.remote
|
||||
def g(x, y):
|
||||
assert x == 314
|
||||
assert isinstance(y[0], ray.ObjectRef)
|
||||
assert ray.get(y) == [2022]
|
||||
|
||||
@ray.remote
|
||||
def h(x):
|
||||
return x
|
||||
|
||||
dag = g.bind(x=h.bind(314), y=[h.bind(2022)])
|
||||
|
||||
# Run with workflow and normal Ray engine.
|
||||
workflow.create(dag).run()
|
||||
ray.get(dag.execute())
|
||||
|
||||
|
||||
def test_workflow_continuation(workflow_start_regular_shared):
|
||||
"""Test unified behavior of returning continuation inside
|
||||
workflow and default Ray execution engine."""
|
||||
|
|
Loading…
Add table
Reference in a new issue