[workflow] Fix workflow recovery issue due to a bug of dynamic output (#21571)

* Fix workflow recovery issue due to a bug of dynamic output

* add tests
This commit is contained in:
Siyuan (Ryans) Zhuang 2022-01-24 15:34:57 -08:00 committed by GitHub
parent c2199a50e3
commit 99b287d236
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 3 deletions

View file

@ -282,6 +282,29 @@ def test_nested_catch_exception_2(workflow_start_regular_shared, tmp_path):
assert isinstance(err, ValueError)
def test_dynamic_output(workflow_start_regular_shared):
@workflow.step
def exponential_fail(k, n):
if n > 0:
if n < 3:
raise Exception("Failed intentionally")
return exponential_fail.options(name=f"step_{n}").step(
k * 2, n - 1)
return k
# When workflow fails, the dynamic output should points to the
# latest successful step.
try:
exponential_fail.options(name="step_0").step(
3, 10).run(workflow_id="dynamic_output")
except Exception:
pass
from ray.workflow.workflow_storage import get_workflow_storage
wf_storage = get_workflow_storage(workflow_id="dynamic_output")
result = wf_storage.inspect_step("step_0")
assert result.output_step_id == "step_3"
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -134,6 +134,7 @@ class WorkflowStorage:
outer_most_step_id: See WorkflowStepContext.
"""
tasks = []
dynamic_output_id = None
if isinstance(ret, Workflow):
# This workflow step returns a nested workflow.
assert step_id != ret.step_id
@ -154,9 +155,6 @@ class WorkflowStorage:
# tasks.append(self._put(self._key_step_output(step_id), ret))
dynamic_output_id = step_id
# TODO (yic): Delete exception file
tasks.append(
self._update_dynamic_output(outer_most_step_id,
dynamic_output_id))
else:
assert ret is None
promise = serialization.dump_to_storage(
@ -166,8 +164,17 @@ class WorkflowStorage:
# tasks.append(
# self._put(self._key_step_exception(step_id), exception))
# Finish checkpointing.
asyncio_run(asyncio.gather(*tasks))
# NOTE: if we update the dynamic output before
# finishing checkpointing, then during recovery, the dynamic could
# would point to a checkpoint that does not exist.
if dynamic_output_id is not None:
asyncio_run(
self._update_dynamic_output(outer_most_step_id,
dynamic_output_id))
def load_step_func_body(self, step_id: StepID) -> Callable:
"""Load the function body of the workflow step.