mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RaySGD] Convert the head worker to a local model (#7746)
Why are these changes needed? Running a worker on head (locally, not as a Ray actor) allows for easier handling of stateful stuff like logging and for easier debugging.
This commit is contained in:
parent
875309fc48
commit
7b27ce2b23
10 changed files with 341 additions and 301 deletions
|
@ -41,6 +41,7 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
|||
|
||||
val_metrics = trainer.validate(num_steps=1)
|
||||
assert val_metrics[BATCH_COUNT] == 1
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
|
@ -62,6 +63,7 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
|||
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
|
||||
assert validation_loss2 <= validation_loss1, (validation_loss2,
|
||||
validation_loss1)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
|
@ -278,6 +280,7 @@ def test_split_batch(ray_start_2_cpus):
|
|||
assert trainer.config[BATCH_SIZE] == (batch_size - 1)
|
||||
assert stats[NUM_SAMPLES] == 600
|
||||
assert stats[BATCH_COUNT] == (data_size // 20)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_reduce_result(ray_start_2_cpus):
|
||||
|
@ -302,6 +305,7 @@ def test_reduce_result(ray_start_2_cpus):
|
|||
assert len(list_stats) == 2
|
||||
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
|
||||
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
|
@ -388,6 +392,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
|||
assert "mean_score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
assert np.isnan(stats["mean_score"])
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
|
@ -479,6 +484,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
|
|||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
|
@ -490,15 +496,25 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
|||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 3:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
ray.kill(self.remote_workers[0])
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
|
@ -512,6 +528,8 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
|||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
|
@ -522,15 +540,25 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
|||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 1:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
self.remote_workers[0].__ray_kill__()
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
|
@ -548,7 +576,9 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
|||
|
||||
try_test.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert len(trainer1.workers) == 1
|
||||
assert len(trainer1.remote_workers) == 1
|
||||
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
|
@ -560,15 +590,25 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
|||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(*args, **kwargs) for w in self.workers
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 2:
|
||||
time.sleep(1)
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.remote_workers[0].__ray_kill__()
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
|
@ -580,6 +620,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
|||
num_workers=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
USE_FP16 = "__use_fp16__"
|
||||
NUM_STEPS = "__num_steps__"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
BATCH_LOGS_RATE_LIMIT = .2
|
||||
NCCL_TIMEOUT_IN_SECONDS = 10
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
from datetime import timedelta
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -45,11 +50,20 @@ class DistributedTorchRunner(TorchRunner):
|
|||
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(self.backend))
|
||||
|
||||
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
|
||||
logger.debug(
|
||||
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
|
||||
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
||||
|
||||
timeout = timedelta(seconds=NCCL_TIMEOUT_IN_SECONDS)
|
||||
dist.init_process_group(
|
||||
backend=self.backend,
|
||||
init_method=url,
|
||||
rank=world_rank,
|
||||
world_size=world_size)
|
||||
world_size=world_size,
|
||||
timeout=timeout)
|
||||
|
||||
def _setup_training(self):
|
||||
logger.debug("Creating model")
|
||||
|
@ -84,7 +98,8 @@ class DistributedTorchRunner(TorchRunner):
|
|||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
super(DistributedTorchRunner, self)._initialize_dataloaders()
|
||||
|
@ -140,12 +155,60 @@ class DistributedTorchRunner(TorchRunner):
|
|||
for model, model_state_dict in zip(self.models, model_state_dicts):
|
||||
model.module.load_state_dict(model_state_dict)
|
||||
|
||||
# def shutdown(self):
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
# super(DistributedTorchRunner, self).shutdown()
|
||||
# TODO: Temporarily removing since it causes hangs on MacOSX.
|
||||
# However, it seems to be harmless to remove permanently
|
||||
# since the processes are shutdown anyways. This comment can be
|
||||
# removed in a future release if it is still not documented
|
||||
# the stable Pytorch docs.
|
||||
# dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
super(DistributedTorchRunner, self).shutdown()
|
||||
|
||||
|
||||
class _DummyActor:
|
||||
def cuda_devices(self):
|
||||
return os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
|
||||
# This is a bit of a hack. It prevents the reassignment of CUDA_VISIBLE_DEVICES
|
||||
# during a trainer resize. We won't need this if we don't shutdown
|
||||
# all the actors.
|
||||
_dummy_actor = None
|
||||
|
||||
|
||||
class LocalDistributedRunner(DistributedTorchRunner):
|
||||
"""A wrapper for running a distributed Runner on the driver.
|
||||
|
||||
A dummy actor is used to reserve resources on the driver node,
|
||||
as specified by `num_cpus` and `num_gpus`. If the Trainer is already
|
||||
in an actor, we will ignore this resource request.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, num_cpus=None, num_gpus=None, **kwargs):
|
||||
ip = ray.services.get_node_ip_address()
|
||||
|
||||
# Reserve a local GPU or CPU for the local worker
|
||||
# TODO: we should make sure this NEVER dies.
|
||||
|
||||
global _dummy_actor
|
||||
if not self.is_actor() and _dummy_actor is None:
|
||||
_dummy_actor = ray.remote(
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources={"node:" + ip: 0.1})(_DummyActor).remote()
|
||||
|
||||
head_cuda = ray.get(_dummy_actor.cuda_devices.remote())
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = head_cuda
|
||||
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
def shutdown(self, cleanup=True):
|
||||
super(LocalDistributedRunner, self).shutdown()
|
||||
global _dummy_actor
|
||||
if cleanup and _dummy_actor:
|
||||
assert not self.is_actor(), "Actor shouldn't have a dummy actor."
|
||||
ray.kill(_dummy_actor)
|
||||
_dummy_actor = None
|
||||
|
||||
def is_actor(self):
|
||||
actor_id = ray.worker.global_worker.actor_id
|
||||
return actor_id != actor_id.nil()
|
||||
|
|
|
@ -85,14 +85,14 @@ def train_example(num_workers=1,
|
|||
backend="nccl" if use_gpu else "gloo",
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16,
|
||||
tqdm=True)
|
||||
use_tqdm=True)
|
||||
pbar = trange(num_epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
info = {"num_steps": 1} if test_mode else {}
|
||||
info["epoch_idx"] = i
|
||||
info["num_epochs"] = num_epochs
|
||||
# Increase `max_retries` to turn on fault tolerance.
|
||||
stats = trainer1.train(max_retries=0, info=info)
|
||||
stats = trainer1.train(max_retries=1, info=info)
|
||||
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
|
||||
|
||||
print(trainer1.validate())
|
||||
|
|
|
@ -243,7 +243,7 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
|||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
tqdm=True)
|
||||
use_tqdm=True)
|
||||
|
||||
from tabulate import tabulate
|
||||
pbar = trange(5, unit="epoch")
|
||||
|
|
|
@ -3,9 +3,9 @@ cluster_name: sgd-pytorch
|
|||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 0
|
||||
initial_workers: 0
|
||||
max_workers: 0
|
||||
min_workers: 2
|
||||
initial_workers: 2
|
||||
max_workers: 2
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
|
@ -27,11 +27,13 @@ auth:
|
|||
# ssh_private_key: ...
|
||||
|
||||
head_node:
|
||||
InstanceType: p3dn.24xlarge
|
||||
InstanceType: p3.2xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
# KeyName: ...
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
SpotOptions:
|
||||
BlockDurationMinutes: 360
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
|
@ -41,11 +43,13 @@ head_node:
|
|||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.16xlarge
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0698bcaf8bd9ef56d
|
||||
# KeyName: ...
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
SpotOptions:
|
||||
BlockDurationMinutes: 360
|
||||
BlockDeviceMappings:
|
||||
- DeviceName: /dev/sda1
|
||||
Ebs:
|
||||
|
@ -65,7 +69,7 @@ setup_commands:
|
|||
|
||||
# Installing this without -U to make sure we don't replace the existing Ray installation
|
||||
- pip install ray[rllib]
|
||||
- pip install -U ipdb torch torchvision
|
||||
- pip install -U ipdb torch torchvision tqdm
|
||||
# Install Apex
|
||||
- rm -rf apex || true
|
||||
- git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
|
||||
|
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import torch
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP
|
||||
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
from ray.util.sgd import utils
|
||||
|
||||
|
@ -49,6 +49,7 @@ class TorchRunner:
|
|||
training_operator_cls=None,
|
||||
config=None,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch"):
|
||||
self.model_creator = model_creator
|
||||
|
@ -68,6 +69,7 @@ class TorchRunner:
|
|||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.apex_args = apex_args or {}
|
||||
if use_fp16 and not amp:
|
||||
raise ImportError(
|
||||
|
@ -133,9 +135,6 @@ class TorchRunner:
|
|||
self.models, self.optimizers = amp.initialize(
|
||||
self.models, self.optimizers, **self.apex_args)
|
||||
|
||||
def set_reporters(self, reporters):
|
||||
return self.training_operator.set_reporters(reporters)
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
logger.debug("Creating model")
|
||||
|
@ -163,7 +162,8 @@ class TorchRunner:
|
|||
validation_loader=self.validation_loader,
|
||||
world_rank=0,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
|
@ -180,6 +180,7 @@ class TorchRunner:
|
|||
self._toggle_profiling(profile=profile)
|
||||
|
||||
info.update({
|
||||
NUM_STEPS: num_steps,
|
||||
USE_FP16: self.use_fp16,
|
||||
SCHEDULER_STEP: self.scheduler_step_freq
|
||||
})
|
||||
|
|
|
@ -4,22 +4,18 @@ import logging
|
|||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
import asyncio
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner)
|
||||
DistributedTorchRunner, LocalDistributedRunner)
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP,
|
||||
BATCH_LOGS_RATE_LIMIT)
|
||||
from ray.util.sgd.torch.tqdm_handler import TqdmHandler
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
|
@ -149,7 +145,7 @@ class TorchTrainer:
|
|||
use_gpu=False,
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
tqdm=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch",
|
||||
num_replicas=None,
|
||||
|
@ -212,6 +208,7 @@ class TorchTrainer:
|
|||
self.max_replicas = num_workers
|
||||
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
|
||||
if apex_args and not isinstance(apex_args, dict):
|
||||
raise ValueError("apex_args needs to be a dict object.")
|
||||
|
@ -221,10 +218,6 @@ class TorchTrainer:
|
|||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
self.handlers = []
|
||||
if tqdm:
|
||||
self.handlers.append(TqdmHandler())
|
||||
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
|
@ -256,68 +249,71 @@ class TorchTrainer:
|
|||
batch_size_per_worker = self._configure_and_split_batch(num_workers)
|
||||
if batch_size_per_worker:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_workers = []
|
||||
|
||||
if num_workers == 1:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(TorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq,
|
||||
)
|
||||
]
|
||||
# Start local worker
|
||||
self.local_worker = TorchRunner(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(self.workers[0].setup.remote())
|
||||
ray.get(self.workers[0].set_reporters.remote(
|
||||
[h.create_reporter() for h in self.handlers]))
|
||||
|
||||
self.local_worker.setup()
|
||||
else:
|
||||
params = dict(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
backend=self.backend,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
num_cpus=1, num_gpus=int(self.use_gpu), **params)
|
||||
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(DistributedTorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
backend=self.backend,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
for i in range(num_workers)
|
||||
self.remote_workers = [
|
||||
RemoteRunner.remote(**params) for i in range(num_workers - 1)
|
||||
]
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
ip = ray.get(self.workers[0].get_node_ip.remote())
|
||||
port = ray.get(self.workers[0].find_free_port.remote())
|
||||
ip = ray.services.get_node_ip_address()
|
||||
port = self.local_worker.find_free_port()
|
||||
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
|
||||
remote_setups = [
|
||||
worker.setup.remote(address, i + 1, num_workers)
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup(address, 0, num_workers)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get([
|
||||
worker.setup.remote(address, i, len(self.workers))
|
||||
for i, worker in enumerate(self.workers)
|
||||
])
|
||||
ray.get([
|
||||
w.set_reporters.remote(
|
||||
[h.create_reporter() for h in self.handlers])
|
||||
for w in self.workers
|
||||
])
|
||||
ray.get(remote_setups)
|
||||
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
|
@ -374,9 +370,6 @@ class TorchTrainer:
|
|||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
for h in self.handlers:
|
||||
h.record_train_info(info, num_steps)
|
||||
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
# Fault handling
|
||||
|
@ -386,14 +379,13 @@ class TorchTrainer:
|
|||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info(
|
||||
"Retrying training step with %d workers." % len(self.workers))
|
||||
logger.info("Retrying training step with %d workers." %
|
||||
(len(self.remote_workers) + 1))
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
worker_stats = ray.get(worker_stats)
|
||||
if reduce_results:
|
||||
return self._process_stats(worker_stats)
|
||||
else:
|
||||
|
@ -413,42 +405,30 @@ class TorchTrainer:
|
|||
stats[stat_key] = worker_stats[0][stat_key]
|
||||
return stats
|
||||
|
||||
def _train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
batch_logs_handler=None):
|
||||
worker_trains = [
|
||||
w.train_epoch.remote(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
for w in self.workers
|
||||
def _train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if not self.handlers:
|
||||
success = check_for_failure(worker_trains)
|
||||
return success, worker_trains
|
||||
|
||||
unfinished = worker_trains
|
||||
try:
|
||||
while len(unfinished) > 0:
|
||||
finished, unfinished = ray.wait(
|
||||
unfinished, timeout=BATCH_LOGS_RATE_LIMIT)
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError as err:
|
||||
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
if "NCCL" in err.args[0]: # there is no specific error message
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
|
||||
# throw errors on agent failure
|
||||
finished = ray.get(finished)
|
||||
raise err
|
||||
|
||||
futures = [h.update() for h in self.handlers]
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.wait(futures))
|
||||
loop.close()
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return True, worker_trains
|
||||
except RayActorError as exc:
|
||||
logger.exception(str(exc))
|
||||
return False, worker_trains
|
||||
return success, None
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
|
@ -460,7 +440,9 @@ class TorchTrainer:
|
|||
A list of objects returned by ``fn`` on each worker.
|
||||
|
||||
"""
|
||||
return ray.get([w.apply.remote(fn) for w in self.workers])
|
||||
remote_calls = [w.apply.remote(fn) for w in self.remote_workers]
|
||||
local_call = self.local_worker.apply(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
|
@ -473,7 +455,11 @@ class TorchTrainer:
|
|||
A list of objects returned by ``fn`` on each operator.
|
||||
|
||||
"""
|
||||
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
|
||||
remote_calls = [
|
||||
w.apply_operator.remote(fn) for w in self.remote_workers
|
||||
]
|
||||
local_call = self.local_worker.apply_operator(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Evaluates the model on the validation data set.
|
||||
|
@ -491,12 +477,15 @@ class TorchTrainer:
|
|||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``.
|
||||
"""
|
||||
worker_stats = ray.get([
|
||||
w.validate.remote(num_steps=num_steps, profile=profile, info=info)
|
||||
for w in self.workers
|
||||
])
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
return self._process_stats(worker_stats)
|
||||
remote_worker_stats = [
|
||||
w.validate.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
local_worker_stats = self.local_worker.validate(**params)
|
||||
|
||||
return self._process_stats([local_worker_stats] +
|
||||
ray.get(remote_worker_stats))
|
||||
|
||||
def update_scheduler(self, metric):
|
||||
"""Calls ``scheduler.step(metric)`` on all schedulers.
|
||||
|
@ -509,7 +498,7 @@ class TorchTrainer:
|
|||
def get_model(self):
|
||||
"""Returns the learned model(s)."""
|
||||
models = self.model_creator(self.config)
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
state = self.local_worker.get_state()
|
||||
if len(state["models"]) == 1:
|
||||
models.load_state_dict(state["models"][0])
|
||||
else:
|
||||
|
@ -517,6 +506,18 @@ class TorchTrainer:
|
|||
model.load_state_dict(state_dict)
|
||||
return models
|
||||
|
||||
def state_dict(self):
|
||||
return self.local_worker.get_state()
|
||||
|
||||
def load_state_dict(self, state):
|
||||
state_id = ray.put(state)
|
||||
|
||||
remote_calls = [
|
||||
worker.set_state.remote(state_id) for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.set_state(state)
|
||||
ray.get(remote_calls)
|
||||
|
||||
def save(self, checkpoint):
|
||||
"""Saves the model(s) to the provided checkpoint.
|
||||
|
||||
|
@ -526,8 +527,7 @@ class TorchTrainer:
|
|||
Returns:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
"""
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
torch.save(state, checkpoint)
|
||||
torch.save(self.state_dict(), checkpoint)
|
||||
return checkpoint
|
||||
|
||||
def restore(self, checkpoint):
|
||||
|
@ -537,36 +537,67 @@ class TorchTrainer:
|
|||
checkpoint (str): Path to target checkpoint file.
|
||||
"""
|
||||
state = torch.load(checkpoint)
|
||||
state_id = ray.put(state)
|
||||
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
||||
self.load_state_dict(state)
|
||||
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
if not force:
|
||||
cleanup = [worker.shutdown.remote() for worker in self.workers]
|
||||
ray.get(cleanup)
|
||||
[worker.__ray_terminate__.remote() for worker in self.workers]
|
||||
else:
|
||||
for worker in self.workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
worker.__ray_kill__()
|
||||
cleanup = [
|
||||
worker.shutdown.remote() for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.shutdown()
|
||||
try:
|
||||
ray.get(cleanup)
|
||||
[
|
||||
worker.__ray_terminate__.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
except RayActorError:
|
||||
logger.warning(
|
||||
"Failed to shutdown gracefully, forcing a shutdown.")
|
||||
|
||||
self.workers = []
|
||||
for worker in self.remote_workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
ray.kill(worker)
|
||||
else:
|
||||
self.local_worker.shutdown()
|
||||
for worker in self.remote_workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
ray.kill(worker)
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_workers = []
|
||||
|
||||
def _reset(self):
|
||||
"""Terminates models without giving up local resource reservation."""
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
for worker in self.remote_workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
ray.kill(worker)
|
||||
self.local_worker = None
|
||||
self.remote_workers = []
|
||||
|
||||
def _check_potential_remote_workers_size(self):
|
||||
# ASSUME 1 GPU + 1 CPU is already reserved for the local worker
|
||||
remote_resources = ray.available_resources()
|
||||
max_remote_workers = self.max_replicas - 1
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("CPU", 0), max_remote_workers)
|
||||
if self.use_gpu:
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("GPU", 0), new_remote_workers)
|
||||
return new_remote_workers
|
||||
|
||||
def _resize_workers(self, checkpoint, max_retries=10):
|
||||
# check available resources
|
||||
self.shutdown(force=True)
|
||||
self._reset()
|
||||
assert checkpoint, "Cannot restore without checkpoint."
|
||||
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
resources = ray.available_resources()
|
||||
new_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
new_workers = min(resources.get("GPU", 0), new_workers)
|
||||
if new_workers:
|
||||
new_remote_workers = self._check_potential_remote_workers_size()
|
||||
if new_remote_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_workers))
|
||||
self._start_workers(int(new_remote_workers) + 1)
|
||||
self.restore(checkpoint)
|
||||
return
|
||||
else:
|
||||
|
@ -578,26 +609,24 @@ class TorchTrainer:
|
|||
|
||||
def _should_resize(self):
|
||||
"""Returns True if past cooldown and exists resources to scale up."""
|
||||
worker_gap = self.max_replicas - len(self.workers)
|
||||
worker_gap = self.max_replicas - 1 - len(self.remote_workers)
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
resources = ray.available_resources()
|
||||
potential_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
potential_workers = min(
|
||||
resources.get("GPU", 0), potential_workers)
|
||||
return potential_workers > 0
|
||||
# Assume 1 resource is already reserved for local worker.
|
||||
potential_remote_size = self._check_potential_remote_workers_size()
|
||||
return potential_remote_size > 0
|
||||
return False
|
||||
|
||||
|
||||
class TorchTrainable(Trainable):
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
remote_worker_count = config["num_workers"] - 1
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=config["num_workers"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_workers"])
|
||||
cpu=1,
|
||||
gpu=int(config["use_gpu"]),
|
||||
extra_cpu=int(remote_worker_count),
|
||||
extra_gpu=int(int(config["use_gpu"]) * remote_worker_count))
|
||||
|
||||
def _setup(self, config):
|
||||
self._trainer = TorchTrainer(**config)
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
import asyncio
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _ReporterActor:
|
||||
def __init__(self):
|
||||
# we need the new_data field to allow sending back None as the legs
|
||||
self._logs = {"new_data": False, "data": None}
|
||||
self._setup = {"new_data": False, "data": None}
|
||||
|
||||
def _send_setup(self, data):
|
||||
self._setup = {"new_data": True, "data": data}
|
||||
|
||||
def _send_logs(self, data):
|
||||
self._logs = {"new_data": True, "data": data}
|
||||
|
||||
def _read_logs(self):
|
||||
res = self._logs
|
||||
|
||||
self._logs = {"new_data": False, "data": None}
|
||||
|
||||
return res
|
||||
|
||||
def _read_setup(self):
|
||||
res = self._setup
|
||||
|
||||
self._setup = {"new_data": False, "data": None}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class TqdmReporter:
|
||||
def __init__(self, actor):
|
||||
self.actor = actor
|
||||
|
||||
self.last_packet_time = 0
|
||||
|
||||
def _send_setup(self, packet):
|
||||
ray.get(self.actor._send_setup.remote(packet))
|
||||
|
||||
def _send_logs(self, packet):
|
||||
cur_time = time.monotonic()
|
||||
if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT:
|
||||
return
|
||||
|
||||
self.last_packet_time = cur_time
|
||||
ray.get(self.actor._send_logs.remote(packet))
|
||||
|
||||
def on_epoch_begin(self, info, training_op):
|
||||
if training_op.world_rank != 0:
|
||||
return
|
||||
|
||||
self.last_packet_time = 0
|
||||
|
||||
self._send_setup({"loader_len": len(training_op.train_loader)})
|
||||
|
||||
def on_batch_end(self, batch_info, metrics, training_op):
|
||||
if training_op.world_rank != 0:
|
||||
return
|
||||
|
||||
pbar_metrics = {}
|
||||
if "train_loss" in metrics:
|
||||
pbar_metrics["loss"] = metrics["train_loss"]
|
||||
|
||||
self._send_logs({
|
||||
"batch_idx": batch_info["batch_idx"],
|
||||
"pbar_metrics": pbar_metrics
|
||||
})
|
||||
|
||||
|
||||
class TqdmHandler:
|
||||
def __init__(self):
|
||||
self.batch_pbar = None
|
||||
self.reporter_actor = _ReporterActor.remote()
|
||||
|
||||
def create_reporter(self):
|
||||
return TqdmReporter(self.reporter_actor)
|
||||
|
||||
def handle_setup_packet(self, packet):
|
||||
n = self.num_steps
|
||||
if n is None:
|
||||
n = packet["loader_len"]
|
||||
|
||||
desc = ""
|
||||
if self.train_info is not None and "epoch_idx" in self.train_info:
|
||||
if "num_epochs" in self.train_info:
|
||||
desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1,
|
||||
self.train_info["num_epochs"])
|
||||
else:
|
||||
desc = "{}e".format(self.train_info["epoch_idx"] + 1)
|
||||
|
||||
self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False)
|
||||
|
||||
def handle_logs_packet(self, packet):
|
||||
self.batch_pbar.n = packet["batch_idx"] + 1
|
||||
self.batch_pbar.set_postfix(packet["pbar_metrics"])
|
||||
|
||||
def record_train_info(self, info, num_steps):
|
||||
self.train_info = info
|
||||
self.num_steps = num_steps
|
||||
|
||||
async def update(self):
|
||||
setup_read, logs_read = await asyncio.gather(
|
||||
self.reporter_actor._read_setup.remote(),
|
||||
self.reporter_actor._read_logs.remote())
|
||||
|
||||
if setup_read["new_data"]:
|
||||
self.handle_setup_packet(setup_read["data"])
|
||||
if logs_read["new_data"]:
|
||||
self.handle_logs_packet(logs_read["data"])
|
|
@ -1,10 +1,11 @@
|
|||
import collections
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
||||
NUM_SAMPLES)
|
||||
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH,
|
||||
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
|
||||
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
|
||||
|
||||
amp = None
|
||||
|
@ -55,7 +56,8 @@ class TrainingOperator:
|
|||
world_rank,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
use_fp16=False):
|
||||
use_fp16=False,
|
||||
use_tqdm=False):
|
||||
# You are not expected to override this method.
|
||||
self._models = models # List of models
|
||||
assert isinstance(models, collections.Iterable), (
|
||||
|
@ -74,6 +76,7 @@ class TrainingOperator:
|
|||
type(schedulers)))
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
self._use_tqdm = use_tqdm
|
||||
self.global_step = 0
|
||||
|
||||
if type(self) is TrainingOperator:
|
||||
|
@ -84,12 +87,8 @@ class TrainingOperator:
|
|||
"TrainingOperator if using multi-scheduler, "
|
||||
"multi-model or multi-optimizer training/validation.")
|
||||
self.timers = TimerCollection()
|
||||
self.reporters = []
|
||||
self.setup(config)
|
||||
|
||||
def set_reporters(self, reporters):
|
||||
self.reporters = reporters
|
||||
|
||||
def _set_timers(self, timers):
|
||||
"""Passes in the timers from the Runner."""
|
||||
self.timers = timers
|
||||
|
@ -142,8 +141,19 @@ class TrainingOperator:
|
|||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
for r in self.reporters:
|
||||
r.on_epoch_begin(info, self)
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
desc = ""
|
||||
if info is not None and "epoch_idx" in info:
|
||||
if "num_epochs" in info:
|
||||
desc = "{}/{}e".format(info["epoch_idx"] + 1,
|
||||
info["num_epochs"])
|
||||
else:
|
||||
desc = "{}e".format(info["epoch_idx"] + 1)
|
||||
_progress_bar = tqdm(
|
||||
total=info[NUM_STEPS] or len(self.train_loader),
|
||||
desc=desc,
|
||||
unit="batch",
|
||||
leave=False)
|
||||
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
|
@ -156,8 +166,10 @@ class TrainingOperator:
|
|||
batch_info.update(info)
|
||||
metrics = self.train_batch(batch, batch_info=batch_info)
|
||||
|
||||
for r in self.reporters:
|
||||
r.on_batch_end(batch_info, metrics, self)
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
_progress_bar.n = batch_idx + 1
|
||||
if "train_loss" in metrics:
|
||||
_progress_bar.set_postfix({"loss": metrics["train_loss"]})
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
|
@ -376,6 +388,11 @@ class TrainingOperator:
|
|||
"""Whether the model and optimizer have been FP16 enabled."""
|
||||
return self._use_fp16
|
||||
|
||||
@property
|
||||
def use_tqdm(self):
|
||||
"""Whether tqdm progress bars are enabled."""
|
||||
return self._use_tqdm
|
||||
|
||||
|
||||
class _TestingOperator(TrainingOperator):
|
||||
def train_epoch(self, iterator, info):
|
||||
|
|
Loading…
Add table
Reference in a new issue