mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Workflow] Optimize out tail recursion in python (#22794)
* add test * warning when inplace subworkflows may use different resources
This commit is contained in:
parent
60a3340387
commit
d67c34256b
2 changed files with 83 additions and 17 deletions
|
@ -1,6 +1,7 @@
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union
|
from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union
|
||||||
import ray
|
import ray
|
||||||
|
@ -106,15 +107,8 @@ def _resolve_dynamic_workflow_refs(workflow_refs: "List[WorkflowRef]"):
|
||||||
return workflow_ref_mapping
|
return workflow_ref_mapping
|
||||||
|
|
||||||
|
|
||||||
def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||||
"""Execute workflow.
|
"""Internal function of workflow execution."""
|
||||||
|
|
||||||
Args:
|
|
||||||
workflow: The workflow to be executed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An object ref that represent the result.
|
|
||||||
"""
|
|
||||||
if workflow.executed:
|
if workflow.executed:
|
||||||
return workflow.result
|
return workflow.result
|
||||||
|
|
||||||
|
@ -183,7 +177,9 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||||
if step_options.step_type == StepType.WAIT:
|
if step_options.step_type == StepType.WAIT:
|
||||||
executor = _workflow_wait_executor
|
executor = _workflow_wait_executor
|
||||||
else:
|
else:
|
||||||
executor = _workflow_step_executor
|
# Tell the executor that we are running inplace. This enables
|
||||||
|
# tail-recursion optimization.
|
||||||
|
executor = functools.partial(_workflow_step_executor, inplace=True)
|
||||||
else:
|
else:
|
||||||
if step_options.step_type == StepType.WAIT:
|
if step_options.step_type == StepType.WAIT:
|
||||||
# This is very important to set "num_cpus=0" to
|
# This is very important to set "num_cpus=0" to
|
||||||
|
@ -205,11 +201,6 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stage 4: post processing outputs
|
# Stage 4: post processing outputs
|
||||||
if not isinstance(persisted_output, WorkflowOutputType):
|
|
||||||
persisted_output = ray.put(persisted_output)
|
|
||||||
if not isinstance(persisted_output, WorkflowOutputType):
|
|
||||||
volatile_output = ray.put(volatile_output)
|
|
||||||
|
|
||||||
if step_options.step_type != StepType.READONLY_ACTOR_METHOD:
|
if step_options.step_type != StepType.READONLY_ACTOR_METHOD:
|
||||||
if not step_options.allow_inplace:
|
if not step_options.allow_inplace:
|
||||||
# TODO: [Possible flaky bug] Here the RUNNING state may
|
# TODO: [Possible flaky bug] Here the RUNNING state may
|
||||||
|
@ -225,6 +216,45 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InplaceReturnedWorkflow:
|
||||||
|
"""Hold information about a workflow returned from an inplace step."""
|
||||||
|
|
||||||
|
# The returned workflow.
|
||||||
|
workflow: Workflow
|
||||||
|
# The dict that contains the context of the inplace returned workflow.
|
||||||
|
context: Dict
|
||||||
|
|
||||||
|
|
||||||
|
def execute_workflow(workflow: Workflow) -> "WorkflowExecutionResult":
|
||||||
|
"""Execute workflow.
|
||||||
|
|
||||||
|
This function also performs tail-recursion optimization for inplace
|
||||||
|
workflow steps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow: The workflow to be executed.
|
||||||
|
Returns:
|
||||||
|
An object ref that represent the result.
|
||||||
|
"""
|
||||||
|
# Tail recursion optimization.
|
||||||
|
context = {}
|
||||||
|
while True:
|
||||||
|
with workflow_context.fork_workflow_step_context(**context):
|
||||||
|
result = _execute_workflow(workflow)
|
||||||
|
if not isinstance(result.persisted_output, InplaceReturnedWorkflow):
|
||||||
|
break
|
||||||
|
workflow = result.persisted_output.workflow
|
||||||
|
context = result.persisted_output.context
|
||||||
|
|
||||||
|
# Convert the outputs into ObjectRefs.
|
||||||
|
if not isinstance(result.persisted_output, WorkflowOutputType):
|
||||||
|
result.persisted_output = ray.put(result.persisted_output)
|
||||||
|
if not isinstance(result.persisted_output, WorkflowOutputType):
|
||||||
|
result.volatile_output = ray.put(result.volatile_output)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _write_step_inputs(
|
async def _write_step_inputs(
|
||||||
wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData
|
wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -383,6 +413,7 @@ def _workflow_step_executor(
|
||||||
step_id: "StepID",
|
step_id: "StepID",
|
||||||
baked_inputs: "_BakedWorkflowInputs",
|
baked_inputs: "_BakedWorkflowInputs",
|
||||||
runtime_options: "WorkflowStepRuntimeOptions",
|
runtime_options: "WorkflowStepRuntimeOptions",
|
||||||
|
inplace: bool = False,
|
||||||
) -> Tuple[Any, Any]:
|
) -> Tuple[Any, Any]:
|
||||||
"""Executor function for workflow step.
|
"""Executor function for workflow step.
|
||||||
|
|
||||||
|
@ -392,6 +423,7 @@ def _workflow_step_executor(
|
||||||
baked_inputs: The processed inputs for the step.
|
baked_inputs: The processed inputs for the step.
|
||||||
context: Workflow step context. Used to access correct storage etc.
|
context: Workflow step context. Used to access correct storage etc.
|
||||||
runtime_options: Parameters for workflow step execution.
|
runtime_options: Parameters for workflow step execution.
|
||||||
|
inplace: Execute the workflow inplace.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Workflow step output.
|
Workflow step output.
|
||||||
|
@ -440,7 +472,9 @@ def _workflow_step_executor(
|
||||||
exception=None,
|
exception=None,
|
||||||
)
|
)
|
||||||
if isinstance(persisted_output, Workflow):
|
if isinstance(persisted_output, Workflow):
|
||||||
|
sub_workflow = persisted_output
|
||||||
outer_most_step_id = context.outer_most_step_id
|
outer_most_step_id = context.outer_most_step_id
|
||||||
|
assert volatile_output is None
|
||||||
if step_type == StepType.FUNCTION:
|
if step_type == StepType.FUNCTION:
|
||||||
# Passing down outer most step so inner nested steps would
|
# Passing down outer most step so inner nested steps would
|
||||||
# access the same outer most step.
|
# access the same outer most step.
|
||||||
|
@ -450,12 +484,30 @@ def _workflow_step_executor(
|
||||||
# current step is the outer most step for the inner nested
|
# current step is the outer most step for the inner nested
|
||||||
# workflow steps.
|
# workflow steps.
|
||||||
outer_most_step_id = workflow_context.get_current_step_id()
|
outer_most_step_id = workflow_context.get_current_step_id()
|
||||||
assert volatile_output is None
|
if inplace:
|
||||||
|
_step_options = sub_workflow.data.step_options
|
||||||
|
if (
|
||||||
|
_step_options.step_type != StepType.WAIT
|
||||||
|
and runtime_options.ray_options != _step_options.ray_options
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Workflow step '{sub_workflow.step_id}' uses "
|
||||||
|
f"a Ray option different to its caller step '{step_id}' "
|
||||||
|
f"and will be executed inplace. Ray assumes it still "
|
||||||
|
f"consumes the same resource as the caller. This may result "
|
||||||
|
f"in oversubscribing resources."
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
InplaceReturnedWorkflow(
|
||||||
|
sub_workflow, {"outer_most_step_id": outer_most_step_id}
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
# Execute sub-workflow. Pass down "outer_most_step_id".
|
# Execute sub-workflow. Pass down "outer_most_step_id".
|
||||||
with workflow_context.fork_workflow_step_context(
|
with workflow_context.fork_workflow_step_context(
|
||||||
outer_most_step_id=outer_most_step_id
|
outer_most_step_id=outer_most_step_id
|
||||||
):
|
):
|
||||||
result = execute_workflow(persisted_output)
|
result = execute_workflow(sub_workflow)
|
||||||
# When virtual actor returns a workflow in the method,
|
# When virtual actor returns a workflow in the method,
|
||||||
# the volatile_output and persisted_output will be put together
|
# the volatile_output and persisted_output will be put together
|
||||||
persisted_output = result.persisted_output
|
persisted_output = result.persisted_output
|
||||||
|
|
|
@ -63,6 +63,20 @@ def test_inplace_workflows(workflow_start_regular_shared):
|
||||||
assert exp_remote.step(k, n).run() == k * 2 ** n
|
assert exp_remote.step(k, n).run() == k * 2 ** n
|
||||||
|
|
||||||
|
|
||||||
|
def test_tail_recursion_optimization(workflow_start_regular_shared):
|
||||||
|
@workflow.step
|
||||||
|
def tail_recursion(n):
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
# check if the stack is growing
|
||||||
|
assert len(inspect.stack(0)) < 20
|
||||||
|
if n <= 0:
|
||||||
|
return "ok"
|
||||||
|
return tail_recursion.options(allow_inplace=True).step(n - 1)
|
||||||
|
|
||||||
|
tail_recursion.step(30).run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue