{ "cells": [ { "cell_type": "markdown", "id": "8a6398d4", "metadata": {}, "source": [ "(tune-mxnet-example)=\n", "\n", "# Using MXNet with Tune\n", "\n", "```{image} /images/mxnet_logo.png\n", ":align: center\n", ":alt: MXNet Logo\n", ":height: 120px\n", ":target: https://mxnet.apache.org/\n", "```\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## Example" ] }, { "cell_type": "code", "execution_count": null, "id": "d3f38a2f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import mxnet as mx\n", "\n", "from ray import tune, logger\n", "from ray.tune.integration.mxnet import TuneCheckpointCallback, TuneReportCallback\n", "from ray.tune.schedulers import ASHAScheduler\n", "\n", "\n", "def train_mnist_mxnet(config, mnist, num_epochs=10):\n", " batch_size = config[\"batch_size\"]\n", " train_iter = mx.io.NDArrayIter(\n", " mnist[\"train_data\"], mnist[\"train_label\"], batch_size, shuffle=True\n", " )\n", " val_iter = mx.io.NDArrayIter(mnist[\"test_data\"], mnist[\"test_label\"], batch_size)\n", "\n", " data = mx.sym.var(\"data\")\n", " data = mx.sym.flatten(data=data)\n", "\n", " fc1 = mx.sym.FullyConnected(data=data, num_hidden=config[\"layer_1_size\"])\n", " act1 = mx.sym.Activation(data=fc1, act_type=\"relu\")\n", "\n", " fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config[\"layer_2_size\"])\n", " act2 = mx.sym.Activation(data=fc2, act_type=\"relu\")\n", "\n", " # MNIST has 10 classes\n", " fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)\n", " # Softmax with cross entropy loss\n", " mlp = mx.sym.SoftmaxOutput(data=fc3, name=\"softmax\")\n", "\n", " # create a trainable module on CPU\n", " mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())\n", " mlp_model.fit(\n", " train_iter,\n", " eval_data=val_iter,\n", " optimizer=\"sgd\",\n", " optimizer_params={\"learning_rate\": config[\"lr\"]},\n", " eval_metric=\"acc\",\n", " batch_end_callback=mx.callback.Speedometer(batch_size, 100),\n", " eval_end_callback=TuneReportCallback({\"mean_accuracy\": \"accuracy\"}),\n", " epoch_end_callback=TuneCheckpointCallback(filename=\"mxnet_cp\", frequency=3),\n", " num_epoch=num_epochs,\n", " )\n", "\n", "\n", "def tune_mnist_mxnet(num_samples=10, num_epochs=10):\n", " logger.info(\"Downloading MNIST data...\")\n", " mnist_data = mx.test_utils.get_mnist()\n", " logger.info(\"Got MNIST data, starting Ray Tune.\")\n", "\n", " config = {\n", " \"layer_1_size\": tune.choice([32, 64, 128]),\n", " \"layer_2_size\": tune.choice([64, 128, 256]),\n", " \"lr\": tune.loguniform(1e-3, 1e-1),\n", " \"batch_size\": tune.choice([32, 64, 128]),\n", " }\n", "\n", " scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)\n", " \n", " tuner = tune.Tuner(\n", " tune.with_parameters(\n", " train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs\n", " ),\n", " tune_config=tune.TuneConfig(\n", " metric=\"mean_accuracy\",\n", " mode=\"max\",\n", " scheduler=scheduler,\n", " num_samples=num_samples,\n", " ),\n", " param_space=config,\n", " )\n", " results = tuner.fit()\n", " return results\n", "\n", "\n", "if __name__ == \"__main__\":\n", " import argparse\n", "\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument(\n", " \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n", " )\n", " parser.add_argument(\n", " \"--server-address\",\n", " type=str,\n", " default=None,\n", " required=False,\n", " help=\"The address of server to connect to if using \" \"Ray Client.\",\n", " )\n", " args, _ = parser.parse_known_args()\n", "\n", " if args.server_address and not args.smoke_test:\n", " import ray\n", "\n", " ray.init(f\"ray://{args.server_address}\")\n", "\n", " if args.smoke_test:\n", " results = tune_mnist_mxnet(num_samples=1, num_epochs=1)\n", " else:\n", " results = tune_mnist_mxnet(num_samples=10, num_epochs=10)\n", "\n", " print(\"Best hyperparameters found were: \", results.get_best_result().config)\n" ] }, { "cell_type": "markdown", "id": "37ab0db6", "metadata": {}, "source": [ "## More MXNet Examples\n", "\n", "\n", "- {doc}`/tune/examples/includes/tune_cifar10_gluon`:\n", " MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }