mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[SGD] Add imagenet example CI (#8150)
This commit is contained in:
parent
518ef4c0b3
commit
c2acb7ffe2
5 changed files with 70 additions and 8 deletions
|
@ -40,6 +40,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py
|
||||
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/util/sgd/torch/examples/image_models/train.py --no-gpu --mock-data --smoke-test --ray-num-workers=2 --model mobilenetv3_small_075 data
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --memory-swap=-1 $DOCKER_SHA \
|
||||
python /ray/python/ray/util/sgd/torch/examples/train_example.py
|
||||
|
||||
|
|
|
@ -28,3 +28,4 @@ torch
|
|||
torchvision
|
||||
xgboost
|
||||
zoopt>=0.4.0
|
||||
timm
|
||||
|
|
|
@ -427,10 +427,7 @@ parser.add_argument("--local_rank", default=0, type=int)
|
|||
|
||||
# ray
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
metavar="ADDR",
|
||||
help="Ray cluster address. [default=auto]")
|
||||
"--ray-address", metavar="ADDR", help="Ray cluster address.")
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--ray-num-workers",
|
||||
|
@ -438,6 +435,16 @@ parser.add_argument(
|
|||
default=1,
|
||||
metavar="N",
|
||||
help="Number of Ray replicas to use. [default=1]")
|
||||
parser.add_argument(
|
||||
"--mock-data",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use mocked data for testing. [default=False]")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Only run one step for testing. [default=False]")
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -460,7 +467,7 @@ def parse_args():
|
|||
args.distributed = False # ray SGD handles this (DistributedSampler)
|
||||
args.device = "cuda" # ray should handle this
|
||||
|
||||
if args.no_gpu == 0 and args.prefetcher:
|
||||
if args.no_gpu and args.prefetcher:
|
||||
logging.warning("Prefetcher needs CUDA currently "
|
||||
"(might be a bug in timm). "
|
||||
"Disabling it.")
|
||||
|
|
|
@ -26,13 +26,14 @@ from ray.util.sgd import TorchTrainer
|
|||
# from ray.util.sgd.torch import TrainingOperator
|
||||
|
||||
from ray.util.sgd.torch.examples.image_models.args import parse_args
|
||||
import ray.util.sgd.torch.examples.image_models.util as util
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
args = config["args"]
|
||||
|
||||
model = create_model(
|
||||
"resnet101", # args.model,
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
|
@ -58,6 +59,12 @@ def data_creator(config):
|
|||
|
||||
args = config["args"]
|
||||
|
||||
train_dir = join(args.data, "train")
|
||||
val_dir = join(args.data, "val")
|
||||
|
||||
if args.mock_data:
|
||||
util.mock_data(train_dir, val_dir)
|
||||
|
||||
# todo: verbose should depend on rank
|
||||
data_config = resolve_data_config(vars(args), verbose=True)
|
||||
|
||||
|
@ -137,11 +144,14 @@ def main():
|
|||
},
|
||||
num_workers=args.ray_num_workers)
|
||||
|
||||
if args.smoke_test:
|
||||
args.epochs = 1
|
||||
|
||||
pbar = trange(args.epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
trainer.train()
|
||||
trainer.train(num_steps=1 if args.smoke_test else None)
|
||||
|
||||
val_stats = trainer.validate()
|
||||
val_stats = trainer.validate(num_steps=1 if args.smoke_test else None)
|
||||
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
|
||||
|
||||
trainer.shutdown()
|
||||
|
|
40
python/ray/util/sgd/torch/examples/image_models/util.py
Normal file
40
python/ray/util/sgd/torch/examples/image_models/util.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import random
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
|
||||
def mock_data(train_dir, val_dir):
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
os.makedirs(val_dir, exist_ok=True)
|
||||
|
||||
max_cls_n = 99999999
|
||||
total_classes = 3
|
||||
per_cls = max_cls_n // total_classes
|
||||
|
||||
max_img_n = 99999999
|
||||
total_imgs = 3
|
||||
per_img = max_img_n // total_imgs
|
||||
|
||||
def mock_class(base, n):
|
||||
random_cls = random.randint(per_cls * n, per_cls * n + per_cls)
|
||||
sub_dir = join(base, "n{:08d}".format(random_cls))
|
||||
os.makedirs(sub_dir, exist_ok=True)
|
||||
|
||||
for i in range(total_imgs):
|
||||
random_img = random.randint(per_img * i, per_img * i + per_img)
|
||||
file = join(sub_dir,
|
||||
"ILSVRC2012_val_{:08d}.JPEG".format(random_img))
|
||||
|
||||
PIL.Image.fromarray(np.zeros((375, 500, 3),
|
||||
dtype=np.uint8)).save(file)
|
||||
|
||||
existing_train_cls = len(os.listdir(train_dir))
|
||||
for i in range(existing_train_cls, total_classes):
|
||||
mock_class(train_dir, i)
|
||||
|
||||
existing_val_cls = len(os.listdir(val_dir))
|
||||
for i in range(existing_val_cls, total_classes):
|
||||
mock_class(val_dir, i)
|
Loading…
Add table
Reference in a new issue