{ "cells": [ { "cell_type": "markdown", "id": "3b05af3b", "metadata": {}, "source": [ "(tune-rllib-example)=\n", "\n", "# Using RLlib with Tune\n", "\n", "```{image} /rllib/images/rllib-logo.png\n", ":align: center\n", ":alt: RLlib Logo\n", ":height: 120px\n", ":target: https://docs.ray.io\n", "```\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## Example\n", "\n", "Example of using PBT with RLlib.\n", "\n", "Note that this requires a cluster with at least 8 GPUs in order for all trials\n", "to run concurrently, otherwise PBT will round-robin train the trials which\n", "is less efficient (or you can set {\"gpu\": 0} to use CPUs for SGD instead).\n", "\n", "Note that Tune in general does not need 8 GPUs, and this is just a more\n", "computationally demanding example." ] }, { "cell_type": "code", "execution_count": null, "id": "19e3c389", "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from ray import tune\n", "from ray.tune.schedulers import PopulationBasedTraining\n", "\n", "if __name__ == \"__main__\":\n", "\n", " # Postprocess the perturbed config to ensure it's still valid\n", " def explore(config):\n", " # ensure we collect enough timesteps to do sgd\n", " if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n", " config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n", " # ensure we run at least one sgd iter\n", " if config[\"num_sgd_iter\"] < 1:\n", " config[\"num_sgd_iter\"] = 1\n", " return config\n", "\n", " pbt = PopulationBasedTraining(\n", " time_attr=\"time_total_s\",\n", " perturbation_interval=120,\n", " resample_probability=0.25,\n", " # Specifies the mutations of these hyperparams\n", " hyperparam_mutations={\n", " \"lambda\": lambda: random.uniform(0.9, 1.0),\n", " \"clip_param\": lambda: random.uniform(0.01, 0.5),\n", " \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n", " \"num_sgd_iter\": lambda: random.randint(1, 30),\n", " \"sgd_minibatch_size\": lambda: random.randint(128, 16384),\n", " \"train_batch_size\": lambda: random.randint(2000, 160000),\n", " },\n", " custom_explore_fn=explore,\n", " )\n", "\n", " analysis = tune.run(\n", " \"PPO\",\n", " name=\"pbt_humanoid_test\",\n", " scheduler=pbt,\n", " num_samples=1,\n", " metric=\"episode_reward_mean\",\n", " mode=\"max\",\n", " config={\n", " \"env\": \"Humanoid-v1\",\n", " \"kl_coeff\": 1.0,\n", " \"num_workers\": 8,\n", " \"num_gpus\": 0, # number of GPUs to use\n", " \"model\": {\"free_log_std\": True},\n", " # These params are tuned from a fixed starting value.\n", " \"lambda\": 0.95,\n", " \"clip_param\": 0.2,\n", " \"lr\": 1e-4,\n", " # These params start off randomly drawn from a set.\n", " \"num_sgd_iter\": tune.choice([10, 20, 30]),\n", " \"sgd_minibatch_size\": tune.choice([128, 512, 2048]),\n", " \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n", " },\n", " )\n", "\n", " print(\"best hyperparameters: \", analysis.best_config)\n" ] }, { "cell_type": "markdown", "id": "6fb69a24", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## More RLlib Examples\n", "\n", "- {doc}`/tune/examples/includes/pb2_ppo_example`:\n", " Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n", " Uses a small population size of 4, so can train on a laptop." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }