[workflow] Enhance dataset tests (#25876)

This commit is contained in:
Siyuan (Ryans) Zhuang 2022-06-16 22:50:31 -07:00 committed by GitHub
parent ce02ac0311
commit fea8dd08fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,22 +8,35 @@ from ray import workflow
@ray.remote
def gen_dataset():
# TODO(ekl) seems checkpointing hangs with nested refs of
# LazyBlockList.
return ray.data.range(1000).map(lambda x: x)
@ray.remote
def gen_dataset_1():
return ray.data.range(1000)
@ray.remote
def gen_dataset_2():
return ray.data.range_table(1000)
@ray.remote
def transform_dataset(in_data):
return in_data.map(lambda x: x * 2)
@ray.remote
def transform_dataset_1(in_data):
return in_data.map(lambda r: {"v2": r["value"] * 2})
@ray.remote
def sum_dataset(ds):
return ds.sum()
def test_dataset(workflow_start_regular):
def test_dataset(workflow_start_regular_shared):
ds_ref = gen_dataset.bind()
transformed_ref = transform_dataset.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)
@ -32,6 +45,24 @@ def test_dataset(workflow_start_regular):
assert result == 2 * sum(range(1000))
def test_dataset_1(workflow_start_regular_shared):
ds_ref = gen_dataset_1.bind()
transformed_ref = transform_dataset.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)
result = workflow.create(output_ref).run()
assert result == 2 * sum(range(1000))
def test_dataset_2(workflow_start_regular_shared):
ds_ref = gen_dataset_2.bind()
transformed_ref = transform_dataset_1.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)
result = workflow.create(output_ref).run()
assert result == 2 * sum(range(1000))
if __name__ == "__main__":
import sys