mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[workflow] Enhance dataset tests (#25876)
This commit is contained in:
parent
ce02ac0311
commit
fea8dd08fc
1 changed files with 34 additions and 3 deletions
|
@ -8,22 +8,35 @@ from ray import workflow
|
|||
|
||||
@ray.remote
|
||||
def gen_dataset():
|
||||
# TODO(ekl) seems checkpointing hangs with nested refs of
|
||||
# LazyBlockList.
|
||||
return ray.data.range(1000).map(lambda x: x)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def gen_dataset_1():
|
||||
return ray.data.range(1000)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def gen_dataset_2():
|
||||
return ray.data.range_table(1000)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transform_dataset(in_data):
|
||||
return in_data.map(lambda x: x * 2)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transform_dataset_1(in_data):
|
||||
return in_data.map(lambda r: {"v2": r["value"] * 2})
|
||||
|
||||
|
||||
@ray.remote
|
||||
def sum_dataset(ds):
|
||||
return ds.sum()
|
||||
|
||||
|
||||
def test_dataset(workflow_start_regular):
|
||||
def test_dataset(workflow_start_regular_shared):
|
||||
ds_ref = gen_dataset.bind()
|
||||
transformed_ref = transform_dataset.bind(ds_ref)
|
||||
output_ref = sum_dataset.bind(transformed_ref)
|
||||
|
@ -32,6 +45,24 @@ def test_dataset(workflow_start_regular):
|
|||
assert result == 2 * sum(range(1000))
|
||||
|
||||
|
||||
def test_dataset_1(workflow_start_regular_shared):
|
||||
ds_ref = gen_dataset_1.bind()
|
||||
transformed_ref = transform_dataset.bind(ds_ref)
|
||||
output_ref = sum_dataset.bind(transformed_ref)
|
||||
|
||||
result = workflow.create(output_ref).run()
|
||||
assert result == 2 * sum(range(1000))
|
||||
|
||||
|
||||
def test_dataset_2(workflow_start_regular_shared):
|
||||
ds_ref = gen_dataset_2.bind()
|
||||
transformed_ref = transform_dataset_1.bind(ds_ref)
|
||||
output_ref = sum_dataset.bind(transformed_ref)
|
||||
|
||||
result = workflow.create(output_ref).run()
|
||||
assert result == 2 * sum(range(1000))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue