mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import argparse
|
|
import os
|
|
import json
|
|
import time
|
|
|
|
import ray
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
PATH = [
|
|
f"s3://shuffling-data-loader-benchmarks/data/input_data_{i}.parquet.snappy"
|
|
for i in range(0, 25)
|
|
]
|
|
|
|
|
|
def create_parser():
|
|
parser = argparse.ArgumentParser(description="Dataset shuffle")
|
|
parser.add_argument(
|
|
"--address", type=str, default=os.environ["RAY_ADDRESS"])
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=250000,
|
|
metavar="N",
|
|
help="input batch size for training (default: 250000)")
|
|
parser.add_argument("--num-workers", type=int, default=4)
|
|
parser.add_argument("--repeat-times", type=int, default=16)
|
|
return parser
|
|
|
|
|
|
def create_torch_iterator(split, batch_size, rank=None):
|
|
print(f"Creating Torch shuffling dataset for worker {rank} with "
|
|
f"{batch_size} batch size.")
|
|
numpy_to_torch_dtype = {
|
|
np.bool: torch.bool,
|
|
np.uint8: torch.uint8,
|
|
np.int8: torch.int8,
|
|
np.int16: torch.int16,
|
|
np.int32: torch.int32,
|
|
np.int64: torch.int64,
|
|
np.float16: torch.float16,
|
|
np.float32: torch.float32,
|
|
np.float64: torch.float64,
|
|
np.complex64: torch.complex64,
|
|
np.complex128: torch.complex128
|
|
}
|
|
DATA_SPEC = {
|
|
"embeddings_name0": (0, 2385, np.int64),
|
|
"embeddings_name1": (0, 201, np.int64),
|
|
"embeddings_name2": (0, 201, np.int64),
|
|
"embeddings_name3": (0, 6, np.int64),
|
|
"embeddings_name4": (0, 19, np.int64),
|
|
"embeddings_name5": (0, 1441, np.int64),
|
|
"embeddings_name6": (0, 201, np.int64),
|
|
"embeddings_name7": (0, 22, np.int64),
|
|
"embeddings_name8": (0, 156, np.int64),
|
|
"embeddings_name9": (0, 1216, np.int64),
|
|
"embeddings_name10": (0, 9216, np.int64),
|
|
"embeddings_name11": (0, 88999, np.int64),
|
|
"embeddings_name12": (0, 941792, np.int64),
|
|
"embeddings_name13": (0, 9405, np.int64),
|
|
"embeddings_name14": (0, 83332, np.int64),
|
|
"embeddings_name15": (0, 828767, np.int64),
|
|
"embeddings_name16": (0, 945195, np.int64),
|
|
"one_hot0": (0, 3, np.int64),
|
|
"one_hot1": (0, 50, np.int64),
|
|
"labels": (0, 1, np.float64),
|
|
}
|
|
feature_columns = list(DATA_SPEC.keys())
|
|
feature_types = [
|
|
numpy_to_torch_dtype[dtype] for _, _, dtype in DATA_SPEC.values()
|
|
]
|
|
label_column = feature_columns.pop()
|
|
label_type = feature_types.pop()
|
|
torch_iterator = split.to_torch(
|
|
label_column=label_column,
|
|
feature_columns=feature_columns,
|
|
label_column_dtype=label_type,
|
|
feature_column_dtypes=feature_types,
|
|
batch_size=batch_size,
|
|
)
|
|
return torch_iterator
|
|
|
|
|
|
def create_dataset(filenames, repeat_times):
|
|
pipeline = ray.data.read_parquet(list(filenames))\
|
|
.repeat(times=repeat_times).random_shuffle_each_window()
|
|
return pipeline
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
print("Connecting to Ray cluster...")
|
|
ray.init(address=args.address)
|
|
|
|
start = time.time()
|
|
|
|
pipeline = create_dataset(PATH, args.repeat_times)
|
|
splits = pipeline.split(args.num_workers)
|
|
|
|
@ray.remote(num_gpus=1)
|
|
def consume(split, rank=None, batch_size=None):
|
|
torch_iterator = create_torch_iterator(
|
|
split, batch_size=batch_size, rank=rank)
|
|
for i, (x, y) in enumerate(torch_iterator):
|
|
time.sleep(1)
|
|
if i % 10 == 0:
|
|
print(i)
|
|
return
|
|
|
|
tasks = [
|
|
consume.remote(split, rank=idx, batch_size=args.batch_size)
|
|
for idx, split in enumerate(splits)
|
|
]
|
|
|
|
ray.get(tasks)
|
|
|
|
delta = time.time() - start
|
|
print(f"success! total time {delta}")
|
|
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
|
|
f.write(json.dumps({"shuffle_time": delta, "success": 1}))
|