{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "(tune-mxnet-example)=\n",
    "\n",
    "# Using MXNet with Tune\n",
    "\n",
    "```{image} /images/mxnet_logo.png\n",
    ":align: center\n",
    ":alt: MXNet Logo\n",
    ":height: 120px\n",
    ":target: https://mxnet.apache.org/\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "\n",
    "from ray import tune, logger\n",
    "from ray.tune.integration.mxnet import TuneCheckpointCallback, TuneReportCallback\n",
    "from ray.tune.schedulers import ASHAScheduler\n",
    "\n",
    "\n",
    "def train_mnist_mxnet(config, mnist, num_epochs=10):\n",
    "    batch_size = config[\"batch_size\"]\n",
    "    train_iter = mx.io.NDArrayIter(\n",
    "        mnist[\"train_data\"], mnist[\"train_label\"], batch_size, shuffle=True\n",
    "    )\n",
    "    val_iter = mx.io.NDArrayIter(mnist[\"test_data\"], mnist[\"test_label\"], batch_size)\n",
    "\n",
    "    data = mx.sym.var(\"data\")\n",
    "    data = mx.sym.flatten(data=data)\n",
    "\n",
    "    fc1 = mx.sym.FullyConnected(data=data, num_hidden=config[\"layer_1_size\"])\n",
    "    act1 = mx.sym.Activation(data=fc1, act_type=\"relu\")\n",
    "\n",
    "    fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config[\"layer_2_size\"])\n",
    "    act2 = mx.sym.Activation(data=fc2, act_type=\"relu\")\n",
    "\n",
    "    # MNIST has 10 classes\n",
    "    fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)\n",
    "    # Softmax with cross entropy loss\n",
    "    mlp = mx.sym.SoftmaxOutput(data=fc3, name=\"softmax\")\n",
    "\n",
    "    # create a trainable module on CPU\n",
    "    mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())\n",
    "    mlp_model.fit(\n",
    "        train_iter,\n",
    "        eval_data=val_iter,\n",
    "        optimizer=\"sgd\",\n",
    "        optimizer_params={\"learning_rate\": config[\"lr\"]},\n",
    "        eval_metric=\"acc\",\n",
    "        batch_end_callback=mx.callback.Speedometer(batch_size, 100),\n",
    "        eval_end_callback=TuneReportCallback({\"mean_accuracy\": \"accuracy\"}),\n",
    "        epoch_end_callback=TuneCheckpointCallback(filename=\"mxnet_cp\", frequency=3),\n",
    "        num_epoch=num_epochs,\n",
    "    )\n",
    "\n",
    "\n",
    "def tune_mnist_mxnet(num_samples=10, num_epochs=10):\n",
    "    logger.info(\"Downloading MNIST data...\")\n",
    "    mnist_data = mx.test_utils.get_mnist()\n",
    "    logger.info(\"Got MNIST data, starting Ray Tune.\")\n",
    "\n",
    "    config = {\n",
    "        \"layer_1_size\": tune.choice([32, 64, 128]),\n",
    "        \"layer_2_size\": tune.choice([64, 128, 256]),\n",
    "        \"lr\": tune.loguniform(1e-3, 1e-1),\n",
    "        \"batch_size\": tune.choice([32, 64, 128]),\n",
    "    }\n",
    "\n",
    "    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)\n",
    "\n",
    "    analysis = tune.run(\n",
    "        tune.with_parameters(\n",
    "            train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs\n",
    "        ),\n",
    "        resources_per_trial={\n",
    "            \"cpu\": 1,\n",
    "        },\n",
    "        metric=\"mean_accuracy\",\n",
    "        mode=\"max\",\n",
    "        config=config,\n",
    "        num_samples=num_samples,\n",
    "        scheduler=scheduler,\n",
    "        name=\"tune_mnist_mxnet\",\n",
    "    )\n",
    "    return analysis\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--server-address\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        required=False,\n",
    "        help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
    "    )\n",
    "    args, _ = parser.parse_known_args()\n",
    "\n",
    "    if args.server_address and not args.smoke_test:\n",
    "        import ray\n",
    "\n",
    "        ray.init(f\"ray://{args.server_address}\")\n",
    "\n",
    "    if args.smoke_test:\n",
    "        analysis = tune_mnist_mxnet(num_samples=1, num_epochs=1)\n",
    "    else:\n",
    "        analysis = tune_mnist_mxnet(num_samples=10, num_epochs=10)\n",
    "\n",
    "    print(\"Best hyperparameters found were: \", analysis.best_config)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## More MXNet Examples\n",
    "\n",
    "\n",
    "- {doc}`/tune/examples/includes/tune_cifar10_gluon`:\n",
    "  MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset.\n"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}