ray/release/nightly_tests/shuffle_data_loader/benchmark.py

324 lines
13 KiB
Python
Raw Normal View History

import argparse
import asyncio
import collections
import contextlib
import glob
import json
import os
import timeit
from typing import List
import numpy as np
import ray
from ray_shuffling_data_loader.shuffle import (shuffle, BatchConsumer)
from ray_shuffling_data_loader.stats import (
TrialStatsCollector, ObjectStoreStatsCollector, process_stats,
human_readable_size)
from ray_shuffling_data_loader.data_generation import generate_data
BATCHQUEUE_ACTOR_NAME = "BatchQueue"
DEFAULT_DATA_DIR = "/mnt/disk0/benchmark_scratch"
DEFAULT_STATS_DIR = "./results"
DEFAULT_UTILIZATION_SAMPLE_PERIOD = 5.0
@ray.remote(num_cpus=0)
class Consumer:
def __init__(self,
rank,
num_epochs,
max_concurrent_epochs,
stats_collector=None):
self._rank = rank
self._num_epochs = num_epochs
self._max_epochs = max_concurrent_epochs
self._curr_epochs = collections.deque()
self._epoch_done_evs = [asyncio.Event() for _ in range(num_epochs)]
self._stats_collector = stats_collector
async def new_epoch(self, epoch):
if len(self._curr_epochs) == self._max_epochs:
first_epoch = self._curr_epochs.popleft()
await self._epoch_done_evs[first_epoch].wait()
self._curr_epochs.append(epoch)
print(f"Starting epoch {epoch} on consumer {self._rank}.")
def consume(self, epoch, batch):
print(f"Consuming batch on consumer {self._rank} for epoch {epoch}.")
if self._stats_collector is not None:
self._stats_collector.consume_batch.remote(epoch, len(batch))
def producer_done(self, epoch):
if self._stats_collector is not None:
self._stats_collector.consume_done.remote(epoch)
self._epoch_done_evs[epoch].set()
print(f"Epoch {epoch} done on consumer {self._rank}.")
async def wait_until_all_epochs_done(self):
await self._epoch_done_evs[self._num_epochs - 1].wait()
def ready(self):
pass
class BatchConsumer(BatchConsumer):
def __init__(self,
num_trainers,
num_epochs,
pg,
max_concurrent_epochs,
stats_collector=None):
self._consumers = [
Consumer.options(placement_group=pg).remote(
rank, num_epochs, max_concurrent_epochs, stats_collector)
for rank in range(num_trainers)
]
def consume(self, rank: int, epoch: int, batches: List[ray.ObjectRef]):
if batches is not None:
for batch in batches:
self._consumers[rank].consume.remote(epoch, batch)
def producer_done(self, rank: int, epoch: int):
self._consumers[rank].producer_done.remote(epoch)
def wait_until_ready(self, epoch: int):
ray.get(
[consumer.new_epoch.remote(epoch) for consumer in self._consumers])
def wait_until_all_epochs_done(self):
ray.get([
consumer.wait_until_all_epochs_done.remote()
for consumer in self._consumers
])
def actors_ready(self):
ray.get([consumer.ready.remote() for consumer in self._consumers])
def get_stats(self):
(consume_times, time_to_consumes, consume_start_times,
consume_end_times) = tuple(
list(zip(*stats_)) for stats_ in zip(*ray.get(
[consumer.get_stats.remote()
for consumer in self._consumers])))
consume_stage_durations = []
for start_times, end_times in zip(consume_start_times,
consume_end_times):
consume_stage_durations.append(max(end_times) - min(start_times))
return consume_times, time_to_consumes, consume_stage_durations
def run_trials(num_epochs,
filenames,
num_reducers,
num_trainers,
max_concurrent_epochs,
utilization_sample_period,
collect_stats=True,
num_trials=None,
trials_timeout=None):
"""
Run shuffling trials.
"""
print("Using from-memory shuffler.")
all_stats = []
pg = ray.util.placement_group(
[{
"CPU": 0.1
} for _ in range(num_trainers)], strategy="SPREAD")
ray.get(pg.ready())
if collect_stats:
stats_collector = TrialStatsCollector.remote(
num_epochs, len(filenames), num_reducers, num_trainers)
object_store_stats_collector = ObjectStoreStatsCollector(
utilization_sample_period)
else:
stats_collector = None
try:
object_store_stats_collector = contextlib.nullcontext()
except AttributeError:
# Python 3.6 doesn't support nullcontext().
object_store_stats_collector = contextlib.suppress()
batch_consumer = BatchConsumer(num_trainers, num_epochs, pg,
max_concurrent_epochs, stats_collector)
# Wait until batch consumer actors have been created.
batch_consumer.actors_ready()
if num_trials is not None:
for trial in range(num_trials):
print(f"Starting trial {trial}.")
with object_store_stats_collector:
duration = shuffle(filenames, batch_consumer, num_epochs,
num_reducers, num_trainers, stats_collector)
print(f"Trial {trial} done after {duration} seconds.")
if collect_stats:
stats = ray.get(stats_collector.get_stats.remote())
store_stats = object_store_stats_collector.get_stats()
else:
stats = duration
store_stats = None
all_stats.append((stats, store_stats))
elif trials_timeout is not None:
start = timeit.default_timer()
trial = 0
while timeit.default_timer() - start < trials_timeout:
print(f"Starting trial {trial}.")
with object_store_stats_collector:
duration = shuffle(filenames, batch_consumer, num_epochs,
num_reducers, num_trainers, stats_collector)
print(f"Trial {trial} done after {duration} seconds.")
if collect_stats:
stats = ray.get(stats_collector.get_stats.remote())
store_stats = object_store_stats_collector.get_stats()
else:
stats = duration
store_stats = None
all_stats.append((stats, store_stats))
trial += 1
else:
raise ValueError(
"One of num_trials and trials_timeout must be specified")
return all_stats
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Shuffling data loader")
parser.add_argument("--num-rows", type=int, default=4 * (10**11))
parser.add_argument("--num-files", type=int, default=100)
parser.add_argument("--max-row-group-skew", type=float, default=0.0)
parser.add_argument("--num-row-groups-per-file", type=int, default=1)
parser.add_argument("--num-reducers", type=int, default=5)
parser.add_argument("--num-trainers", type=int, default=5)
parser.add_argument("--num-epochs", type=int, default=10)
parser.add_argument("--max-concurrent-epochs", type=int, default=None)
parser.add_argument("--batch-size", type=int, default=100)
parser.add_argument("--num-trials", type=int, default=None)
parser.add_argument("--trials-timeout", type=int, default=None)
parser.add_argument(
"--utilization-sample-period",
type=float,
default=DEFAULT_UTILIZATION_SAMPLE_PERIOD)
parser.add_argument("--cluster", action="store_true")
parser.add_argument("--data-dir", type=str, default=DEFAULT_DATA_DIR)
parser.add_argument("--stats-dir", type=str, default=DEFAULT_STATS_DIR)
parser.add_argument("--clear-old-data", action="store_true")
parser.add_argument("--use-old-data", action="store_true")
parser.add_argument("--no-stats", action="store_true")
parser.add_argument("--no-epoch-stats", action="store_true")
parser.add_argument("--no-consumer-stats", action="store_true")
parser.add_argument("--overwrite-stats", action="store_true")
parser.add_argument("--unique-stats", action="store_true")
args = parser.parse_args()
if args.num_row_groups_per_file < 1:
raise ValueError("Must have at least one row group per file.")
num_trials = args.num_trials
trials_timeout = args.trials_timeout
if num_trials is not None and trials_timeout is not None:
raise ValueError(
"Only one of --num-trials and --trials-timeout should be "
"specified.")
if num_trials is None and trials_timeout is None:
num_trials = 3
if args.clear_old_data and args.use_old_data:
raise ValueError(
"Only one of --clear-old-data and --use-old-data should be "
"specified.")
data_dir = args.data_dir
if args.clear_old_data:
print(f"Clearing old data from {data_dir}.")
files = glob.glob(os.path.join(data_dir, "*.parquet.snappy"))
for f in files:
os.remove(f)
if args.cluster:
print("Connecting to an existing Ray cluster.")
ray.init(address="auto")
else:
print("Starting a new local Ray cluster.")
ray.init(resources={"resources": 100})
num_rows = args.num_rows
num_row_groups_per_file = args.num_row_groups_per_file
num_files = args.num_files
max_row_group_skew = args.max_row_group_skew
if not args.use_old_data:
print(f"Generating {num_rows} rows over {num_files} files, with "
f"{num_row_groups_per_file} row groups per file and at most "
f"{100 * max_row_group_skew:.1f}% row group skew.")
filenames, num_bytes = generate_data(num_rows, num_files,
num_row_groups_per_file,
max_row_group_skew, data_dir)
print(f"Generated {len(filenames)} files containing {num_rows} rows "
f"with {num_row_groups_per_file} row groups per file, totalling "
f"{human_readable_size(num_bytes)}.")
else:
filenames = [
os.path.join(data_dir, f"input_data_{file_index}.parquet.snappy")
for file_index in range(num_files)
]
print("Not generating input data, using existing data instead.")
num_reducers = args.num_reducers
num_trainers = args.num_trainers
batch_size = args.batch_size
num_epochs = args.num_epochs
max_concurrent_epochs = args.max_concurrent_epochs
if max_concurrent_epochs is None or max_concurrent_epochs > num_epochs:
max_concurrent_epochs = num_epochs
assert max_concurrent_epochs > 0
utilization_sample_period = args.utilization_sample_period
# TODO(Clark): Add warmup trials.
print("\nRunning real trials.")
if num_trials is not None:
print(f"Running {num_trials} shuffle trials with {num_epochs} epochs, "
f"{num_reducers} reducers, {num_trainers} trainers, and a batch "
f"size of {batch_size} over {num_rows} rows.")
else:
print(f"Running {trials_timeout} seconds of shuffle trials with "
f"{num_epochs} epochs, {num_reducers} reducers, {num_trainers} "
f"trainers, and a batch size of {batch_size} over {num_rows} "
"rows.")
print(f"Shuffling will be pipelined with at most "
f"{max_concurrent_epochs} concurrent epochs.")
collect_stats = not args.no_stats
all_stats = run_trials(num_epochs, filenames, num_reducers, num_trainers,
max_concurrent_epochs, utilization_sample_period,
collect_stats, num_trials, trials_timeout)
if collect_stats:
process_stats(all_stats, args.overwrite_stats, args.stats_dir,
args.no_epoch_stats, args.no_consumer_stats,
args.unique_stats, num_rows, num_files,
num_row_groups_per_file, batch_size, num_reducers,
num_trainers, num_epochs, max_concurrent_epochs)
else:
print("Shuffle trials done, no detailed stats collected.")
times, _ = zip(*all_stats)
mean = np.mean(times)
std = np.std(times)
throughput_std = np.std(
[num_epochs * num_rows / time for time in times])
batch_throughput_std = np.std(
[(num_epochs * num_rows / batch_size) / time for time in times])
print(f"\nMean over {len(times)} trials: {mean:.3f}s +- {std}")
print(f"Mean throughput over {len(times)} trials: "
f"{num_epochs * num_rows / mean:.2f} rows/s +- "
f"{throughput_std:.2f}")
print(f"Mean batch throughput over {len(times)} trials: "
f"{(num_epochs * num_rows / batch_size) / mean:.2f} batches/s "
f"+- {batch_throughput_std:.2f}")
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
f.write(json.dumps({"success": 1}))