ray/rllib/offline/tests/test_dataset_reader.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

272 lines
9.4 KiB
Python
Raw Normal View History

import tempfile
import os
from pathlib import Path
import unittest
import pytest
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.dataset_reader import (
DatasetReader,
get_dataset_and_shards,
_unzip_if_needed,
)
class TestDatasetReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
# TODO(Kourosh): Hitting S3 in CI is currently broken due to some AWS
# credentials issues, using a local file instead for now.
# cls.dset_path = "s3://air-example-data/rllib/cartpole/large.json"
cls.dset_path = "tests/data/pendulum/large.json"
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_dataset_reader_itr_batches(self):
"""Test that the dataset reader iterates over batches of rows correctly."""
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(dataset, ioctx)
assert len(reader.next()) >= 1200
def test_dataset_shard_with_only_local(self):
"""Tests whether the dataset_shard function works correctly for a single shard
for the local worker."""
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": self.dset_path},
}
# two ways of doing this:
# we have no remote workers
_, shards = get_dataset_and_shards(config, num_workers=0)
assert len(shards) == 1
assert isinstance(shards[0], ray.data.Dataset)
def test_dataset_shard_remote_workers_with_local_worker(self):
"""Tests whether the dataset_shard function works correctly for the remote
workers with a dummy dataset shard for the local worker."""
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": self.dset_path},
}
NUM_WORKERS = 4
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
assert len(shards) == NUM_WORKERS + 1
assert shards[0] is None
assert all(
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
)
def test_dataset_shard_with_task_parallelization(self):
"""Tests whether the dataset_shard function works correctly with parallelism
for reading the dataset."""
config = {
"input": "dataset",
"input_config": {
"format": "json",
"paths": self.dset_path,
"parallelism": 10,
},
}
NUM_WORKERS = 4
_, shards = get_dataset_and_shards(config, num_workers=NUM_WORKERS)
assert len(shards) == NUM_WORKERS + 1
assert shards[0] is None
assert all(
isinstance(remote_shard, ray.data.Dataset) for remote_shard in shards[1:]
)
def test_dataset_shard_with_loader_fn(self):
"""Tests whether the dataset_shard function works correctly with loader_fn."""
dset = ray.data.range(100)
config = {"input": "dataset", "input_config": {"loader_fn": lambda: dset}}
ret_dataset, _ = get_dataset_and_shards(config)
assert ret_dataset.count() == dset.count()
def test_dataset_shard_error_with_unsupported_dataset_format(self):
"""Tests whether the dataset_shard function raises an error when an unsupported
dataset format is specified."""
config = {
"input": "dataset",
"input_config": {
"format": "__UNSUPPORTED_FORMAT__",
"paths": self.dset_path,
},
}
with self.assertRaises(ValueError):
get_dataset_and_shards(config)
def test_dataset_shard_error_with_both_format_and_loader_fn(self):
"""Tests whether the dataset_shard function raises an error when both format
and loader_fn are specified."""
dset = ray.data.range(100)
config = {
"input": "dataset",
"input_config": {
"format": "json",
"paths": self.dset_path,
"loader_fn": lambda: dset,
},
}
with self.assertRaises(ValueError):
get_dataset_and_shards(config)
def test_default_ioctx(self):
# Test DatasetReader without passing in IOContext
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
reader = DatasetReader(dataset)
# Reads in one line of Pendulum dataset with 600 timesteps
assert len(reader.next()) == 600
class TestUnzipIfNeeded(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.s3_path = "s3://air-example-data/rllib/pendulum"
cls.relative_path = "tests/data/pendulum"
cls.absolute_path = str(
Path(__file__).parent.parent.parent / "tests" / "data" / "pendulum"
)
# @TODO: unskip when this is fixed
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
def test_s3_zip(self):
"""Tests whether the unzip_if_needed function works correctly on s3 zip
files"""
unzipped_paths = _unzip_if_needed([self.s3_path + "/enormous.zip"], "json")
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
def test_relative_zip(self):
"""Tests whether the unzip_if_needed function works correctly on relative zip
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.relative_path) / "enormous.zip")], "json"
)
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
def test_absolute_zip(self):
"""Tests whether the unzip_if_needed function works correctly on absolute zip
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.absolute_path) / "enormous.zip")], "json"
)
self.assertEqual(
str(Path(unzipped_paths[0]).absolute()),
str(Path("./").absolute() / "enormous.json"),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
# @TODO: unskip when this is fixed
@pytest.mark.skip(reason="Shouldn't hit S3 in CI")
def test_s3_json(self):
"""Tests whether the unzip_if_needed function works correctly on s3 json
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed([self.s3_path + "/large.json"], "json")
self.assertEqual(
unzipped_paths[0],
self.s3_path + "/large.json",
)
os.chdir(cwdir)
def test_relative_json(self):
"""Tests whether the unzip_if_needed function works correctly on relative json
files"""
# this should work regardless of where th current working directory is.
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.relative_path) / "large.json")], "json"
)
self.assertEqual(
os.path.realpath(str(Path(unzipped_paths[0]).absolute())),
os.path.realpath(
str(
Path(__file__).parent.parent.parent
/ self.relative_path
/ "large.json"
)
),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
def test_absolute_json(self):
"""Tests whether the unzip_if_needed function works correctly on absolute json
files"""
with tempfile.TemporaryDirectory() as tmp_dir:
cwdir = os.getcwd()
os.chdir(tmp_dir)
unzipped_paths = _unzip_if_needed(
[str(Path(self.absolute_path) / "large.json")], "json"
)
self.assertEqual(
os.path.realpath(unzipped_paths[0]),
os.path.realpath(
str(Path(self.absolute_path).absolute() / "large.json")
),
)
assert all([Path(fpath).exists() for fpath in unzipped_paths])
os.chdir(cwdir)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))