{ "cells": [ { "cell_type": "markdown", "source": [ "(tune-horovod-example)=\n", "\n", "# Using Horovod with Tune\n", "\n", "```{image} /images/horovod.png\n", ":align: center\n", ":alt: Horovod Logo\n", ":height: 120px\n", ":target: https://horovod.ai/\n", "```\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## Example" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "\n", "import ray\n", "from ray import tune\n", "from ray.tune.integration.horovod import DistributedTrainableCreator\n", "import time\n", "\n", "\n", "def sq(x):\n", " m2 = 1.0\n", " m1 = -20.0\n", " m0 = 50.0\n", " return m2 * x * x + m1 * x + m0\n", "\n", "\n", "def qu(x):\n", " m3 = 10.0\n", " m2 = 5.0\n", " m1 = -20.0\n", " m0 = -5.0\n", " return m3 * x * x * x + m2 * x * x + m1 * x + m0\n", "\n", "\n", "class Net(torch.nn.Module):\n", " def __init__(self, mode=\"sq\"):\n", " super(Net, self).__init__()\n", "\n", " if mode == \"square\":\n", " self.mode = 0\n", " self.param = torch.nn.Parameter(torch.FloatTensor([1.0, -1.0]))\n", " else:\n", " self.mode = 1\n", " self.param = torch.nn.Parameter(torch.FloatTensor([1.0, -1.0, 1.0]))\n", "\n", " def forward(self, x):\n", " if ~self.mode:\n", " return x * x + self.param[0] * x + self.param[1]\n", " else:\n", " return_val = 10 * x * x * x\n", " return_val += self.param[0] * x * x\n", " return_val += self.param[1] * x + self.param[2]\n", " return return_val\n", "\n", "\n", "def train(config):\n", " import torch\n", " import horovod.torch as hvd\n", "\n", " hvd.init()\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " mode = config[\"mode\"]\n", " net = Net(mode).to(device)\n", " optimizer = torch.optim.SGD(\n", " net.parameters(),\n", " lr=config[\"lr\"],\n", " )\n", " optimizer = hvd.DistributedOptimizer(optimizer)\n", "\n", " num_steps = 5\n", " print(hvd.size())\n", " np.random.seed(1 + hvd.rank())\n", " torch.manual_seed(1234)\n", " # To ensure consistent initialization across slots,\n", " hvd.broadcast_parameters(net.state_dict(), root_rank=0)\n", " hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n", "\n", " start = time.time()\n", " x_max = config[\"x_max\"]\n", " for step in range(1, num_steps + 1):\n", " features = torch.Tensor(np.random.rand(1) * 2 * x_max - x_max).to(device)\n", " if mode == \"square\":\n", " labels = sq(features)\n", " else:\n", " labels = qu(features)\n", " optimizer.zero_grad()\n", " outputs = net(features)\n", " loss = torch.nn.MSELoss()(outputs, labels)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " time.sleep(0.1)\n", " tune.report(loss=loss.item())\n", " total = time.time() - start\n", " print(f\"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.\")\n", "\n", "\n", "def tune_horovod(\n", " hosts_per_trial, slots_per_host, num_samples, use_gpu, mode=\"square\", x_max=1.0\n", "):\n", " horovod_trainable = DistributedTrainableCreator(\n", " train,\n", " use_gpu=use_gpu,\n", " num_hosts=hosts_per_trial,\n", " num_slots=slots_per_host,\n", " replicate_pem=False,\n", " )\n", " analysis = tune.run(\n", " horovod_trainable,\n", " metric=\"loss\",\n", " mode=\"min\",\n", " config={\"lr\": tune.uniform(0.1, 1), \"mode\": mode, \"x_max\": x_max},\n", " num_samples=num_samples,\n", " fail_fast=True,\n", " )\n", " print(\"Best hyperparameters found were: \", analysis.best_config)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " import argparse\n", "\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument(\n", " \"--mode\", type=str, default=\"square\", choices=[\"square\", \"cubic\"]\n", " )\n", " parser.add_argument(\n", " \"--learning_rate\", type=float, default=0.1, dest=\"learning_rate\"\n", " )\n", " parser.add_argument(\"--x_max\", type=float, default=1.0, dest=\"x_max\")\n", " parser.add_argument(\"--gpu\", default=False, action=\"store_true\")\n", " parser.add_argument(\n", " \"--smoke-test\", default=True, action=\"store_true\", help=(\"Finish quickly for testing.\")\n", " )\n", " parser.add_argument(\"--hosts-per-trial\", type=int, default=1)\n", " parser.add_argument(\"--slots-per-host\", type=int, default=2)\n", " parser.add_argument(\n", " \"--server-address\",\n", " type=str,\n", " default=None,\n", " required=False,\n", " help=\"The address of server to connect to if using \" \"Ray Client.\",\n", " )\n", " args, _ = parser.parse_known_args()\n", "\n", " if args.smoke_test:\n", " ray.init(num_cpus=2)\n", " elif args.server_address:\n", " ray.init(f\"ray://{args.server_address}\")\n", "\n", " # import ray\n", " # ray.init(address=\"auto\") # assumes ray is started with ray up\n", "\n", " tune_horovod(\n", " hosts_per_trial=args.hosts_per_trial,\n", " slots_per_host=args.slots_per_host,\n", " num_samples=2 if args.smoke_test else 10,\n", " use_gpu=args.gpu,\n", " mode=args.mode,\n", " x_max=args.x_max,\n", " )" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }