diff --git a/doc/source/conf.py b/doc/source/conf.py index a350a2810..e17ded13b 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -39,6 +39,7 @@ MOCK_MODULES = [ "horovod", "horovod.ray", "kubernetes", + "mxnet.model", "psutil", "ray._raylet", "ray.core.generated", diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index ef29d89b3..24debbf94 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -23,6 +23,16 @@ Kubernetes (tune.integration.kubernetes) .. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer +.. _tune-integration-mxnet: + +MXNet (tune.integration.mxnet) +------------------------------ + +.. autoclass:: ray.tune.integration.mxnet.TuneReportCallback + +.. autoclass:: ray.tune.integration.mxnet.TuneCheckpointCallback + + .. _tune-integration-pytorch-lightning: PyTorch Lightning (tune.integration.pytorch_lightning) diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index af17d7c42..36303e6de 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -460,6 +460,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "mxnet_example", + size = "small", + srcs = ["examples/mxnet_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) + py_test( name = "nevergrad_example", size = "medium", diff --git a/python/ray/tune/examples/mxnet_example.py b/python/ray/tune/examples/mxnet_example.py new file mode 100644 index 000000000..b128c121d --- /dev/null +++ b/python/ray/tune/examples/mxnet_example.py @@ -0,0 +1,95 @@ +from functools import partial + +import mxnet as mx +from ray import tune, logger +from ray.tune import CLIReporter +from ray.tune.integration.mxnet import TuneCheckpointCallback, \ + TuneReportCallback +from ray.tune.schedulers import ASHAScheduler + + +def train_mnist_mxnet(config, mnist, num_epochs=10): + batch_size = config["batch_size"] + train_iter = mx.io.NDArrayIter( + mnist["train_data"], mnist["train_label"], batch_size, shuffle=True) + val_iter = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"], + batch_size) + + data = mx.sym.var("data") + data = mx.sym.flatten(data=data) + + fc1 = mx.sym.FullyConnected(data=data, num_hidden=config["layer_1_size"]) + act1 = mx.sym.Activation(data=fc1, act_type="relu") + + fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config["layer_2_size"]) + act2 = mx.sym.Activation(data=fc2, act_type="relu") + + # MNIST has 10 classes + fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) + # Softmax with cross entropy loss + mlp = mx.sym.SoftmaxOutput(data=fc3, name="softmax") + + # create a trainable module on CPU + mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu()) + mlp_model.fit( + train_iter, + eval_data=val_iter, + optimizer="sgd", + optimizer_params={"learning_rate": config["lr"]}, + eval_metric="acc", + batch_end_callback=mx.callback.Speedometer(batch_size, 100), + eval_end_callback=TuneReportCallback({ + "mean_accuracy": "accuracy" + }), + epoch_end_callback=TuneCheckpointCallback( + filename="mxnet_cp", frequency=3), + num_epoch=num_epochs) + + +def tune_mnist_mxnet(num_samples=10, num_epochs=10): + logger.info("Downloading MNIST data...") + mnist_data = mx.test_utils.get_mnist() + logger.info("Got MNIST data, starting Ray Tune.") + + config = { + "layer_1_size": tune.choice([32, 64, 128]), + "layer_2_size": tune.choice([64, 128, 256]), + "lr": tune.loguniform(1e-3, 1e-1), + "batch_size": tune.choice([32, 64, 128]) + } + + scheduler = ASHAScheduler( + metric="mean_accuracy", + mode="max", + max_t=num_epochs, + grace_period=1, + reduction_factor=2) + + reporter = CLIReporter( + parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"], + metric_columns=["loss", "mean_accuracy", "training_iteration"]) + + tune.run( + partial(train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs), + resources_per_trial={ + "cpu": 1, + }, + config=config, + num_samples=num_samples, + scheduler=scheduler, + progress_reporter=reporter, + name="tune_mnist_mxnet") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + tune_mnist_mxnet(num_samples=1, num_epochs=1) + else: + tune_mnist_mxnet(num_samples=10, num_epochs=10) diff --git a/python/ray/tune/integration/mxnet.py b/python/ray/tune/integration/mxnet.py new file mode 100644 index 000000000..435f2c34a --- /dev/null +++ b/python/ray/tune/integration/mxnet.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Union + +from ray import tune + +from mxnet.model import save_checkpoint + +import os + + +class TuneCallback: + """Base class for Tune's MXNet callbacks.""" + pass + + +class TuneReportCallback(TuneCallback): + """MXNet to Ray Tune reporting callback + + Reports metrics to Ray Tune. + + This has to be passed to MXNet as the ``eval_end_callback``. + + Args: + metrics (str|list|dict): Metrics to report to Tune. If this is a list, + each item describes the metric key reported to MXNet, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to MXNet. + + Example: + + .. code-block:: python + + from ray.tune.integration.mxnet import TuneReportCallback + + # mlp_model is a MXNet model + mlp_model.fit( + train_iter, + # ... + eval_metric="acc", + eval_end_callback=TuneReportCallback({ + "mean_accuracy": "accuracy" + })) + + """ + + def __init__(self, + metrics: Union[None, str, List[str], Dict[str, str]] = None): + if isinstance(metrics, str): + metrics = [metrics] + self._metrics = metrics + + def __call__(self, param): + if not param.eval_metric: + return + if not self._metrics: + report_dict = dict(param.eval_metric.get_name_value()) + else: + report_dict = {} + lookup_dict = dict(param.eval_metric.get_name_value()) + for key in self._metrics: + if isinstance(self._metrics, dict): + metric = self._metrics[key] + else: + metric = key + report_dict[key] = lookup_dict[metric] + tune.report(**report_dict) + + +class TuneCheckpointCallback(TuneCallback): + """MXNet checkpoint callback + + Saves checkpoints after each epoch. + + This has to be passed to the ``epoch_end_callback`` of the MXNet model. + + Checkpoint are currently not registered if no ``tune.report()`` call + is made afterwards. You have to use this in conjunction with the + ``TuneReportCallback`` to work! + + Args: + filename (str): Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + frequency (int): Integer indicating how often checkpoints should be + saved. + + Example: + + .. code-block:: python + + + from ray.tune.integration.mxnet import TuneReportCallback, \ + TuneCheckpointCallback + + # mlp_model is a MXNet model + mlp_model.fit( + train_iter, + # ... + eval_metric="acc", + eval_end_callback=TuneReportCallback({ + "mean_accuracy": "accuracy" + }), + epoch_end_callback=TuneCheckpointCallback( + filename="mxnet_cp", + frequency=3 + )) + + """ + + def __init__(self, filename: str = "checkpoint", frequency: int = 1): + self._filename = filename + self._frequency = frequency + + def __call__(self, epoch, sym, arg, aux): + if epoch % self._frequency != 0: + return + with tune.checkpoint_dir(step=epoch) as checkpoint_dir: + save_checkpoint( + os.path.join(checkpoint_dir, self._filename), epoch, sym, arg, + aux)