[Workflow] Optimize out tail recursion in python (#22794)

* add test

* warning when inplace subworkflows may use different resources
This commit is contained in:
Siyuan (Ryans) Zhuang 2022-03-16 01:51:18 -07:00 committed by GitHub
parent 60a3340387
commit d67c34256b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 17 deletions

View file

@ -1,6 +1,7 @@
import time
import asyncio
from dataclasses import dataclass
import functools
import logging
from typing import List, Tuple, Any, Dict, Callable, Optional, TYPE_CHECKING, Union
import ray
@ -106,15 +107,8 @@ def _resolve_dynamic_workflow_refs(workflow_refs: "List[WorkflowRef]"):
return workflow_ref_mapping
def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
"""Execute workflow.
Args:
workflow: The workflow to be executed.
Returns:
An object ref that represent the result.
"""
def _execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
"""Internal function of workflow execution."""
if workflow.executed:
return workflow.result
@ -183,7 +177,9 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
if step_options.step_type == StepType.WAIT:
executor = _workflow_wait_executor
else:
executor = _workflow_step_executor
# Tell the executor that we are running inplace. This enables
# tail-recursion optimization.
executor = functools.partial(_workflow_step_executor, inplace=True)
else:
if step_options.step_type == StepType.WAIT:
# This is very important to set "num_cpus=0" to
@ -205,11 +201,6 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
)
# Stage 4: post processing outputs
if not isinstance(persisted_output, WorkflowOutputType):
persisted_output = ray.put(persisted_output)
if not isinstance(persisted_output, WorkflowOutputType):
volatile_output = ray.put(volatile_output)
if step_options.step_type != StepType.READONLY_ACTOR_METHOD:
if not step_options.allow_inplace:
# TODO: [Possible flaky bug] Here the RUNNING state may
@ -225,6 +216,45 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
return result
@dataclass
class InplaceReturnedWorkflow:
"""Hold information about a workflow returned from an inplace step."""
# The returned workflow.
workflow: Workflow
# The dict that contains the context of the inplace returned workflow.
context: Dict
def execute_workflow(workflow: Workflow) -> "WorkflowExecutionResult":
"""Execute workflow.
This function also performs tail-recursion optimization for inplace
workflow steps.
Args:
workflow: The workflow to be executed.
Returns:
An object ref that represent the result.
"""
# Tail recursion optimization.
context = {}
while True:
with workflow_context.fork_workflow_step_context(**context):
result = _execute_workflow(workflow)
if not isinstance(result.persisted_output, InplaceReturnedWorkflow):
break
workflow = result.persisted_output.workflow
context = result.persisted_output.context
# Convert the outputs into ObjectRefs.
if not isinstance(result.persisted_output, WorkflowOutputType):
result.persisted_output = ray.put(result.persisted_output)
if not isinstance(result.persisted_output, WorkflowOutputType):
result.volatile_output = ray.put(result.volatile_output)
return result
async def _write_step_inputs(
wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData
) -> None:
@ -383,6 +413,7 @@ def _workflow_step_executor(
step_id: "StepID",
baked_inputs: "_BakedWorkflowInputs",
runtime_options: "WorkflowStepRuntimeOptions",
inplace: bool = False,
) -> Tuple[Any, Any]:
"""Executor function for workflow step.
@ -392,6 +423,7 @@ def _workflow_step_executor(
baked_inputs: The processed inputs for the step.
context: Workflow step context. Used to access correct storage etc.
runtime_options: Parameters for workflow step execution.
inplace: Execute the workflow inplace.
Returns:
Workflow step output.
@ -440,7 +472,9 @@ def _workflow_step_executor(
exception=None,
)
if isinstance(persisted_output, Workflow):
sub_workflow = persisted_output
outer_most_step_id = context.outer_most_step_id
assert volatile_output is None
if step_type == StepType.FUNCTION:
# Passing down outer most step so inner nested steps would
# access the same outer most step.
@ -450,12 +484,30 @@ def _workflow_step_executor(
# current step is the outer most step for the inner nested
# workflow steps.
outer_most_step_id = workflow_context.get_current_step_id()
assert volatile_output is None
if inplace:
_step_options = sub_workflow.data.step_options
if (
_step_options.step_type != StepType.WAIT
and runtime_options.ray_options != _step_options.ray_options
):
logger.warning(
f"Workflow step '{sub_workflow.step_id}' uses "
f"a Ray option different to its caller step '{step_id}' "
f"and will be executed inplace. Ray assumes it still "
f"consumes the same resource as the caller. This may result "
f"in oversubscribing resources."
)
return (
InplaceReturnedWorkflow(
sub_workflow, {"outer_most_step_id": outer_most_step_id}
),
None,
)
# Execute sub-workflow. Pass down "outer_most_step_id".
with workflow_context.fork_workflow_step_context(
outer_most_step_id=outer_most_step_id
):
result = execute_workflow(persisted_output)
result = execute_workflow(sub_workflow)
# When virtual actor returns a workflow in the method,
# the volatile_output and persisted_output will be put together
persisted_output = result.persisted_output

View file

@ -63,6 +63,20 @@ def test_inplace_workflows(workflow_start_regular_shared):
assert exp_remote.step(k, n).run() == k * 2 ** n
def test_tail_recursion_optimization(workflow_start_regular_shared):
@workflow.step
def tail_recursion(n):
import inspect
# check if the stack is growing
assert len(inspect.stack(0)) < 20
if n <= 0:
return "ok"
return tail_recursion.options(allow_inplace=True).step(n - 1)
tail_recursion.step(30).run()
if __name__ == "__main__":
import sys