mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
307 lines
10 KiB
Python
307 lines
10 KiB
Python
from collections import OrderedDict
|
|
import argparse
|
|
import os
|
|
import json
|
|
import ray
|
|
import time
|
|
import timeit
|
|
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import torch
|
|
import horovod.torch as hvd
|
|
from horovod.ray import RayExecutor
|
|
|
|
from ray_shuffling_data_loader.data_generation import DATA_SPEC
|
|
from ray_shuffling_data_loader.embedding_model import MyModel, annotation, \
|
|
huber_loss
|
|
from ray.data.dataset_pipeline import DatasetPipeline
|
|
|
|
# Training settings
|
|
parser = argparse.ArgumentParser(description="Dataset ingestion Example")
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=250000,
|
|
metavar="N",
|
|
help="input batch size for training (default: 64)")
|
|
parser.add_argument(
|
|
"--epochs",
|
|
type=int,
|
|
default=10,
|
|
metavar="N",
|
|
help="number of epochs to train (default: 10)")
|
|
parser.add_argument(
|
|
"--debug", action="store_true", default=False, help="disables hvd")
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=42,
|
|
metavar="S",
|
|
help="random seed (default: 42)")
|
|
parser.add_argument(
|
|
"--log-interval",
|
|
type=int,
|
|
default=10,
|
|
metavar="N",
|
|
help=("how many batches to wait before logging training "
|
|
"status"))
|
|
parser.add_argument("--num-workers", type=int, default=16)
|
|
parser.add_argument("--mock-train-step-time", type=float, default=1.0)
|
|
parser.add_argument("--num-files", type=int, default=30)
|
|
parser.add_argument("--num-windows", type=int, default=1)
|
|
|
|
SIZE_50_G = 30 # 49.17GB
|
|
SIZE_100_G = 62 # 101.62GB
|
|
SIZE_500_G = 305 # 499.93GB
|
|
SIZE_1500_G = 915
|
|
|
|
|
|
def construct_optimizers(model):
|
|
sparse_params = []
|
|
dense_params = []
|
|
for k, v in model.named_parameters():
|
|
if "input.embeddings.embeddings" in k:
|
|
sparse_params.append((k, v))
|
|
else:
|
|
dense_params.append((k, v))
|
|
|
|
optimizers = []
|
|
if len(dense_params) > 0:
|
|
opt = optim.Adam([v for _, v in dense_params], lr=0.001)
|
|
opt = hvd.DistributedOptimizer(opt, dense_params)
|
|
optimizers.append(opt)
|
|
if len(sparse_params) > 0:
|
|
opt = optim.SparseAdam([v for _, v in sparse_params], lr=0.001)
|
|
opt = hvd.DistributedOptimizer(opt, sparse_params)
|
|
optimizers.append(opt)
|
|
|
|
if hvd.rank() == 0:
|
|
print(optimizers)
|
|
|
|
return optimizers
|
|
|
|
|
|
def train_main(args, splits):
|
|
# Horovod: initialize library.
|
|
hvd.init()
|
|
torch.manual_seed(args.seed)
|
|
|
|
if torch.cuda.is_available():
|
|
# Horovod: pin GPU to local rank.
|
|
torch.cuda.set_device(hvd.local_rank())
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
# Horovod: limit # of CPU threads to be used per worker.
|
|
torch.set_num_threads(1)
|
|
rank = hvd.rank()
|
|
|
|
model = MyModel(annotation, use_bn=False)
|
|
# By default, Adasum doesn"t need scaling up learning rate.
|
|
if torch.cuda.is_available():
|
|
# Move model to GPU.
|
|
model.cuda()
|
|
|
|
optimizers = construct_optimizers(model)
|
|
loss_function = huber_loss
|
|
# Horovod: broadcast parameters & optimizer state.
|
|
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
|
for opt in optimizers:
|
|
hvd.broadcast_optimizer_state(opt, root_rank=0)
|
|
|
|
def _train(epoch, train_dataset):
|
|
model.train()
|
|
# Horovod: set epoch to sampler for shuffling.
|
|
# train_dataset.set_epoch(epoch)
|
|
start_epoch = timeit.default_timer()
|
|
last_batch_time = start_epoch
|
|
batch_wait_times = []
|
|
for batch_idx, (data, target) in enumerate(train_dataset):
|
|
batch_wait_times.append(timeit.default_timer() - last_batch_time)
|
|
if torch.cuda.is_available():
|
|
data = [t.cuda() for t in data]
|
|
target = target.cuda()
|
|
for opt in optimizers:
|
|
opt.zero_grad()
|
|
batch = OrderedDict()
|
|
batch["embeddings"] = OrderedDict()
|
|
batch["one_hot"] = OrderedDict()
|
|
for name, tensor in zip(annotation["embeddings"], data):
|
|
batch["embeddings"][name] = tensor
|
|
hot0, hot1 = data[-2:]
|
|
batch["one_hot"]["hot0"] = hot0
|
|
batch["one_hot"]["hot1"] = hot1
|
|
|
|
batch_pred = model(batch)
|
|
|
|
if batch_idx % args.log_interval == 0:
|
|
print(
|
|
f"Processing batch {batch_idx} in epoch {epoch} on worker "
|
|
f"{rank}.")
|
|
time.sleep(args.mock_train_step_time)
|
|
loss = loss_function(batch_pred, target, delta=60)
|
|
loss.mean().backward()
|
|
for opt in optimizers:
|
|
opt.step()
|
|
|
|
last_batch_time = timeit.default_timer()
|
|
epoch_duration = timeit.default_timer() - start_epoch
|
|
avg_batch_wait_time = np.mean(batch_wait_times)
|
|
std_batch_wait_time = np.std(batch_wait_times)
|
|
max_batch_wait_time = np.max(batch_wait_times)
|
|
min_batch_wait_time = np.min(batch_wait_times)
|
|
print(f"\nEpoch {epoch}, worker {rank} stats over "
|
|
f"{len(batch_wait_times)} steps: {epoch_duration:.3f}")
|
|
print(f"Mean batch wait time: {avg_batch_wait_time:.3f}s +- "
|
|
f"{std_batch_wait_time}")
|
|
print(f"Max batch wait time: {max_batch_wait_time:.3f}s")
|
|
print(f"Min batch wait time: {min_batch_wait_time:.3f}s")
|
|
return batch_wait_times
|
|
|
|
print(f"Starting training on worker {rank}.")
|
|
batch_wait_times = []
|
|
for epoch, split_ds in enumerate(splits[rank].iter_datasets()):
|
|
train_dataset = create_torch_iterator(split_ds, args.batch_size, rank)
|
|
new_batch_times = _train(epoch, train_dataset)
|
|
new_batch_times.pop(0)
|
|
batch_wait_times.extend(new_batch_times)
|
|
print(f"Done training on worker {rank}.")
|
|
avg_batch_wait_time = np.mean(batch_wait_times)
|
|
std_batch_wait_time = np.std(batch_wait_times)
|
|
max_batch_wait_time = np.max(batch_wait_times)
|
|
min_batch_wait_time = np.min(batch_wait_times)
|
|
print(f"\nWorker {rank} training stats over {args.epochs} epochs:")
|
|
print(f"Mean batch wait time: {avg_batch_wait_time:.3f}s +- "
|
|
f"{std_batch_wait_time}")
|
|
print(f"Max batch wait time: {max_batch_wait_time:.3f}s")
|
|
print(f"Min batch wait time: {min_batch_wait_time:.3f}s")
|
|
|
|
|
|
######################################################
|
|
|
|
numpy_to_torch_dtype = {
|
|
np.bool: torch.bool,
|
|
np.uint8: torch.uint8,
|
|
np.int8: torch.int8,
|
|
np.int16: torch.int16,
|
|
np.int32: torch.int32,
|
|
np.int64: torch.int64,
|
|
np.float16: torch.float16,
|
|
np.float32: torch.float32,
|
|
np.float64: torch.float64,
|
|
np.complex64: torch.complex64,
|
|
np.complex128: torch.complex128
|
|
}
|
|
|
|
|
|
def create_torch_iterator(split, batch_size, rank=None):
|
|
print(f"Creating Torch shuffling dataset for worker {rank} with "
|
|
f"{batch_size} batch size.")
|
|
feature_columns = list(DATA_SPEC.keys())
|
|
feature_types = [
|
|
numpy_to_torch_dtype[dtype] for _, _, dtype in DATA_SPEC.values()
|
|
]
|
|
label_column = feature_columns.pop()
|
|
label_type = feature_types.pop()
|
|
|
|
torch_iterator = split.to_torch(
|
|
label_column=label_column,
|
|
feature_columns=feature_columns,
|
|
label_column_dtype=label_type,
|
|
feature_column_dtypes=feature_types,
|
|
batch_size=batch_size,
|
|
# prefetch_blocks: int = 0,
|
|
# drop_last: bool = False
|
|
)
|
|
return torch_iterator
|
|
|
|
|
|
def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
|
|
if num_windows > 1:
|
|
num_rows = ray.data.read_parquet(
|
|
files, _spread_resource_prefix="node:").count(
|
|
) # This should only read Parquet metadata.
|
|
file_splits = np.array_split(files, num_windows)
|
|
|
|
class Windower:
|
|
def __init__(self):
|
|
self.i = 0
|
|
self.iterations = epochs * num_windows
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i >= self.iterations:
|
|
raise StopIteration()
|
|
split = file_splits[self.i % num_windows]
|
|
self.i += 1
|
|
return lambda: ray.data.read_parquet(
|
|
list(split), _spread_resource_prefix="node:")
|
|
|
|
pipe = DatasetPipeline.from_iterable(Windower())
|
|
split_indices = [
|
|
i * num_rows // num_windows // num_workers
|
|
for i in range(1, num_workers)
|
|
]
|
|
pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
|
|
pipe_shards = pipe.split_at_indices(split_indices)
|
|
else:
|
|
ds = ray.data.read_parquet(files, _spread_resource_prefix="node:")
|
|
pipe = ds.repeat(epochs)
|
|
pipe = pipe.random_shuffle_each_window(_spread_resource_prefix="node:")
|
|
pipe_shards = pipe.split(num_workers, equal=True)
|
|
return pipe_shards
|
|
|
|
|
|
@ray.remote
|
|
def consume(split, rank=None, batch_size=None):
|
|
torch_iterator = create_torch_iterator(
|
|
split, batch_size=batch_size, rank=rank)
|
|
for i, (x, y) in enumerate(torch_iterator):
|
|
if i % 10 == 0:
|
|
print(i)
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
import ray
|
|
print("Connecting to Ray cluster...")
|
|
ray.init(address="auto")
|
|
|
|
num = args.num_files
|
|
|
|
files = [
|
|
f"s3://shuffling-data-loader-benchmarks/data/r10_000_000_000-f1000"
|
|
f"/input_data_{i}.parquet.snappy" for i in range(args.num_files)
|
|
]
|
|
|
|
start = time.time()
|
|
|
|
splits = create_dataset(
|
|
files,
|
|
num_workers=args.num_workers,
|
|
epochs=args.epochs,
|
|
num_windows=args.num_windows)
|
|
|
|
if args.debug:
|
|
tasks = [
|
|
consume.options(num_gpus=1).remote(
|
|
split, rank=idx, batch_size=args.batch_size)
|
|
for idx, split in enumerate(splits)
|
|
]
|
|
ray.get(tasks)
|
|
else:
|
|
print("Create Ray executor")
|
|
settings = RayExecutor.create_settings(timeout_s=30)
|
|
executor = RayExecutor(
|
|
settings, num_workers=args.num_workers, use_gpu=True)
|
|
executor.start()
|
|
executor.run(train_main, args=[args, splits])
|
|
executor.shutdown()
|
|
|
|
delta = time.time() - start
|
|
print(f"success! total time {delta}")
|
|
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
|
|
f.write(json.dumps({"ingest_time": delta, "success": 1}))
|