From 9416fce91b38d4a028be5ac97e3d0d271a7db378 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 20 Aug 2021 08:31:21 -0700 Subject: [PATCH] [SGD] v2 Tune integration + iterator API (#17839) * [SGD] implement SGD Trainer.to_tune_trainable * address some comments * add RESULT_DUPLICATE * extract trainable creation logic out of Trainer * add 1 CPU for driver * use class attribute to fix serialization issues * add examples * add test for tune error * tune * test tune_linear * run_iterator * add to build file * Update python/ray/util/sgd/v2/trainer.py Co-authored-by: matthewdeng * Update python/ray/util/sgd/v2/trainer.py Co-authored-by: matthewdeng * address comments * fix tests & address comments * resolve merge * lint * fix * add team tag to tests * fix tests * lint Co-authored-by: Matthew Deng --- python/ray/util/sgd/v2/BUILD | 27 ++ python/ray/util/sgd/v2/backends/backend.py | 15 +- .../sgd/v2/examples/mlflow_fashion_mnist.py | 64 +++++ .../v2/examples/tensorflow_mnist_example.py | 16 +- .../sgd/v2/examples/train_fashion_mnist.py | 3 +- .../ray/util/sgd/v2/examples/train_linear.py | 2 + .../ray/util/sgd/v2/examples/tune_linear.py | 56 ++++ .../sgd/v2/examples/tune_tensorflow_mnist.py | 59 ++++ python/ray/util/sgd/v2/tests/test_trainer.py | 70 +++++ python/ray/util/sgd/v2/tests/test_tune.py | 134 +++++++++ python/ray/util/sgd/v2/trainer.py | 270 +++++++++++++++--- 11 files changed, 679 insertions(+), 37 deletions(-) create mode 100644 python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py create mode 100644 python/ray/util/sgd/v2/examples/tune_linear.py create mode 100644 python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py create mode 100644 python/ray/util/sgd/v2/tests/test_tune.py diff --git a/python/ray/util/sgd/v2/BUILD b/python/ray/util/sgd/v2/BUILD index 7fae74595..712defc04 100644 --- a/python/ray/util/sgd/v2/BUILD +++ b/python/ray/util/sgd/v2/BUILD @@ -2,6 +2,15 @@ # Tests from the python/ray/util/sgd/v2/examples directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- +py_test( + name = "mlflow_fashion_mnist", + size = "medium", + main = "examples/mlflow_fashion_mnist.py", + srcs = ["examples/mlflow_fashion_mnist.py"], + tags = ["team:ml", "exclusive"], + deps = [":sgd_v2_lib"], + args = ["--smoke-test"] +) py_test( name = "transformers_example", @@ -15,6 +24,16 @@ py_test( "--max_train_steps=2", "--start_local", "--num_workers=2"] ) +py_test( + name = "tune_linear", + size = "medium", + main = "examples/tune_linear.py", + srcs = ["examples/tune_linear.py"], + tags = ["team:ml", "exclusive"], + deps = [":sgd_v2_lib"], + args = ["--smoke-test"] +) + # -------------------------------------------------------------------- # Tests from the python/ray/util/sgd/v2/tests directory. # Please keep these sorted alphabetically. @@ -44,6 +63,14 @@ py_test( deps = [":sgd_v2_lib"] ) +py_test( + name = "test_tune", + size = "medium", + srcs = ["tests/test_tune.py"], + tags = ["team:ml", "exclusive"], + deps = [":sgd_v2_lib"] +) + py_test( name = "test_worker_group", size = "small", diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index 58743f2fd..b2605df20 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -414,6 +414,10 @@ class BackendExecutor: self.worker_group.shutdown() self.worker_group = InactiveWorkerGroup() + @property + def is_started(self): + return not isinstance(self.worker_group, InactiveWorkerGroup) + @property def latest_run_dir(self) -> Optional[Path]: """Path to the latest run directory.""" @@ -455,7 +459,16 @@ class InactiveWorkerGroupError(Exception): class InactiveWorkerGroup(): # TODO: fix inheritence. perhaps create WorkerGroupInterface. - def __getattribute__(self, *args, **kwargs): + + # Need to define getstate and setstate so that getattr does not screwup + # pickling. See https://stackoverflow.com/a/50888571/11249691 + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + vars(self).update(state) + + def __getattr__(self, name): raise InactiveWorkerGroupError() def __len__(self): diff --git a/python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py b/python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py new file mode 100644 index 000000000..134b9849a --- /dev/null +++ b/python/ray/util/sgd/v2/examples/mlflow_fashion_mnist.py @@ -0,0 +1,64 @@ +import argparse + +import mlflow + +from ray.util.sgd.v2 import Trainer +from ray.util.sgd.v2.examples.train_fashion_mnist import train_func + + +def main(num_workers=1, use_gpu=False): + mlflow.set_experiment("sgd_torch_fashion_mnist") + + trainer = Trainer( + backend="torch", num_workers=num_workers, use_gpu=use_gpu) + trainer.start() + iterator = trainer.run_iterator( + train_func=train_func, + config={ + "lr": 1e-3, + "batch_size": 64, + "epochs": 4 + }) + + for intermediate_result in iterator: + first_worker_result = intermediate_result[0] + mlflow.log_metric("loss", first_worker_result["loss"]) + + print("Full losses for rank 0 worker: ", iterator.get_final_results()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Ray") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=1, + help="Sets number of workers for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Enables GPU training") + + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + args, _ = parser.parse_known_args() + + import ray + + if args.smoke_test: + ray.init(num_cpus=2) + args.num_workers = 2 + args.use_gpu = False + else: + ray.init(address=args.address) + main(num_workers=args.num_workers, use_gpu=args.use_gpu) diff --git a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py index 5d1139696..f87380cf9 100644 --- a/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py +++ b/python/ray/util/sgd/v2/examples/tensorflow_mnist_example.py @@ -7,9 +7,17 @@ import os import numpy as np import tensorflow as tf +from tensorflow.keras.callbacks import Callback + +import ray.util.sgd.v2 as sgd from ray.util.sgd.v2 import Trainer +class SGDReportCallback(Callback): + def on_epoch_end(self, epoch, logs=None): + sgd.report(**logs) + + def mnist_dataset(batch_size): (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() # The `x` arrays are in uint8 and have values in the [0, 255] range. @@ -56,8 +64,12 @@ def train_func(config): multi_worker_model = build_and_compile_cnn_model(config) history = multi_worker_model.fit( - multi_worker_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch) - return history.history + multi_worker_dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + callbacks=[SGDReportCallback()]) + results = history.history + return results def train_tensorflow_mnist(num_workers=1, use_gpu=False): diff --git a/python/ray/util/sgd/v2/examples/train_fashion_mnist.py b/python/ray/util/sgd/v2/examples/train_fashion_mnist.py index 023e84e5d..f52cc9ae8 100644 --- a/python/ray/util/sgd/v2/examples/train_fashion_mnist.py +++ b/python/ray/util/sgd/v2/examples/train_fashion_mnist.py @@ -117,7 +117,8 @@ def train_func(config: Dict): def train_fashion_mnist(num_workers=1, use_gpu=False): - trainer = Trainer(backend="torch", num_workers=num_workers) + trainer = Trainer( + backend="torch", num_workers=num_workers, use_gpu=use_gpu) trainer.start() result = trainer.run( train_func=train_func, diff --git a/python/ray/util/sgd/v2/examples/train_linear.py b/python/ray/util/sgd/v2/examples/train_linear.py index 76daf1a81..937832065 100644 --- a/python/ray/util/sgd/v2/examples/train_linear.py +++ b/python/ray/util/sgd/v2/examples/train_linear.py @@ -1,6 +1,7 @@ import argparse import numpy as np +import ray.util.sgd.v2 as sgd import torch import torch.nn as nn from ray.util.sgd.v2 import Trainer, TorchConfig @@ -80,6 +81,7 @@ def train_func(config): for _ in range(epochs): train(train_loader, model, loss_fn, optimizer) result = validate(validation_loader, model, loss_fn) + sgd.report(**result) results.append(result) return results diff --git a/python/ray/util/sgd/v2/examples/tune_linear.py b/python/ray/util/sgd/v2/examples/tune_linear.py new file mode 100644 index 000000000..c32dfd5d7 --- /dev/null +++ b/python/ray/util/sgd/v2/examples/tune_linear.py @@ -0,0 +1,56 @@ +import argparse + +import ray +from ray import tune +from ray.util.sgd.v2 import Trainer + +from train_linear import train_func + + +def tune_linear(num_workers, num_samples): + trainer = Trainer("torch", num_workers=num_workers) + Trainable = trainer.to_tune_trainable(train_func) + analysis = tune.run( + Trainable, + num_samples=num_samples, + config={ + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([4, 16, 32]), + "epochs": 3 + }) + results = analysis.get_best_config(metric="loss", mode="min") + print(results) + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Ray") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--num-samples", + type=int, + default=2, + help="Sets number of samples for training.") + + args = parser.parse_args() + + if args.smoke_test: + ray.init(num_cpus=4) + else: + ray.init(address=args.address) + tune_linear(num_workers=args.num_workers, num_samples=args.num_samples) diff --git a/python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py b/python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py new file mode 100644 index 000000000..0d1525ab7 --- /dev/null +++ b/python/ray/util/sgd/v2/examples/tune_tensorflow_mnist.py @@ -0,0 +1,59 @@ +import argparse + +import ray +from ray import tune +from ray.util.sgd.v2 import Trainer + +from tensorflow_mnist_example import train_func + + +def tune_tensorflow_mnist(num_workers, num_samples): + trainer = Trainer(backend="tensorflow", num_workers=num_workers) + Trainable = trainer.to_tune_trainable(train_func) + analysis = tune.run( + Trainable, + num_samples=num_samples, + config={ + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "epochs": 3 + }) + best_loss = analysis.get_best_config(metric="loss", mode="min") + best_accuracy = analysis.get_best_config(metric="accuracy", mode="max") + print(f"Best loss config: {best_loss}") + print(f"Best accuracy config: {best_accuracy}") + return analysis + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Ray") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=2, + help="Sets number of workers for training.") + parser.add_argument( + "--num-samples", + type=int, + default=2, + help="Sets number of samples for training.") + + args = parser.parse_args() + + if args.smoke_test: + ray.init(num_cpus=4) + else: + ray.init(address=args.address) + tune_tensorflow_mnist( + num_workers=args.num_workers, num_samples=args.num_samples) diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index ada3a5069..768766618 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -7,6 +7,7 @@ import ray import ray.util.sgd.v2 as sgd import tensorflow as tf import torch + from ray.util.sgd.v2 import Trainer from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface, \ BackendExecutor @@ -36,6 +37,14 @@ def ray_start_2_cpus_2_gpus(): ray.shutdown() +@pytest.fixture +def ray_start_8_cpus(): + address_info = ray.init(num_cpus=8) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + class TestConfig(BackendConfig): @property def backend_cls(self): @@ -212,6 +221,67 @@ def test_mismatch_report(ray_start_2_cpus): trainer.run(train) +def test_run_iterator(ray_start_2_cpus): + config = TestConfig() + + def train_func(): + for i in range(3): + sgd.report(index=i) + return 1 + + trainer = Trainer(config, num_workers=2) + trainer.start() + iterator = trainer.run_iterator(train_func) + + count = 0 + for results in iterator: + assert (value["index"] == count for value in results) + count += 1 + + assert count == 3 + assert iterator.is_finished() + assert iterator.get_final_results() == [1, 1] + + with pytest.raises(StopIteration): + next(iterator) + + +def test_run_iterator_returns(ray_start_2_cpus): + config = TestConfig() + + def train_func(): + for i in range(3): + sgd.report(index=i) + return 1 + + trainer = Trainer(config, num_workers=2) + trainer.start() + iterator = trainer.run_iterator(train_func) + + assert iterator.get_final_results() is None + assert iterator.get_final_results(force=True) == [1, 1] + + with pytest.raises(StopIteration): + next(iterator) + + +def test_run_iterator_error(ray_start_2_cpus): + config = TestConfig() + + def fail_train(): + raise NotImplementedError + + trainer = Trainer(config, num_workers=2) + trainer.start() + iterator = trainer.run_iterator(fail_train) + + with pytest.raises(NotImplementedError): + next(iterator) + + assert iterator.get_final_results() is None + assert iterator.is_finished() + + def test_checkpoint(ray_start_2_cpus): config = TestConfig() diff --git a/python/ray/util/sgd/v2/tests/test_tune.py b/python/ray/util/sgd/v2/tests/test_tune.py new file mode 100644 index 000000000..d056b55cd --- /dev/null +++ b/python/ray/util/sgd/v2/tests/test_tune.py @@ -0,0 +1,134 @@ +import pytest + +import torch +import tensorflow as tf + +import ray +from ray import tune +from ray.tune import TuneError + +from ray.util.sgd.v2 import Trainer +from ray.util.sgd.v2.backends.backend import BackendInterface, BackendConfig +from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ + tensorflow_mnist_train_func +from ray.util.sgd.v2.examples.train_fashion_mnist import train_func as \ + fashion_mnist_train_func +from ray.util.sgd.v2.worker_group import WorkerGroup + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_4_cpus_4_gpus(): + address_info = ray.init(num_cpus=2, num_gpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_8_cpus(): + address_info = ray.init(num_cpus=8) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestConfig(BackendConfig): + @property + def backend_cls(self): + return TestBackend + + +class TestBackend(BackendInterface): + def on_start(self, worker_group: WorkerGroup, backend_config: TestConfig): + pass + + def on_shutdown(self, worker_group: WorkerGroup, + backend_config: TestConfig): + pass + + +def torch_fashion_mnist(num_workers, use_gpu, num_samples): + epochs = 2 + + trainer = Trainer("torch", num_workers=num_workers, use_gpu=use_gpu) + MnistTrainable = trainer.to_tune_trainable(fashion_mnist_train_func) + + analysis = tune.run( + MnistTrainable, + num_samples=num_samples, + config={ + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "epochs": epochs + }) + + # Check that loss decreases in each trial. + for path, df in analysis.trial_dataframes.items(): + assert df.loc[1, "loss"] < df.loc[0, "loss"] + + +def test_tune_torch_fashion_mnist(ray_start_8_cpus): + torch_fashion_mnist(num_workers=2, use_gpu=False, num_samples=2) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Only run if multiple GPUs are available.") +def test_tune_fashion_mnist_gpu(ray_start_4_cpus_4_gpus): + torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +def tune_tensorflow_mnist(num_workers, use_gpu, num_samples): + epochs = 2 + trainer = Trainer("tensorflow", num_workers=num_workers, use_gpu=use_gpu) + MnistTrainable = trainer.to_tune_trainable(tensorflow_mnist_train_func) + + analysis = tune.run( + MnistTrainable, + num_samples=num_samples, + config={ + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "epochs": epochs + }) + + # Check that loss decreases in each trial. + for path, df in analysis.trial_dataframes.items(): + assert df.loc[1, "loss"] < df.loc[0, "loss"] + + +def test_tune_tensorflow_mnist(ray_start_8_cpus): + tune_tensorflow_mnist(num_workers=2, use_gpu=False, num_samples=2) + + +@pytest.mark.skipif( + len(tf.config.list_physical_devices("GPU")) < 2, + reason="Only run if multiple GPUs are available.") +def test_tune_tensorflow_mnist_gpu(ray_start_4_cpus_4_gpus): + tune_tensorflow_mnist(num_workers=2, use_gpu=True, num_samples=1) + + +def test_tune_error(ray_start_2_cpus): + def train_func(config): + raise RuntimeError("Error in training function!") + + trainer = Trainer(TestConfig()) + TestTrainable = trainer.to_tune_trainable(train_func) + + with pytest.raises(TuneError): + tune.run(TestTrainable) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/sgd/v2/trainer.py b/python/ray/util/sgd/v2/trainer.py index a3bed2860..4f23146ad 100644 --- a/python/ray/util/sgd/v2/trainer.py +++ b/python/ray/util/sgd/v2/trainer.py @@ -1,9 +1,9 @@ import inspect import logging from pathlib import Path -from typing import Union, Callable, List, TypeVar, Optional, Any, Dict +from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \ + Type, Iterator -from ray.tune import Trainable from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \ InactiveWorkerGroupError, SGDBackendError from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig @@ -11,6 +11,22 @@ from ray.util.sgd.v2.backends.torch import TorchConfig from ray.util.sgd.v2.callbacks.callback import SGDCallback from ray.util.sgd.v2.checkpoint import CheckpointStrategy +# Ray SGD should be usable even if Tune is not installed. +try: + TUNE_INSTALLED = True + from ray import tune + from ray.tune import Trainable + from ray.tune import PlacementGroupFactory + from ray.tune.function_runner import wrap_function +except ImportError: + TUNE_INSTALLED = False + tune = PlacementGroupFactory = Trainable = object + + def noop(): + return + + wrap_function = noop + T = TypeVar("T") S = TypeVar("S") @@ -47,6 +63,12 @@ class Trainer: use_gpu: bool = False, resources_per_worker: Optional[Dict[str, float]] = None, logdir: Optional[str] = None): + + self._backend = backend + self._num_workers = num_workers + self._use_gpu = use_gpu + self._resources_per_worker = resources_per_worker + # Setup executor. backend_config = self._get_backend_config(backend) @@ -131,42 +153,83 @@ class Trainer: list corresponds to the output of the training function from each worker. """ - train_func = self._get_train_func(train_func, config) # TODO(matt): Set default callbacks. callbacks = [] if callbacks is None else callbacks finished_with_errors = False + for callback in callbacks: + callback.start_training() + try: - for callback in callbacks: - callback.start_training() - self._executor.start_training(train_func, checkpoint, - checkpoint_strategy) + iterator = self.run_iterator( + train_func=train_func, + config=config, + checkpoint=checkpoint, + checkpoint_strategy=checkpoint_strategy) + for intermediate_result in iterator: + for callback in callbacks: + callback.handle_result(intermediate_result) - while True: - intermediate_results = self._executor.fetch_next_result() - if intermediate_results is None: - break - else: - for callback in callbacks: - callback.handle_result(intermediate_results) - - return self._executor.finish_training() - except InactiveWorkerGroupError: - finished_with_errors = True - raise RuntimeError( - "This Trainer is not active. It is either shutdown already or " - "never started in the first place. Either create a new " - "Trainer or start this one.") from None - except SGDBackendError: - finished_with_errors = True - raise RuntimeError("Training failed. You should not be seeing " - "this error and this is a bug. Please create " - "a new issue at " - "https://github.com/ray-project/ray.") from None + assert iterator.is_finished() + return iterator.get_final_results() finally: for callback in callbacks: callback.finish_training(error=finished_with_errors) + def run_iterator( + self, + train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], + config: Optional[Dict[str, Any]] = None, + checkpoint: Optional[Dict] = None, + checkpoint_strategy: Optional[CheckpointStrategy] = None + ) -> Iterator[List[Dict]]: + """Same as ``run`` except returns an iterator over the results. + + This is useful if you want to have more customization of what to do + with the intermediate results or how to use the ``Trainer`` with Ray + Tune. + + .. code-block:: python + + def train_func(config): + ... + for _ in config["epochs"]: + metrics = train() + metrics = validate(...) + ray.sgd.report(**metrics) + return model + + iterator = trainer.run_iterator(train_func, config=config) + + for result in iterator: + do_stuff(result) + latest_ckpt = trainer.get_latest_checkpoint() + + assert iterator.is_finished() + model = iterator.get_fin()[0] + + Args: + train_func (Callable): The training function to execute. + This can either take in no arguments or a ``config`` dict. + config (Optional[Dict]): Configurations to pass into + ``train_func``. If None then an empty Dict will be created. + checkpoint (Optional[Dict]): The checkpoint data that should be + loaded onto each worker and accessed by the training function + via ``sgd.load_checkpoint()``. + checkpoint_strategy (Optional[CheckpointStrategy]): The + configurations for saving checkpoints. + + Returns: + An Iterator over the intermediate results from ``sgd.report()``. + """ + train_func = self._get_train_func(train_func, config) + + return SGDIterator( + backend_executor=self._executor, + train_func=train_func, + checkpoint=checkpoint, + checkpoint_strategy=checkpoint_strategy) + def _get_train_func( self, train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], @@ -273,8 +336,8 @@ class Trainer: """Shuts down the training execution service.""" self._executor.shutdown() - def to_tune_trainable( - self, train_func: Callable[[Dict[str, Any]], T]) -> Trainable: + def to_tune_trainable(self, train_func: Callable[[Dict[str, Any]], T] + ) -> Type[Trainable]: """Creates a Tune ``Trainable`` from the input training function. Args: @@ -284,8 +347,149 @@ class Trainer: Returns: A Trainable that can directly be passed into ``tune.run()``. """ + if not TUNE_INSTALLED: + raise ValueError("Tune is not installed. Please install ray[" + "tune] to use the Tune integration.") - def trainable_func(config: Dict[str, Any]) -> T: - pass + if self._executor.is_started: + raise RuntimeError("The Trainer must not be active to use " + "`to_tune_trainable`. Either shutdown the " + "Trainer or don't start it in the first place.") - raise NotImplementedError + return _create_tune_trainable(train_func, self._backend, + self._num_workers, self._use_gpu, + self._resources_per_worker) + + +class SGDIterator: + """An iterator over SGD results. Returned by ``trainer.run_iterator``.""" + + def __init__( + self, backend_executor: BackendExecutor, + train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], + checkpoint: Optional[Dict], + checkpoint_strategy: Optional[CheckpointStrategy]): + self._executor = backend_executor + self._run_with_error_handling( + lambda: self._executor.start_training( + train_func=train_func, + checkpoint=checkpoint, + checkpoint_strategy=checkpoint_strategy) + ) + + self._final_results = None + self._finished_training = False + + def __iter__(self): + return self + + def _run_with_error_handling(self, func: Callable): + try: + return func() + except InactiveWorkerGroupError: + raise RuntimeError( + "This Trainer is not active. It is either shutdown " + "already or never started in the first place. " + "Either create a new Trainer or start this one.") \ + from None + except SGDBackendError: + raise RuntimeError("Training failed. You should not be seeing " + "this error and this is a bug. Please create " + "a new issue at " + "https://github.com/ray-project/ray.") from None + + def __next__(self): + if self.is_finished(): + raise StopIteration + next_results = self._run_with_error_handling( + self._executor.fetch_next_result) + if next_results is None: + try: + self._final_results = \ + self._run_with_error_handling( + self._executor.finish_training) + finally: + self._finished_training = True + raise StopIteration + else: + return next_results + + def is_finished(self) -> bool: + return self._finished_training + + def get_final_results(self, force: bool = False) -> List[T]: + """Gets the training func return values from each worker. + + If ``force`` is ``True``, then immediately finish training + and return even if all the intermediate results have not + been processed yet. Else, intermediate results must be + processed before obtaining the final results. Defaults to + False. + """ + + if not self.is_finished(): + assert self._final_results is None + if force: + try: + self._final_results = \ + self._run_with_error_handling( + self._executor.finish_training) + finally: + self._finished_training = True + else: + logger.info("Please finish iterating through the " + "intermediate results before getting the" + "final returns. If you would like " + "training to finish immediately and get " + "the final returns, then set " + "`force=True`.") + + return self._final_results + + +def _create_tune_trainable(train_func, backend, num_workers, use_gpu, + resources_per_worker): + """Creates a Tune Trainable class for SGD training. + + This function populates class attributes and methods. + """ + + # TODO(amog): Implement checkpointing for Tune. + def tune_function(config, checkpoint_dir=None): + trainer = Trainer( + backend=backend, + num_workers=num_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker) + + trainer.start() + + iterator = trainer.run_iterator(train_func, config) + + for results in iterator: + first_worker_results = results[0] + + tune.report(**first_worker_results) + + trainer.shutdown() + + trainable_cls = wrap_function(tune_function) + + class SgdTrainable(trainable_cls): + """Add default resources to the Trainable.""" + + @classmethod + def default_resource_request(cls, + config: Dict) -> PlacementGroupFactory: + head_bundle = [{"CPU": 1}] # driver + worker_resources = {"CPU": 1, "GPU": int(use_gpu)} + worker_resources_extra = {} if resources_per_worker is None else\ + resources_per_worker + worker_bundles = [{ + **worker_resources, + **worker_resources_extra + } for _ in range(num_workers)] + bundles = head_bundle + worker_bundles + return PlacementGroupFactory(bundles, strategy="PACK") + + return SgdTrainable