[Workflow] Optimize out tail recursion in python (#22794)

* add test

* warning when inplace subworkflows may use different resources
This commit is contained in:
Siyuan (Ryans) Zhuang 2022-03-16 01:51:18 -07:00 committed by GitHub
parent 60a3340387
commit d67c34256b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 17 deletions

View file

@ -1,6 +1,7 @@
import time import time
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import functools
import logging import logging
from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union
import ray import ray
@ -106,15 +107,8 @@ def _resolve_dynamic_workflow_refs(workflow_refs: "List[WorkflowRef]"):
return workflow_ref_mapping return workflow_ref_mapping
def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult": def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
"""Execute workflow. """Internal function of workflow execution."""
Args:
workflow: The workflow to be executed.
Returns:
An object ref that represent the result.
"""
if workflow.executed: if workflow.executed:
return workflow.result return workflow.result
@ -183,7 +177,9 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
if step_options.step_type == StepType.WAIT: if step_options.step_type == StepType.WAIT:
executor = _workflow_wait_executor executor = _workflow_wait_executor
else: else:
executor = _workflow_step_executor # Tell the executor that we are running inplace. This enables
# tail-recursion optimization.
executor = functools.partial(_workflow_step_executor, inplace=True)
else: else:
if step_options.step_type == StepType.WAIT: if step_options.step_type == StepType.WAIT:
# This is very important to set "num_cpus=0" to # This is very important to set "num_cpus=0" to
@ -205,11 +201,6 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
) )
# Stage 4: post processing outputs # Stage 4: post processing outputs
if not isinstance(persisted_output, WorkflowOutputType):
persisted_output = ray.put(persisted_output)
if not isinstance(persisted_output, WorkflowOutputType):
volatile_output = ray.put(volatile_output)
if step_options.step_type != StepType.READONLY_ACTOR_METHOD: if step_options.step_type != StepType.READONLY_ACTOR_METHOD:
if not step_options.allow_inplace: if not step_options.allow_inplace:
# TODO: [Possible flaky bug] Here the RUNNING state may # TODO: [Possible flaky bug] Here the RUNNING state may
@ -225,6 +216,45 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
return result return result
@dataclass
class InplaceReturnedWorkflow:
"""Hold information about a workflow returned from an inplace step."""
# The returned workflow.
workflow: Workflow
# The dict that contains the context of the inplace returned workflow.
context: Dict
def execute_workflow(workflow: Workflow) -> "WorkflowExecutionResult":
"""Execute workflow.
This function also performs tail-recursion optimization for inplace
workflow steps.
Args:
workflow: The workflow to be executed.
Returns:
An object ref that represent the result.
"""
# Tail recursion optimization.
context = {}
while True:
with workflow_context.fork_workflow_step_context(**context):
result = _execute_workflow(workflow)
if not isinstance(result.persisted_output, InplaceReturnedWorkflow):
break
workflow = result.persisted_output.workflow
context = result.persisted_output.context
# Convert the outputs into ObjectRefs.
if not isinstance(result.persisted_output, WorkflowOutputType):
result.persisted_output = ray.put(result.persisted_output)
if not isinstance(result.persisted_output, WorkflowOutputType):
result.volatile_output = ray.put(result.volatile_output)
return result
async def _write_step_inputs( async def _write_step_inputs(
wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData
) -> None: ) -> None:
@ -383,6 +413,7 @@ def _workflow_step_executor(
step_id: "StepID", step_id: "StepID",
baked_inputs: "_BakedWorkflowInputs", baked_inputs: "_BakedWorkflowInputs",
runtime_options: "WorkflowStepRuntimeOptions", runtime_options: "WorkflowStepRuntimeOptions",
inplace: bool = False,
) -> Tuple[Any, Any]: ) -> Tuple[Any, Any]:
"""Executor function for workflow step. """Executor function for workflow step.
@ -392,6 +423,7 @@ def _workflow_step_executor(
baked_inputs: The processed inputs for the step. baked_inputs: The processed inputs for the step.
context: Workflow step context. Used to access correct storage etc. context: Workflow step context. Used to access correct storage etc.
runtime_options: Parameters for workflow step execution. runtime_options: Parameters for workflow step execution.
inplace: Execute the workflow inplace.
Returns: Returns:
Workflow step output. Workflow step output.
@ -440,7 +472,9 @@ def _workflow_step_executor(
exception=None, exception=None,
) )
if isinstance(persisted_output, Workflow): if isinstance(persisted_output, Workflow):
sub_workflow = persisted_output
outer_most_step_id = context.outer_most_step_id outer_most_step_id = context.outer_most_step_id
assert volatile_output is None
if step_type == StepType.FUNCTION: if step_type == StepType.FUNCTION:
# Passing down outer most step so inner nested steps would # Passing down outer most step so inner nested steps would
# access the same outer most step. # access the same outer most step.
@ -450,12 +484,30 @@ def _workflow_step_executor(
# current step is the outer most step for the inner nested # current step is the outer most step for the inner nested
# workflow steps. # workflow steps.
outer_most_step_id = workflow_context.get_current_step_id() outer_most_step_id = workflow_context.get_current_step_id()
assert volatile_output is None if inplace:
_step_options = sub_workflow.data.step_options
if (
_step_options.step_type != StepType.WAIT
and runtime_options.ray_options != _step_options.ray_options
):
logger.warning(
f"Workflow step '{sub_workflow.step_id}' uses "
f"a Ray option different to its caller step '{step_id}' "
f"and will be executed inplace. Ray assumes it still "
f"consumes the same resource as the caller. This may result "
f"in oversubscribing resources."
)
return (
InplaceReturnedWorkflow(
sub_workflow, {"outer_most_step_id": outer_most_step_id}
),
None,
)
# Execute sub-workflow. Pass down "outer_most_step_id". # Execute sub-workflow. Pass down "outer_most_step_id".
with workflow_context.fork_workflow_step_context( with workflow_context.fork_workflow_step_context(
outer_most_step_id=outer_most_step_id outer_most_step_id=outer_most_step_id
): ):
result = execute_workflow(persisted_output) result = execute_workflow(sub_workflow)
# When virtual actor returns a workflow in the method, # When virtual actor returns a workflow in the method,
# the volatile_output and persisted_output will be put together # the volatile_output and persisted_output will be put together
persisted_output = result.persisted_output persisted_output = result.persisted_output

View file

@ -63,6 +63,20 @@ def test_inplace_workflows(workflow_start_regular_shared):
assert exp_remote.step(k, n).run() == k * 2 ** n assert exp_remote.step(k, n).run() == k * 2 ** n
def test_tail_recursion_optimization(workflow_start_regular_shared):
@workflow.step
def tail_recursion(n):
import inspect
# check if the stack is growing
assert len(inspect.stack(0)) < 20
if n <= 0:
return "ok"
return tail_recursion.options(allow_inplace=True).step(n - 1)
tail_recursion.step(30).run()
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys