ray/doc/source/tune/examples/mxnet_example.ipynb
Kai Fricke 803c094534
[air/tuner/docs] Update docs for Tuner() API 2b: Tune examples (ipynb) (#26884)
This PR updates the Ray AIR/Tune ipynb examples to use the Tuner() API instead of tune.run().

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Richard Liaw <rliaw@berkeley.edu>
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Signed-off-by: Kai Fricke <coding@kaifricke.com>

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
2022-07-24 18:53:57 +01:00

175 lines
5.5 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"id": "8a6398d4",
"metadata": {},
"source": [
"(tune-mxnet-example)=\n",
"\n",
"# Using MXNet with Tune\n",
"\n",
"```{image} /images/mxnet_logo.png\n",
":align: center\n",
":alt: MXNet Logo\n",
":height: 120px\n",
":target: https://mxnet.apache.org/\n",
"```\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```\n",
"\n",
"## Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3f38a2f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import mxnet as mx\n",
"\n",
"from ray import tune, logger\n",
"from ray.tune.integration.mxnet import TuneCheckpointCallback, TuneReportCallback\n",
"from ray.tune.schedulers import ASHAScheduler\n",
"\n",
"\n",
"def train_mnist_mxnet(config, mnist, num_epochs=10):\n",
" batch_size = config[\"batch_size\"]\n",
" train_iter = mx.io.NDArrayIter(\n",
" mnist[\"train_data\"], mnist[\"train_label\"], batch_size, shuffle=True\n",
" )\n",
" val_iter = mx.io.NDArrayIter(mnist[\"test_data\"], mnist[\"test_label\"], batch_size)\n",
"\n",
" data = mx.sym.var(\"data\")\n",
" data = mx.sym.flatten(data=data)\n",
"\n",
" fc1 = mx.sym.FullyConnected(data=data, num_hidden=config[\"layer_1_size\"])\n",
" act1 = mx.sym.Activation(data=fc1, act_type=\"relu\")\n",
"\n",
" fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config[\"layer_2_size\"])\n",
" act2 = mx.sym.Activation(data=fc2, act_type=\"relu\")\n",
"\n",
" # MNIST has 10 classes\n",
" fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)\n",
" # Softmax with cross entropy loss\n",
" mlp = mx.sym.SoftmaxOutput(data=fc3, name=\"softmax\")\n",
"\n",
" # create a trainable module on CPU\n",
" mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())\n",
" mlp_model.fit(\n",
" train_iter,\n",
" eval_data=val_iter,\n",
" optimizer=\"sgd\",\n",
" optimizer_params={\"learning_rate\": config[\"lr\"]},\n",
" eval_metric=\"acc\",\n",
" batch_end_callback=mx.callback.Speedometer(batch_size, 100),\n",
" eval_end_callback=TuneReportCallback({\"mean_accuracy\": \"accuracy\"}),\n",
" epoch_end_callback=TuneCheckpointCallback(filename=\"mxnet_cp\", frequency=3),\n",
" num_epoch=num_epochs,\n",
" )\n",
"\n",
"\n",
"def tune_mnist_mxnet(num_samples=10, num_epochs=10):\n",
" logger.info(\"Downloading MNIST data...\")\n",
" mnist_data = mx.test_utils.get_mnist()\n",
" logger.info(\"Got MNIST data, starting Ray Tune.\")\n",
"\n",
" config = {\n",
" \"layer_1_size\": tune.choice([32, 64, 128]),\n",
" \"layer_2_size\": tune.choice([64, 128, 256]),\n",
" \"lr\": tune.loguniform(1e-3, 1e-1),\n",
" \"batch_size\": tune.choice([32, 64, 128]),\n",
" }\n",
"\n",
" scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)\n",
" \n",
" tuner = tune.Tuner(\n",
" tune.with_parameters(\n",
" train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs\n",
" ),\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"mean_accuracy\",\n",
" mode=\"max\",\n",
" scheduler=scheduler,\n",
" num_samples=num_samples,\n",
" ),\n",
" param_space=config,\n",
" )\n",
" results = tuner.fit()\n",
" return results\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" import argparse\n",
"\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\n",
" \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
" )\n",
" parser.add_argument(\n",
" \"--server-address\",\n",
" type=str,\n",
" default=None,\n",
" required=False,\n",
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
" )\n",
" args, _ = parser.parse_known_args()\n",
"\n",
" if args.server_address and not args.smoke_test:\n",
" import ray\n",
"\n",
" ray.init(f\"ray://{args.server_address}\")\n",
"\n",
" if args.smoke_test:\n",
" results = tune_mnist_mxnet(num_samples=1, num_epochs=1)\n",
" else:\n",
" results = tune_mnist_mxnet(num_samples=10, num_epochs=10)\n",
"\n",
" print(\"Best hyperparameters found were: \", results.get_best_result().config)\n"
]
},
{
"cell_type": "markdown",
"id": "37ab0db6",
"metadata": {},
"source": [
"## More MXNet Examples\n",
"\n",
"\n",
"- {doc}`/tune/examples/includes/tune_cifar10_gluon`:\n",
" MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
},
"orphan": true
},
"nbformat": 4,
"nbformat_minor": 5
}