[tune] added MXNet integration callbacks (#10533)

This commit is contained in:
Kai Fricke 2020-09-04 02:06:44 +01:00 committed by GitHub
parent ead30ca655
commit 5c3d4a6670
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 234 additions and 0 deletions

View file

@ -39,6 +39,7 @@ MOCK_MODULES = [
"horovod", "horovod",
"horovod.ray", "horovod.ray",
"kubernetes", "kubernetes",
"mxnet.model",
"psutil", "psutil",
"ray._raylet", "ray._raylet",
"ray.core.generated", "ray.core.generated",

View file

@ -23,6 +23,16 @@ Kubernetes (tune.integration.kubernetes)
.. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer .. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer
.. _tune-integration-mxnet:
MXNet (tune.integration.mxnet)
------------------------------
.. autoclass:: ray.tune.integration.mxnet.TuneReportCallback
.. autoclass:: ray.tune.integration.mxnet.TuneCheckpointCallback
.. _tune-integration-pytorch-lightning: .. _tune-integration-pytorch-lightning:
PyTorch Lightning (tune.integration.pytorch_lightning) PyTorch Lightning (tune.integration.pytorch_lightning)

View file

@ -460,6 +460,15 @@ py_test(
args = ["--smoke-test"] args = ["--smoke-test"]
) )
py_test(
name = "mxnet_example",
size = "small",
srcs = ["examples/mxnet_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"],
args = ["--smoke-test"]
)
py_test( py_test(
name = "nevergrad_example", name = "nevergrad_example",
size = "medium", size = "medium",

View file

@ -0,0 +1,95 @@
from functools import partial
import mxnet as mx
from ray import tune, logger
from ray.tune import CLIReporter
from ray.tune.integration.mxnet import TuneCheckpointCallback, \
TuneReportCallback
from ray.tune.schedulers import ASHAScheduler
def train_mnist_mxnet(config, mnist, num_epochs=10):
batch_size = config["batch_size"]
train_iter = mx.io.NDArrayIter(
mnist["train_data"], mnist["train_label"], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"],
batch_size)
data = mx.sym.var("data")
data = mx.sym.flatten(data=data)
fc1 = mx.sym.FullyConnected(data=data, num_hidden=config["layer_1_size"])
act1 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config["layer_2_size"])
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp = mx.sym.SoftmaxOutput(data=fc3, name="softmax")
# create a trainable module on CPU
mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())
mlp_model.fit(
train_iter,
eval_data=val_iter,
optimizer="sgd",
optimizer_params={"learning_rate": config["lr"]},
eval_metric="acc",
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
eval_end_callback=TuneReportCallback({
"mean_accuracy": "accuracy"
}),
epoch_end_callback=TuneCheckpointCallback(
filename="mxnet_cp", frequency=3),
num_epoch=num_epochs)
def tune_mnist_mxnet(num_samples=10, num_epochs=10):
logger.info("Downloading MNIST data...")
mnist_data = mx.test_utils.get_mnist()
logger.info("Got MNIST data, starting Ray Tune.")
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-3, 1e-1),
"batch_size": tune.choice([32, 64, 128])
}
scheduler = ASHAScheduler(
metric="mean_accuracy",
mode="max",
max_t=num_epochs,
grace_period=1,
reduction_factor=2)
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
tune.run(
partial(train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs),
resources_per_trial={
"cpu": 1,
},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter,
name="tune_mnist_mxnet")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist_mxnet(num_samples=1, num_epochs=1)
else:
tune_mnist_mxnet(num_samples=10, num_epochs=10)

View file

@ -0,0 +1,119 @@
from typing import Dict, List, Union
from ray import tune
from mxnet.model import save_checkpoint
import os
class TuneCallback:
"""Base class for Tune's MXNet callbacks."""
pass
class TuneReportCallback(TuneCallback):
"""MXNet to Ray Tune reporting callback
Reports metrics to Ray Tune.
This has to be passed to MXNet as the ``eval_end_callback``.
Args:
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
each item describes the metric key reported to MXNet,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to MXNet.
Example:
.. code-block:: python
from ray.tune.integration.mxnet import TuneReportCallback
# mlp_model is a MXNet model
mlp_model.fit(
train_iter,
# ...
eval_metric="acc",
eval_end_callback=TuneReportCallback({
"mean_accuracy": "accuracy"
}))
"""
def __init__(self,
metrics: Union[None, str, List[str], Dict[str, str]] = None):
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
def __call__(self, param):
if not param.eval_metric:
return
if not self._metrics:
report_dict = dict(param.eval_metric.get_name_value())
else:
report_dict = {}
lookup_dict = dict(param.eval_metric.get_name_value())
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = lookup_dict[metric]
tune.report(**report_dict)
class TuneCheckpointCallback(TuneCallback):
"""MXNet checkpoint callback
Saves checkpoints after each epoch.
This has to be passed to the ``epoch_end_callback`` of the MXNet model.
Checkpoint are currently not registered if no ``tune.report()`` call
is made afterwards. You have to use this in conjunction with the
``TuneReportCallback`` to work!
Args:
filename (str): Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
frequency (int): Integer indicating how often checkpoints should be
saved.
Example:
.. code-block:: python
from ray.tune.integration.mxnet import TuneReportCallback, \
TuneCheckpointCallback
# mlp_model is a MXNet model
mlp_model.fit(
train_iter,
# ...
eval_metric="acc",
eval_end_callback=TuneReportCallback({
"mean_accuracy": "accuracy"
}),
epoch_end_callback=TuneCheckpointCallback(
filename="mxnet_cp",
frequency=3
))
"""
def __init__(self, filename: str = "checkpoint", frequency: int = 1):
self._filename = filename
self._frequency = frequency
def __call__(self, epoch, sym, arg, aux):
if epoch % self._frequency != 0:
return
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
save_checkpoint(
os.path.join(checkpoint_dir, self._filename), epoch, sym, arg,
aux)