mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Workflow] Optimize out tail recursion in python (#22794)
* add test * warning when inplace subworkflows may use different resources
This commit is contained in:
parent
60a3340387
commit
d67c34256b
2 changed files with 83 additions and 17 deletions
|
@ -1,6 +1,7 @@
|
|||
import time
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
import logging
|
||||
from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union
|
||||
import ray
|
||||
|
@ -106,15 +107,8 @@ def _resolve_dynamic_workflow_refs(workflow_refs: "List[WorkflowRef]"):
|
|||
return workflow_ref_mapping
|
||||
|
||||
|
||||
def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||
"""Execute workflow.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to be executed.
|
||||
|
||||
Returns:
|
||||
An object ref that represent the result.
|
||||
"""
|
||||
def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||
"""Internal function of workflow execution."""
|
||||
if workflow.executed:
|
||||
return workflow.result
|
||||
|
||||
|
@ -183,7 +177,9 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
if step_options.step_type == StepType.WAIT:
|
||||
executor = _workflow_wait_executor
|
||||
else:
|
||||
executor = _workflow_step_executor
|
||||
# Tell the executor that we are running inplace. This enables
|
||||
# tail-recursion optimization.
|
||||
executor = functools.partial(_workflow_step_executor, inplace=True)
|
||||
else:
|
||||
if step_options.step_type == StepType.WAIT:
|
||||
# This is very important to set "num_cpus=0" to
|
||||
|
@ -205,11 +201,6 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
)
|
||||
|
||||
# Stage 4: post processing outputs
|
||||
if not isinstance(persisted_output, WorkflowOutputType):
|
||||
persisted_output = ray.put(persisted_output)
|
||||
if not isinstance(persisted_output, WorkflowOutputType):
|
||||
volatile_output = ray.put(volatile_output)
|
||||
|
||||
if step_options.step_type != StepType.READONLY_ACTOR_METHOD:
|
||||
if not step_options.allow_inplace:
|
||||
# TODO: [Possible flaky bug] Here the RUNNING state may
|
||||
|
@ -225,6 +216,45 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class InplaceReturnedWorkflow:
|
||||
"""Hold information about a workflow returned from an inplace step."""
|
||||
|
||||
# The returned workflow.
|
||||
workflow: Workflow
|
||||
# The dict that contains the context of the inplace returned workflow.
|
||||
context: Dict
|
||||
|
||||
|
||||
def execute_workflow(workflow: Workflow) -> "WorkflowExecutionResult":
|
||||
"""Execute workflow.
|
||||
|
||||
This function also performs tail-recursion optimization for inplace
|
||||
workflow steps.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to be executed.
|
||||
Returns:
|
||||
An object ref that represent the result.
|
||||
"""
|
||||
# Tail recursion optimization.
|
||||
context = {}
|
||||
while True:
|
||||
with workflow_context.fork_workflow_step_context(**context):
|
||||
result = _execute_workflow(workflow)
|
||||
if not isinstance(result.persisted_output, InplaceReturnedWorkflow):
|
||||
break
|
||||
workflow = result.persisted_output.workflow
|
||||
context = result.persisted_output.context
|
||||
|
||||
# Convert the outputs into ObjectRefs.
|
||||
if not isinstance(result.persisted_output, WorkflowOutputType):
|
||||
result.persisted_output = ray.put(result.persisted_output)
|
||||
if not isinstance(result.persisted_output, WorkflowOutputType):
|
||||
result.volatile_output = ray.put(result.volatile_output)
|
||||
return result
|
||||
|
||||
|
||||
async def _write_step_inputs(
|
||||
wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData
|
||||
) -> None:
|
||||
|
@ -383,6 +413,7 @@ def _workflow_step_executor(
|
|||
step_id: "StepID",
|
||||
baked_inputs: "_BakedWorkflowInputs",
|
||||
runtime_options: "WorkflowStepRuntimeOptions",
|
||||
inplace: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""Executor function for workflow step.
|
||||
|
||||
|
@ -392,6 +423,7 @@ def _workflow_step_executor(
|
|||
baked_inputs: The processed inputs for the step.
|
||||
context: Workflow step context. Used to access correct storage etc.
|
||||
runtime_options: Parameters for workflow step execution.
|
||||
inplace: Execute the workflow inplace.
|
||||
|
||||
Returns:
|
||||
Workflow step output.
|
||||
|
@ -440,7 +472,9 @@ def _workflow_step_executor(
|
|||
exception=None,
|
||||
)
|
||||
if isinstance(persisted_output, Workflow):
|
||||
sub_workflow = persisted_output
|
||||
outer_most_step_id = context.outer_most_step_id
|
||||
assert volatile_output is None
|
||||
if step_type == StepType.FUNCTION:
|
||||
# Passing down outer most step so inner nested steps would
|
||||
# access the same outer most step.
|
||||
|
@ -450,12 +484,30 @@ def _workflow_step_executor(
|
|||
# current step is the outer most step for the inner nested
|
||||
# workflow steps.
|
||||
outer_most_step_id = workflow_context.get_current_step_id()
|
||||
assert volatile_output is None
|
||||
if inplace:
|
||||
_step_options = sub_workflow.data.step_options
|
||||
if (
|
||||
_step_options.step_type != StepType.WAIT
|
||||
and runtime_options.ray_options != _step_options.ray_options
|
||||
):
|
||||
logger.warning(
|
||||
f"Workflow step '{sub_workflow.step_id}' uses "
|
||||
f"a Ray option different to its caller step '{step_id}' "
|
||||
f"and will be executed inplace. Ray assumes it still "
|
||||
f"consumes the same resource as the caller. This may result "
|
||||
f"in oversubscribing resources."
|
||||
)
|
||||
return (
|
||||
InplaceReturnedWorkflow(
|
||||
sub_workflow, {"outer_most_step_id": outer_most_step_id}
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Execute sub-workflow. Pass down "outer_most_step_id".
|
||||
with workflow_context.fork_workflow_step_context(
|
||||
outer_most_step_id=outer_most_step_id
|
||||
):
|
||||
result = execute_workflow(persisted_output)
|
||||
result = execute_workflow(sub_workflow)
|
||||
# When virtual actor returns a workflow in the method,
|
||||
# the volatile_output and persisted_output will be put together
|
||||
persisted_output = result.persisted_output
|
||||
|
|
|
@ -63,6 +63,20 @@ def test_inplace_workflows(workflow_start_regular_shared):
|
|||
assert exp_remote.step(k, n).run() == k * 2 ** n
|
||||
|
||||
|
||||
def test_tail_recursion_optimization(workflow_start_regular_shared):
|
||||
@workflow.step
|
||||
def tail_recursion(n):
|
||||
import inspect
|
||||
|
||||
# check if the stack is growing
|
||||
assert len(inspect.stack(0)) < 20
|
||||
if n <= 0:
|
||||
return "ok"
|
||||
return tail_recursion.options(allow_inplace=True).step(n - 1)
|
||||
|
||||
tail_recursion.step(30).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue