mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[dataset][nightly-test] add pipelined ingestion/training nightly test
This commit is contained in:
parent
565131a854
commit
7c99aae033
8 changed files with 473 additions and 6 deletions
|
@ -82,6 +82,8 @@ CORE_NIGHTLY_TESTS = {
|
|||
"~/ray/release/nightly_tests/dataset/dataset_test.yaml": [
|
||||
"inference",
|
||||
"shuffle_data_loader",
|
||||
"pipelined_training_50_gb",
|
||||
"pipelined_ingestion_1500_gb_15_windows",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
run:
|
||||
timeout: 600
|
||||
prepare: python wait_cluster.py
|
||||
prepare: python wait_cluster.py 2 600
|
||||
script: python inference.py
|
||||
|
||||
- name: shuffle_data_loader
|
||||
|
@ -24,3 +24,31 @@
|
|||
run:
|
||||
timeout: 1800
|
||||
script: python dataset_shuffle_data_loader.py
|
||||
|
||||
- name: pipelined_training_50_gb
|
||||
owner:
|
||||
mail: "core@anyscale.com"
|
||||
slack: "@Chen Shen"
|
||||
|
||||
cluster:
|
||||
app_config: pipelined_training_app.yaml
|
||||
compute_template: pipelined_training_compute.yaml
|
||||
|
||||
run:
|
||||
timeout: 4800
|
||||
prepare: python wait_cluster.py 15 1200
|
||||
script: python pipelined_training.py --epochs 5
|
||||
|
||||
- name: pipelined_ingestion_1500_gb_15_windows
|
||||
owner:
|
||||
mail: "core@anyscale.com"
|
||||
slack: "@Chen Shen"
|
||||
|
||||
cluster:
|
||||
app_config: pipelined_ingestion_app.yaml
|
||||
compute_template: pipelined_ingestion_compute.yaml
|
||||
|
||||
run:
|
||||
timeout: 4800
|
||||
prepare: python wait_cluster.py 21 2400
|
||||
script: python pipelined_training.py --epochs 2 --num-windows 15 --num-files 915 --debug
|
||||
|
|
17
release/nightly_tests/dataset/pipelined_ingestion_app.yaml
Normal file
17
release/nightly_tests/dataset/pipelined_ingestion_app.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu"
|
||||
env_vars: {}
|
||||
|
||||
python:
|
||||
pip_packages: []
|
||||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
- pip uninstall -y numpy ray || true
|
||||
- sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy
|
||||
- pip install numpy || true
|
||||
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
|
||||
- pip install -U git+https://github.com/ray-project/ray_shuffling_data_loader.git@add-embedding-model
|
||||
- pip install ray[default]
|
||||
- pip install pyarrow
|
||||
- pip install torch torchvision
|
||||
- HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install -U git+https://github.com/horovod/horovod.git
|
|
@ -0,0 +1,29 @@
|
|||
cloud_id: cld_17WvYIBBkdgLwEUNcLeRAE
|
||||
region: us-west-2
|
||||
|
||||
max_workers: 999
|
||||
|
||||
aws:
|
||||
IamInstanceProfile: {"Name": "ray-autoscaler-v1"}
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 500
|
||||
|
||||
head_node_type:
|
||||
name: head_node
|
||||
instance_type: i3.8xlarge
|
||||
|
||||
worker_node_types:
|
||||
- name: memory_node
|
||||
instance_type: i3.8xlarge
|
||||
min_workers: 16
|
||||
max_workers: 16
|
||||
use_spot: false
|
||||
- name: gpu_node
|
||||
instance_type: i3.8xlarge
|
||||
min_workers: 4
|
||||
max_workers: 4
|
||||
use_spot: false
|
||||
resources:
|
||||
gpu: 4
|
307
release/nightly_tests/dataset/pipelined_training.py
Normal file
307
release/nightly_tests/dataset/pipelined_training.py
Normal file
|
@ -0,0 +1,307 @@
|
|||
from collections import OrderedDict
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import ray
|
||||
import time
|
||||
import timeit
|
||||
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import torch
|
||||
import horovod.torch as hvd
|
||||
from horovod.ray import RayExecutor
|
||||
|
||||
from ray_shuffling_data_loader.data_generation import DATA_SPEC
|
||||
from ray_shuffling_data_loader.embedding_model import MyModel, annotation, \
|
||||
huber_loss
|
||||
from ray.data.dataset_pipeline import DatasetPipeline
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="Dataset ingestion Example")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=250000,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 10)")
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", default=False, help="disables hvd")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
metavar="S",
|
||||
help="random seed (default: 42)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help=("how many batches to wait before logging training "
|
||||
"status"))
|
||||
parser.add_argument("--num-workers", type=int, default=16)
|
||||
parser.add_argument("--mock-train-step-time", type=float, default=1.0)
|
||||
parser.add_argument("--num-files", type=int, default=30)
|
||||
parser.add_argument("--num-windows", type=int, default=1)
|
||||
|
||||
SIZE_50_G = 30 # 49.17GB
|
||||
SIZE_100_G = 62 # 101.62GB
|
||||
SIZE_500_G = 305 # 499.93GB
|
||||
SIZE_1500_G = 915
|
||||
|
||||
|
||||
def construct_optimizers(model):
|
||||
sparse_params = []
|
||||
dense_params = []
|
||||
for k, v in model.named_parameters():
|
||||
if "input.embeddings.embeddings" in k:
|
||||
sparse_params.append((k, v))
|
||||
else:
|
||||
dense_params.append((k, v))
|
||||
|
||||
optimizers = []
|
||||
if len(dense_params) > 0:
|
||||
opt = optim.Adam([v for _, v in dense_params], lr=0.001)
|
||||
opt = hvd.DistributedOptimizer(opt, dense_params)
|
||||
optimizers.append(opt)
|
||||
if len(sparse_params) > 0:
|
||||
opt = optim.SparseAdam([v for _, v in sparse_params], lr=0.001)
|
||||
opt = hvd.DistributedOptimizer(opt, sparse_params)
|
||||
optimizers.append(opt)
|
||||
|
||||
if hvd.rank() == 0:
|
||||
print(optimizers)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def train_main(args, splits):
|
||||
# Horovod: initialize library.
|
||||
hvd.init()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Horovod: pin GPU to local rank.
|
||||
torch.cuda.set_device(hvd.local_rank())
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
# Horovod: limit # of CPU threads to be used per worker.
|
||||
torch.set_num_threads(1)
|
||||
rank = hvd.rank()
|
||||
|
||||
model = MyModel(annotation, use_bn=False)
|
||||
# By default, Adasum doesn"t need scaling up learning rate.
|
||||
if torch.cuda.is_available():
|
||||
# Move model to GPU.
|
||||
model.cuda()
|
||||
|
||||
optimizers = construct_optimizers(model)
|
||||
loss_function = huber_loss
|
||||
# Horovod: broadcast parameters & optimizer state.
|
||||
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
||||
for opt in optimizers:
|
||||
hvd.broadcast_optimizer_state(opt, root_rank=0)
|
||||
|
||||
def _train(epoch, train_dataset):
|
||||
model.train()
|
||||
# Horovod: set epoch to sampler for shuffling.
|
||||
# train_dataset.set_epoch(epoch)
|
||||
start_epoch = timeit.default_timer()
|
||||
last_batch_time = start_epoch
|
||||
batch_wait_times = []
|
||||
for batch_idx, (data, target) in enumerate(train_dataset):
|
||||
batch_wait_times.append(timeit.default_timer() - last_batch_time)
|
||||
if torch.cuda.is_available():
|
||||
data = [t.cuda() for t in data]
|
||||
target = target.cuda()
|
||||
for opt in optimizers:
|
||||
opt.zero_grad()
|
||||
batch = OrderedDict()
|
||||
batch["embeddings"] = OrderedDict()
|
||||
batch["one_hot"] = OrderedDict()
|
||||
for name, tensor in zip(annotation["embeddings"], data):
|
||||
batch["embeddings"][name] = tensor
|
||||
hot0, hot1 = data[-2:]
|
||||
batch["one_hot"]["hot0"] = hot0
|
||||
batch["one_hot"]["hot1"] = hot1
|
||||
|
||||
batch_pred = model(batch)
|
||||
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print(
|
||||
f"Processing batch {batch_idx} in epoch {epoch} on worker "
|
||||
f"{rank}.")
|
||||
time.sleep(args.mock_train_step_time)
|
||||
loss = loss_function(batch_pred, target, delta=60)
|
||||
loss.mean().backward()
|
||||
for opt in optimizers:
|
||||
opt.step()
|
||||
|
||||
last_batch_time = timeit.default_timer()
|
||||
epoch_duration = timeit.default_timer() - start_epoch
|
||||
avg_batch_wait_time = np.mean(batch_wait_times)
|
||||
std_batch_wait_time = np.std(batch_wait_times)
|
||||
max_batch_wait_time = np.max(batch_wait_times)
|
||||
min_batch_wait_time = np.min(batch_wait_times)
|
||||
print(f"\nEpoch {epoch}, worker {rank} stats over "
|
||||
f"{len(batch_wait_times)} steps: {epoch_duration:.3f}")
|
||||
print(f"Mean batch wait time: {avg_batch_wait_time:.3f}s +- "
|
||||
f"{std_batch_wait_time}")
|
||||
print(f"Max batch wait time: {max_batch_wait_time:.3f}s")
|
||||
print(f"Min batch wait time: {min_batch_wait_time:.3f}s")
|
||||
return batch_wait_times
|
||||
|
||||
print(f"Starting training on worker {rank}.")
|
||||
batch_wait_times = []
|
||||
for epoch, split_ds in enumerate(splits[rank].iter_datasets()):
|
||||
train_dataset = create_torch_iterator(split_ds, args.batch_size, rank)
|
||||
new_batch_times = _train(epoch, train_dataset)
|
||||
new_batch_times.pop(0)
|
||||
batch_wait_times.extend(new_batch_times)
|
||||
print(f"Done training on worker {rank}.")
|
||||
avg_batch_wait_time = np.mean(batch_wait_times)
|
||||
std_batch_wait_time = np.std(batch_wait_times)
|
||||
max_batch_wait_time = np.max(batch_wait_times)
|
||||
min_batch_wait_time = np.min(batch_wait_times)
|
||||
print(f"\nWorker {rank} training stats over {args.epochs} epochs:")
|
||||
print(f"Mean batch wait time: {avg_batch_wait_time:.3f}s +- "
|
||||
f"{std_batch_wait_time}")
|
||||
print(f"Max batch wait time: {max_batch_wait_time:.3f}s")
|
||||
print(f"Min batch wait time: {min_batch_wait_time:.3f}s")
|
||||
|
||||
|
||||
######################################################
|
||||
|
||||
numpy_to_torch_dtype = {
|
||||
np.bool: torch.bool,
|
||||
np.uint8: torch.uint8,
|
||||
np.int8: torch.int8,
|
||||
np.int16: torch.int16,
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
np.float64: torch.float64,
|
||||
np.complex64: torch.complex64,
|
||||
np.complex128: torch.complex128
|
||||
}
|
||||
|
||||
|
||||
def create_torch_iterator(split, batch_size, rank=None):
|
||||
print(f"Creating Torch shuffling dataset for worker {rank} with "
|
||||
f"{batch_size} batch size.")
|
||||
feature_columns = list(DATA_SPEC.keys())
|
||||
feature_types = [
|
||||
numpy_to_torch_dtype[dtype] for _, _, dtype in DATA_SPEC.values()
|
||||
]
|
||||
label_column = feature_columns.pop()
|
||||
label_type = feature_types.pop()
|
||||
|
||||
torch_iterator = split.to_torch(
|
||||
label_column=label_column,
|
||||
feature_columns=feature_columns,
|
||||
label_column_dtype=label_type,
|
||||
feature_column_dtypes=feature_types,
|
||||
batch_size=batch_size,
|
||||
# prefetch_blocks: int = 0,
|
||||
# drop_last: bool = False
|
||||
)
|
||||
return torch_iterator
|
||||
|
||||
|
||||
def create_dataset(files, num_workers=4, epochs=50, num_windows=1):
|
||||
if num_windows > 1:
|
||||
num_rows = ray.data.read_parquet(
|
||||
files, _spread_resource_prefix="node:").count(
|
||||
) # This should only read Parquet metadata.
|
||||
file_splits = np.array_split(files, num_windows)
|
||||
|
||||
class Windower:
|
||||
def __init__(self):
|
||||
self.i = 0
|
||||
self.iterations = epochs * num_windows
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.i >= self.iterations:
|
||||
raise StopIteration()
|
||||
split = file_splits[self.i % num_windows]
|
||||
self.i += 1
|
||||
return lambda: ray.data.read_parquet(
|
||||
list(split), _spread_resource_prefix="node:")
|
||||
|
||||
pipe = DatasetPipeline.from_iterable(Windower())
|
||||
split_indices = [
|
||||
i * num_rows // num_windows // num_workers
|
||||
for i in range(1, num_workers)
|
||||
]
|
||||
pipe = pipe.random_shuffle(_spread_resource_prefix="node:")
|
||||
pipe_shards = pipe.split_at_indices(split_indices)
|
||||
else:
|
||||
ds = ray.data.read_parquet(files, _spread_resource_prefix="node:")
|
||||
pipe = ds.repeat(epochs)
|
||||
pipe = pipe.random_shuffle(_spread_resource_prefix="node:")
|
||||
pipe_shards = pipe.split(num_workers, equal=True)
|
||||
return pipe_shards
|
||||
|
||||
|
||||
@ray.remote
|
||||
def consume(split, rank=None, batch_size=None):
|
||||
torch_iterator = create_torch_iterator(
|
||||
split, batch_size=batch_size, rank=rank)
|
||||
for i, (x, y) in enumerate(torch_iterator):
|
||||
if i % 10 == 0:
|
||||
print(i)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
import ray
|
||||
print("Connecting to Ray cluster...")
|
||||
ray.init(address="auto")
|
||||
|
||||
num = args.num_files
|
||||
|
||||
files = [
|
||||
f"s3://shuffling-data-loader-benchmarks/data/r10_000_000_000-f1000"
|
||||
f"/input_data_{i}.parquet.snappy" for i in range(args.num_files)
|
||||
]
|
||||
|
||||
start = time.time()
|
||||
|
||||
splits = create_dataset(
|
||||
files,
|
||||
num_workers=args.num_workers,
|
||||
epochs=args.epochs,
|
||||
num_windows=args.num_windows)
|
||||
|
||||
if args.debug:
|
||||
tasks = [
|
||||
consume.options(num_gpus=1).remote(
|
||||
split, rank=idx, batch_size=args.batch_size)
|
||||
for idx, split in enumerate(splits)
|
||||
]
|
||||
ray.get(tasks)
|
||||
else:
|
||||
print("Create Ray executor")
|
||||
settings = RayExecutor.create_settings(timeout_s=30)
|
||||
executor = RayExecutor(
|
||||
settings, num_workers=args.num_workers, use_gpu=True)
|
||||
executor.start()
|
||||
executor.run(train_main, args=[args, splits])
|
||||
executor.shutdown()
|
||||
|
||||
delta = time.time() - start
|
||||
print(f"success! total time {delta}")
|
||||
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
|
||||
f.write(json.dumps({"ingest_time": delta, "success": 1}))
|
17
release/nightly_tests/dataset/pipelined_training_app.yaml
Normal file
17
release/nightly_tests/dataset/pipelined_training_app.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu"
|
||||
env_vars: {}
|
||||
|
||||
python:
|
||||
pip_packages: []
|
||||
conda_packages: []
|
||||
|
||||
post_build_cmds:
|
||||
- pip uninstall -y numpy ray || true
|
||||
- sudo rm -rf /home/ray/anaconda3/lib/python3.7/site-packages/numpy
|
||||
- pip install numpy || true
|
||||
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
|
||||
- pip install -U git+https://github.com/ray-project/ray_shuffling_data_loader.git@add-embedding-model
|
||||
- pip install ray[default]
|
||||
- pip install pyarrow
|
||||
- pip install torch torchvision
|
||||
- HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install -U git+https://github.com/horovod/horovod.git
|
|
@ -0,0 +1,27 @@
|
|||
cloud_id: cld_17WvYIBBkdgLwEUNcLeRAE
|
||||
region: us-west-2
|
||||
|
||||
max_workers: 999
|
||||
|
||||
aws:
|
||||
IamInstanceProfile: {"Name": "ray-autoscaler-v1"}
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
VolumeSize: 500
|
||||
|
||||
head_node_type:
|
||||
name: head_node
|
||||
instance_type: i3.8xlarge
|
||||
|
||||
worker_node_types:
|
||||
- name: memory_node
|
||||
instance_type: i3.8xlarge
|
||||
min_workers: 10
|
||||
max_workers: 10
|
||||
use_spot: false
|
||||
- name: gpu_node
|
||||
instance_type: p3.8xlarge
|
||||
min_workers: 4
|
||||
max_workers: 4
|
||||
use_spot: false
|
|
@ -1,9 +1,49 @@
|
|||
import ray
|
||||
import argparse
|
||||
import time
|
||||
|
||||
ray.init()
|
||||
import ray
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"num_nodes",
|
||||
type=int,
|
||||
help="Wait for this number of nodes (includes head)")
|
||||
|
||||
parser.add_argument(
|
||||
"max_time_s", type=int, help="Wait for this number of seconds")
|
||||
|
||||
parser.add_argument(
|
||||
"--feedback_interval_s",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Wait for this number of seconds")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
curr_nodes = 0
|
||||
start = time.time()
|
||||
next_feedback = start
|
||||
max_time = start + args.max_time_s
|
||||
while not curr_nodes >= args.num_nodes:
|
||||
now = time.time()
|
||||
|
||||
if now >= max_time:
|
||||
raise RuntimeError(
|
||||
f"Maximum wait time reached, but only "
|
||||
f"{curr_nodes}/{args.num_nodes} nodes came up. Aborting.")
|
||||
|
||||
if now >= next_feedback:
|
||||
passed = now - start
|
||||
print(f"Waiting for more nodes to come up: "
|
||||
f"{curr_nodes}/{args.num_nodes} "
|
||||
f"({passed:.0f} seconds passed)")
|
||||
next_feedback = now + args.feedback_interval_s
|
||||
|
||||
while ray.cluster_resources().get("GPU", 0) != 2:
|
||||
print("Waiting for GPUs {}/2".format(ray.cluster_resources().get(
|
||||
"GPU", 400)))
|
||||
time.sleep(5)
|
||||
curr_nodes = len(ray.nodes())
|
||||
|
||||
passed = time.time() - start
|
||||
print(f"Cluster is up: {curr_nodes}/{args.num_nodes} nodes online after "
|
||||
f"{passed:.0f} seconds")
|
||||
|
|
Loading…
Add table
Reference in a new issue