[RaySGD] Convert the head worker to a local model (#7746)

Why are these changes needed?

Running a worker on head (locally, not as a Ray actor) allows for easier handling of stateful stuff like logging and for easier debugging.
This commit is contained in:
Maksim Smolin 2020-03-27 20:19:15 -07:00 committed by GitHub
parent 875309fc48
commit 7b27ce2b23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 341 additions and 301 deletions

View file

@ -41,6 +41,7 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
val_metrics = trainer.validate(num_steps=1)
assert val_metrics[BATCH_COUNT] == 1
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@ -62,6 +63,7 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
assert validation_loss2 <= validation_loss1, (validation_loss2,
validation_loss1)
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@ -278,6 +280,7 @@ def test_split_batch(ray_start_2_cpus):
assert trainer.config[BATCH_SIZE] == (batch_size - 1)
assert stats[NUM_SAMPLES] == 600
assert stats[BATCH_COUNT] == (data_size // 20)
trainer.shutdown()
def test_reduce_result(ray_start_2_cpus):
@ -302,6 +305,7 @@ def test_reduce_result(ray_start_2_cpus):
assert len(list_stats) == 2
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@ -388,6 +392,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
assert "mean_score" in stats
assert stats["last_score"] == 0
assert np.isnan(stats["mean_score"])
trainer.shutdown()
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
@ -479,6 +484,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
trainer2.shutdown()
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
@ -490,15 +496,25 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
def step_with_fail(self, **params):
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 3:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
ray.kill(self.remote_workers[0])
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
@ -512,6 +528,8 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
with pytest.raises(RuntimeError):
trainer1.train(max_retries=1)
trainer1.shutdown(force=True)
def test_resize(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
@ -522,15 +540,25 @@ def test_resize(ray_start_2_cpus): # noqa: F811
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
def step_with_fail(self, **params):
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 1:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
self.remote_workers[0].__ray_kill__()
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
@ -548,7 +576,9 @@ def test_resize(ray_start_2_cpus): # noqa: F811
try_test.remote()
trainer1.train(max_retries=1)
assert len(trainer1.workers) == 1
assert len(trainer1.remote_workers) == 1
trainer1.shutdown()
def test_fail_twice(ray_start_2_cpus): # noqa: F811
@ -560,15 +590,25 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, *args, **kwargs):
worker_stats = [
w.train_epoch.remote(*args, **kwargs) for w in self.workers
def step_with_fail(self, **params):
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 2:
time.sleep(1)
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
time.sleep(1) # Make the batch will fail correctly.
self.remote_workers[0].__ray_kill__()
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
@ -580,6 +620,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
num_workers=2)
trainer1.train(max_retries=2)
trainer1.shutdown()
if __name__ == "__main__":

View file

@ -1,7 +1,8 @@
USE_FP16 = "__use_fp16__"
NUM_STEPS = "__num_steps__"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
BATCH_LOGS_RATE_LIMIT = .2
NCCL_TIMEOUT_IN_SECONDS = 10
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}

View file

@ -1,12 +1,17 @@
from datetime import timedelta
import collections
import logging
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner
logger = logging.getLogger(__name__)
@ -45,11 +50,20 @@ class DistributedTorchRunner(TorchRunner):
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ:
logger.debug(
"Setting NCCL_BLOCKING_WAIT for detecting node failure. "
"To override this behavior, you can set NCCL_BLOCKING_WAIT=0.")
os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout = timedelta(seconds=NCCL_TIMEOUT_IN_SECONDS)
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)
world_size=world_size,
timeout=timeout)
def _setup_training(self):
logger.debug("Creating model")
@ -84,7 +98,8 @@ class DistributedTorchRunner(TorchRunner):
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
def _initialize_dataloaders(self):
super(DistributedTorchRunner, self)._initialize_dataloaders()
@ -140,12 +155,60 @@ class DistributedTorchRunner(TorchRunner):
for model, model_state_dict in zip(self.models, model_state_dicts):
model.module.load_state_dict(model_state_dict)
# def shutdown(self):
def shutdown(self):
"""Attempts to shut down the worker."""
# super(DistributedTorchRunner, self).shutdown()
# TODO: Temporarily removing since it causes hangs on MacOSX.
# However, it seems to be harmless to remove permanently
# since the processes are shutdown anyways. This comment can be
# removed in a future release if it is still not documented
# the stable Pytorch docs.
# dist.destroy_process_group()
dist.destroy_process_group()
super(DistributedTorchRunner, self).shutdown()
class _DummyActor:
def cuda_devices(self):
return os.environ["CUDA_VISIBLE_DEVICES"]
# This is a bit of a hack. It prevents the reassignment of CUDA_VISIBLE_DEVICES
# during a trainer resize. We won't need this if we don't shutdown
# all the actors.
_dummy_actor = None
class LocalDistributedRunner(DistributedTorchRunner):
"""A wrapper for running a distributed Runner on the driver.
A dummy actor is used to reserve resources on the driver node,
as specified by `num_cpus` and `num_gpus`. If the Trainer is already
in an actor, we will ignore this resource request.
"""
def __init__(self, *args, num_cpus=None, num_gpus=None, **kwargs):
ip = ray.services.get_node_ip_address()
# Reserve a local GPU or CPU for the local worker
# TODO: we should make sure this NEVER dies.
global _dummy_actor
if not self.is_actor() and _dummy_actor is None:
_dummy_actor = ray.remote(
num_cpus=num_cpus,
num_gpus=num_gpus,
resources={"node:" + ip: 0.1})(_DummyActor).remote()
head_cuda = ray.get(_dummy_actor.cuda_devices.remote())
os.environ["CUDA_VISIBLE_DEVICES"] = head_cuda
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
def shutdown(self, cleanup=True):
super(LocalDistributedRunner, self).shutdown()
global _dummy_actor
if cleanup and _dummy_actor:
assert not self.is_actor(), "Actor shouldn't have a dummy actor."
ray.kill(_dummy_actor)
_dummy_actor = None
def is_actor(self):
actor_id = ray.worker.global_worker.actor_id
return actor_id != actor_id.nil()

View file

@ -85,14 +85,14 @@ def train_example(num_workers=1,
backend="nccl" if use_gpu else "gloo",
scheduler_step_freq="epoch",
use_fp16=use_fp16,
tqdm=True)
use_tqdm=True)
pbar = trange(num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if test_mode else {}
info["epoch_idx"] = i
info["num_epochs"] = num_epochs
# Increase `max_retries` to turn on fault tolerance.
stats = trainer1.train(max_retries=0, info=info)
stats = trainer1.train(max_retries=1, info=info)
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
print(trainer1.validate())

View file

@ -243,7 +243,7 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
config=config,
use_gpu=use_gpu,
backend="nccl" if use_gpu else "gloo",
tqdm=True)
use_tqdm=True)
from tabulate import tabulate
pbar = trange(5, unit="epoch")

View file

@ -3,9 +3,9 @@ cluster_name: sgd-pytorch
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 0
initial_workers: 0
max_workers: 0
min_workers: 2
initial_workers: 2
max_workers: 2
target_utilization_fraction: 0.9
@ -27,11 +27,13 @@ auth:
# ssh_private_key: ...
head_node:
InstanceType: p3dn.24xlarge
InstanceType: p3.2xlarge
ImageId: ami-0698bcaf8bd9ef56d
# KeyName: ...
InstanceMarketOptions:
MarketType: spot
SpotOptions:
BlockDurationMinutes: 360
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
@ -41,11 +43,13 @@ head_node:
worker_nodes:
InstanceType: p3.16xlarge
InstanceType: p3.8xlarge
ImageId: ami-0698bcaf8bd9ef56d
# KeyName: ...
InstanceMarketOptions:
MarketType: spot
SpotOptions:
BlockDurationMinutes: 360
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
@ -65,7 +69,7 @@ setup_commands:
# Installing this without -U to make sure we don't replace the existing Ray installation
- pip install ray[rllib]
- pip install -U ipdb torch torchvision
- pip install -U ipdb torch torchvision tqdm
# Install Apex
- rm -rf apex || true
- git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true

View file

@ -8,7 +8,7 @@ import tempfile
import torch
import ray
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd import utils
@ -49,6 +49,7 @@ class TorchRunner:
training_operator_cls=None,
config=None,
use_fp16=False,
use_tqdm=False,
apex_args=None,
scheduler_step_freq="batch"):
self.model_creator = model_creator
@ -68,6 +69,7 @@ class TorchRunner:
self.train_loader = None
self.validation_loader = None
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.apex_args = apex_args or {}
if use_fp16 and not amp:
raise ImportError(
@ -133,9 +135,6 @@ class TorchRunner:
self.models, self.optimizers = amp.initialize(
self.models, self.optimizers, **self.apex_args)
def set_reporters(self, reporters):
return self.training_operator.set_reporters(reporters)
def setup(self):
"""Initializes the model."""
logger.debug("Creating model")
@ -163,7 +162,8 @@ class TorchRunner:
validation_loader=self.validation_loader,
world_rank=0,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
def get_node_ip(self):
"""Returns the IP address of the current node."""
@ -180,6 +180,7 @@ class TorchRunner:
self._toggle_profiling(profile=profile)
info.update({
NUM_STEPS: num_steps,
USE_FP16: self.use_fp16,
SCHEDULER_STEP: self.scheduler_step_freq
})

View file

@ -4,22 +4,18 @@ import logging
import numbers
import tempfile
import time
import asyncio
import torch
import torch.distributed as dist
import ray
from ray.exceptions import RayActorError
from ray.tune import Trainable
from ray.tune.trial import Resources
from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner)
DistributedTorchRunner, LocalDistributedRunner)
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP,
BATCH_LOGS_RATE_LIMIT)
from ray.util.sgd.torch.tqdm_handler import TqdmHandler
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
@ -149,7 +145,7 @@ class TorchTrainer:
use_gpu=False,
backend="auto",
use_fp16=False,
tqdm=False,
use_tqdm=False,
apex_args=None,
scheduler_step_freq="batch",
num_replicas=None,
@ -212,6 +208,7 @@ class TorchTrainer:
self.max_replicas = num_workers
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
if apex_args and not isinstance(apex_args, dict):
raise ValueError("apex_args needs to be a dict object.")
@ -221,10 +218,6 @@ class TorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
self.handlers = []
if tqdm:
self.handlers.append(TqdmHandler())
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
@ -256,68 +249,71 @@ class TorchTrainer:
batch_size_per_worker = self._configure_and_split_batch(num_workers)
if batch_size_per_worker:
worker_config[BATCH_SIZE] = batch_size_per_worker
self.local_worker = None
self.remote_workers = []
if num_workers == 1:
# Generate actor class
Runner = ray.remote(
num_cpus=1, num_gpus=int(self.use_gpu))(TorchRunner)
# Start workers
self.workers = [
Runner.remote(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq,
)
]
# Start local worker
self.local_worker = TorchRunner(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
# Get setup tasks in order to throw errors on failure
ray.get(self.workers[0].setup.remote())
ray.get(self.workers[0].set_reporters.remote(
[h.create_reporter() for h in self.handlers]))
self.local_worker.setup()
else:
params = dict(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
backend=self.backend,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
# Start local worker
self.local_worker = LocalDistributedRunner(
num_cpus=1, num_gpus=int(self.use_gpu), **params)
# Generate actor class
Runner = ray.remote(
RemoteRunner = ray.remote(
num_cpus=1, num_gpus=int(self.use_gpu))(DistributedTorchRunner)
# Start workers
self.workers = [
Runner.remote(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
backend=self.backend,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
for i in range(num_workers)
self.remote_workers = [
RemoteRunner.remote(**params) for i in range(num_workers - 1)
]
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
# Compute URL for initializing distributed PyTorch
ip = ray.get(self.workers[0].get_node_ip.remote())
port = ray.get(self.workers[0].find_free_port.remote())
ip = ray.services.get_node_ip_address()
port = self.local_worker.find_free_port()
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
remote_setups = [
worker.setup.remote(address, i + 1, num_workers)
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup(address, 0, num_workers)
# Get setup tasks in order to throw errors on failure
ray.get([
worker.setup.remote(address, i, len(self.workers))
for i, worker in enumerate(self.workers)
])
ray.get([
w.set_reporters.remote(
[h.create_reporter() for h in self.handlers])
for w in self.workers
])
ray.get(remote_setups)
def train(self,
num_steps=None,
@ -374,9 +370,6 @@ class TorchTrainer:
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers(checkpoint=checkpoint)
for h in self.handlers:
h.record_train_info(info, num_steps)
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
# Fault handling
@ -386,14 +379,13 @@ class TorchTrainer:
else:
self._num_failures += 1
self._resize_workers(checkpoint=checkpoint)
logger.info(
"Retrying training step with %d workers." % len(self.workers))
logger.info("Retrying training step with %d workers." %
(len(self.remote_workers) + 1))
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
if not success:
raise RuntimeError("Training run failed.")
worker_stats = ray.get(worker_stats)
if reduce_results:
return self._process_stats(worker_stats)
else:
@ -413,42 +405,30 @@ class TorchTrainer:
stats[stat_key] = worker_stats[0][stat_key]
return stats
def _train_epoch(self,
num_steps=None,
profile=False,
info=None,
batch_logs_handler=None):
worker_trains = [
w.train_epoch.remote(
num_steps=num_steps, profile=profile, info=info)
for w in self.workers
def _train_epoch(self, num_steps=None, profile=False, info=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if not self.handlers:
success = check_for_failure(worker_trains)
return success, worker_trains
unfinished = worker_trains
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(
unfinished, timeout=BATCH_LOGS_RATE_LIMIT)
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError as err:
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
logger.warning(err)
return False, None
if "NCCL" in err.args[0]: # there is no specific error message
logger.warning(err)
return False, None
# throw errors on agent failure
finished = ray.get(finished)
raise err
futures = [h.update() for h in self.handlers]
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.wait(futures))
loop.close()
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return True, worker_trains
except RayActorError as exc:
logger.exception(str(exc))
return False, worker_trains
return success, None
def apply_all_workers(self, fn):
"""Run a function on all operators on the workers.
@ -460,7 +440,9 @@ class TorchTrainer:
A list of objects returned by ``fn`` on each worker.
"""
return ray.get([w.apply.remote(fn) for w in self.workers])
remote_calls = [w.apply.remote(fn) for w in self.remote_workers]
local_call = self.local_worker.apply(fn)
return [local_call] + ray.get(remote_calls)
def apply_all_operators(self, fn):
"""Run a function on all operators on the workers.
@ -473,7 +455,11 @@ class TorchTrainer:
A list of objects returned by ``fn`` on each operator.
"""
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
remote_calls = [
w.apply_operator.remote(fn) for w in self.remote_workers
]
local_call = self.local_worker.apply_operator(fn)
return [local_call] + ray.get(remote_calls)
def validate(self, num_steps=None, profile=False, info=None):
"""Evaluates the model on the validation data set.
@ -491,12 +477,15 @@ class TorchTrainer:
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
worker_stats = ray.get([
w.validate.remote(num_steps=num_steps, profile=profile, info=info)
for w in self.workers
])
params = dict(num_steps=num_steps, profile=profile, info=info)
return self._process_stats(worker_stats)
remote_worker_stats = [
w.validate.remote(**params) for w in self.remote_workers
]
local_worker_stats = self.local_worker.validate(**params)
return self._process_stats([local_worker_stats] +
ray.get(remote_worker_stats))
def update_scheduler(self, metric):
"""Calls ``scheduler.step(metric)`` on all schedulers.
@ -509,7 +498,7 @@ class TorchTrainer:
def get_model(self):
"""Returns the learned model(s)."""
models = self.model_creator(self.config)
state = ray.get(self.workers[0].get_state.remote())
state = self.local_worker.get_state()
if len(state["models"]) == 1:
models.load_state_dict(state["models"][0])
else:
@ -517,6 +506,18 @@ class TorchTrainer:
model.load_state_dict(state_dict)
return models
def state_dict(self):
return self.local_worker.get_state()
def load_state_dict(self, state):
state_id = ray.put(state)
remote_calls = [
worker.set_state.remote(state_id) for worker in self.remote_workers
]
self.local_worker.set_state(state)
ray.get(remote_calls)
def save(self, checkpoint):
"""Saves the model(s) to the provided checkpoint.
@ -526,8 +527,7 @@ class TorchTrainer:
Returns:
checkpoint (str): Path to target checkpoint file.
"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, checkpoint)
torch.save(self.state_dict(), checkpoint)
return checkpoint
def restore(self, checkpoint):
@ -537,36 +537,67 @@ class TorchTrainer:
checkpoint (str): Path to target checkpoint file.
"""
state = torch.load(checkpoint)
state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
self.load_state_dict(state)
def shutdown(self, force=False):
"""Shuts down workers and releases resources."""
if not force:
cleanup = [worker.shutdown.remote() for worker in self.workers]
ray.get(cleanup)
[worker.__ray_terminate__.remote() for worker in self.workers]
else:
for worker in self.workers:
logger.warning("Killing worker {}.".format(worker))
worker.__ray_kill__()
cleanup = [
worker.shutdown.remote() for worker in self.remote_workers
]
self.local_worker.shutdown()
try:
ray.get(cleanup)
[
worker.__ray_terminate__.remote()
for worker in self.remote_workers
]
except RayActorError:
logger.warning(
"Failed to shutdown gracefully, forcing a shutdown.")
self.workers = []
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
ray.kill(worker)
else:
self.local_worker.shutdown()
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
ray.kill(worker)
self.local_worker = None
self.remote_workers = []
def _reset(self):
"""Terminates models without giving up local resource reservation."""
self.local_worker.shutdown(cleanup=False)
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
ray.kill(worker)
self.local_worker = None
self.remote_workers = []
def _check_potential_remote_workers_size(self):
# ASSUME 1 GPU + 1 CPU is already reserved for the local worker
remote_resources = ray.available_resources()
max_remote_workers = self.max_replicas - 1
new_remote_workers = min(
remote_resources.get("CPU", 0), max_remote_workers)
if self.use_gpu:
new_remote_workers = min(
remote_resources.get("GPU", 0), new_remote_workers)
return new_remote_workers
def _resize_workers(self, checkpoint, max_retries=10):
# check available resources
self.shutdown(force=True)
self._reset()
assert checkpoint, "Cannot restore without checkpoint."
time.sleep(1)
for i in range(max_retries):
resources = ray.available_resources()
new_workers = min(resources.get("CPU", 0), self.max_replicas)
if self.use_gpu:
new_workers = min(resources.get("GPU", 0), new_workers)
if new_workers:
new_remote_workers = self._check_potential_remote_workers_size()
if new_remote_workers:
self._last_resize = time.time()
self._start_workers(int(new_workers))
self._start_workers(int(new_remote_workers) + 1)
self.restore(checkpoint)
return
else:
@ -578,26 +609,24 @@ class TorchTrainer:
def _should_resize(self):
"""Returns True if past cooldown and exists resources to scale up."""
worker_gap = self.max_replicas - len(self.workers)
worker_gap = self.max_replicas - 1 - len(self.remote_workers)
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
if past_cooldown and worker_gap:
resources = ray.available_resources()
potential_workers = min(resources.get("CPU", 0), self.max_replicas)
if self.use_gpu:
potential_workers = min(
resources.get("GPU", 0), potential_workers)
return potential_workers > 0
# Assume 1 resource is already reserved for local worker.
potential_remote_size = self._check_potential_remote_workers_size()
return potential_remote_size > 0
return False
class TorchTrainable(Trainable):
@classmethod
def default_resource_request(cls, config):
remote_worker_count = config["num_workers"] - 1
return Resources(
cpu=0,
gpu=0,
extra_cpu=config["num_workers"],
extra_gpu=int(config["use_gpu"]) * config["num_workers"])
cpu=1,
gpu=int(config["use_gpu"]),
extra_cpu=int(remote_worker_count),
extra_gpu=int(int(config["use_gpu"]) * remote_worker_count))
def _setup(self, config):
self._trainer = TorchTrainer(**config)

View file

@ -1,116 +0,0 @@
import asyncio
import time
from tqdm import tqdm
import ray
from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT
@ray.remote(num_cpus=0)
class _ReporterActor:
def __init__(self):
# we need the new_data field to allow sending back None as the legs
self._logs = {"new_data": False, "data": None}
self._setup = {"new_data": False, "data": None}
def _send_setup(self, data):
self._setup = {"new_data": True, "data": data}
def _send_logs(self, data):
self._logs = {"new_data": True, "data": data}
def _read_logs(self):
res = self._logs
self._logs = {"new_data": False, "data": None}
return res
def _read_setup(self):
res = self._setup
self._setup = {"new_data": False, "data": None}
return res
class TqdmReporter:
def __init__(self, actor):
self.actor = actor
self.last_packet_time = 0
def _send_setup(self, packet):
ray.get(self.actor._send_setup.remote(packet))
def _send_logs(self, packet):
cur_time = time.monotonic()
if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT:
return
self.last_packet_time = cur_time
ray.get(self.actor._send_logs.remote(packet))
def on_epoch_begin(self, info, training_op):
if training_op.world_rank != 0:
return
self.last_packet_time = 0
self._send_setup({"loader_len": len(training_op.train_loader)})
def on_batch_end(self, batch_info, metrics, training_op):
if training_op.world_rank != 0:
return
pbar_metrics = {}
if "train_loss" in metrics:
pbar_metrics["loss"] = metrics["train_loss"]
self._send_logs({
"batch_idx": batch_info["batch_idx"],
"pbar_metrics": pbar_metrics
})
class TqdmHandler:
def __init__(self):
self.batch_pbar = None
self.reporter_actor = _ReporterActor.remote()
def create_reporter(self):
return TqdmReporter(self.reporter_actor)
def handle_setup_packet(self, packet):
n = self.num_steps
if n is None:
n = packet["loader_len"]
desc = ""
if self.train_info is not None and "epoch_idx" in self.train_info:
if "num_epochs" in self.train_info:
desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1,
self.train_info["num_epochs"])
else:
desc = "{}e".format(self.train_info["epoch_idx"] + 1)
self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False)
def handle_logs_packet(self, packet):
self.batch_pbar.n = packet["batch_idx"] + 1
self.batch_pbar.set_postfix(packet["pbar_metrics"])
def record_train_info(self, info, num_steps):
self.train_info = info
self.num_steps = num_steps
async def update(self):
setup_read, logs_read = await asyncio.gather(
self.reporter_actor._read_setup.remote(),
self.reporter_actor._read_logs.remote())
if setup_read["new_data"]:
self.handle_setup_packet(setup_read["data"])
if logs_read["new_data"]:
self.handle_logs_packet(logs_read["data"])

View file

@ -1,10 +1,11 @@
import collections
from tqdm import tqdm
import torch
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
NUM_SAMPLES)
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH,
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS,
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
amp = None
@ -55,7 +56,8 @@ class TrainingOperator:
world_rank,
criterion=None,
schedulers=None,
use_fp16=False):
use_fp16=False,
use_tqdm=False):
# You are not expected to override this method.
self._models = models # List of models
assert isinstance(models, collections.Iterable), (
@ -74,6 +76,7 @@ class TrainingOperator:
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16
self._use_tqdm = use_tqdm
self.global_step = 0
if type(self) is TrainingOperator:
@ -84,12 +87,8 @@ class TrainingOperator:
"TrainingOperator if using multi-scheduler, "
"multi-model or multi-optimizer training/validation.")
self.timers = TimerCollection()
self.reporters = []
self.setup(config)
def set_reporters(self, reporters):
self.reporters = reporters
def _set_timers(self, timers):
"""Passes in the timers from the Runner."""
self.timers = timers
@ -142,8 +141,19 @@ class TrainingOperator:
Returns:
A dict of metrics from training.
"""
for r in self.reporters:
r.on_epoch_begin(info, self)
if self.use_tqdm and self.world_rank == 0:
desc = ""
if info is not None and "epoch_idx" in info:
if "num_epochs" in info:
desc = "{}/{}e".format(info["epoch_idx"] + 1,
info["num_epochs"])
else:
desc = "{}e".format(info["epoch_idx"] + 1)
_progress_bar = tqdm(
total=info[NUM_STEPS] or len(self.train_loader),
desc=desc,
unit="batch",
leave=False)
metric_meters = AverageMeterCollection()
@ -156,8 +166,10 @@ class TrainingOperator:
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info)
for r in self.reporters:
r.on_batch_end(batch_info, metrics, self)
if self.use_tqdm and self.world_rank == 0:
_progress_bar.n = batch_idx + 1
if "train_loss" in metrics:
_progress_bar.set_postfix({"loss": metrics["train_loss"]})
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
@ -376,6 +388,11 @@ class TrainingOperator:
"""Whether the model and optimizer have been FP16 enabled."""
return self._use_fp16
@property
def use_tqdm(self):
"""Whether tqdm progress bars are enabled."""
return self._use_tqdm
class _TestingOperator(TrainingOperator):
def train_epoch(self, iterator, info):