mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
130 lines
No EOL
4.2 KiB
Text
130 lines
No EOL
4.2 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3b05af3b",
|
|
"metadata": {},
|
|
"source": [
|
|
"(tune-rllib-example)=\n",
|
|
"\n",
|
|
"# Using RLlib with Tune\n",
|
|
"\n",
|
|
"```{image} /rllib/images/rllib-logo.png\n",
|
|
":align: center\n",
|
|
":alt: RLlib Logo\n",
|
|
":height: 120px\n",
|
|
":target: https://docs.ray.io\n",
|
|
"```\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```\n",
|
|
"\n",
|
|
"## Example\n",
|
|
"\n",
|
|
"Example of using PBT with RLlib.\n",
|
|
"\n",
|
|
"Note that this requires a cluster with at least 8 GPUs in order for all trials\n",
|
|
"to run concurrently, otherwise PBT will round-robin train the trials which\n",
|
|
"is less efficient (or you can set {\"gpu\": 0} to use CPUs for SGD instead).\n",
|
|
"\n",
|
|
"Note that Tune in general does not need 8 GPUs, and this is just a more\n",
|
|
"computationally demanding example."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "19e3c389",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"from ray import tune\n",
|
|
"from ray.tune.schedulers import PopulationBasedTraining\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
"\n",
|
|
" # Postprocess the perturbed config to ensure it's still valid\n",
|
|
" def explore(config):\n",
|
|
" # ensure we collect enough timesteps to do sgd\n",
|
|
" if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n",
|
|
" config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n",
|
|
" # ensure we run at least one sgd iter\n",
|
|
" if config[\"num_sgd_iter\"] < 1:\n",
|
|
" config[\"num_sgd_iter\"] = 1\n",
|
|
" return config\n",
|
|
"\n",
|
|
" pbt = PopulationBasedTraining(\n",
|
|
" time_attr=\"time_total_s\",\n",
|
|
" perturbation_interval=120,\n",
|
|
" resample_probability=0.25,\n",
|
|
" # Specifies the mutations of these hyperparams\n",
|
|
" hyperparam_mutations={\n",
|
|
" \"lambda\": lambda: random.uniform(0.9, 1.0),\n",
|
|
" \"clip_param\": lambda: random.uniform(0.01, 0.5),\n",
|
|
" \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n",
|
|
" \"num_sgd_iter\": lambda: random.randint(1, 30),\n",
|
|
" \"sgd_minibatch_size\": lambda: random.randint(128, 16384),\n",
|
|
" \"train_batch_size\": lambda: random.randint(2000, 160000),\n",
|
|
" },\n",
|
|
" custom_explore_fn=explore,\n",
|
|
" )\n",
|
|
"\n",
|
|
" analysis = tune.run(\n",
|
|
" \"PPO\",\n",
|
|
" name=\"pbt_humanoid_test\",\n",
|
|
" scheduler=pbt,\n",
|
|
" num_samples=1,\n",
|
|
" metric=\"episode_reward_mean\",\n",
|
|
" mode=\"max\",\n",
|
|
" config={\n",
|
|
" \"env\": \"Humanoid-v1\",\n",
|
|
" \"kl_coeff\": 1.0,\n",
|
|
" \"num_workers\": 8,\n",
|
|
" \"num_gpus\": 0, # number of GPUs to use\n",
|
|
" \"model\": {\"free_log_std\": True},\n",
|
|
" # These params are tuned from a fixed starting value.\n",
|
|
" \"lambda\": 0.95,\n",
|
|
" \"clip_param\": 0.2,\n",
|
|
" \"lr\": 1e-4,\n",
|
|
" # These params start off randomly drawn from a set.\n",
|
|
" \"num_sgd_iter\": tune.choice([10, 20, 30]),\n",
|
|
" \"sgd_minibatch_size\": tune.choice([128, 512, 2048]),\n",
|
|
" \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n",
|
|
" },\n",
|
|
" )\n",
|
|
"\n",
|
|
" print(\"best hyperparameters: \", analysis.best_config)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6fb69a24",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## More RLlib Examples\n",
|
|
"\n",
|
|
"- {doc}`/tune/examples/includes/pb2_ppo_example`:\n",
|
|
" Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n",
|
|
" Uses a small population size of 4, so can train on a laptop."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |