[workflow] Test for better coverage (#17233)

* update

* workflow.init

* update

* update

* update tests

* check

* up

* update

* update

* check

* merge

* fix tests

* update

* add tests

* up

* format

* add space

* Update test_storage.py

Co-authored-by: Siyuan <suquark@gmail.com>
This commit is contained in:
Yi Cheng 2021-07-21 14:48:36 -07:00 committed by GitHub
parent 2e37826458
commit 5accfa662c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 61 additions and 49 deletions

View file

@ -28,6 +28,7 @@ def init(storage: "Optional[Union[str, Storage]]" = None) -> None:
"""
if storage is None:
storage = os.environ.get("RAY_WORKFLOW_STORAGE")
if storage is None:
# We should use get_temp_dir_path, but for ray client, we don't
# have this one. We need a flag to tell whether it's a client
@ -36,7 +37,6 @@ def init(storage: "Optional[Union[str, Storage]]" = None) -> None:
logger.warning("Using default local dir: `/tmp/ray/workflow_data`. "
"This should only be used for testing purposes.")
storage = "file:///tmp/ray/workflow_data"
if isinstance(storage, str):
storage = storage_base.create_storage(storage)
elif not isinstance(storage, Storage):

View file

@ -3,71 +3,84 @@ from contextlib import contextmanager
import pytest
from moto import mock_s3
from mock_server import * # noqa
from pytest_lazyfixture import lazy_fixture
import tempfile
import os
import ray
from ray.experimental import workflow
from ray.experimental.workflow import storage
from ray.tests.conftest import get_default_fixture_ray_kwargs
@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def filesystem_storage():
# TODO: use tmp path once fixed the path issues
storage.set_global_storage(
storage.create_storage("/tmp/ray/workflow_data/"))
yield storage.get_global_storage()
with tempfile.TemporaryDirectory() as d:
yield d
@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def aws_credentials():
import os
old_env = os.environ
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
yield
yield (f"aws_access_key_id={os.environ['AWS_ACCESS_KEY_ID']}&"
f"aws_secret_access_key={os.environ['AWS_SECRET_ACCESS_KEY']}&"
f"aws_session_token={os.environ['AWS_SESSION_TOKEN']}")
os.environ = old_env
@pytest.fixture(scope="function")
@pytest.fixture(scope="session")
def s3_storage(aws_credentials, s3_server):
with mock_s3():
client = boto3.client(
"s3", region_name="us-west-2", endpoint_url=s3_server)
client.create_bucket(Bucket="test_bucket")
url = ("s3://test_bucket/workflow"
f"?region_name=us-west-2&endpoint_url={s3_server}")
storage.set_global_storage(storage.create_storage(url))
yield storage.get_global_storage()
f"?region_name=us-west-2&endpoint_url={s3_server}"
f"&{aws_credentials}")
yield url
@contextmanager
def _workflow_start(**kwargs):
def _workflow_start(storage_url, shared, **kwargs):
init_kwargs = get_default_fixture_ray_kwargs()
init_kwargs.update(kwargs)
# Start the Ray processes.
address_info = ray.init(**init_kwargs)
if ray.is_initialized():
ray.shutdown()
storage.set_global_storage(None)
# Sometimes pytest does not cleanup all global variables.
# we have to manually reset the workflow storage. This
# should not be an issue for normal use cases, because global variables
# are freed after the driver exits.
storage.set_global_storage(None)
workflow.init()
address_info = ray.init(**init_kwargs)
workflow.init(storage_url)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
storage.set_global_storage(None)
@pytest.fixture
def workflow_start_regular(request):
@pytest.fixture(scope="function")
def workflow_start_regular(storage_url, request):
param = getattr(request, "param", {})
with _workflow_start(**param) as res:
with _workflow_start(storage_url, False, **param) as res:
yield res
@pytest.fixture(scope="module")
def workflow_start_regular_shared(request):
@pytest.fixture(scope="session")
def workflow_start_regular_shared(storage_url, request):
param = getattr(request, "param", {})
with _workflow_start(**param) as res:
with _workflow_start(storage_url, True, **param) as res:
yield res
def pytest_generate_tests(metafunc):
if "storage_url" in metafunc.fixturenames:
metafunc.parametrize(
"storage_url",
[lazy_fixture("s3_storage"),
lazy_fixture("filesystem_storage")],
scope="session")

View file

@ -19,7 +19,8 @@ def start_service(service_name, host, port):
# For debugging
# args = '{0} {1} -H {2} -p {3} 2>&1 | tee -a /tmp/moto.log'.format(moto_svr_path, service_name, host, port)
process = sp.Popen(
args, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE) # shell=True
args, stdin=sp.PIPE, stdout=sp.DEVNULL,
stderr=sp.DEVNULL) # shell=True
url = "http://{host}:{port}".format(host=host, port=port)
for i in range(0, 30):

View file

@ -120,11 +120,11 @@ if __name__ == "__main__":
def test_recovery_cluster_failure():
subprocess.run(["ray start --head"], shell=True)
subprocess.check_call(["ray", "start", "--head"])
time.sleep(1)
proc = run_string_as_driver_nonblocking(driver_script)
time.sleep(10)
subprocess.run(["ray stop"], shell=True)
subprocess.check_call(["ray", "stop"])
proc.kill()
time.sleep(1)
workflow.init()

View file

@ -3,8 +3,8 @@ import asyncio
import ray
from ray.tests.conftest import * # noqa
from ray.experimental.workflow import workflow_storage
from ray.experimental.workflow import storage
from ray.experimental.workflow.workflow_storage import asyncio_run
from pytest_lazyfixture import lazy_fixture
def some_func(x):
@ -16,11 +16,8 @@ def some_func2(x):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"raw_storage",
[lazy_fixture("filesystem_storage"),
lazy_fixture("s3_storage")])
async def test_raw_storage(workflow_start_regular, raw_storage):
async def test_raw_storage(workflow_start_regular):
raw_storage = storage.get_global_storage()
workflow_id = test_workflow_storage.__name__
step_id = "some_step"
input_metadata = {"2": "c"}
@ -95,17 +92,17 @@ async def test_raw_storage(workflow_start_regular, raw_storage):
assert load_step_output == output
@pytest.mark.parametrize(
"raw_storage",
[lazy_fixture("filesystem_storage"),
lazy_fixture("s3_storage")])
def test_workflow_storage(workflow_start_regular, raw_storage):
def test_workflow_storage(workflow_start_regular):
raw_storage = storage.get_global_storage()
workflow_id = test_workflow_storage.__name__
step_id = "some_step"
input_metadata = {
"name": "test_basic_workflows.append1",
"object_refs": ["abc"],
"workflows": ["def"]
"workflows": ["def"],
"step_max_retries": 1,
"catch_exceptions": False,
"ray_options": {},
}
output_metadata = {
"output_step_id": "a12423",
@ -169,7 +166,8 @@ def test_workflow_storage(workflow_start_regular, raw_storage):
args_valid=True,
func_body_valid=True,
object_refs=input_metadata["object_refs"],
workflows=input_metadata["workflows"])
workflows=input_metadata["workflows"],
ray_options={})
assert inspect_result.is_recoverable()
step_id = "some_step4"
@ -182,7 +180,8 @@ def test_workflow_storage(workflow_start_regular, raw_storage):
assert inspect_result == workflow_storage.StepInspectResult(
func_body_valid=True,
object_refs=input_metadata["object_refs"],
workflows=input_metadata["workflows"])
workflows=input_metadata["workflows"],
ray_options={})
assert not inspect_result.is_recoverable()
step_id = "some_step5"
@ -192,7 +191,8 @@ def test_workflow_storage(workflow_start_regular, raw_storage):
inspect_result = wf_storage.inspect_step(step_id)
assert inspect_result == workflow_storage.StepInspectResult(
object_refs=input_metadata["object_refs"],
workflows=input_metadata["workflows"])
workflows=input_metadata["workflows"],
ray_options={})
assert not inspect_result.is_recoverable()
step_id = "some_step6"

View file

@ -39,9 +39,9 @@ class StepInspectResult:
# The num of retry for application exception
step_max_retries: int = 1
# Whether the user want to handle the exception mannually
catch_exceptions: Optional[bool] = None
catch_exceptions: bool = False
# ray_remote options
ray_options: Dict[str, Any] = None
ray_options: Optional[Dict[str, Any]] = None
def is_recoverable(self) -> bool:
return (self.output_object_valid or self.output_step_id
@ -234,11 +234,9 @@ class WorkflowStorage:
catch_exceptions = metadata.get("catch_exceptions")
ray_options = metadata.get("ray_options", {})
except storage.DataLoadError:
input_object_refs = None
input_workflows = None
step_max_retries = None
catch_exceptions = None
ray_options = {}
return StepInspectResult(
args_valid=field_list.args_exists,
func_body_valid=field_list.func_body_exists)
return StepInspectResult(
args_valid=field_list.args_exists,
func_body_valid=field_list.func_body_exists,