mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] added MXNet integration callbacks (#10533)
This commit is contained in:
parent
ead30ca655
commit
5c3d4a6670
5 changed files with 234 additions and 0 deletions
|
@ -39,6 +39,7 @@ MOCK_MODULES = [
|
|||
"horovod",
|
||||
"horovod.ray",
|
||||
"kubernetes",
|
||||
"mxnet.model",
|
||||
"psutil",
|
||||
"ray._raylet",
|
||||
"ray.core.generated",
|
||||
|
|
|
@ -23,6 +23,16 @@ Kubernetes (tune.integration.kubernetes)
|
|||
|
||||
.. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer
|
||||
|
||||
.. _tune-integration-mxnet:
|
||||
|
||||
MXNet (tune.integration.mxnet)
|
||||
------------------------------
|
||||
|
||||
.. autoclass:: ray.tune.integration.mxnet.TuneReportCallback
|
||||
|
||||
.. autoclass:: ray.tune.integration.mxnet.TuneCheckpointCallback
|
||||
|
||||
|
||||
.. _tune-integration-pytorch-lightning:
|
||||
|
||||
PyTorch Lightning (tune.integration.pytorch_lightning)
|
||||
|
|
|
@ -460,6 +460,15 @@ py_test(
|
|||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mxnet_example",
|
||||
size = "small",
|
||||
srcs = ["examples/mxnet_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "nevergrad_example",
|
||||
size = "medium",
|
||||
|
|
95
python/ray/tune/examples/mxnet_example.py
Normal file
95
python/ray/tune/examples/mxnet_example.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
from functools import partial
|
||||
|
||||
import mxnet as mx
|
||||
from ray import tune, logger
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.integration.mxnet import TuneCheckpointCallback, \
|
||||
TuneReportCallback
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
|
||||
|
||||
def train_mnist_mxnet(config, mnist, num_epochs=10):
|
||||
batch_size = config["batch_size"]
|
||||
train_iter = mx.io.NDArrayIter(
|
||||
mnist["train_data"], mnist["train_label"], batch_size, shuffle=True)
|
||||
val_iter = mx.io.NDArrayIter(mnist["test_data"], mnist["test_label"],
|
||||
batch_size)
|
||||
|
||||
data = mx.sym.var("data")
|
||||
data = mx.sym.flatten(data=data)
|
||||
|
||||
fc1 = mx.sym.FullyConnected(data=data, num_hidden=config["layer_1_size"])
|
||||
act1 = mx.sym.Activation(data=fc1, act_type="relu")
|
||||
|
||||
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config["layer_2_size"])
|
||||
act2 = mx.sym.Activation(data=fc2, act_type="relu")
|
||||
|
||||
# MNIST has 10 classes
|
||||
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
|
||||
# Softmax with cross entropy loss
|
||||
mlp = mx.sym.SoftmaxOutput(data=fc3, name="softmax")
|
||||
|
||||
# create a trainable module on CPU
|
||||
mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())
|
||||
mlp_model.fit(
|
||||
train_iter,
|
||||
eval_data=val_iter,
|
||||
optimizer="sgd",
|
||||
optimizer_params={"learning_rate": config["lr"]},
|
||||
eval_metric="acc",
|
||||
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
|
||||
eval_end_callback=TuneReportCallback({
|
||||
"mean_accuracy": "accuracy"
|
||||
}),
|
||||
epoch_end_callback=TuneCheckpointCallback(
|
||||
filename="mxnet_cp", frequency=3),
|
||||
num_epoch=num_epochs)
|
||||
|
||||
|
||||
def tune_mnist_mxnet(num_samples=10, num_epochs=10):
|
||||
logger.info("Downloading MNIST data...")
|
||||
mnist_data = mx.test_utils.get_mnist()
|
||||
logger.info("Got MNIST data, starting Ray Tune.")
|
||||
|
||||
config = {
|
||||
"layer_1_size": tune.choice([32, 64, 128]),
|
||||
"layer_2_size": tune.choice([64, 128, 256]),
|
||||
"lr": tune.loguniform(1e-3, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128])
|
||||
}
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
max_t=num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
|
||||
tune.run(
|
||||
partial(train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs),
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter,
|
||||
name="tune_mnist_mxnet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
tune_mnist_mxnet(num_samples=1, num_epochs=1)
|
||||
else:
|
||||
tune_mnist_mxnet(num_samples=10, num_epochs=10)
|
119
python/ray/tune/integration/mxnet.py
Normal file
119
python/ray/tune/integration/mxnet.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
from typing import Dict, List, Union
|
||||
|
||||
from ray import tune
|
||||
|
||||
from mxnet.model import save_checkpoint
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class TuneCallback:
|
||||
"""Base class for Tune's MXNet callbacks."""
|
||||
pass
|
||||
|
||||
|
||||
class TuneReportCallback(TuneCallback):
|
||||
"""MXNet to Ray Tune reporting callback
|
||||
|
||||
Reports metrics to Ray Tune.
|
||||
|
||||
This has to be passed to MXNet as the ``eval_end_callback``.
|
||||
|
||||
Args:
|
||||
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
|
||||
each item describes the metric key reported to MXNet,
|
||||
and it will reported under the same name to Tune. If this is a
|
||||
dict, each key will be the name reported to Tune and the respective
|
||||
value will be the metric key reported to MXNet.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mxnet import TuneReportCallback
|
||||
|
||||
# mlp_model is a MXNet model
|
||||
mlp_model.fit(
|
||||
train_iter,
|
||||
# ...
|
||||
eval_metric="acc",
|
||||
eval_end_callback=TuneReportCallback({
|
||||
"mean_accuracy": "accuracy"
|
||||
}))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[None, str, List[str], Dict[str, str]] = None):
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def __call__(self, param):
|
||||
if not param.eval_metric:
|
||||
return
|
||||
if not self._metrics:
|
||||
report_dict = dict(param.eval_metric.get_name_value())
|
||||
else:
|
||||
report_dict = {}
|
||||
lookup_dict = dict(param.eval_metric.get_name_value())
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
metric = self._metrics[key]
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = lookup_dict[metric]
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
class TuneCheckpointCallback(TuneCallback):
|
||||
"""MXNet checkpoint callback
|
||||
|
||||
Saves checkpoints after each epoch.
|
||||
|
||||
This has to be passed to the ``epoch_end_callback`` of the MXNet model.
|
||||
|
||||
Checkpoint are currently not registered if no ``tune.report()`` call
|
||||
is made afterwards. You have to use this in conjunction with the
|
||||
``TuneReportCallback`` to work!
|
||||
|
||||
Args:
|
||||
filename (str): Filename of the checkpoint within the checkpoint
|
||||
directory. Defaults to "checkpoint".
|
||||
frequency (int): Integer indicating how often checkpoints should be
|
||||
saved.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from ray.tune.integration.mxnet import TuneReportCallback, \
|
||||
TuneCheckpointCallback
|
||||
|
||||
# mlp_model is a MXNet model
|
||||
mlp_model.fit(
|
||||
train_iter,
|
||||
# ...
|
||||
eval_metric="acc",
|
||||
eval_end_callback=TuneReportCallback({
|
||||
"mean_accuracy": "accuracy"
|
||||
}),
|
||||
epoch_end_callback=TuneCheckpointCallback(
|
||||
filename="mxnet_cp",
|
||||
frequency=3
|
||||
))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str = "checkpoint", frequency: int = 1):
|
||||
self._filename = filename
|
||||
self._frequency = frequency
|
||||
|
||||
def __call__(self, epoch, sym, arg, aux):
|
||||
if epoch % self._frequency != 0:
|
||||
return
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
save_checkpoint(
|
||||
os.path.join(checkpoint_dir, self._filename), epoch, sym, arg,
|
||||
aux)
|
Loading…
Add table
Reference in a new issue