mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
271 lines
9.4 KiB
Python
271 lines
9.4 KiB
Python
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
import unittest
|
|
import pytest
|
|
|
|
|
|
import ray
|
|
from ray.rllib.offline import IOContext
|
|
from ray.rllib.offline.dataset_reader import (
|
|
DatasetReader,
|
|
get_dataset_and_shards,
|
|
_unzip_if_needed,
|
|
)
|
|
|
|
|
|
class TestDatasetReader(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init()
|
|
# TODO(Kourosh): Hitting S3 in CI is currently broken due to some AWS
|
|
# credentials issues, using a local file instead for now.
|
|
|
|
# cls.dset_path = "s3://air-example-data/rllib/cartpole/large.json"
|
|
cls.dset_path = "tests/data/pendulum/large.json"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_dataset_reader_itr_batches(self):
|
|
"""Test that the dataset reader iterates over batches of rows correctly."""
|
|
input_config = {"format": "json", "paths": self.dset_path}
|
|
dataset, _ = get_dataset_and_shards(
|
|
{"input": "dataset", "input_config": input_config}
|
|
)
|
|
|
|
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
|
|
reader = DatasetReader(dataset, ioctx)
|
|
assert len(reader.next()) >= 1200
|
|
|
|
def test_dataset_shard_with_only_local(self):
|
|
"""Tests whether the dataset_shard function works correctly for a single shard
|
|
for the local worker."""
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {"format": "json", "paths": self.dset_path},
|
|
}
|
|
|
|
# two ways of doing this:
|
|
|
|
# we have no remote workers
|
|
_, shards = get_dataset_and_shards(config, num_workers=0)
|
|
|
|
assert len(shards) == 1
|
|
assert isinstance(shards[0], ray.data.Dataset)
|
|
|
|
def test_dataset_shard_remote_workers_with_local_worker(self):
|
|
"""Tests whether the dataset_shard function works correctly for the remote
|
|
workers with a dummy dataset shard for the local worker."""
|
|
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {"format": "json", "paths": self.dset_path},
|
|
}
|
|
NUM_WORKERS = 4
|
|
|
|
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
|
|
|
|
assert len(shards) == NUM_WORKERS + 1
|
|
assert shards[0] is None
|
|
assert all(
|
|
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
|
|
)
|
|
|
|
def test_dataset_shard_with_task_parallelization(self):
|
|
"""Tests whether the dataset_shard function works correctly with parallelism
|
|
for reading the dataset."""
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {
|
|
"format": "json",
|
|
"paths": self.dset_path,
|
|
"parallelism": 10,
|
|
},
|
|
}
|
|
NUM_WORKERS = 4
|
|
|
|
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
|
|
|
|
assert len(shards) == NUM_WORKERS + 1
|
|
assert shards[0] is None
|
|
assert all(
|
|
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
|
|
)
|
|
|
|
def test_dataset_shard_with_loader_fn(self):
|
|
"""Tests whether the dataset_shard function works correctly with loader_fn."""
|
|
dset = ray.data.range(100)
|
|
config = {"input": "dataset", "input_config": {"loader_fn": lambda: dset}}
|
|
|
|
ret_dataset, _ = get_dataset_and_shards(config)
|
|
assert ret_dataset.count() == dset.count()
|
|
|
|
def test_dataset_shard_error_with_unsupported_dataset_format(self):
|
|
"""Tests whether the dataset_shard function raises an error when an unsupported
|
|
dataset format is specified."""
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {
|
|
"format": "__UNSUPPORTED_FORMAT__",
|
|
"paths": self.dset_path,
|
|
},
|
|
}
|
|
|
|
with self.assertRaises(ValueError):
|
|
get_dataset_and_shards(config)
|
|
|
|
def test_dataset_shard_error_with_both_format_and_loader_fn(self):
|
|
"""Tests whether the dataset_shard function raises an error when both format
|
|
and loader_fn are specified."""
|
|
dset = ray.data.range(100)
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {
|
|
"format": "json",
|
|
"paths": self.dset_path,
|
|
"loader_fn": lambda: dset,
|
|
},
|
|
}
|
|
|
|
with self.assertRaises(ValueError):
|
|
get_dataset_and_shards(config)
|
|
|
|
def test_default_ioctx(self):
|
|
# Test DatasetReader without passing in IOContext
|
|
input_config = {"format": "json", "paths": self.dset_path}
|
|
dataset, _ = get_dataset_and_shards(
|
|
{"input": "dataset", "input_config": input_config}
|
|
)
|
|
reader = DatasetReader(dataset)
|
|
# Reads in one line of Pendulum dataset with 600 timesteps
|
|
assert len(reader.next()) == 600
|
|
|
|
|
|
class TestUnzipIfNeeded(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
cls.s3_path = "s3://air-example-data/rllib/pendulum"
|
|
cls.relative_path = "tests/data/pendulum"
|
|
cls.absolute_path = str(
|
|
Path(__file__).parent.parent.parent / "tests" / "data" / "pendulum"
|
|
)
|
|
|
|
# @TODO: unskip when this is fixed
|
|
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
|
|
def test_s3_zip(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on s3 zip
|
|
files"""
|
|
unzipped_paths = _unzip_if_needed([self.s3_path + "/enormous.zip"], "json")
|
|
self.assertEqual(
|
|
str(Path(unzipped_paths[0]).absolute()),
|
|
str(Path("./").absolute() / "enormous.json"),
|
|
)
|
|
|
|
def test_relative_zip(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on relative zip
|
|
files"""
|
|
|
|
# this should work regardless of where th current working directory is.
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cwdir = os.getcwd()
|
|
os.chdir(tmp_dir)
|
|
unzipped_paths = _unzip_if_needed(
|
|
[str(Path(self.relative_path) / "enormous.zip")], "json"
|
|
)
|
|
self.assertEqual(
|
|
str(Path(unzipped_paths[0]).absolute()),
|
|
str(Path("./").absolute() / "enormous.json"),
|
|
)
|
|
|
|
assert all([Path(fpath).exists() for fpath in unzipped_paths])
|
|
os.chdir(cwdir)
|
|
|
|
def test_absolute_zip(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on absolute zip
|
|
files"""
|
|
|
|
# this should work regardless of where th current working directory is.
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cwdir = os.getcwd()
|
|
os.chdir(tmp_dir)
|
|
unzipped_paths = _unzip_if_needed(
|
|
[str(Path(self.absolute_path) / "enormous.zip")], "json"
|
|
)
|
|
self.assertEqual(
|
|
str(Path(unzipped_paths[0]).absolute()),
|
|
str(Path("./").absolute() / "enormous.json"),
|
|
)
|
|
|
|
assert all([Path(fpath).exists() for fpath in unzipped_paths])
|
|
os.chdir(cwdir)
|
|
|
|
# @TODO: unskip when this is fixed
|
|
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
|
|
def test_s3_json(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on s3 json
|
|
files"""
|
|
|
|
# this should work regardless of where th current working directory is.
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cwdir = os.getcwd()
|
|
os.chdir(tmp_dir)
|
|
unzipped_paths = _unzip_if_needed([self.s3_path + "/large.json"], "json")
|
|
self.assertEqual(
|
|
unzipped_paths[0],
|
|
self.s3_path + "/large.json",
|
|
)
|
|
|
|
os.chdir(cwdir)
|
|
|
|
def test_relative_json(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on relative json
|
|
files"""
|
|
# this should work regardless of where th current working directory is.
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cwdir = os.getcwd()
|
|
os.chdir(tmp_dir)
|
|
unzipped_paths = _unzip_if_needed(
|
|
[str(Path(self.relative_path) / "large.json")], "json"
|
|
)
|
|
self.assertEqual(
|
|
os.path.realpath(str(Path(unzipped_paths[0]).absolute())),
|
|
os.path.realpath(
|
|
str(
|
|
Path(__file__).parent.parent.parent
|
|
/ self.relative_path
|
|
/ "large.json"
|
|
)
|
|
),
|
|
)
|
|
|
|
assert all([Path(fpath).exists() for fpath in unzipped_paths])
|
|
os.chdir(cwdir)
|
|
|
|
def test_absolute_json(self):
|
|
"""Tests whether the unzip_if_needed function works correctly on absolute json
|
|
files"""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
cwdir = os.getcwd()
|
|
os.chdir(tmp_dir)
|
|
unzipped_paths = _unzip_if_needed(
|
|
[str(Path(self.absolute_path) / "large.json")], "json"
|
|
)
|
|
self.assertEqual(
|
|
os.path.realpath(unzipped_paths[0]),
|
|
os.path.realpath(
|
|
str(Path(self.absolute_path).absolute() / "large.json")
|
|
),
|
|
)
|
|
|
|
assert all([Path(fpath).exists() for fpath in unzipped_paths])
|
|
os.chdir(cwdir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
import pytest
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|