mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[SGD] v2 initial checkpoint functionality (#17632)
* [SGD] initial checkpoint functionality * remove thread implementation and merge with fetch_next_result * Update comment Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com> * address comments * add additional tests * fix imports * load most recently saved checkpoint Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
parent
d6eeb5dc70
commit
55680a1f9e
7 changed files with 487 additions and 44 deletions
|
@ -1,9 +1,10 @@
|
|||
from ray.util.sgd.v2.backends import BackendConfig, TorchConfig
|
||||
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||
from ray.util.sgd.v2.trainer import Trainer
|
||||
from ray.util.sgd.v2.session import report, world_rank
|
||||
from ray.util.sgd.v2.session import load_checkpoint, save_checkpoint, report, \
|
||||
world_rank
|
||||
|
||||
__all__ = [
|
||||
"BackendConfig", "report", "SGDCallback", "TorchConfig", "Trainer",
|
||||
"world_rank"
|
||||
"BackendConfig", "load_checkpoint", "report", "save_checkpoint",
|
||||
"SGDCallback", "TorchConfig", "Trainer", "world_rank"
|
||||
]
|
||||
|
|
|
@ -3,10 +3,11 @@ from typing import Callable, TypeVar, List, Optional, Dict
|
|||
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
||||
from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
|
||||
from ray.ray_constants import env_integer
|
||||
from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
|
||||
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
|
||||
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -52,6 +53,7 @@ class BackendExecutor:
|
|||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
|
||||
self.worker_group = InactiveWorkerGroup()
|
||||
self.latest_checkpoint = None
|
||||
|
||||
def start(self, initialization_hook: Optional[Callable[[], None]] = None):
|
||||
"""Starts the worker group."""
|
||||
|
@ -62,24 +64,30 @@ class BackendExecutor:
|
|||
self.worker_group.execute(initialization_hook)
|
||||
self._backend.on_start(self.worker_group, self._backend_config)
|
||||
|
||||
def start_training(self, train_func: Callable[[], T]) -> None:
|
||||
def start_training(self,
|
||||
train_func: Callable[[], T],
|
||||
checkpoint: Optional[Dict] = None) -> None:
|
||||
"""Executes a training function on all workers in a separate thread.
|
||||
|
||||
``finish_training`` should be called after this.
|
||||
|
||||
Args:
|
||||
train_func (Callable): The training function to run on each worker.
|
||||
checkpoint (Optional[Dict]): The checkpoint data that should be
|
||||
loaded onto each worker and accessed by the training function
|
||||
via ``sgd.load_checkpoint()``.
|
||||
"""
|
||||
|
||||
use_detailed_autofilled_metrics = env_integer(
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
|
||||
|
||||
# First initialize the session.
|
||||
def initialize_session(world_rank, train_func):
|
||||
def initialize_session(world_rank, train_func, checkpoint):
|
||||
try:
|
||||
init_session(
|
||||
training_func=train_func,
|
||||
world_rank=world_rank,
|
||||
checkpoint=checkpoint,
|
||||
detailed_autofilled_metrics=use_detailed_autofilled_metrics
|
||||
)
|
||||
except ValueError:
|
||||
|
@ -96,7 +104,8 @@ class BackendExecutor:
|
|||
world_rank,
|
||||
initialize_session,
|
||||
world_rank=world_rank,
|
||||
train_func=train_func))
|
||||
train_func=train_func,
|
||||
checkpoint=checkpoint))
|
||||
|
||||
ray.get(futures)
|
||||
|
||||
|
@ -107,16 +116,16 @@ class BackendExecutor:
|
|||
|
||||
self.worker_group.execute_async(train_async)
|
||||
|
||||
def fetch_next_result(self) -> Optional[List[Dict]]:
|
||||
"""Fetch next results produced by ``sgd.report()`` from each worker.
|
||||
def _get_next_results(self) -> Optional[List[TrainingResult]]:
|
||||
"""Fetches the next ``TrainingResult`` from each worker.
|
||||
|
||||
Assumes ``start_training`` has already been called.
|
||||
Each ``TrainingResult`` is expected to correspond to the same step from
|
||||
each worker (e.g. the same call to ``sgd.report()`` or
|
||||
``sgd.checkpoint()``).
|
||||
|
||||
Returns:
|
||||
A list of dictionaries of values passed to ``sgd.report()`` from
|
||||
each worker. Each item corresponds to an intermediate result
|
||||
a single worker. If there are no more items to fetch,
|
||||
returns None.
|
||||
A list of ``TrainingResult``s with the same
|
||||
``TrainingResultType``, or ``None`` if there are no more results.
|
||||
"""
|
||||
|
||||
def get_next():
|
||||
|
@ -141,6 +150,7 @@ class BackendExecutor:
|
|||
|
||||
return result
|
||||
|
||||
# Get next result from each worker.
|
||||
futures = self.worker_group.execute_async(get_next)
|
||||
results = self.get_with_failure_handling(futures)
|
||||
|
||||
|
@ -150,13 +160,56 @@ class BackendExecutor:
|
|||
if not all(r is None for r in results):
|
||||
raise RuntimeError("Some workers returned results while "
|
||||
"others didn't. Make sure that "
|
||||
"`sgd.report()` is called the same number "
|
||||
"of times on all workers.")
|
||||
"`sgd.report()` and `sgd.checkpoint()` are "
|
||||
"called the same number of times on all "
|
||||
"workers.")
|
||||
else:
|
||||
results = None
|
||||
|
||||
# Return None if all results are None.
|
||||
return None
|
||||
first_result = results[0]
|
||||
result_type = first_result.type
|
||||
if any(r.type != result_type for r in results):
|
||||
raise RuntimeError("Some workers returned results with "
|
||||
"different types. Make sure `sgd.report()` and "
|
||||
"`sgd.save_checkpoint()` are called the same "
|
||||
"number of times and in the same order on each "
|
||||
"worker.")
|
||||
return results
|
||||
|
||||
def _process_checkpoint(self, results):
|
||||
# Process checkpoint
|
||||
self.latest_checkpoint = results[0].data
|
||||
|
||||
def fetch_next_result(self) -> Optional[List[Dict]]:
|
||||
"""Fetch next results produced by ``sgd.report()`` from each worker.
|
||||
|
||||
Assumes ``start_training`` has already been called.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries of values passed to ``sgd.report()`` from
|
||||
each worker. Each item corresponds to an intermediate result
|
||||
a single worker. If there are no more items to fetch,
|
||||
returns None.
|
||||
"""
|
||||
|
||||
while True:
|
||||
results = self._get_next_results()
|
||||
if results is None:
|
||||
return None
|
||||
first_result = results[0]
|
||||
result_type = first_result.type
|
||||
if result_type is TrainingResultType.REPORT:
|
||||
result_data = [r.data for r in results]
|
||||
return result_data
|
||||
elif result_type is TrainingResultType.CHECKPOINT:
|
||||
self._process_checkpoint(results)
|
||||
# Iterate until next REPORT call or training has finished.
|
||||
else:
|
||||
raise SGDBackendError(f"Unexpected result type: "
|
||||
f"{result_type}. "
|
||||
f"Expected one of "
|
||||
f"{[type in TrainingResultType]}")
|
||||
|
||||
def finish_training(self) -> List[T]:
|
||||
"""Finish training and return final results. Propagate any exceptions.
|
||||
|
||||
|
@ -169,6 +222,19 @@ class BackendExecutor:
|
|||
Each item corresponds to the return value from a single worker.
|
||||
"""
|
||||
|
||||
def pause_reporting():
|
||||
# Get the session for this worker.
|
||||
try:
|
||||
session = get_session()
|
||||
except ValueError:
|
||||
# Session is not initialized yet.
|
||||
raise SGDBackendError("`finish_training` has been called "
|
||||
"before `start_training`. Please call "
|
||||
"`start_training` before "
|
||||
"`finish_training`.")
|
||||
|
||||
return session.pause_reporting()
|
||||
|
||||
def end_training():
|
||||
# Get the session for this worker.
|
||||
try:
|
||||
|
@ -190,6 +256,23 @@ class BackendExecutor:
|
|||
|
||||
return output
|
||||
|
||||
# Disable workers from enqueuing results from `sgd.report()`.
|
||||
# Results will not be processed during the execution of `finish`.
|
||||
# Note: Reported results may still be enqueued at this point,
|
||||
# and should be handled appropriately.
|
||||
futures = self.worker_group.execute_async(pause_reporting)
|
||||
self.get_with_failure_handling(futures)
|
||||
|
||||
# Finish up processing checkpoints. Reporting has been disabled.
|
||||
while True:
|
||||
results = self._get_next_results()
|
||||
if results is None:
|
||||
break
|
||||
result_type = results[0].type
|
||||
# Process checkpoints and ignore other result types.
|
||||
if result_type is TrainingResultType.CHECKPOINT:
|
||||
self._process_checkpoint(results)
|
||||
|
||||
futures = self.worker_group.execute_async(end_training)
|
||||
return self.get_with_failure_handling(futures)
|
||||
|
||||
|
@ -223,6 +306,9 @@ class BackendExecutor:
|
|||
raise RuntimeError("Worker crashed during training. "
|
||||
"Training unsuccessful.")
|
||||
|
||||
def get_latest_checkpoint(self) -> Optional[Dict]:
|
||||
return self.latest_checkpoint
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down the workers in the worker group."""
|
||||
try:
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import queue
|
||||
import time
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
from typing import Callable
|
||||
from typing import Optional, Dict
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.v2.constants import (
|
||||
|
@ -13,17 +15,30 @@ from ray.util.sgd.v2.constants import (
|
|||
from ray.util.sgd.v2.utils import PropagatingThread
|
||||
|
||||
|
||||
class TrainingResultType(Enum):
|
||||
REPORT = auto()
|
||||
CHECKPOINT = auto()
|
||||
|
||||
|
||||
class TrainingResult():
|
||||
def __init__(self, type: TrainingResultType, data: Dict):
|
||||
self.type = type
|
||||
self.data = data
|
||||
|
||||
|
||||
class Session:
|
||||
"""Holds information for training on each worker."""
|
||||
|
||||
def __init__(self,
|
||||
training_func: Callable,
|
||||
world_rank: int,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
detailed_autofilled_metrics: bool = False):
|
||||
# The Thread object that is running the training function.
|
||||
self.training_thread = PropagatingThread(
|
||||
target=training_func, daemon=True)
|
||||
self.world_rank = world_rank
|
||||
self.loaded_checkpoint = checkpoint
|
||||
|
||||
# This lock is used to control the execution of the training thread.
|
||||
self.continue_lock = threading.Semaphore(0)
|
||||
|
@ -50,19 +65,16 @@ class Session:
|
|||
self.training_started = True
|
||||
self.training_thread.start()
|
||||
|
||||
def pause_reporting(self):
|
||||
"""Ignore all future ``sgd.report()`` calls."""
|
||||
self.ignore_report = True
|
||||
|
||||
def finish(self):
|
||||
"""Finishes the training thread.
|
||||
|
||||
Either returns the output from training or raises any Exception from
|
||||
training.
|
||||
|
||||
"""
|
||||
# Ignore all future sgd.report calls.
|
||||
self.ignore_report = True
|
||||
|
||||
# Release lock so that training will continue even if
|
||||
# fetch_next_result is not exhausted.
|
||||
self.continue_lock.release()
|
||||
|
||||
# Wait for training to finish.
|
||||
# This will raise any errors that occur during training, including
|
||||
|
@ -71,7 +83,7 @@ class Session:
|
|||
# If training finished successfully, then return results.
|
||||
return func_output
|
||||
|
||||
def get_next(self):
|
||||
def get_next(self) -> Optional[TrainingResult]:
|
||||
"""Gets next result from the queue."""
|
||||
if not self.training_started:
|
||||
raise RuntimeError("Please call start before calling get_next.")
|
||||
|
@ -142,13 +154,36 @@ class Session:
|
|||
|
||||
kwargs = self._auto_fill_metrics(kwargs)
|
||||
|
||||
result = TrainingResult(TrainingResultType.REPORT, kwargs.copy())
|
||||
|
||||
# Add result to a thread-safe queue.
|
||||
self.result_queue.put(kwargs, block=True)
|
||||
self.result_queue.put(result, block=True)
|
||||
|
||||
# Acquire lock to stop the training thread until main thread
|
||||
# triggers resume.
|
||||
self.continue_lock.acquire()
|
||||
|
||||
def checkpoint(self, **kwargs):
|
||||
"""Adds kwargs to the queue to be consumed by main thread.
|
||||
|
||||
Also stores the checkpoint in ``self.loaded_checkpoint``.
|
||||
"""
|
||||
|
||||
# Update session checkpoint to latest checkpoint.
|
||||
self.loaded_checkpoint = kwargs
|
||||
|
||||
# Only store checkpoints on worker with rank 0.
|
||||
if self.world_rank != 0:
|
||||
kwargs = {}
|
||||
|
||||
result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
|
||||
# Add result to a thread-safe queue.
|
||||
self.result_queue.put(result, block=True)
|
||||
|
||||
# Acquire lock to stop the training thread until
|
||||
# checkpoint has been processed.
|
||||
self.continue_lock.acquire()
|
||||
|
||||
|
||||
_session = None
|
||||
|
||||
|
@ -226,3 +261,58 @@ def world_rank() -> int:
|
|||
"""
|
||||
session = get_session()
|
||||
return session.world_rank
|
||||
|
||||
|
||||
def load_checkpoint() -> Optional[Dict]:
|
||||
"""Loads checkpoint data onto the worker.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.util import sgd
|
||||
|
||||
def train_func():
|
||||
checkpoint = sgd.load_checkpoint()
|
||||
for iter in range(checkpoint["epoch"], 5):
|
||||
print(iter)
|
||||
|
||||
trainer = Trainer(backend="torch")
|
||||
trainer.start()
|
||||
trainer.run(train_func, checkpoint={"epoch": 3})
|
||||
# 3
|
||||
# 4
|
||||
trainer.shutdown()
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be checkpointed by SGD.
|
||||
Returns:
|
||||
The most recently saved checkpoint if ``sgd.save_checkpoint()``
|
||||
has been called. Otherwise, the checkpoint that the session was
|
||||
originally initialized with. ``None`` if neither exist.
|
||||
"""
|
||||
session = get_session()
|
||||
return session.loaded_checkpoint
|
||||
|
||||
|
||||
def save_checkpoint(**kwargs) -> None:
|
||||
"""Checkpoints all keyword arguments to SGD as restorable state.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray.util import sgd
|
||||
|
||||
def train_func():
|
||||
for iter in range(100):
|
||||
time.sleep(1)
|
||||
sgd.save_checkpoint(epoch=iter)
|
||||
|
||||
trainer = Trainer(backend="torch")
|
||||
trainer.start()
|
||||
trainer.run(train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be checkpointed by SGD.
|
||||
"""
|
||||
session = get_session()
|
||||
session.checkpoint(**kwargs)
|
||||
|
|
|
@ -145,6 +145,38 @@ def test_no_exhaust(ray_start_2_cpus):
|
|||
assert output == [2, 2]
|
||||
|
||||
|
||||
def test_checkpoint(ray_start_2_cpus):
|
||||
def train():
|
||||
for i in range(2):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=1)
|
||||
e.start()
|
||||
|
||||
e.start_training(train)
|
||||
e.finish_training()
|
||||
|
||||
latest_checkpoint = e.get_latest_checkpoint()
|
||||
assert latest_checkpoint is not None
|
||||
assert latest_checkpoint["epoch"] == 1
|
||||
|
||||
|
||||
def test_mismatch_checkpoint_report(ray_start_2_cpus):
|
||||
def train():
|
||||
if (sgd.world_rank()) == 0:
|
||||
sgd.save_checkpoint(epoch=0)
|
||||
else:
|
||||
sgd.report(iter=0)
|
||||
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
e.start_training(train)
|
||||
with pytest.raises(RuntimeError):
|
||||
e.finish_training()
|
||||
|
||||
|
||||
def test_tensorflow_start(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
tensorflow_config = TensorflowConfig()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import time
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
from ray.util.sgd.v2.session import init_session, shutdown_session, \
|
||||
get_session, world_rank, report
|
||||
get_session, world_rank, report, save_checkpoint, TrainingResultType, \
|
||||
load_checkpoint
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
@ -47,8 +48,8 @@ def test_report():
|
|||
init_session(training_func=train, world_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
assert session.get_next()["loss"] == 0
|
||||
assert session.get_next()["loss"] == 1
|
||||
assert session.get_next().data["loss"] == 0
|
||||
assert session.get_next().data["loss"] == 1
|
||||
shutdown_session()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -72,6 +73,7 @@ def test_report_fail():
|
|||
|
||||
def test_report_after_finish(session):
|
||||
session.start()
|
||||
session.pause_reporting()
|
||||
session.finish()
|
||||
for _ in range(2):
|
||||
report(loss=1)
|
||||
|
@ -83,6 +85,59 @@ def test_no_start(session):
|
|||
session.get_next()
|
||||
|
||||
|
||||
def test_checkpoint():
|
||||
def train():
|
||||
for i in range(2):
|
||||
save_checkpoint(epoch=i)
|
||||
|
||||
def validate_zero(expected):
|
||||
next = session.get_next()
|
||||
assert next is not None
|
||||
assert next.type == TrainingResultType.CHECKPOINT
|
||||
assert next.data["epoch"] == expected
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
validate_zero(0)
|
||||
validate_zero(1)
|
||||
session.finish()
|
||||
shutdown_session()
|
||||
|
||||
def validate_nonzero():
|
||||
next = session.get_next()
|
||||
assert next is not None
|
||||
assert next.type == TrainingResultType.CHECKPOINT
|
||||
assert next.data == {}
|
||||
|
||||
init_session(training_func=train, world_rank=1)
|
||||
session = get_session()
|
||||
session.start()
|
||||
validate_nonzero()
|
||||
validate_nonzero()
|
||||
session.finish()
|
||||
shutdown_session()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
save_checkpoint(epoch=2)
|
||||
|
||||
|
||||
def test_load_checkpoint_after_save():
|
||||
def train():
|
||||
for i in range(2):
|
||||
save_checkpoint(epoch=i)
|
||||
checkpoint = load_checkpoint()
|
||||
assert checkpoint["epoch"] == i
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
for i in range(2):
|
||||
session.get_next()
|
||||
session.finish()
|
||||
shutdown_session()
|
||||
|
||||
|
||||
def test_locking():
|
||||
"""Tests that report pauses training until fetch_next or finish."""
|
||||
|
||||
|
@ -106,6 +161,10 @@ def test_locking():
|
|||
session.start()
|
||||
time.sleep(3)
|
||||
|
||||
session.pause_reporting()
|
||||
# Releases session.continue_lock to resume the training thread.
|
||||
session.get_next()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
session.finish()
|
||||
shutdown_session()
|
||||
|
|
|
@ -73,11 +73,11 @@ def gen_new_backend_executor(special_f):
|
|||
"""Returns a BackendExecutor that runs special_f on worker 0."""
|
||||
|
||||
class TestBackendExecutor(BackendExecutor):
|
||||
def start_training(self, train_func):
|
||||
def start_training(self, train_func, checkpoint):
|
||||
special_execute = gen_execute_single_async_special(special_f)
|
||||
with patch.object(WorkerGroup, "execute_single_async",
|
||||
special_execute):
|
||||
super().start_training(train_func)
|
||||
super().start_training(train_func, checkpoint)
|
||||
|
||||
return TestBackendExecutor
|
||||
|
||||
|
@ -161,12 +161,15 @@ def test_fast_slow(ray_start_2_cpus):
|
|||
|
||||
def train():
|
||||
for i in range(2):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
sgd.report(index=i)
|
||||
|
||||
def train_slow():
|
||||
sgd.report(index=0)
|
||||
time.sleep(5)
|
||||
sgd.report(index=1)
|
||||
for i in range(2):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
time.sleep(5)
|
||||
sgd.report(index=i)
|
||||
time.sleep(5)
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_slow)
|
||||
callback = TestCallback()
|
||||
|
@ -177,6 +180,8 @@ def test_fast_slow(ray_start_2_cpus):
|
|||
trainer.start()
|
||||
trainer.run(train, callbacks=[callback])
|
||||
|
||||
assert trainer.get_latest_checkpoint()["epoch"] == 1
|
||||
|
||||
result_list = callback.result_list
|
||||
assert len(result_list) == 2
|
||||
for index in range(len(result_list)):
|
||||
|
@ -206,6 +211,116 @@ def test_mismatch_report(ray_start_2_cpus):
|
|||
trainer.run(train)
|
||||
|
||||
|
||||
def test_checkpoint(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
def train_func():
|
||||
assert sgd.load_checkpoint() is None
|
||||
for i in range(3):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
return 1
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
trainer.run(train_func)
|
||||
checkpoint = trainer.get_latest_checkpoint()
|
||||
|
||||
assert checkpoint is not None
|
||||
assert checkpoint["epoch"] == 2
|
||||
|
||||
def train_func_checkpoint():
|
||||
checkpoint = sgd.load_checkpoint()
|
||||
assert checkpoint is not None
|
||||
assert checkpoint["epoch"] == 2
|
||||
|
||||
for i in range(checkpoint["epoch"], 5):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
return 1
|
||||
|
||||
trainer.run(train_func_checkpoint, checkpoint=checkpoint)
|
||||
checkpoint = trainer.get_latest_checkpoint()
|
||||
|
||||
assert checkpoint is not None
|
||||
assert checkpoint["epoch"] == 4
|
||||
|
||||
|
||||
def test_mismatch_checkpoint(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
def train():
|
||||
for i in range(2):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
|
||||
def train_mismatch():
|
||||
sgd.save_checkpoint(epoch=0)
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_mismatch)
|
||||
|
||||
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer.run(train)
|
||||
|
||||
|
||||
def test_mismatch_checkpoint_report(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
def train():
|
||||
for i in range(2):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
sgd.report(index=i)
|
||||
|
||||
def train_mismatch():
|
||||
sgd.save_checkpoint(epoch=0)
|
||||
sgd.report(index=0)
|
||||
# skip checkpoint
|
||||
sgd.report(index=1)
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_mismatch)
|
||||
callback = TestCallback()
|
||||
|
||||
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer.run(train, callbacks=[callback])
|
||||
# validate checkpoint
|
||||
assert trainer.get_latest_checkpoint()["epoch"] == 0
|
||||
# validate callback
|
||||
result_list = callback.result_list
|
||||
assert len(result_list) == 1 # 1 epoch succeeded
|
||||
intermediate_results = result_list[0]
|
||||
assert len(intermediate_results) == 2 # both workers reported
|
||||
for worker_result in intermediate_results:
|
||||
assert worker_result["index"] == 0
|
||||
|
||||
|
||||
def test_load_checkpoint(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
def train_func_checkpoint():
|
||||
checkpoint = sgd.load_checkpoint()
|
||||
assert checkpoint is not None
|
||||
assert checkpoint["epoch"] == 3
|
||||
|
||||
result = []
|
||||
for i in range(checkpoint["epoch"], 5):
|
||||
result.append(i)
|
||||
return result
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
result = trainer.run(train_func_checkpoint, checkpoint={"epoch": 3})
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0] == [3, 4]
|
||||
assert result[1] == [3, 4]
|
||||
|
||||
|
||||
def test_world_rank(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
|
@ -425,6 +540,58 @@ def test_worker_failure_2(ray_start_2_cpus):
|
|||
trainer.run(train)
|
||||
|
||||
|
||||
def test_worker_failure_checkpoint(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
def train():
|
||||
for i in range(3):
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
sgd.report(index=i)
|
||||
|
||||
def train_actor_failure():
|
||||
sgd.save_checkpoint(epoch=0)
|
||||
sgd.report(index=0)
|
||||
sgd.save_checkpoint(epoch=1)
|
||||
import sys
|
||||
sys.exit(0)
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
|
||||
|
||||
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer.run(train)
|
||||
assert trainer.get_latest_checkpoint()["epoch"] == 1
|
||||
|
||||
|
||||
def test_worker_failure_checkpoint_2(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
def train():
|
||||
for i in range(3):
|
||||
sgd.report(index=i)
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
|
||||
def train_actor_failure():
|
||||
for i in range(3):
|
||||
sgd.report(index=i)
|
||||
sgd.save_checkpoint(epoch=i)
|
||||
import sys
|
||||
sys.exit(0)
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
|
||||
|
||||
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer.run(train)
|
||||
assert trainer.get_latest_checkpoint()["epoch"] == 2
|
||||
|
||||
|
||||
def test_worker_kill(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
|
|
|
@ -111,7 +111,8 @@ class Trainer:
|
|||
def run(self,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
callbacks: Optional[List[SGDCallback]] = None) -> List[T]:
|
||||
callbacks: Optional[List[SGDCallback]] = None,
|
||||
checkpoint: Optional[Dict] = None) -> List[T]:
|
||||
"""Runs a training function in a distributed manner.
|
||||
|
||||
Args:
|
||||
|
@ -122,6 +123,9 @@ class Trainer:
|
|||
callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
|
||||
will be executed during training. If this is not set,
|
||||
currently there are NO default Callbacks.
|
||||
checkpoint (Optional[Dict]): The checkpoint data that should be
|
||||
loaded onto each worker and accessed by the training function
|
||||
via ``sgd.load_checkpoint()``.
|
||||
|
||||
Returns:
|
||||
A list of results from the training function. Each value in the
|
||||
|
@ -136,7 +140,7 @@ class Trainer:
|
|||
try:
|
||||
for callback in callbacks:
|
||||
callback.start_training()
|
||||
self._executor.start_training(train_func)
|
||||
self._executor.start_training(train_func, checkpoint)
|
||||
|
||||
while True:
|
||||
intermediate_results = self._executor.fetch_next_result()
|
||||
|
@ -221,6 +225,10 @@ class Trainer:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_latest_checkpoint(self) -> Optional[Dict]:
|
||||
"""Gets the latest checkpoint for this Trainer."""
|
||||
return self._executor.get_latest_checkpoint()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shuts down the training execution service."""
|
||||
self._executor.shutdown()
|
||||
|
|
Loading…
Add table
Reference in a new issue