diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 49806967b..41d4923e5 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -103,12 +103,19 @@ class WorkflowStaticRef: step_id: StepID # The ObjectRef of the output. ref: ObjectRef + # This tag indicates we should resolve the workflow like an ObjectRef, when + # included in the arguments of another workflow. + _resolve_like_object_ref_in_args: bool = False def __hash__(self): return hash(self.step_id + self.ref.hex()) def __reduce__(self): - return WorkflowStaticRef, (self.step_id, _RefBypass(self.ref)) + return WorkflowStaticRef, ( + self.step_id, + _RefBypass(self.ref), + self._resolve_like_object_ref_in_args, + ) @PublicAPI(stability="beta") diff --git a/python/ray/workflow/dag_to_workflow.py b/python/ray/workflow/dag_to_workflow.py index 0efdc7732..4fecf8eb8 100644 --- a/python/ray/workflow/dag_to_workflow.py +++ b/python/ray/workflow/dag_to_workflow.py @@ -20,8 +20,14 @@ def transform_ray_dag_to_workflow(dag_node: DAGNode, input_context: DAGInputData def _node_visitor(node: Any) -> Any: if isinstance(node, FunctionNode): - workflow_step = workflow.step(node._body).options(**node._bound_options) - return workflow_step.step(*node._bound_args, **node._bound_kwargs) + # "_resolve_like_object_ref_in_args" indicates we should resolve the + # workflow like an ObjectRef, when included in the arguments of + # another workflow. + workflow_step = workflow.step(node._body).options( + **node._bound_options, _resolve_like_object_ref_in_args=True + ) + wf = workflow_step.step(*node._bound_args, **node._bound_kwargs) + return wf if isinstance(node, InputAtrributeNode): return node._execute_impl() # get data from input node if isinstance(node, InputNode): diff --git a/python/ray/workflow/recovery.py b/python/ray/workflow/recovery.py index 8fcdee162..0887a6d98 100644 --- a/python/ray/workflow/recovery.py +++ b/python/ray/workflow/recovery.py @@ -35,10 +35,10 @@ class WorkflowNotResumableError(Exception): @WorkflowStepFunction def _recover_workflow_step( - args: List[Any], - kwargs: Dict[str, Any], input_workflows: List[Any], input_workflow_refs: List[WorkflowRef], + *args, + **kwargs, ): """A workflow step that recovers the output of an unfinished step. @@ -151,8 +151,13 @@ def _construct_resume_workflow_from_step( workflow_refs = list(map(WorkflowRef, result.workflow_refs)) args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs) + # Note: we must uppack args and kwargs, so the refs in the args/kwargs can get + # resolved consistently like in Ray. recovery_workflow: Workflow = _recover_workflow_step.step( - args, kwargs, input_workflows, workflow_refs + input_workflows, + workflow_refs, + *args, + **kwargs, ) recovery_workflow._step_id = step_id # override step_options diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index fe21d6f75..e82ce2250 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -149,10 +149,17 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": for w in inputs.workflows: static_ref = w.ref if static_ref is None: + extra_options = w.data.step_options.ray_options # The input workflow is not a reference to an executed - # workflow . + # workflow. output = execute_workflow(w).persisted_output - static_ref = WorkflowStaticRef(step_id=w.step_id, ref=output) + static_ref = WorkflowStaticRef( + step_id=w.step_id, + ref=output, + _resolve_like_object_ref_in_args=extra_options.get( + "_resolve_like_object_ref_in_args", False + ), + ) workflow_outputs.append(static_ref) baked_inputs = _BakedWorkflowInputs( @@ -187,9 +194,10 @@ def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": # tasks. executor = _workflow_wait_executor_remote.options(num_cpus=0).remote else: - executor = _workflow_step_executor_remote.options( - **step_options.ray_options - ).remote + ray_options = step_options.ray_options.copy() + # cleanup the "_resolve_like_object_ref_in_args" option, it is not for Ray. + ray_options.pop("_resolve_like_object_ref_in_args", None) + executor = _workflow_step_executor_remote.options(**ray_options).remote # Stage 3: execution persisted_output, volatile_output = executor( @@ -627,7 +635,10 @@ class _BakedWorkflowInputs: """ objects_mapping = [] for obj_ref in self.workflow_outputs: - obj, ref = _resolve_object_ref(obj_ref.ref) + if obj_ref._resolve_like_object_ref_in_args: + obj = obj_ref.ref + else: + obj, ref = _resolve_object_ref(obj_ref.ref) objects_mapping.append(obj) workflow_ref_mapping = _resolve_dynamic_workflow_refs(self.workflow_refs) diff --git a/python/ray/workflow/tests/test_dag_to_workflow.py b/python/ray/workflow/tests/test_dag_to_workflow.py index d48644b84..77c835d51 100644 --- a/python/ray/workflow/tests/test_dag_to_workflow.py +++ b/python/ray/workflow/tests/test_dag_to_workflow.py @@ -137,6 +137,26 @@ def test_dereference_object_refs(workflow_start_regular_shared): ray.get(dag.execute()) +def test_dereference_dags(workflow_start_regular_shared): + """Ensure that DAGs are dereferenced like ObjectRefs in ray tasks.""" + + @ray.remote + def g(x, y): + assert x == 314 + assert isinstance(y[0], ray.ObjectRef) + assert ray.get(y) == [2022] + + @ray.remote + def h(x): + return x + + dag = g.bind(x=h.bind(314), y=[h.bind(2022)]) + + # Run with workflow and normal Ray engine. + workflow.create(dag).run() + ray.get(dag.execute()) + + def test_workflow_continuation(workflow_start_regular_shared): """Test unified behavior of returning continuation inside workflow and default Ray execution engine."""