Dataset speed up read (#17435)

This commit is contained in:
Alex Wu 2021-08-01 18:03:46 -07:00 committed by GitHub
parent 6703091cdc
commit d9cd3800c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 21 deletions

View file

@ -23,3 +23,6 @@ class BinaryDatasource(FileBasedDatasource):
return path, data
else:
return data
def _rows_per_file(self):
return 1

View file

@ -62,22 +62,35 @@ class FileBasedDatasource(Datasource[Union[ArrowRow, Any]]):
builder.add(data)
return builder.build()
read_tasks = [
ReadTask(
read_tasks = []
for read_paths, file_sizes in zip(
np.array_split(paths, parallelism),
np.array_split(file_sizes, parallelism)):
if len(read_paths) <= 0:
continue
if self._rows_per_file() is None:
num_rows = None
else:
num_rows = len(read_paths) * self._rows_per_file()
read_task = ReadTask(
lambda read_paths=read_paths: read_files(
read_paths, filesystem),
BlockMetadata(
num_rows=None,
num_rows=num_rows,
size_bytes=sum(file_sizes),
schema=schema,
input_files=read_paths)) for read_paths, file_sizes in zip(
np.array_split(paths, parallelism),
np.array_split(file_sizes, parallelism))
if len(read_paths) > 0
]
input_files=read_paths)
)
read_tasks.append(read_task)
return read_tasks
def _rows_per_file(self):
"""Returns the number of rows per file, or None if unknown.
"""
return None
def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args):
"""Reads a single file, passing all kwargs to the reader.

View file

@ -1,6 +1,6 @@
import logging
from typing import List, Any, Union, Optional, Tuple, Callable, TypeVar, \
TYPE_CHECKING
from typing import List, Any, Dict, Union, Optional, Tuple, Callable, \
TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
import pyarrow
@ -104,6 +104,7 @@ def range_arrow(n: int, *, parallelism: int = 200) -> Dataset[ArrowRow]:
def read_datasource(datasource: Datasource[T],
*,
parallelism: int = 200,
ray_remote_args: Dict[str, Any] = None,
**read_args) -> Dataset[T]:
"""Read a dataset from a custom data source.
@ -111,6 +112,7 @@ def read_datasource(datasource: Datasource[T],
datasource: The datasource to read data from.
parallelism: The requested parallelism of the read.
read_args: Additional kwargs to pass to the datasource impl.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
Returns:
Dataset holding the data read from the datasource.
@ -118,10 +120,14 @@ def read_datasource(datasource: Datasource[T],
read_tasks = datasource.prepare_read(parallelism, **read_args)
@ray.remote
def remote_read(task: ReadTask) -> Block:
return task()
if ray_remote_args:
remote_read = ray.remote(**ray_remote_args)(remote_read)
else:
remote_read = ray.remote(remote_read)
calls: List[Callable[[], ObjectRef[Block]]] = []
metadata: List[BlockMetadata] = []
@ -157,6 +163,7 @@ def read_parquet(paths: Union[str, List[str]],
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
columns: Optional[List[str]] = None,
parallelism: int = 200,
ray_remote_args: Dict[str, Any] = None,
**arrow_parquet_args) -> Dataset[ArrowRow]:
"""Create an Arrow dataset from parquet files.
@ -172,6 +179,7 @@ def read_parquet(paths: Union[str, List[str]],
filesystem: The filesystem implementation to read from.
columns: A list of column names to read.
parallelism: The amount of parallelism to use for the dataset.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
arrow_parquet_args: Other parquet read options to pass to pyarrow.
Returns:
@ -183,6 +191,7 @@ def read_parquet(paths: Union[str, List[str]],
paths=paths,
filesystem=filesystem,
columns=columns,
ray_remote_args=ray_remote_args,
**arrow_parquet_args)
@ -191,6 +200,7 @@ def read_json(paths: Union[str, List[str]],
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
parallelism: int = 200,
ray_remote_args: Dict[str, Any] = None,
**arrow_json_args) -> Dataset[ArrowRow]:
"""Create an Arrow dataset from json files.
@ -209,6 +219,7 @@ def read_json(paths: Union[str, List[str]],
A list of paths can contain both files and directories.
filesystem: The filesystem implementation to read from.
parallelism: The amount of parallelism to use for the dataset.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
arrow_json_args: Other json read options to pass to pyarrow.
Returns:
@ -219,6 +230,7 @@ def read_json(paths: Union[str, List[str]],
parallelism=parallelism,
paths=paths,
filesystem=filesystem,
ray_remote_args=ray_remote_args,
**arrow_json_args)
@ -227,6 +239,7 @@ def read_csv(paths: Union[str, List[str]],
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
parallelism: int = 200,
ray_remote_args: Dict[str, Any] = None,
**arrow_csv_args) -> Dataset[ArrowRow]:
"""Create an Arrow dataset from csv files.
@ -245,6 +258,7 @@ def read_csv(paths: Union[str, List[str]],
A list of paths can contain both files and directories.
filesystem: The filesystem implementation to read from.
parallelism: The amount of parallelism to use for the dataset.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
arrow_csv_args: Other csv read options to pass to pyarrow.
Returns:
@ -255,6 +269,7 @@ def read_csv(paths: Union[str, List[str]],
parallelism=parallelism,
paths=paths,
filesystem=filesystem,
ray_remote_args=ray_remote_args,
**arrow_csv_args)
@ -264,7 +279,9 @@ def read_binary_files(
*,
include_paths: bool = False,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
parallelism: int = 200) -> Dataset[Union[Tuple[str, bytes], bytes]]:
parallelism: int = 200,
ray_remote_args: Dict[str, Any] = None,
) -> Dataset[Union[Tuple[str, bytes], bytes]]:
"""Create a dataset from binary files of arbitrary contents.
Examples:
@ -280,6 +297,7 @@ def read_binary_files(
dataset records. When specified, the dataset records will be a
tuple of the file path and the file contents.
filesystem: The filesystem implementation to read from.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
parallelism: The amount of parallelism to use for the dataset.
Returns:
@ -291,6 +309,7 @@ def read_binary_files(
paths=paths,
include_paths=include_paths,
filesystem=filesystem,
ray_remote_args=ray_remote_args,
schema=bytes)

View file

@ -1782,6 +1782,7 @@ if __name__ == "__main__":
if args.ray_wheels:
os.environ["RAY_WHEELS"] = str(args.ray_wheels)
url = str(args.ray_wheels)
elif not args.check:
url = find_ray_wheels(
GLOBAL_CONFIG["RAY_REPO"],

View file

@ -9,6 +9,6 @@
run:
timeout: 600
prepare: sleep 0
prepare: python wait_cluster.py
script: python inference.py

View file

@ -75,17 +75,15 @@ def infer(batch):
ray.init()
while ray.cluster_resources().get("GPU", 0) != 2:
print("Waiting for GPUs {}/2".format(ray.cluster_resources().get(
"GPU", 400)))
time.sleep(5)
start_time = time.time()
print("Downloading...")
ds = ray.experimental.data.read_binary_files(
"s3://anyscale-data/small-images/", parallelism=400)
ds = ds.limit(100 * 1000)
"s3://anyscale-data/small-images/",
parallelism=1000,
ray_remote_args={"num_cpus": 0.5})
# Do a blocking map so that we can measure the download time.
ds = ds.map(lambda x: x)
end_download_time = time.time()
print("Preprocessing...")

View file

@ -0,0 +1,9 @@
import ray
import time
ray.init()
while ray.cluster_resources().get("GPU", 0) != 2:
print("Waiting for GPUs {}/2".format(ray.cluster_resources().get(
"GPU", 400)))
time.sleep(5)