mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[workflow] workflow.wait() feature (#20163)
This PR implements `workflow.wait()`. When combined with checkpointing, it allows skipping sync & checkpointing of unfinished workflows.
This commit is contained in:
parent
adbcc4f79a
commit
3eb76466a0
8 changed files with 575 additions and 24 deletions
|
@ -38,3 +38,31 @@ Inplace is also useful when you need to pass something that is only valid in the
|
|||
def Foo():
|
||||
x = "<something that is only valid in the current process>"
|
||||
return Bar.options(allow_inplace=True).step(x)
|
||||
|
||||
|
||||
Wait for Partial Results
|
||||
------------------------
|
||||
|
||||
By default, a workflow step will only execute after the completion of all of its dependencies. This blocking behavior prevents certain types of workflows from being expressed (e.g., wait for two of the three steps to finish).
|
||||
|
||||
Analogous to ``ray.wait()``, in Ray Workflow we have ``workflow.wait(*steps: List[Workflow[T]], num_returns: int = 1, timeout: float = None) -> (List[T], List[Workflow[T])``. Calling `workflow.wait` would generate a logical step . The output of the logical step is a tuple of ready workflow results, and workflow results that have not yet been computed. For example, you can use it to print out workflow results as they are computed in the following dynamic workflow:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@workflow.step
|
||||
def do_task(i):
|
||||
time.sleep(random.random())
|
||||
return "task {}".format(i)
|
||||
|
||||
@workflow.step
|
||||
def report_results(wait_result: Tuple[List[str], List[Workflow[str]]]):
|
||||
ready, remaining = wait_result
|
||||
for result in ready:
|
||||
print("Completed", result)
|
||||
if not remaining:
|
||||
return "All done"
|
||||
else:
|
||||
return report_results.step(workflow.wait(remaining))
|
||||
|
||||
tasks = [do_task.step(i) for i in range(100)]
|
||||
report_results.step(workflow.wait(tasks)).run()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ray.workflow.api import (
|
||||
step, init, virtual_actor, get_output, get_actor, get_status, get_metadata,
|
||||
resume, cancel, list_all, resume_all, wait_for_event, sleep, delete)
|
||||
resume, cancel, list_all, resume_all, wait_for_event, sleep, delete, wait)
|
||||
from ray.workflow.workflow_access import WorkflowExecutionError
|
||||
from ray.workflow.common import WorkflowStatus
|
||||
from ray.workflow.event_listener import EventListener
|
||||
|
@ -22,6 +22,7 @@ __all__ = [
|
|||
"sleep",
|
||||
"EventListener",
|
||||
"delete",
|
||||
"wait",
|
||||
]
|
||||
|
||||
globals().update(WorkflowStatus.__members__)
|
||||
|
|
|
@ -356,7 +356,7 @@ def wait_for_event(event_listener_type: EventListenerType, *args,
|
|||
get_message.step(event_listener_type, *args, **kwargs))
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@PublicAPI(stability="beta")
|
||||
def sleep(duration: float) -> Workflow[Event]:
|
||||
"""
|
||||
A workfow that resolves after sleeping for a given duration.
|
||||
|
@ -478,5 +478,82 @@ def delete(workflow_id: str) -> None:
|
|||
wf_storage.delete_workflow()
|
||||
|
||||
|
||||
WaitResult = Tuple[List[Any], List[Workflow]]
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def wait(workflows: List[Workflow],
|
||||
num_returns: int = 1,
|
||||
timeout: Optional[float] = None) -> Workflow[WaitResult]:
|
||||
"""Return a list of result of workflows that are ready and a list of
|
||||
workflows that are pending.
|
||||
|
||||
Examples:
|
||||
>>> tasks = [task.step() for _ in range(3)]
|
||||
>>> wait_step = workflow.wait(tasks, num_returns=1)
|
||||
>>> print(wait_step.run())
|
||||
([result_1], [<Workflow object>, <Workflow object>])
|
||||
|
||||
>>> tasks = [task.step() for _ in range(2)] + [forever.step()]
|
||||
>>> wait_step = workflow.wait(tasks, num_returns=3, timeout=10)
|
||||
>>> print(wait_step.run())
|
||||
([result_1, result_2], [<Workflow object>])
|
||||
|
||||
If timeout is set, the function returns either when the requested number of
|
||||
workflows are ready or when the timeout is reached, whichever occurs first.
|
||||
If it is not set, the function simply waits until that number of workflows
|
||||
is ready and returns that exact number of workflows.
|
||||
|
||||
This method returns two lists. The first list consists of workflows
|
||||
references that correspond to workflows that are ready. The second
|
||||
list corresponds to the rest of the workflows (which may or may not be
|
||||
ready).
|
||||
|
||||
Ordering of the input list of workflows is preserved. That is, if A
|
||||
precedes B in the input list, and both are in the ready list, then A will
|
||||
precede B in the ready list. This also holds true if A and B are both in
|
||||
the remaining list.
|
||||
|
||||
This method will issue a warning if it's running inside an async context.
|
||||
|
||||
Args:
|
||||
workflows (List[Workflow]): List of workflows that may
|
||||
or may not be ready. Note that these workflows must be unique.
|
||||
num_returns (int): The number of workflows that should be returned.
|
||||
timeout (float): The maximum amount of time in seconds to wait before
|
||||
returning.
|
||||
|
||||
Returns:
|
||||
A list of ready workflow results that are ready and a list of the
|
||||
remaining workflows.
|
||||
"""
|
||||
from ray.workflow import serialization_context
|
||||
from ray.workflow.common import WorkflowData
|
||||
for w in workflows:
|
||||
if not isinstance(w, Workflow):
|
||||
raise TypeError("The input of workflow.wait should be a list "
|
||||
"of workflows.")
|
||||
wait_inputs = serialization_context.make_workflow_inputs(workflows)
|
||||
step_options = WorkflowStepRuntimeOptions.make(
|
||||
step_type=StepType.WAIT,
|
||||
# Pass the options through Ray options. "num_returns" conflicts with
|
||||
# the "num_returns" for Ray remote functions, so we need to wrap it
|
||||
# under "wait_options".
|
||||
ray_options={
|
||||
"wait_options": {
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout,
|
||||
}
|
||||
},
|
||||
)
|
||||
workflow_data = WorkflowData(
|
||||
func_body=None,
|
||||
inputs=wait_inputs,
|
||||
step_options=step_options,
|
||||
name="workflow.wait",
|
||||
user_metadata={})
|
||||
return Workflow(workflow_data)
|
||||
|
||||
|
||||
__all__ = ("step", "virtual_actor", "resume", "get_output", "get_actor",
|
||||
"resume_all", "get_status", "get_metadata", "cancel")
|
||||
|
|
|
@ -42,6 +42,15 @@ def ensure_ray_initialized():
|
|||
class WorkflowRef:
|
||||
"""This class represents a dynamic reference of a workflow output.
|
||||
|
||||
A dynamic reference means the workflow
|
||||
|
||||
1. has not executed yet
|
||||
2. has been running
|
||||
3. has failed
|
||||
4. has finished
|
||||
|
||||
So this class only contains the ID of the workflow step.
|
||||
|
||||
See 'step_executor._resolve_dynamic_workflow_refs' for how we handle
|
||||
workflow refs."""
|
||||
# The ID of the step that produces the output of the workflow.
|
||||
|
@ -54,6 +63,42 @@ class WorkflowRef:
|
|||
return hash(self.step_id)
|
||||
|
||||
|
||||
class _RefBypass:
|
||||
"""Prevents an object ref from being hooked by a serializer."""
|
||||
|
||||
def __init__(self, ref):
|
||||
self._ref = ref
|
||||
|
||||
def __reduce__(self):
|
||||
from ray import cloudpickle
|
||||
return cloudpickle.loads, (cloudpickle.dumps(self._ref), )
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStaticRef:
|
||||
"""This class represents a static reference of a workflow output.
|
||||
|
||||
A static reference means the workflow has already been executed,
|
||||
and we have both the workflow step ID and the object ref to it
|
||||
living outputs.
|
||||
|
||||
This could be used when you want to return a running workflow
|
||||
from a workflow step. For example, the remaining workflows
|
||||
returned by 'workflow.wait' contains a static ref to these
|
||||
pending workflows.
|
||||
"""
|
||||
# The ID of the step that produces the output of the workflow.
|
||||
step_id: StepID
|
||||
# The ObjectRef of the output.
|
||||
ref: ObjectRef
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.step_id + self.ref.hex())
|
||||
|
||||
def __reduce__(self):
|
||||
return WorkflowStaticRef, (self.step_id, _RefBypass(self.ref))
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
@unique
|
||||
class WorkflowStatus(str, Enum):
|
||||
|
@ -77,6 +122,7 @@ class StepType(str, Enum):
|
|||
FUNCTION = "FUNCTION"
|
||||
ACTOR_METHOD = "ACTOR_METHOD"
|
||||
READONLY_ACTOR_METHOD = "READONLY_ACTOR_METHOD"
|
||||
WAIT = "WAIT"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -244,6 +290,12 @@ T = TypeVar("T")
|
|||
|
||||
|
||||
class Workflow(Generic[T]):
|
||||
"""This class represents a workflow.
|
||||
|
||||
It would either be a workflow that is not executed, or it is a reference
|
||||
to a running workflow when 'workflow.ref' is not None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
workflow_data: WorkflowData,
|
||||
prepare_inputs: Optional[Callable] = None):
|
||||
|
@ -257,6 +309,7 @@ class Workflow(Generic[T]):
|
|||
self._result: Optional[WorkflowExecutionResult] = None
|
||||
# step id will be generated during runtime
|
||||
self._step_id: StepID = None
|
||||
self._ref: Optional[WorkflowStaticRef] = None
|
||||
|
||||
@property
|
||||
def _workflow_id(self):
|
||||
|
@ -287,6 +340,8 @@ class Workflow(Generic[T]):
|
|||
def step_id(self) -> StepID:
|
||||
if self._step_id is not None:
|
||||
return self._step_id
|
||||
if self._ref is not None:
|
||||
return self._ref.step_id
|
||||
|
||||
from ray.workflow.workflow_access import \
|
||||
get_or_create_management_actor
|
||||
|
@ -318,11 +373,39 @@ class Workflow(Generic[T]):
|
|||
del self._prepare_inputs
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def ref(self) -> Optional[WorkflowStaticRef]:
|
||||
return self._ref
|
||||
|
||||
@classmethod
|
||||
def from_ref(cls, workflow_ref: WorkflowStaticRef) -> "Workflow":
|
||||
inputs = WorkflowInputs(args=None, workflows=[], workflow_refs=[])
|
||||
data = WorkflowData(
|
||||
func_body=None,
|
||||
inputs=inputs,
|
||||
name=None,
|
||||
step_options=WorkflowStepRuntimeOptions.make(
|
||||
step_type=StepType.FUNCTION),
|
||||
user_metadata={})
|
||||
wf = Workflow(data)
|
||||
wf._ref = workflow_ref
|
||||
return wf
|
||||
|
||||
def __reduce__(self):
|
||||
raise ValueError(
|
||||
"Workflow[T] objects are not serializable. "
|
||||
"This means they cannot be passed or returned from Ray "
|
||||
"remote, or stored in Ray objects.")
|
||||
"""Serialization helper for workflow.
|
||||
|
||||
By default Workflow[T] objects are not serializable, except
|
||||
it is a reference to a workflow (when workflow.ref is not 'None').
|
||||
The reference can be passed around, but the workflow must
|
||||
be processed locally so we can capture it in the DAG and
|
||||
checkpoint its inputs properly.
|
||||
"""
|
||||
if self._ref is None:
|
||||
raise ValueError(
|
||||
"Workflow[T] objects are not serializable. "
|
||||
"This means they cannot be passed or returned from Ray "
|
||||
"remote, or stored in Ray objects.")
|
||||
return Workflow.from_ref, (self._ref, )
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def run(self,
|
||||
|
|
|
@ -61,14 +61,18 @@ def run(entry_workflow: Workflow,
|
|||
except Exception:
|
||||
wf_exists = False
|
||||
|
||||
# "Is growing" means we could adding steps to the (top-level)
|
||||
# workflow to grow the workflow dynamically at runtime.
|
||||
is_growing = (step_type not in (StepType.FUNCTION, StepType.WAIT))
|
||||
|
||||
# We only commit for
|
||||
# - virtual actor tasks: it's dynamic tasks, so we always add
|
||||
# - it's a new workflow
|
||||
# TODO (yic): follow up with force rerun
|
||||
if step_type != StepType.FUNCTION or not wf_exists:
|
||||
if is_growing or not wf_exists:
|
||||
commit_step(ws, "", entry_workflow, exception=None)
|
||||
workflow_manager = get_or_create_management_actor()
|
||||
ignore_existing = (step_type != StepType.FUNCTION)
|
||||
ignore_existing = is_growing
|
||||
# NOTE: It is important to 'ray.get' the returned output. This
|
||||
# ensures caller of 'run()' holds the reference to the workflow
|
||||
# result. Otherwise if the actor removes the reference of the
|
||||
|
@ -76,7 +80,7 @@ def run(entry_workflow: Workflow,
|
|||
result: "WorkflowExecutionResult" = ray.get(
|
||||
workflow_manager.run_or_resume.remote(workflow_id,
|
||||
ignore_existing))
|
||||
if step_type == StepType.FUNCTION:
|
||||
if not is_growing:
|
||||
return flatten_workflow_output(workflow_id,
|
||||
result.persisted_output)
|
||||
else:
|
||||
|
|
|
@ -4,7 +4,8 @@ import ray
|
|||
from ray.workflow import workflow_context
|
||||
from ray.workflow import serialization
|
||||
from ray.workflow.common import (Workflow, StepID, WorkflowRef,
|
||||
WorkflowExecutionResult)
|
||||
WorkflowStaticRef, WorkflowExecutionResult,
|
||||
StepType)
|
||||
from ray.workflow import storage
|
||||
from ray.workflow import workflow_storage
|
||||
from ray.workflow.step_function import WorkflowStepFunction
|
||||
|
@ -50,6 +51,37 @@ def _recover_workflow_step(args: List[Any], kwargs: Dict[str, Any],
|
|||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _reconstruct_wait_step(reader: workflow_storage.WorkflowStorage,
|
||||
result: workflow_storage.StepInspectResult,
|
||||
input_map: Dict[StepID, Any]):
|
||||
input_workflows = []
|
||||
step_options = result.step_options
|
||||
wait_options = step_options.ray_options.get("wait_options", {})
|
||||
for i, _step_id in enumerate(result.workflows):
|
||||
# Check whether the step has been loaded or not to avoid
|
||||
# duplication
|
||||
if _step_id in input_map:
|
||||
r = input_map[_step_id]
|
||||
else:
|
||||
r = _construct_resume_workflow_from_step(reader, _step_id,
|
||||
input_map)
|
||||
input_map[_step_id] = r
|
||||
if isinstance(r, Workflow):
|
||||
input_workflows.append(r)
|
||||
else:
|
||||
assert isinstance(r, StepID)
|
||||
# TODO (Alex): We should consider caching these outputs too.
|
||||
output = reader.load_step_output(r)
|
||||
# Simulate a workflow with a workflow reference so it could be
|
||||
# used directly by 'workflow.wait'.
|
||||
static_ref = WorkflowStaticRef(step_id=r, ref=ray.put(output))
|
||||
wf = Workflow.from_ref(static_ref)
|
||||
input_workflows.append(wf)
|
||||
|
||||
from ray import workflow
|
||||
return workflow.wait(input_workflows, **wait_options)
|
||||
|
||||
|
||||
def _construct_resume_workflow_from_step(
|
||||
reader: workflow_storage.WorkflowStorage, step_id: StepID,
|
||||
input_map: Dict[StepID, Any]) -> Union[Workflow, StepID]:
|
||||
|
@ -78,6 +110,11 @@ def _construct_resume_workflow_from_step(
|
|||
if not result.is_recoverable():
|
||||
raise WorkflowStepNotRecoverableError(step_id)
|
||||
|
||||
step_options = result.step_options
|
||||
# Process the wait step as a special case.
|
||||
if step_options.step_type == StepType.WAIT:
|
||||
return _reconstruct_wait_step(reader, result, input_map)
|
||||
|
||||
with serialization.objectref_cache():
|
||||
input_workflows = []
|
||||
for i, _step_id in enumerate(result.workflows):
|
||||
|
@ -99,7 +136,6 @@ def _construct_resume_workflow_from_step(
|
|||
|
||||
args, kwargs = reader.load_step_args(step_id, input_workflows,
|
||||
workflow_refs)
|
||||
step_options = result.step_options
|
||||
recovery_workflow: Workflow = _recover_workflow_step.step(
|
||||
args, kwargs, input_workflows, workflow_refs)
|
||||
recovery_workflow._step_id = step_id
|
||||
|
|
|
@ -24,13 +24,14 @@ from ray.workflow.common import (
|
|||
StepType,
|
||||
StepID,
|
||||
WorkflowData,
|
||||
WorkflowStaticRef,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.workflow.common import (WorkflowRef, WorkflowStepRuntimeOptions)
|
||||
from ray.workflow.workflow_context import WorkflowStepContext
|
||||
|
||||
StepInputTupleToResolve = Tuple[ObjectRef, List[ObjectRef], List[ObjectRef]]
|
||||
WaitResult = Tuple[List[Any], List[Workflow]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -110,11 +111,18 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
# Stage 1: prepare inputs
|
||||
workflow_data = workflow.data
|
||||
inputs = workflow_data.inputs
|
||||
workflow_outputs = []
|
||||
with workflow_context.fork_workflow_step_context(
|
||||
outer_most_step_id=None, last_step_of_workflow=False):
|
||||
workflow_outputs = [
|
||||
execute_workflow(w).persisted_output for w in inputs.workflows
|
||||
]
|
||||
for w in inputs.workflows:
|
||||
static_ref = w.ref
|
||||
if static_ref is None:
|
||||
# The input workflow is not a reference to an executed
|
||||
# workflow .
|
||||
output = execute_workflow(w).persisted_output
|
||||
static_ref = WorkflowStaticRef(step_id=w.step_id, ref=output)
|
||||
workflow_outputs.append(static_ref)
|
||||
|
||||
baked_inputs = _BakedWorkflowInputs(
|
||||
args=workflow_data.inputs.args,
|
||||
workflow_outputs=workflow_outputs,
|
||||
|
@ -134,10 +142,20 @@ def execute_workflow(workflow: "Workflow") -> "WorkflowExecutionResult":
|
|||
# TODO(suquark): We still have recursive Python calls.
|
||||
# This would cause stack overflow if we have a really
|
||||
# deep recursive call. We should fix it later.
|
||||
executor = _workflow_step_executor
|
||||
if step_options.step_type == StepType.WAIT:
|
||||
executor = _workflow_wait_executor
|
||||
else:
|
||||
executor = _workflow_step_executor
|
||||
else:
|
||||
executor = _workflow_step_executor_remote.options(
|
||||
**step_options.ray_options).remote
|
||||
if step_options.step_type == StepType.WAIT:
|
||||
# This is very important to set "num_cpus=0" to
|
||||
# ensure "workflow.wait" is not blocked by other
|
||||
# tasks.
|
||||
executor = _workflow_wait_executor_remote.options(
|
||||
num_cpus=0).remote
|
||||
else:
|
||||
executor = _workflow_step_executor_remote.options(
|
||||
**step_options.ray_options).remote
|
||||
|
||||
# Stage 3: execution
|
||||
persisted_output, volatile_output = executor(
|
||||
|
@ -203,10 +221,12 @@ def commit_step(store: workflow_storage.WorkflowStorage, step_id: "StepID",
|
|||
from ray.workflow.common import Workflow
|
||||
if isinstance(ret, Workflow):
|
||||
assert not ret.executed
|
||||
tasks = [
|
||||
_write_step_inputs(store, w.step_id, w.data)
|
||||
for w in ret._iter_workflows_in_dag()
|
||||
]
|
||||
tasks = []
|
||||
for w in ret._iter_workflows_in_dag():
|
||||
# If this is a reference to a workflow, do not checkpoint
|
||||
# its input (again).
|
||||
if w.ref is None:
|
||||
tasks.append(_write_step_inputs(store, w.step_id, w.data))
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
|
||||
|
||||
context = workflow_context.get_workflow_step_context()
|
||||
|
@ -396,13 +416,61 @@ def _workflow_step_executor_remote(
|
|||
runtime_options)
|
||||
|
||||
|
||||
def _workflow_wait_executor(func: Callable, context: "WorkflowStepContext",
|
||||
step_id: "StepID",
|
||||
baked_inputs: "_BakedWorkflowInputs",
|
||||
runtime_options: "WorkflowStepRuntimeOptions"
|
||||
) -> Tuple[WaitResult, None]:
|
||||
"""Executor of 'workflow.wait' steps.
|
||||
|
||||
It returns a tuple that contains wait result. The wait result is a list
|
||||
of result of workflows that are ready and a list of workflows that are
|
||||
pending.
|
||||
"""
|
||||
# Part 1: Update the context for the step.
|
||||
workflow_context.update_workflow_step_context(context, step_id)
|
||||
context = workflow_context.get_workflow_step_context()
|
||||
step_type = runtime_options.step_type
|
||||
assert step_type == StepType.WAIT
|
||||
wait_options = runtime_options.ray_options.get("wait_options", {})
|
||||
|
||||
# Part 2: Resolve any ready workflows.
|
||||
ready_workflows, remaining_workflows = baked_inputs.wait(**wait_options)
|
||||
ready_objects = []
|
||||
for w in ready_workflows:
|
||||
obj, _, = _resolve_object_ref(w.ref.ref)
|
||||
ready_objects.append(obj)
|
||||
persisted_output = (ready_objects, remaining_workflows)
|
||||
|
||||
# Part 3: Save the outputs.
|
||||
store = workflow_storage.get_workflow_storage()
|
||||
commit_step(store, step_id, persisted_output, exception=None)
|
||||
if context.last_step_of_workflow:
|
||||
# advance the progress of the workflow
|
||||
store.advance_progress(step_id)
|
||||
|
||||
_record_step_status(step_id, WorkflowStatus.SUCCESSFUL)
|
||||
logger.info(get_step_status_info(WorkflowStatus.SUCCESSFUL))
|
||||
return persisted_output, None
|
||||
|
||||
|
||||
@ray.remote(num_returns=2)
|
||||
def _workflow_wait_executor_remote(
|
||||
func: Callable, context: "WorkflowStepContext", step_id: "StepID",
|
||||
baked_inputs: "_BakedWorkflowInputs",
|
||||
runtime_options: "WorkflowStepRuntimeOptions") -> Any:
|
||||
"""The remote version of '_workflow_wait_executor'"""
|
||||
return _workflow_wait_executor(func, context, step_id, baked_inputs,
|
||||
runtime_options)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BakedWorkflowInputs:
|
||||
"""This class stores pre-processed inputs for workflow step execution.
|
||||
Especially, all input workflows to the workflow step will be scheduled,
|
||||
and their outputs (ObjectRefs) replace the original workflows."""
|
||||
args: "ObjectRef"
|
||||
workflow_outputs: "List[ObjectRef]"
|
||||
workflow_outputs: "List[WorkflowStaticRef]"
|
||||
workflow_refs: "List[WorkflowRef]"
|
||||
|
||||
def resolve(self) -> Tuple[List, Dict]:
|
||||
|
@ -424,7 +492,7 @@ class _BakedWorkflowInputs:
|
|||
"""
|
||||
objects_mapping = []
|
||||
for obj_ref in self.workflow_outputs:
|
||||
obj, ref = _resolve_object_ref(obj_ref)
|
||||
obj, ref = _resolve_object_ref(obj_ref.ref)
|
||||
objects_mapping.append(obj)
|
||||
|
||||
workflow_ref_mapping = _resolve_dynamic_workflow_refs(
|
||||
|
@ -442,6 +510,33 @@ class _BakedWorkflowInputs:
|
|||
]
|
||||
return signature.recover_args(flattened_args)
|
||||
|
||||
def wait(self, num_returns: int = 1, timeout: Optional[float] = None
|
||||
) -> Tuple[List[Workflow], List[Workflow]]:
|
||||
"""Return a list of workflows that are ready and a list of workflows that
|
||||
are not. See `api.wait()` for details.
|
||||
|
||||
Args:
|
||||
num_returns (int): The number of workflows that should be returned.
|
||||
timeout (float): The maximum amount of time in seconds to wait
|
||||
before returning.
|
||||
|
||||
Returns:
|
||||
A list of workflows that are ready and a list of the remaining
|
||||
workflows.
|
||||
"""
|
||||
if self.workflow_refs:
|
||||
raise ValueError("Currently, we do not support wait operations "
|
||||
"on dynamic workflow refs. They are typically "
|
||||
"generated by virtual actors.")
|
||||
refs_map = {w.ref: w for w in self.workflow_outputs}
|
||||
ready_ids, remaining_ids = ray.wait(
|
||||
list(refs_map.keys()), num_returns=num_returns, timeout=timeout)
|
||||
ready_workflows = [Workflow.from_ref(refs_map[i]) for i in ready_ids]
|
||||
remaining_workflows = [
|
||||
Workflow.from_ref(refs_map[i]) for i in remaining_ids
|
||||
]
|
||||
return ready_workflows, remaining_workflows
|
||||
|
||||
def __reduce__(self):
|
||||
return _BakedWorkflowInputs, (self.args, self.workflow_outputs,
|
||||
self.workflow_refs)
|
||||
|
|
227
python/ray/workflow/tests/test_wait.py
Normal file
227
python/ray/workflow/tests/test_wait.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
from ray.tests.conftest import * # noqa
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import ray
|
||||
from ray import workflow
|
||||
from ray.workflow.common import Workflow
|
||||
from ray.workflow.tests import utils
|
||||
from ray.exceptions import RaySystemError
|
||||
|
||||
|
||||
@workflow.step
|
||||
def wait_multiple_steps():
|
||||
@workflow.step
|
||||
def sleep_identity(x: int):
|
||||
time.sleep(x)
|
||||
return x
|
||||
|
||||
ws = [
|
||||
sleep_identity.step(1),
|
||||
sleep_identity.step(3),
|
||||
sleep_identity.step(10),
|
||||
sleep_identity.step(2),
|
||||
sleep_identity.step(12),
|
||||
]
|
||||
return workflow.wait(ws, num_returns=4, timeout=5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_start_regular_shared",
|
||||
[{
|
||||
"num_cpus": 8
|
||||
# We need more CPUs, otherwise task execution could be blocked.
|
||||
}],
|
||||
indirect=True)
|
||||
def test_wait_basics(workflow_start_regular_shared):
|
||||
# This tests basic usage of 'workflow.wait':
|
||||
# 1. It returns ready tasks precisely and preserves the original order.
|
||||
# 2. All steps would see the same waiting result.
|
||||
# 3. We can pass remaining pending workflows to another workflow,
|
||||
# and they can be resolved like normal workflows.
|
||||
@workflow.step
|
||||
def return_ready(wait_results):
|
||||
ready, unready = wait_results
|
||||
return ready
|
||||
|
||||
@workflow.step
|
||||
def join(a, b):
|
||||
return a, b
|
||||
|
||||
wait_result = wait_multiple_steps.step()
|
||||
a = return_ready.step(wait_result)
|
||||
b = return_ready.step(wait_result)
|
||||
ready1, ready2 = join.step(a, b).run()
|
||||
assert ready1 == ready2 == [1, 3, 2]
|
||||
|
||||
@workflow.step
|
||||
def get_all(ready, unready):
|
||||
return ready, unready
|
||||
|
||||
@workflow.step
|
||||
def filter_all(wait_results):
|
||||
ready, unready = wait_results
|
||||
return get_all.step(ready, unready)
|
||||
|
||||
@workflow.step
|
||||
def composite():
|
||||
w = wait_multiple_steps.step()
|
||||
return filter_all.step(w)
|
||||
|
||||
ready, unready = composite.step().run()
|
||||
assert ready == [1, 3, 2]
|
||||
assert unready == [10, 12]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
workflow.wait([1, 2])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_start_regular_shared",
|
||||
[{
|
||||
"num_cpus": 8
|
||||
# We need more CPUs, otherwise task execution could be blocked.
|
||||
}],
|
||||
indirect=True)
|
||||
def test_wait_basics_2(workflow_start_regular_shared):
|
||||
# Test "workflow.wait" running in the top level DAG,
|
||||
# or running "workflow.wait" directly.
|
||||
@workflow.step
|
||||
def sleep_identity(x: int):
|
||||
time.sleep(x)
|
||||
return x
|
||||
|
||||
@workflow.step
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
ws = [
|
||||
sleep_identity.step(1),
|
||||
sleep_identity.step(5),
|
||||
sleep_identity.step(2),
|
||||
]
|
||||
w = workflow.wait(ws, num_returns=2, timeout=3)
|
||||
ready, remaining = identity.step(w).run()
|
||||
assert ready == [1, 2]
|
||||
|
||||
ws = [
|
||||
sleep_identity.step(2),
|
||||
sleep_identity.step(5),
|
||||
sleep_identity.step(1),
|
||||
]
|
||||
w = workflow.wait(ws, num_returns=2, timeout=3)
|
||||
ready, remaining = w.run()
|
||||
assert ready == [2, 1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_start_regular_shared",
|
||||
[{
|
||||
"num_cpus": 8
|
||||
# We need more CPUs, otherwise task execution could be blocked.
|
||||
}],
|
||||
indirect=True)
|
||||
def test_wait_recursive(workflow_start_regular_shared):
|
||||
# This tests that we can 'workflow.wait' the remaining pending workflow
|
||||
# returned by another 'workflow.wait' recursively.
|
||||
w = wait_multiple_steps.step()
|
||||
|
||||
@workflow.step
|
||||
def recursive_wait(s):
|
||||
ready, unready = s
|
||||
if len(unready) == 2 and not isinstance(unready[0], Workflow):
|
||||
ready_2, unready = unready
|
||||
print(ready, (ready_2, unready))
|
||||
ready += ready_2
|
||||
|
||||
if not unready:
|
||||
return ready
|
||||
w = workflow.wait(unready)
|
||||
return recursive_wait.step([ready, w])
|
||||
|
||||
assert recursive_wait.step(w).run() == [1, 3, 2, 10, 12]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_start_regular_shared",
|
||||
[{
|
||||
"num_cpus": 8
|
||||
# We need more CPUs, otherwise task execution could be blocked.
|
||||
}],
|
||||
indirect=True)
|
||||
def test_wait_failure_recovery_1(workflow_start_regular_shared):
|
||||
# This tests that if a step using the output of "workflow.wait" as its
|
||||
# input, it can be recovered after failure.
|
||||
@workflow.step
|
||||
def get_all(ready, unready):
|
||||
return ready, unready
|
||||
|
||||
@workflow.step
|
||||
def filter_all_2(wait_results):
|
||||
assert wait_results[0] == [1, 3, 2]
|
||||
# failure point
|
||||
assert utils.check_global_mark()
|
||||
ready, unready = wait_results
|
||||
return get_all.step(ready, unready)
|
||||
|
||||
@workflow.step
|
||||
def composite_2():
|
||||
w = wait_multiple_steps.step()
|
||||
return filter_all_2.step(w)
|
||||
|
||||
utils.unset_global_mark()
|
||||
|
||||
with pytest.raises(RaySystemError):
|
||||
composite_2.step().run(workflow_id="wait_failure_recovery")
|
||||
|
||||
utils.set_global_mark()
|
||||
|
||||
ready, unready = ray.get(workflow.resume("wait_failure_recovery"))
|
||||
assert ready == [1, 3, 2]
|
||||
assert unready == [10, 12]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"workflow_start_regular_shared",
|
||||
[{
|
||||
"num_cpus": 8
|
||||
# We need more CPUs, otherwise task execution could be blocked.
|
||||
}],
|
||||
indirect=True)
|
||||
def test_wait_failure_recovery_2(workflow_start_regular_shared):
|
||||
# Test failing "workflow.wait" and its input steps.
|
||||
|
||||
@workflow.step
|
||||
def sleep_identity(x: int):
|
||||
# block the step by a global mark
|
||||
while not utils.check_global_mark():
|
||||
time.sleep(0.1)
|
||||
time.sleep(x)
|
||||
return x
|
||||
|
||||
@workflow.step
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
ws = [
|
||||
sleep_identity.step(2),
|
||||
sleep_identity.step(5),
|
||||
sleep_identity.step(1),
|
||||
]
|
||||
w = workflow.wait(ws, num_returns=2, timeout=None)
|
||||
utils.unset_global_mark()
|
||||
_ = identity.step(w).run_async(workflow_id="wait_failure_recovery_2")
|
||||
# wait util "workflow.wait" has been running
|
||||
time.sleep(10)
|
||||
workflow.cancel("wait_failure_recovery_2")
|
||||
time.sleep(2)
|
||||
|
||||
utils.set_global_mark()
|
||||
ready, unready = ray.get(workflow.resume("wait_failure_recovery_2"))
|
||||
assert ready == [2, 1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue