[SGD] v2 initial checkpoint functionality (#17632)

* [SGD] initial checkpoint functionality

* remove thread implementation and merge with fetch_next_result

* Update comment

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>

* address comments

* add additional tests

* fix imports

* load most recently saved checkpoint

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
matthewdeng 2021-08-12 08:52:04 -07:00 committed by GitHub
parent d6eeb5dc70
commit 55680a1f9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 487 additions and 44 deletions

View file

@ -1,9 +1,10 @@
from ray.util.sgd.v2.backends import BackendConfig, TorchConfig
from ray.util.sgd.v2.callbacks import SGDCallback
from ray.util.sgd.v2.trainer import Trainer
from ray.util.sgd.v2.session import report, world_rank
from ray.util.sgd.v2.session import load_checkpoint, save_checkpoint, report, \
world_rank
__all__ = [
"BackendConfig", "report", "SGDCallback", "TorchConfig", "Trainer",
"world_rank"
"BackendConfig", "load_checkpoint", "report", "save_checkpoint",
"SGDCallback", "TorchConfig", "Trainer", "world_rank"
]

View file

@ -3,10 +3,11 @@ from typing import Callable, TypeVar, List, Optional, Dict
import ray
from ray.exceptions import RayActorError
from ray.util.sgd.v2.worker_group import WorkerGroup
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
from ray.ray_constants import env_integer
from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.worker_group import WorkerGroup
T = TypeVar("T")
@ -52,6 +53,7 @@ class BackendExecutor:
self._num_gpus_per_worker = num_gpus_per_worker
self.worker_group = InactiveWorkerGroup()
self.latest_checkpoint = None
def start(self, initialization_hook: Optional[Callable[[], None]] = None):
"""Starts the worker group."""
@ -62,24 +64,30 @@ class BackendExecutor:
self.worker_group.execute(initialization_hook)
self._backend.on_start(self.worker_group, self._backend_config)
def start_training(self, train_func: Callable[[], T]) -> None:
def start_training(self,
train_func: Callable[[], T],
checkpoint: Optional[Dict] = None) -> None:
"""Executes a training function on all workers in a separate thread.
``finish_training`` should be called after this.
Args:
train_func (Callable): The training function to run on each worker.
checkpoint (Optional[Dict]): The checkpoint data that should be
loaded onto each worker and accessed by the training function
via ``sgd.load_checkpoint()``.
"""
use_detailed_autofilled_metrics = env_integer(
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
# First initialize the session.
def initialize_session(world_rank, train_func):
def initialize_session(world_rank, train_func, checkpoint):
try:
init_session(
training_func=train_func,
world_rank=world_rank,
checkpoint=checkpoint,
detailed_autofilled_metrics=use_detailed_autofilled_metrics
)
except ValueError:
@ -96,7 +104,8 @@ class BackendExecutor:
world_rank,
initialize_session,
world_rank=world_rank,
train_func=train_func))
train_func=train_func,
checkpoint=checkpoint))
ray.get(futures)
@ -107,16 +116,16 @@ class BackendExecutor:
self.worker_group.execute_async(train_async)
def fetch_next_result(self) -> Optional[List[Dict]]:
"""Fetch next results produced by ``sgd.report()`` from each worker.
def _get_next_results(self) -> Optional[List[TrainingResult]]:
"""Fetches the next ``TrainingResult`` from each worker.
Assumes ``start_training`` has already been called.
Each ``TrainingResult`` is expected to correspond to the same step from
each worker (e.g. the same call to ``sgd.report()`` or
``sgd.checkpoint()``).
Returns:
A list of dictionaries of values passed to ``sgd.report()`` from
each worker. Each item corresponds to an intermediate result
a single worker. If there are no more items to fetch,
returns None.
A list of ``TrainingResult``s with the same
``TrainingResultType``, or ``None`` if there are no more results.
"""
def get_next():
@ -141,6 +150,7 @@ class BackendExecutor:
return result
# Get next result from each worker.
futures = self.worker_group.execute_async(get_next)
results = self.get_with_failure_handling(futures)
@ -150,13 +160,56 @@ class BackendExecutor:
if not all(r is None for r in results):
raise RuntimeError("Some workers returned results while "
"others didn't. Make sure that "
"`sgd.report()` is called the same number "
"of times on all workers.")
"`sgd.report()` and `sgd.checkpoint()` are "
"called the same number of times on all "
"workers.")
else:
results = None
# Return None if all results are None.
return None
first_result = results[0]
result_type = first_result.type
if any(r.type != result_type for r in results):
raise RuntimeError("Some workers returned results with "
"different types. Make sure `sgd.report()` and "
"`sgd.save_checkpoint()` are called the same "
"number of times and in the same order on each "
"worker.")
return results
def _process_checkpoint(self, results):
# Process checkpoint
self.latest_checkpoint = results[0].data
def fetch_next_result(self) -> Optional[List[Dict]]:
"""Fetch next results produced by ``sgd.report()`` from each worker.
Assumes ``start_training`` has already been called.
Returns:
A list of dictionaries of values passed to ``sgd.report()`` from
each worker. Each item corresponds to an intermediate result
a single worker. If there are no more items to fetch,
returns None.
"""
while True:
results = self._get_next_results()
if results is None:
return None
first_result = results[0]
result_type = first_result.type
if result_type is TrainingResultType.REPORT:
result_data = [r.data for r in results]
return result_data
elif result_type is TrainingResultType.CHECKPOINT:
self._process_checkpoint(results)
# Iterate until next REPORT call or training has finished.
else:
raise SGDBackendError(f"Unexpected result type: "
f"{result_type}. "
f"Expected one of "
f"{[type in TrainingResultType]}")
def finish_training(self) -> List[T]:
"""Finish training and return final results. Propagate any exceptions.
@ -169,6 +222,19 @@ class BackendExecutor:
Each item corresponds to the return value from a single worker.
"""
def pause_reporting():
# Get the session for this worker.
try:
session = get_session()
except ValueError:
# Session is not initialized yet.
raise SGDBackendError("`finish_training` has been called "
"before `start_training`. Please call "
"`start_training` before "
"`finish_training`.")
return session.pause_reporting()
def end_training():
# Get the session for this worker.
try:
@ -190,6 +256,23 @@ class BackendExecutor:
return output
# Disable workers from enqueuing results from `sgd.report()`.
# Results will not be processed during the execution of `finish`.
# Note: Reported results may still be enqueued at this point,
# and should be handled appropriately.
futures = self.worker_group.execute_async(pause_reporting)
self.get_with_failure_handling(futures)
# Finish up processing checkpoints. Reporting has been disabled.
while True:
results = self._get_next_results()
if results is None:
break
result_type = results[0].type
# Process checkpoints and ignore other result types.
if result_type is TrainingResultType.CHECKPOINT:
self._process_checkpoint(results)
futures = self.worker_group.execute_async(end_training)
return self.get_with_failure_handling(futures)
@ -223,6 +306,9 @@ class BackendExecutor:
raise RuntimeError("Worker crashed during training. "
"Training unsuccessful.")
def get_latest_checkpoint(self) -> Optional[Dict]:
return self.latest_checkpoint
def shutdown(self):
"""Shuts down the workers in the worker group."""
try:

View file

@ -1,10 +1,12 @@
import queue
import time
from datetime import datetime
import threading
import os
import platform
import queue
import threading
import time
from datetime import datetime
from enum import Enum, auto
from typing import Callable
from typing import Optional, Dict
import ray
from ray.util.sgd.v2.constants import (
@ -13,17 +15,30 @@ from ray.util.sgd.v2.constants import (
from ray.util.sgd.v2.utils import PropagatingThread
class TrainingResultType(Enum):
REPORT = auto()
CHECKPOINT = auto()
class TrainingResult():
def __init__(self, type: TrainingResultType, data: Dict):
self.type = type
self.data = data
class Session:
"""Holds information for training on each worker."""
def __init__(self,
training_func: Callable,
world_rank: int,
checkpoint: Optional[Dict] = None,
detailed_autofilled_metrics: bool = False):
# The Thread object that is running the training function.
self.training_thread = PropagatingThread(
target=training_func, daemon=True)
self.world_rank = world_rank
self.loaded_checkpoint = checkpoint
# This lock is used to control the execution of the training thread.
self.continue_lock = threading.Semaphore(0)
@ -50,19 +65,16 @@ class Session:
self.training_started = True
self.training_thread.start()
def pause_reporting(self):
"""Ignore all future ``sgd.report()`` calls."""
self.ignore_report = True
def finish(self):
"""Finishes the training thread.
Either returns the output from training or raises any Exception from
training.
"""
# Ignore all future sgd.report calls.
self.ignore_report = True
# Release lock so that training will continue even if
# fetch_next_result is not exhausted.
self.continue_lock.release()
# Wait for training to finish.
# This will raise any errors that occur during training, including
@ -71,7 +83,7 @@ class Session:
# If training finished successfully, then return results.
return func_output
def get_next(self):
def get_next(self) -> Optional[TrainingResult]:
"""Gets next result from the queue."""
if not self.training_started:
raise RuntimeError("Please call start before calling get_next.")
@ -142,13 +154,36 @@ class Session:
kwargs = self._auto_fill_metrics(kwargs)
result = TrainingResult(TrainingResultType.REPORT, kwargs.copy())
# Add result to a thread-safe queue.
self.result_queue.put(kwargs, block=True)
self.result_queue.put(result, block=True)
# Acquire lock to stop the training thread until main thread
# triggers resume.
self.continue_lock.acquire()
def checkpoint(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread.
Also stores the checkpoint in ``self.loaded_checkpoint``.
"""
# Update session checkpoint to latest checkpoint.
self.loaded_checkpoint = kwargs
# Only store checkpoints on worker with rank 0.
if self.world_rank != 0:
kwargs = {}
result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)
# Acquire lock to stop the training thread until
# checkpoint has been processed.
self.continue_lock.acquire()
_session = None
@ -226,3 +261,58 @@ def world_rank() -> int:
"""
session = get_session()
return session.world_rank
def load_checkpoint() -> Optional[Dict]:
"""Loads checkpoint data onto the worker.
.. code-block:: python
from ray.util import sgd
def train_func():
checkpoint = sgd.load_checkpoint()
for iter in range(checkpoint["epoch"], 5):
print(iter)
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func, checkpoint={"epoch": 3})
# 3
# 4
trainer.shutdown()
Args:
**kwargs: Any key value pair to be checkpointed by SGD.
Returns:
The most recently saved checkpoint if ``sgd.save_checkpoint()``
has been called. Otherwise, the checkpoint that the session was
originally initialized with. ``None`` if neither exist.
"""
session = get_session()
return session.loaded_checkpoint
def save_checkpoint(**kwargs) -> None:
"""Checkpoints all keyword arguments to SGD as restorable state.
.. code-block:: python
import time
from ray.util import sgd
def train_func():
for iter in range(100):
time.sleep(1)
sgd.save_checkpoint(epoch=iter)
trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Args:
**kwargs: Any key value pair to be checkpointed by SGD.
"""
session = get_session()
session.checkpoint(**kwargs)

View file

@ -145,6 +145,38 @@ def test_no_exhaust(ray_start_2_cpus):
assert output == [2, 2]
def test_checkpoint(ray_start_2_cpus):
def train():
for i in range(2):
sgd.save_checkpoint(epoch=i)
config = TestConfig()
e = BackendExecutor(config, num_workers=1)
e.start()
e.start_training(train)
e.finish_training()
latest_checkpoint = e.get_latest_checkpoint()
assert latest_checkpoint is not None
assert latest_checkpoint["epoch"] == 1
def test_mismatch_checkpoint_report(ray_start_2_cpus):
def train():
if (sgd.world_rank()) == 0:
sgd.save_checkpoint(epoch=0)
else:
sgd.report(iter=0)
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
e.start_training(train)
with pytest.raises(RuntimeError):
e.finish_training()
def test_tensorflow_start(ray_start_2_cpus):
num_workers = 2
tensorflow_config = TensorflowConfig()

View file

@ -1,8 +1,9 @@
import time
import pytest
import pytest
from ray.util.sgd.v2.session import init_session, shutdown_session, \
get_session, world_rank, report
get_session, world_rank, report, save_checkpoint, TrainingResultType, \
load_checkpoint
@pytest.fixture(scope="function")
@ -47,8 +48,8 @@ def test_report():
init_session(training_func=train, world_rank=0)
session = get_session()
session.start()
assert session.get_next()["loss"] == 0
assert session.get_next()["loss"] == 1
assert session.get_next().data["loss"] == 0
assert session.get_next().data["loss"] == 1
shutdown_session()
with pytest.raises(ValueError):
@ -72,6 +73,7 @@ def test_report_fail():
def test_report_after_finish(session):
session.start()
session.pause_reporting()
session.finish()
for _ in range(2):
report(loss=1)
@ -83,6 +85,59 @@ def test_no_start(session):
session.get_next()
def test_checkpoint():
def train():
for i in range(2):
save_checkpoint(epoch=i)
def validate_zero(expected):
next = session.get_next()
assert next is not None
assert next.type == TrainingResultType.CHECKPOINT
assert next.data["epoch"] == expected
init_session(training_func=train, world_rank=0)
session = get_session()
session.start()
validate_zero(0)
validate_zero(1)
session.finish()
shutdown_session()
def validate_nonzero():
next = session.get_next()
assert next is not None
assert next.type == TrainingResultType.CHECKPOINT
assert next.data == {}
init_session(training_func=train, world_rank=1)
session = get_session()
session.start()
validate_nonzero()
validate_nonzero()
session.finish()
shutdown_session()
with pytest.raises(ValueError):
save_checkpoint(epoch=2)
def test_load_checkpoint_after_save():
def train():
for i in range(2):
save_checkpoint(epoch=i)
checkpoint = load_checkpoint()
assert checkpoint["epoch"] == i
init_session(training_func=train, world_rank=0)
session = get_session()
session.start()
for i in range(2):
session.get_next()
session.finish()
shutdown_session()
def test_locking():
"""Tests that report pauses training until fetch_next or finish."""
@ -106,6 +161,10 @@ def test_locking():
session.start()
time.sleep(3)
session.pause_reporting()
# Releases session.continue_lock to resume the training thread.
session.get_next()
with pytest.raises(KeyboardInterrupt):
session.finish()
shutdown_session()

View file

@ -73,11 +73,11 @@ def gen_new_backend_executor(special_f):
"""Returns a BackendExecutor that runs special_f on worker 0."""
class TestBackendExecutor(BackendExecutor):
def start_training(self, train_func):
def start_training(self, train_func, checkpoint):
special_execute = gen_execute_single_async_special(special_f)
with patch.object(WorkerGroup, "execute_single_async",
special_execute):
super().start_training(train_func)
super().start_training(train_func, checkpoint)
return TestBackendExecutor
@ -161,12 +161,15 @@ def test_fast_slow(ray_start_2_cpus):
def train():
for i in range(2):
sgd.save_checkpoint(epoch=i)
sgd.report(index=i)
def train_slow():
sgd.report(index=0)
time.sleep(5)
sgd.report(index=1)
for i in range(2):
sgd.save_checkpoint(epoch=i)
time.sleep(5)
sgd.report(index=i)
time.sleep(5)
new_backend_executor_cls = gen_new_backend_executor(train_slow)
callback = TestCallback()
@ -177,6 +180,8 @@ def test_fast_slow(ray_start_2_cpus):
trainer.start()
trainer.run(train, callbacks=[callback])
assert trainer.get_latest_checkpoint()["epoch"] == 1
result_list = callback.result_list
assert len(result_list) == 2
for index in range(len(result_list)):
@ -206,6 +211,116 @@ def test_mismatch_report(ray_start_2_cpus):
trainer.run(train)
def test_checkpoint(ray_start_2_cpus):
config = TestConfig()
def train_func():
assert sgd.load_checkpoint() is None
for i in range(3):
sgd.save_checkpoint(epoch=i)
return 1
trainer = Trainer(config, num_workers=2)
trainer.start()
trainer.run(train_func)
checkpoint = trainer.get_latest_checkpoint()
assert checkpoint is not None
assert checkpoint["epoch"] == 2
def train_func_checkpoint():
checkpoint = sgd.load_checkpoint()
assert checkpoint is not None
assert checkpoint["epoch"] == 2
for i in range(checkpoint["epoch"], 5):
sgd.save_checkpoint(epoch=i)
return 1
trainer.run(train_func_checkpoint, checkpoint=checkpoint)
checkpoint = trainer.get_latest_checkpoint()
assert checkpoint is not None
assert checkpoint["epoch"] == 4
def test_mismatch_checkpoint(ray_start_2_cpus):
test_config = TestConfig()
def train():
for i in range(2):
sgd.save_checkpoint(epoch=i)
def train_mismatch():
sgd.save_checkpoint(epoch=0)
new_backend_executor_cls = gen_new_backend_executor(train_mismatch)
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
with pytest.raises(RuntimeError):
trainer.run(train)
def test_mismatch_checkpoint_report(ray_start_2_cpus):
test_config = TestConfig()
def train():
for i in range(2):
sgd.save_checkpoint(epoch=i)
sgd.report(index=i)
def train_mismatch():
sgd.save_checkpoint(epoch=0)
sgd.report(index=0)
# skip checkpoint
sgd.report(index=1)
new_backend_executor_cls = gen_new_backend_executor(train_mismatch)
callback = TestCallback()
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
with pytest.raises(RuntimeError):
trainer.run(train, callbacks=[callback])
# validate checkpoint
assert trainer.get_latest_checkpoint()["epoch"] == 0
# validate callback
result_list = callback.result_list
assert len(result_list) == 1 # 1 epoch succeeded
intermediate_results = result_list[0]
assert len(intermediate_results) == 2 # both workers reported
for worker_result in intermediate_results:
assert worker_result["index"] == 0
def test_load_checkpoint(ray_start_2_cpus):
config = TestConfig()
def train_func_checkpoint():
checkpoint = sgd.load_checkpoint()
assert checkpoint is not None
assert checkpoint["epoch"] == 3
result = []
for i in range(checkpoint["epoch"], 5):
result.append(i)
return result
trainer = Trainer(config, num_workers=2)
trainer.start()
result = trainer.run(train_func_checkpoint, checkpoint={"epoch": 3})
assert result is not None
assert len(result) == 2
assert result[0] == [3, 4]
assert result[1] == [3, 4]
def test_world_rank(ray_start_2_cpus):
config = TestConfig()
@ -425,6 +540,58 @@ def test_worker_failure_2(ray_start_2_cpus):
trainer.run(train)
def test_worker_failure_checkpoint(ray_start_2_cpus):
test_config = TestConfig()
def train():
for i in range(3):
sgd.save_checkpoint(epoch=i)
sgd.report(index=i)
def train_actor_failure():
sgd.save_checkpoint(epoch=0)
sgd.report(index=0)
sgd.save_checkpoint(epoch=1)
import sys
sys.exit(0)
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
with pytest.raises(RuntimeError):
trainer.run(train)
assert trainer.get_latest_checkpoint()["epoch"] == 1
def test_worker_failure_checkpoint_2(ray_start_2_cpus):
test_config = TestConfig()
def train():
for i in range(3):
sgd.report(index=i)
sgd.save_checkpoint(epoch=i)
def train_actor_failure():
for i in range(3):
sgd.report(index=i)
sgd.save_checkpoint(epoch=i)
import sys
sys.exit(0)
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
with pytest.raises(RuntimeError):
trainer.run(train)
assert trainer.get_latest_checkpoint()["epoch"] == 2
def test_worker_kill(ray_start_2_cpus):
test_config = TestConfig()

View file

@ -111,7 +111,8 @@ class Trainer:
def run(self,
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
config: Optional[Dict[str, Any]] = None,
callbacks: Optional[List[SGDCallback]] = None) -> List[T]:
callbacks: Optional[List[SGDCallback]] = None,
checkpoint: Optional[Dict] = None) -> List[T]:
"""Runs a training function in a distributed manner.
Args:
@ -122,6 +123,9 @@ class Trainer:
callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
will be executed during training. If this is not set,
currently there are NO default Callbacks.
checkpoint (Optional[Dict]): The checkpoint data that should be
loaded onto each worker and accessed by the training function
via ``sgd.load_checkpoint()``.
Returns:
A list of results from the training function. Each value in the
@ -136,7 +140,7 @@ class Trainer:
try:
for callback in callbacks:
callback.start_training()
self._executor.start_training(train_func)
self._executor.start_training(train_func, checkpoint)
while True:
intermediate_results = self._executor.fetch_next_result()
@ -221,6 +225,10 @@ class Trainer:
"""
raise NotImplementedError
def get_latest_checkpoint(self) -> Optional[Dict]:
"""Gets the latest checkpoint for this Trainer."""
return self._executor.get_latest_checkpoint()
def shutdown(self):
"""Shuts down the training execution service."""
self._executor.shutdown()