mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

This PR updates the Ray AIR/Tune ipynb examples to use the Tuner() API instead of tune.run(). Signed-off-by: Kai Fricke <kai@anyscale.com> Signed-off-by: Richard Liaw <rliaw@berkeley.edu> Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Signed-off-by: Kai Fricke <coding@kaifricke.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
175 lines
5.5 KiB
Text
175 lines
5.5 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8a6398d4",
|
|
"metadata": {},
|
|
"source": [
|
|
"(tune-mxnet-example)=\n",
|
|
"\n",
|
|
"# Using MXNet with Tune\n",
|
|
"\n",
|
|
"```{image} /images/mxnet_logo.png\n",
|
|
":align: center\n",
|
|
":alt: MXNet Logo\n",
|
|
":height: 120px\n",
|
|
":target: https://mxnet.apache.org/\n",
|
|
"```\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```\n",
|
|
"\n",
|
|
"## Example"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d3f38a2f",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import mxnet as mx\n",
|
|
"\n",
|
|
"from ray import tune, logger\n",
|
|
"from ray.tune.integration.mxnet import TuneCheckpointCallback, TuneReportCallback\n",
|
|
"from ray.tune.schedulers import ASHAScheduler\n",
|
|
"\n",
|
|
"\n",
|
|
"def train_mnist_mxnet(config, mnist, num_epochs=10):\n",
|
|
" batch_size = config[\"batch_size\"]\n",
|
|
" train_iter = mx.io.NDArrayIter(\n",
|
|
" mnist[\"train_data\"], mnist[\"train_label\"], batch_size, shuffle=True\n",
|
|
" )\n",
|
|
" val_iter = mx.io.NDArrayIter(mnist[\"test_data\"], mnist[\"test_label\"], batch_size)\n",
|
|
"\n",
|
|
" data = mx.sym.var(\"data\")\n",
|
|
" data = mx.sym.flatten(data=data)\n",
|
|
"\n",
|
|
" fc1 = mx.sym.FullyConnected(data=data, num_hidden=config[\"layer_1_size\"])\n",
|
|
" act1 = mx.sym.Activation(data=fc1, act_type=\"relu\")\n",
|
|
"\n",
|
|
" fc2 = mx.sym.FullyConnected(data=act1, num_hidden=config[\"layer_2_size\"])\n",
|
|
" act2 = mx.sym.Activation(data=fc2, act_type=\"relu\")\n",
|
|
"\n",
|
|
" # MNIST has 10 classes\n",
|
|
" fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)\n",
|
|
" # Softmax with cross entropy loss\n",
|
|
" mlp = mx.sym.SoftmaxOutput(data=fc3, name=\"softmax\")\n",
|
|
"\n",
|
|
" # create a trainable module on CPU\n",
|
|
" mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())\n",
|
|
" mlp_model.fit(\n",
|
|
" train_iter,\n",
|
|
" eval_data=val_iter,\n",
|
|
" optimizer=\"sgd\",\n",
|
|
" optimizer_params={\"learning_rate\": config[\"lr\"]},\n",
|
|
" eval_metric=\"acc\",\n",
|
|
" batch_end_callback=mx.callback.Speedometer(batch_size, 100),\n",
|
|
" eval_end_callback=TuneReportCallback({\"mean_accuracy\": \"accuracy\"}),\n",
|
|
" epoch_end_callback=TuneCheckpointCallback(filename=\"mxnet_cp\", frequency=3),\n",
|
|
" num_epoch=num_epochs,\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"def tune_mnist_mxnet(num_samples=10, num_epochs=10):\n",
|
|
" logger.info(\"Downloading MNIST data...\")\n",
|
|
" mnist_data = mx.test_utils.get_mnist()\n",
|
|
" logger.info(\"Got MNIST data, starting Ray Tune.\")\n",
|
|
"\n",
|
|
" config = {\n",
|
|
" \"layer_1_size\": tune.choice([32, 64, 128]),\n",
|
|
" \"layer_2_size\": tune.choice([64, 128, 256]),\n",
|
|
" \"lr\": tune.loguniform(1e-3, 1e-1),\n",
|
|
" \"batch_size\": tune.choice([32, 64, 128]),\n",
|
|
" }\n",
|
|
"\n",
|
|
" scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)\n",
|
|
" \n",
|
|
" tuner = tune.Tuner(\n",
|
|
" tune.with_parameters(\n",
|
|
" train_mnist_mxnet, mnist=mnist_data, num_epochs=num_epochs\n",
|
|
" ),\n",
|
|
" tune_config=tune.TuneConfig(\n",
|
|
" metric=\"mean_accuracy\",\n",
|
|
" mode=\"max\",\n",
|
|
" scheduler=scheduler,\n",
|
|
" num_samples=num_samples,\n",
|
|
" ),\n",
|
|
" param_space=config,\n",
|
|
" )\n",
|
|
" results = tuner.fit()\n",
|
|
" return results\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" import argparse\n",
|
|
"\n",
|
|
" parser = argparse.ArgumentParser()\n",
|
|
" parser.add_argument(\n",
|
|
" \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
|
|
" )\n",
|
|
" parser.add_argument(\n",
|
|
" \"--server-address\",\n",
|
|
" type=str,\n",
|
|
" default=None,\n",
|
|
" required=False,\n",
|
|
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
|
|
" )\n",
|
|
" args, _ = parser.parse_known_args()\n",
|
|
"\n",
|
|
" if args.server_address and not args.smoke_test:\n",
|
|
" import ray\n",
|
|
"\n",
|
|
" ray.init(f\"ray://{args.server_address}\")\n",
|
|
"\n",
|
|
" if args.smoke_test:\n",
|
|
" results = tune_mnist_mxnet(num_samples=1, num_epochs=1)\n",
|
|
" else:\n",
|
|
" results = tune_mnist_mxnet(num_samples=10, num_epochs=10)\n",
|
|
"\n",
|
|
" print(\"Best hyperparameters found were: \", results.get_best_result().config)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "37ab0db6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## More MXNet Examples\n",
|
|
"\n",
|
|
"\n",
|
|
"- {doc}`/tune/examples/includes/tune_cifar10_gluon`:\n",
|
|
" MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset.\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|