mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[SGD] v2 Tune integration + iterator API (#17839)
* [SGD] implement SGD Trainer.to_tune_trainable * address some comments * add RESULT_DUPLICATE * extract trainable creation logic out of Trainer * add 1 CPU for driver * use class attribute to fix serialization issues * add examples * add test for tune error * tune * test tune_linear * run_iterator * add to build file * Update python/ray/util/sgd/v2/trainer.py Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> * Update python/ray/util/sgd/v2/trainer.py Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> * address comments * fix tests & address comments * resolve merge * lint * fix * add team tag to tests * fix tests * lint Co-authored-by: Matthew Deng <matthew.j.deng@gmail.com>
This commit is contained in:
parent
60aee4a330
commit
9416fce91b
11 changed files with 679 additions and 37 deletions
|
@ -2,6 +2,15 @@
|
|||
# Tests from the python/ray/util/sgd/v2/examples directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "mlflow_fashion_mnist",
|
||||
size = "medium",
|
||||
main = "examples/mlflow_fashion_mnist.py",
|
||||
srcs = ["examples/mlflow_fashion_mnist.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":sgd_v2_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "transformers_example",
|
||||
|
@ -15,6 +24,16 @@ py_test(
|
|||
"--max_train_steps=2", "--start_local", "--num_workers=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_linear",
|
||||
size = "medium",
|
||||
main = "examples/tune_linear.py",
|
||||
srcs = ["examples/tune_linear.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":sgd_v2_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/sgd/v2/tests directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
|
@ -44,6 +63,14 @@ py_test(
|
|||
deps = [":sgd_v2_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_tune",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_tune.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":sgd_v2_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_worker_group",
|
||||
size = "small",
|
||||
|
|
|
@ -414,6 +414,10 @@ class BackendExecutor:
|
|||
self.worker_group.shutdown()
|
||||
self.worker_group = InactiveWorkerGroup()
|
||||
|
||||
@property
|
||||
def is_started(self):
|
||||
return not isinstance(self.worker_group, InactiveWorkerGroup)
|
||||
|
||||
@property
|
||||
def latest_run_dir(self) -> Optional[Path]:
|
||||
"""Path to the latest run directory."""
|
||||
|
@ -455,7 +459,16 @@ class InactiveWorkerGroupError(Exception):
|
|||
|
||||
class InactiveWorkerGroup():
|
||||
# TODO: fix inheritence. perhaps create WorkerGroupInterface.
|
||||
def __getattribute__(self, *args, **kwargs):
|
||||
|
||||
# Need to define getstate and setstate so that getattr does not screwup
|
||||
# pickling. See https://stackoverflow.com/a/50888571/11249691
|
||||
def __getstate__(self):
|
||||
return vars(self)
|
||||
|
||||
def __setstate__(self, state):
|
||||
vars(self).update(state)
|
||||
|
||||
def __getattr__(self, name):
|
||||
raise InactiveWorkerGroupError()
|
||||
|
||||
def __len__(self):
|
||||
|
|
64
python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py
Normal file
64
python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import argparse
|
||||
|
||||
import mlflow
|
||||
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from ray.util.sgd.v2.examples.train_fashion_mnist import train_func
|
||||
|
||||
|
||||
def main(num_workers=1, use_gpu=False):
|
||||
mlflow.set_experiment("sgd_torch_fashion_mnist")
|
||||
|
||||
trainer = Trainer(
|
||||
backend="torch", num_workers=num_workers, use_gpu=use_gpu)
|
||||
trainer.start()
|
||||
iterator = trainer.run_iterator(
|
||||
train_func=train_func,
|
||||
config={
|
||||
"lr": 1e-3,
|
||||
"batch_size": 64,
|
||||
"epochs": 4
|
||||
})
|
||||
|
||||
for intermediate_result in iterator:
|
||||
first_worker_result = intermediate_result[0]
|
||||
mlflow.log_metric("loss", first_worker_result["loss"])
|
||||
|
||||
print("Full losses for rank 0 worker: ", iterator.get_final_results())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=2)
|
||||
args.num_workers = 2
|
||||
args.use_gpu = False
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
main(num_workers=args.num_workers, use_gpu=args.use_gpu)
|
|
@ -7,9 +7,17 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
import ray.util.sgd.v2 as sgd
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
|
||||
|
||||
class SGDReportCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
sgd.report(**logs)
|
||||
|
||||
|
||||
def mnist_dataset(batch_size):
|
||||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
|
||||
# The `x` arrays are in uint8 and have values in the [0, 255] range.
|
||||
|
@ -56,8 +64,12 @@ def train_func(config):
|
|||
multi_worker_model = build_and_compile_cnn_model(config)
|
||||
|
||||
history = multi_worker_model.fit(
|
||||
multi_worker_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)
|
||||
return history.history
|
||||
multi_worker_dataset,
|
||||
epochs=epochs,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
callbacks=[SGDReportCallback()])
|
||||
results = history.history
|
||||
return results
|
||||
|
||||
|
||||
def train_tensorflow_mnist(num_workers=1, use_gpu=False):
|
||||
|
|
|
@ -117,7 +117,8 @@ def train_func(config: Dict):
|
|||
|
||||
|
||||
def train_fashion_mnist(num_workers=1, use_gpu=False):
|
||||
trainer = Trainer(backend="torch", num_workers=num_workers)
|
||||
trainer = Trainer(
|
||||
backend="torch", num_workers=num_workers, use_gpu=use_gpu)
|
||||
trainer.start()
|
||||
result = trainer.run(
|
||||
train_func=train_func,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import ray.util.sgd.v2 as sgd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ray.util.sgd.v2 import Trainer, TorchConfig
|
||||
|
@ -80,6 +81,7 @@ def train_func(config):
|
|||
for _ in range(epochs):
|
||||
train(train_loader, model, loss_fn, optimizer)
|
||||
result = validate(validation_loader, model, loss_fn)
|
||||
sgd.report(**result)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
|
56
python/ray/util/sgd/v2/examples/tune_linear.py
Normal file
56
python/ray/util/sgd/v2/examples/tune_linear.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
|
||||
from train_linear import train_func
|
||||
|
||||
|
||||
def tune_linear(num_workers, num_samples):
|
||||
trainer = Trainer("torch", num_workers=num_workers)
|
||||
Trainable = trainer.to_tune_trainable(train_func)
|
||||
analysis = tune.run(
|
||||
Trainable,
|
||||
num_samples=num_samples,
|
||||
config={
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([4, 16, 32]),
|
||||
"epochs": 3
|
||||
})
|
||||
results = analysis.get_best_config(metric="loss", mode="min")
|
||||
print(results)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of samples for training.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=4)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
tune_linear(num_workers=args.num_workers, num_samples=args.num_samples)
|
59
python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py
Normal file
59
python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
|
||||
from tensorflow_mnist_example import train_func
|
||||
|
||||
|
||||
def tune_tensorflow_mnist(num_workers, num_samples):
|
||||
trainer = Trainer(backend="tensorflow", num_workers=num_workers)
|
||||
Trainable = trainer.to_tune_trainable(train_func)
|
||||
analysis = tune.run(
|
||||
Trainable,
|
||||
num_samples=num_samples,
|
||||
config={
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"epochs": 3
|
||||
})
|
||||
best_loss = analysis.get_best_config(metric="loss", mode="min")
|
||||
best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
|
||||
print(f"Best loss config: {best_loss}")
|
||||
print(f"Best accuracy config: {best_accuracy}")
|
||||
return analysis
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of samples for training.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=4)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
tune_tensorflow_mnist(
|
||||
num_workers=args.num_workers, num_samples=args.num_samples)
|
|
@ -7,6 +7,7 @@ import ray
|
|||
import ray.util.sgd.v2 as sgd
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface, \
|
||||
BackendExecutor
|
||||
|
@ -36,6 +37,14 @@ def ray_start_2_cpus_2_gpus():
|
|||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_8_cpus():
|
||||
address_info = ray.init(num_cpus=8)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class TestConfig(BackendConfig):
|
||||
@property
|
||||
def backend_cls(self):
|
||||
|
@ -212,6 +221,67 @@ def test_mismatch_report(ray_start_2_cpus):
|
|||
trainer.run(train)
|
||||
|
||||
|
||||
def test_run_iterator(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
def train_func():
|
||||
for i in range(3):
|
||||
sgd.report(index=i)
|
||||
return 1
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
iterator = trainer.run_iterator(train_func)
|
||||
|
||||
count = 0
|
||||
for results in iterator:
|
||||
assert (value["index"] == count for value in results)
|
||||
count += 1
|
||||
|
||||
assert count == 3
|
||||
assert iterator.is_finished()
|
||||
assert iterator.get_final_results() == [1, 1]
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(iterator)
|
||||
|
||||
|
||||
def test_run_iterator_returns(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
def train_func():
|
||||
for i in range(3):
|
||||
sgd.report(index=i)
|
||||
return 1
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
iterator = trainer.run_iterator(train_func)
|
||||
|
||||
assert iterator.get_final_results() is None
|
||||
assert iterator.get_final_results(force=True) == [1, 1]
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(iterator)
|
||||
|
||||
|
||||
def test_run_iterator_error(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
def fail_train():
|
||||
raise NotImplementedError
|
||||
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
trainer.start()
|
||||
iterator = trainer.run_iterator(fail_train)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
next(iterator)
|
||||
|
||||
assert iterator.get_final_results() is None
|
||||
assert iterator.is_finished()
|
||||
|
||||
|
||||
def test_checkpoint(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
|
||||
|
|
134
python/ray/util/sgd/v2/tests/test_tune.py
Normal file
134
python/ray/util/sgd/v2/tests/test_tune.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
import pytest
|
||||
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import TuneError
|
||||
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from ray.util.sgd.v2.backends.backend import BackendInterface, BackendConfig
|
||||
from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \
|
||||
tensorflow_mnist_train_func
|
||||
from ray.util.sgd.v2.examples.train_fashion_mnist import train_func as \
|
||||
fashion_mnist_train_func
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus_4_gpus():
|
||||
address_info = ray.init(num_cpus=2, num_gpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_8_cpus():
|
||||
address_info = ray.init(num_cpus=8)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class TestConfig(BackendConfig):
|
||||
@property
|
||||
def backend_cls(self):
|
||||
return TestBackend
|
||||
|
||||
|
||||
class TestBackend(BackendInterface):
|
||||
def on_start(self, worker_group: WorkerGroup, backend_config: TestConfig):
|
||||
pass
|
||||
|
||||
def on_shutdown(self, worker_group: WorkerGroup,
|
||||
backend_config: TestConfig):
|
||||
pass
|
||||
|
||||
|
||||
def torch_fashion_mnist(num_workers, use_gpu, num_samples):
|
||||
epochs = 2
|
||||
|
||||
trainer = Trainer("torch", num_workers=num_workers, use_gpu=use_gpu)
|
||||
MnistTrainable = trainer.to_tune_trainable(fashion_mnist_train_func)
|
||||
|
||||
analysis = tune.run(
|
||||
MnistTrainable,
|
||||
num_samples=num_samples,
|
||||
config={
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"epochs": epochs
|
||||
})
|
||||
|
||||
# Check that loss decreases in each trial.
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
assert df.loc[1, "loss"] < df.loc[0, "loss"]
|
||||
|
||||
|
||||
def test_tune_torch_fashion_mnist(ray_start_8_cpus):
|
||||
torch_fashion_mnist(num_workers=2, use_gpu=False, num_samples=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2,
|
||||
reason="Only run if multiple GPUs are available.")
|
||||
def test_tune_fashion_mnist_gpu(ray_start_4_cpus_4_gpus):
|
||||
torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1)
|
||||
|
||||
|
||||
def tune_tensorflow_mnist(num_workers, use_gpu, num_samples):
|
||||
epochs = 2
|
||||
trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu)
|
||||
MnistTrainable = trainer.to_tune_trainable(tensorflow_mnist_train_func)
|
||||
|
||||
analysis = tune.run(
|
||||
MnistTrainable,
|
||||
num_samples=num_samples,
|
||||
config={
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"epochs": epochs
|
||||
})
|
||||
|
||||
# Check that loss decreases in each trial.
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
assert df.loc[1, "loss"] < df.loc[0, "loss"]
|
||||
|
||||
|
||||
def test_tune_tensorflow_mnist(ray_start_8_cpus):
|
||||
tune_tensorflow_mnist(num_workers=2, use_gpu=False, num_samples=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
len(tf.config.list_physical_devices("GPU")) < 2,
|
||||
reason="Only run if multiple GPUs are available.")
|
||||
def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_4_gpus):
|
||||
tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1)
|
||||
|
||||
|
||||
def test_tune_error(ray_start_2_cpus):
|
||||
def train_func(config):
|
||||
raise RuntimeError("Error in training function!")
|
||||
|
||||
trainer = Trainer(TestConfig())
|
||||
TestTrainable = trainer.to_tune_trainable(train_func)
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
tune.run(TestTrainable)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
|
@ -1,9 +1,9 @@
|
|||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict
|
||||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
|
||||
Type, Iterator
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \
|
||||
InactiveWorkerGroupError, SGDBackendError
|
||||
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
|
||||
|
@ -11,6 +11,22 @@ from ray.util.sgd.v2.backends.torch import TorchConfig
|
|||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
|
||||
|
||||
# Ray SGD should be usable even if Tune is not installed.
|
||||
try:
|
||||
TUNE_INSTALLED = True
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune import PlacementGroupFactory
|
||||
from ray.tune.function_runner import wrap_function
|
||||
except ImportError:
|
||||
TUNE_INSTALLED = False
|
||||
tune = PlacementGroupFactory = Trainable = object
|
||||
|
||||
def noop():
|
||||
return
|
||||
|
||||
wrap_function = noop
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
@ -47,6 +63,12 @@ class Trainer:
|
|||
use_gpu: bool = False,
|
||||
resources_per_worker: Optional[Dict[str, float]] = None,
|
||||
logdir: Optional[str] = None):
|
||||
|
||||
self._backend = backend
|
||||
self._num_workers = num_workers
|
||||
self._use_gpu = use_gpu
|
||||
self._resources_per_worker = resources_per_worker
|
||||
|
||||
# Setup executor.
|
||||
backend_config = self._get_backend_config(backend)
|
||||
|
||||
|
@ -131,42 +153,83 @@ class Trainer:
|
|||
list corresponds to the output of the training function from
|
||||
each worker.
|
||||
"""
|
||||
train_func = self._get_train_func(train_func, config)
|
||||
# TODO(matt): Set default callbacks.
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
finished_with_errors = False
|
||||
|
||||
for callback in callbacks:
|
||||
callback.start_training()
|
||||
|
||||
try:
|
||||
for callback in callbacks:
|
||||
callback.start_training()
|
||||
self._executor.start_training(train_func, checkpoint,
|
||||
checkpoint_strategy)
|
||||
iterator = self.run_iterator(
|
||||
train_func=train_func,
|
||||
config=config,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy)
|
||||
for intermediate_result in iterator:
|
||||
for callback in callbacks:
|
||||
callback.handle_result(intermediate_result)
|
||||
|
||||
while True:
|
||||
intermediate_results = self._executor.fetch_next_result()
|
||||
if intermediate_results is None:
|
||||
break
|
||||
else:
|
||||
for callback in callbacks:
|
||||
callback.handle_result(intermediate_results)
|
||||
|
||||
return self._executor.finish_training()
|
||||
except InactiveWorkerGroupError:
|
||||
finished_with_errors = True
|
||||
raise RuntimeError(
|
||||
"This Trainer is not active. It is either shutdown already or "
|
||||
"never started in the first place. Either create a new "
|
||||
"Trainer or start this one.") from None
|
||||
except SGDBackendError:
|
||||
finished_with_errors = True
|
||||
raise RuntimeError("Training failed. You should not be seeing "
|
||||
"this error and this is a bug. Please create "
|
||||
"a new issue at "
|
||||
"https://github.com/ray-project/ray.") from None
|
||||
assert iterator.is_finished()
|
||||
return iterator.get_final_results()
|
||||
finally:
|
||||
for callback in callbacks:
|
||||
callback.finish_training(error=finished_with_errors)
|
||||
|
||||
def run_iterator(
|
||||
self,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
checkpoint_strategy: Optional[CheckpointStrategy] = None
|
||||
) -> Iterator[List[Dict]]:
|
||||
"""Same as ``run`` except returns an iterator over the results.
|
||||
|
||||
This is useful if you want to have more customization of what to do
|
||||
with the intermediate results or how to use the ``Trainer`` with Ray
|
||||
Tune.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def train_func(config):
|
||||
...
|
||||
for _ in config["epochs"]:
|
||||
metrics = train()
|
||||
metrics = validate(...)
|
||||
ray.sgd.report(**metrics)
|
||||
return model
|
||||
|
||||
iterator = trainer.run_iterator(train_func, config=config)
|
||||
|
||||
for result in iterator:
|
||||
do_stuff(result)
|
||||
latest_ckpt = trainer.get_latest_checkpoint()
|
||||
|
||||
assert iterator.is_finished()
|
||||
model = iterator.get_fin()[0]
|
||||
|
||||
Args:
|
||||
train_func (Callable): The training function to execute.
|
||||
This can either take in no arguments or a ``config`` dict.
|
||||
config (Optional[Dict]): Configurations to pass into
|
||||
``train_func``. If None then an empty Dict will be created.
|
||||
checkpoint (Optional[Dict]): The checkpoint data that should be
|
||||
loaded onto each worker and accessed by the training function
|
||||
via ``sgd.load_checkpoint()``.
|
||||
checkpoint_strategy (Optional[CheckpointStrategy]): The
|
||||
configurations for saving checkpoints.
|
||||
|
||||
Returns:
|
||||
An Iterator over the intermediate results from ``sgd.report()``.
|
||||
"""
|
||||
train_func = self._get_train_func(train_func, config)
|
||||
|
||||
return SGDIterator(
|
||||
backend_executor=self._executor,
|
||||
train_func=train_func,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy)
|
||||
|
||||
def _get_train_func(
|
||||
self,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
|
@ -273,8 +336,8 @@ class Trainer:
|
|||
"""Shuts down the training execution service."""
|
||||
self._executor.shutdown()
|
||||
|
||||
def to_tune_trainable(
|
||||
self, train_func: Callable[[Dict[str, Any]], T]) -> Trainable:
|
||||
def to_tune_trainable(self, train_func: Callable[[Dict[str, Any]], T]
|
||||
) -> Type[Trainable]:
|
||||
"""Creates a Tune ``Trainable`` from the input training function.
|
||||
|
||||
Args:
|
||||
|
@ -284,8 +347,149 @@ class Trainer:
|
|||
Returns:
|
||||
A Trainable that can directly be passed into ``tune.run()``.
|
||||
"""
|
||||
if not TUNE_INSTALLED:
|
||||
raise ValueError("Tune is not installed. Please install ray["
|
||||
"tune] to use the Tune integration.")
|
||||
|
||||
def trainable_func(config: Dict[str, Any]) -> T:
|
||||
pass
|
||||
if self._executor.is_started:
|
||||
raise RuntimeError("The Trainer must not be active to use "
|
||||
"`to_tune_trainable`. Either shutdown the "
|
||||
"Trainer or don't start it in the first place.")
|
||||
|
||||
raise NotImplementedError
|
||||
return _create_tune_trainable(train_func, self._backend,
|
||||
self._num_workers, self._use_gpu,
|
||||
self._resources_per_worker)
|
||||
|
||||
|
||||
class SGDIterator:
|
||||
"""An iterator over SGD results. Returned by ``trainer.run_iterator``."""
|
||||
|
||||
def __init__(
|
||||
self, backend_executor: BackendExecutor,
|
||||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
checkpoint: Optional[Dict],
|
||||
checkpoint_strategy: Optional[CheckpointStrategy]):
|
||||
self._executor = backend_executor
|
||||
self._run_with_error_handling(
|
||||
lambda: self._executor.start_training(
|
||||
train_func=train_func,
|
||||
checkpoint=checkpoint,
|
||||
checkpoint_strategy=checkpoint_strategy)
|
||||
)
|
||||
|
||||
self._final_results = None
|
||||
self._finished_training = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _run_with_error_handling(self, func: Callable):
|
||||
try:
|
||||
return func()
|
||||
except InactiveWorkerGroupError:
|
||||
raise RuntimeError(
|
||||
"This Trainer is not active. It is either shutdown "
|
||||
"already or never started in the first place. "
|
||||
"Either create a new Trainer or start this one.") \
|
||||
from None
|
||||
except SGDBackendError:
|
||||
raise RuntimeError("Training failed. You should not be seeing "
|
||||
"this error and this is a bug. Please create "
|
||||
"a new issue at "
|
||||
"https://github.com/ray-project/ray.") from None
|
||||
|
||||
def __next__(self):
|
||||
if self.is_finished():
|
||||
raise StopIteration
|
||||
next_results = self._run_with_error_handling(
|
||||
self._executor.fetch_next_result)
|
||||
if next_results is None:
|
||||
try:
|
||||
self._final_results = \
|
||||
self._run_with_error_handling(
|
||||
self._executor.finish_training)
|
||||
finally:
|
||||
self._finished_training = True
|
||||
raise StopIteration
|
||||
else:
|
||||
return next_results
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self._finished_training
|
||||
|
||||
def get_final_results(self, force: bool = False) -> List[T]:
|
||||
"""Gets the training func return values from each worker.
|
||||
|
||||
If ``force`` is ``True``, then immediately finish training
|
||||
and return even if all the intermediate results have not
|
||||
been processed yet. Else, intermediate results must be
|
||||
processed before obtaining the final results. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
if not self.is_finished():
|
||||
assert self._final_results is None
|
||||
if force:
|
||||
try:
|
||||
self._final_results = \
|
||||
self._run_with_error_handling(
|
||||
self._executor.finish_training)
|
||||
finally:
|
||||
self._finished_training = True
|
||||
else:
|
||||
logger.info("Please finish iterating through the "
|
||||
"intermediate results before getting the"
|
||||
"final returns. If you would like "
|
||||
"training to finish immediately and get "
|
||||
"the final returns, then set "
|
||||
"`force=True`.")
|
||||
|
||||
return self._final_results
|
||||
|
||||
|
||||
def _create_tune_trainable(train_func, backend, num_workers, use_gpu,
|
||||
resources_per_worker):
|
||||
"""Creates a Tune Trainable class for SGD training.
|
||||
|
||||
This function populates class attributes and methods.
|
||||
"""
|
||||
|
||||
# TODO(amog): Implement checkpointing for Tune.
|
||||
def tune_function(config, checkpoint_dir=None):
|
||||
trainer = Trainer(
|
||||
backend=backend,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
resources_per_worker=resources_per_worker)
|
||||
|
||||
trainer.start()
|
||||
|
||||
iterator = trainer.run_iterator(train_func, config)
|
||||
|
||||
for results in iterator:
|
||||
first_worker_results = results[0]
|
||||
|
||||
tune.report(**first_worker_results)
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
trainable_cls = wrap_function(tune_function)
|
||||
|
||||
class SgdTrainable(trainable_cls):
|
||||
"""Add default resources to the Trainable."""
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls,
|
||||
config: Dict) -> PlacementGroupFactory:
|
||||
head_bundle = [{"CPU": 1}] # driver
|
||||
worker_resources = {"CPU": 1, "GPU": int(use_gpu)}
|
||||
worker_resources_extra = {} if resources_per_worker is None else\
|
||||
resources_per_worker
|
||||
worker_bundles = [{
|
||||
**worker_resources,
|
||||
**worker_resources_extra
|
||||
} for _ in range(num_workers)]
|
||||
bundles = head_bundle + worker_bundles
|
||||
return PlacementGroupFactory(bundles, strategy="PACK")
|
||||
|
||||
return SgdTrainable
|
||||
|
|
Loading…
Add table
Reference in a new issue