ray/rllib/offline/tests/test_dataset_reader.py

271 lines
9.4 KiB
Python

import tempfile
import os
from pathlib import Path
import unittest
import pytest
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.dataset_reader import (
DatasetReader,
get_dataset_and_shards,
_unzip_if_needed,
)
class TestDatasetReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
# TODO(Kourosh): Hitting S3 in CI is currently broken due to some AWS
# credentials issues, using a local file instead for now.
# cls.dset_path = "s3://air-example-data/rllib/cartpole/large.json"
cls.dset_path = "tests/data/pendulum/large.json"
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_dataset_reader_itr_batches(self):
"""Test that the dataset reader iterates over batches of rows correctly."""
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(dataset, ioctx)
assert len(reader.next()) >= 1200
def test_dataset_shard_with_only_local(self):
"""Tests whether the dataset_shard function works correctly for a single shard
for the local worker."""
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": self.dset_path},
}
# two ways of doing this:
# we have no remote workers
_, shards = get_dataset_and_shards(config, num_workers=0)
assert len(shards) == 1
assert isinstance(shards[0], ray.data.Dataset)
def test_dataset_shard_remote_workers_with_local_worker(self):
"""Tests whether the dataset_shard function works correctly for the remote
workers with a dummy dataset shard for the local worker."""
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": self.dset_path},
}
NUM_WORKERS = 4
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
assert len(shards) == NUM_WORKERS + 1
assert shards[0] is None
assert all(
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
)
def test_dataset_shard_with_task_parallelization(self):
"""Tests whether the dataset_shard function works correctly with parallelism
for reading the dataset."""
config = {
"input": "dataset",
"input_config": {
"format": "json",
"paths": self.dset_path,
"parallelism": 10,
},
}
NUM_WORKERS = 4
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
assert len(shards) == NUM_WORKERS + 1
assert shards[0] is None
assert all(
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
)
def test_dataset_shard_with_loader_fn(self):
"""Tests whether the dataset_shard function works correctly with loader_fn."""
dset = ray.data.range(100)
config = {"input": "dataset", "input_config": {"loader_fn": lambda: dset}}
ret_dataset, _ = get_dataset_and_shards(config)
assert ret_dataset.count() == dset.count()
def test_dataset_shard_error_with_unsupported_dataset_format(self):
"""Tests whether the dataset_shard function raises an error when an unsupported
dataset format is specified."""
config = {
"input": "dataset",
"input_config": {
"format": "__UNSUPPORTED_FORMAT__",
"paths": self.dset_path,
},
}
with self.assertRaises(ValueError):
get_dataset_and_shards(config)
def test_dataset_shard_error_with_both_format_and_loader_fn(self):
"""Tests whether the dataset_shard function raises an error when both format
and loader_fn are specified."""
dset = ray.data.range(100)
config = {
"input": "dataset",
"input_config": {
"format": "json",
"paths": self.dset_path,
"loader_fn": lambda: dset,
},
}
with self.assertRaises(ValueError):
get_dataset_and_shards(config)
def test_default_ioctx(self):
# Test DatasetReader without passing in IOContext
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
reader = DatasetReader(dataset)
# Reads in one line of Pendulum dataset with 600 timesteps
assert len(reader.next()) == 600
class TestUnzipIfNeeded(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.s3_path = "s3://air-example-data/rllib/pendulum"
cls.relative_path = "tests/data/pendulum"
cls.absolute_path = str(
Path(__file__).parent.parent.parent / "tests" / "data" / "pendulum"
)
# @TODO: unskip when this is fixed
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
def test_s3_zip(self):
"""Tests whether the unzip_if_needed function works correctly on s3 zip
files"""
unzipped_paths = _unzip_if_needed([self.s3_path + "/enormous.zip"], "json")
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
def test_relative_zip(self):
"""Tests whether the unzip_if_needed function works correctly on relative zip
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.relative_path) / "enormous.zip")], "json"
)
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
def test_absolute_zip(self):
"""Tests whether the unzip_if_needed function works correctly on absolute zip
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.absolute_path) / "enormous.zip")], "json"
)
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
# @TODO: unskip when this is fixed
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
def test_s3_json(self):
"""Tests whether the unzip_if_needed function works correctly on s3 json
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed([self.s3_path + "/large.json"], "json")
self.assertEqual(
unzipped_paths[0],
self.s3_path + "/large.json",
)
os.chdir(cwdir)
def test_relative_json(self):
"""Tests whether the unzip_if_needed function works correctly on relative json
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.relative_path) / "large.json")], "json"
)
self.assertEqual(
os.path.realpath(str(Path(unzipped_paths[0]).absolute())),
os.path.realpath(
str(
Path(__file__).parent.parent.parent
/ self.relative_path
/ "large.json"
)
),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
def test_absolute_json(self):
"""Tests whether the unzip_if_needed function works correctly on absolute json
files"""
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.absolute_path) / "large.json")], "json"
)
self.assertEqual(
os.path.realpath(unzipped_paths[0]),
os.path.realpath(
str(Path(self.absolute_path).absolute() / "large.json")
),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))