2017-07-10 23:36:14 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2017-11-06 23:41:17 -08:00
|
|
|
import argparse
|
2020-02-15 23:50:44 +01:00
|
|
|
import os
|
|
|
|
from pathlib import Path
|
2017-10-13 16:18:16 -07:00
|
|
|
import yaml
|
2017-07-10 23:36:14 +00:00
|
|
|
|
2018-01-24 16:55:17 -08:00
|
|
|
import ray
|
2018-12-29 11:42:25 +08:00
|
|
|
from ray.tune.config_parser import make_parser
|
2019-07-02 20:46:00 -07:00
|
|
|
from ray.tune.result import DEFAULT_RESULTS_DIR
|
2019-08-01 01:03:10 -07:00
|
|
|
from ray.tune.resources import resources_to_json
|
2020-10-08 23:10:23 +00:00
|
|
|
from ray.tune.tune import run_experiments
|
|
|
|
from ray.tune.schedulers import create_scheduler
|
2021-05-18 13:18:12 +02:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
2020-01-18 03:48:44 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
|
|
|
|
# Try to import both backends for flag checking/warnings.
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-01-18 03:48:44 +01:00
|
|
|
torch, _ = try_import_torch()
|
2017-11-06 23:41:17 -08:00
|
|
|
|
|
|
|
EXAMPLE_USAGE = """
|
2018-07-12 19:12:04 +02:00
|
|
|
Training example via RLlib CLI:
|
|
|
|
rllib train --run DQN --env CartPole-v0
|
2017-10-13 16:18:16 -07:00
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
Grid search example via RLlib CLI:
|
2022-06-04 07:35:24 +02:00
|
|
|
rllib train -f tuned_examples/cartpole-ppo-grid-search-example.yaml
|
2018-07-12 19:12:04 +02:00
|
|
|
|
|
|
|
Grid search example via executable:
|
2022-06-04 07:35:24 +02:00
|
|
|
./train.py -f tuned_examples/cartpole-ppo-grid-search-example.yaml
|
2017-11-12 12:05:18 -08:00
|
|
|
|
|
|
|
Note that -f overrides all other trial-specific command-line options.
|
2017-11-06 23:41:17 -08:00
|
|
|
"""
|
2017-10-13 16:18:16 -07:00
|
|
|
|
2017-07-10 23:36:14 +00:00
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
def create_parser(parser_creator=None):
|
|
|
|
parser = make_parser(
|
|
|
|
parser_creator=parser_creator,
|
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
|
description="Train a reinforcement learning agent.",
|
|
|
|
epilog=EXAMPLE_USAGE,
|
|
|
|
)
|
2017-07-10 23:36:14 +00:00
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
# See also the base parser definition in ray/tune/config_parser.py
|
|
|
|
parser.add_argument(
|
2019-08-10 00:18:41 -07:00
|
|
|
"--ray-address",
|
2018-07-19 15:30:36 -07:00
|
|
|
default=None,
|
|
|
|
type=str,
|
2018-12-09 05:48:52 -08:00
|
|
|
help="Connect to an existing Ray cluster at this address instead "
|
|
|
|
"of starting a new one.",
|
|
|
|
)
|
2021-09-20 15:31:57 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--ray-ui", action="store_true", help="Whether to enable the Ray web UI."
|
|
|
|
)
|
|
|
|
# Deprecated: Use --ray-ui, instead.
|
2020-04-23 12:39:19 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--no-ray-ui",
|
|
|
|
action="store_true",
|
2021-09-20 15:31:57 +02:00
|
|
|
help="Deprecated! Ray UI is disabled by default now. "
|
|
|
|
"Use `--ray-ui` to enable.",
|
|
|
|
)
|
2020-06-05 15:40:30 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--local-mode",
|
|
|
|
action="store_true",
|
2021-02-08 12:05:16 +01:00
|
|
|
help="Run ray in local mode for easier debugging.",
|
|
|
|
)
|
2018-07-12 19:12:04 +02:00
|
|
|
parser.add_argument(
|
2018-07-19 15:30:36 -07:00
|
|
|
"--ray-num-cpus",
|
|
|
|
default=None,
|
|
|
|
type=int,
|
2018-12-09 05:48:52 -08:00
|
|
|
help="--num-cpus to use if starting a new cluster.",
|
|
|
|
)
|
2018-07-12 19:12:04 +02:00
|
|
|
parser.add_argument(
|
2018-07-19 15:30:36 -07:00
|
|
|
"--ray-num-gpus",
|
|
|
|
default=None,
|
|
|
|
type=int,
|
2018-12-09 05:48:52 -08:00
|
|
|
help="--num-gpus to use if starting a new cluster.",
|
|
|
|
)
|
2018-11-08 00:04:20 -08:00
|
|
|
parser.add_argument(
|
2019-01-07 12:44:49 -08:00
|
|
|
"--ray-num-nodes",
|
2018-11-08 00:04:20 -08:00
|
|
|
default=None,
|
|
|
|
type=int,
|
|
|
|
help="Emulate multiple cluster nodes for debugging.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--ray-object-store-memory",
|
|
|
|
default=None,
|
|
|
|
type=int,
|
2018-12-09 05:48:52 -08:00
|
|
|
help="--object-store-memory to use if starting a new cluster.",
|
|
|
|
)
|
2018-07-12 19:12:04 +02:00
|
|
|
parser.add_argument(
|
2018-07-19 15:30:36 -07:00
|
|
|
"--experiment-name",
|
|
|
|
default="default",
|
|
|
|
type=str,
|
2018-07-12 19:12:04 +02:00
|
|
|
help="Name of the subdirectory under `local_dir` to put results in.",
|
|
|
|
)
|
2019-07-02 20:46:00 -07:00
|
|
|
parser.add_argument(
|
|
|
|
"--local-dir",
|
|
|
|
default=DEFAULT_RESULTS_DIR,
|
|
|
|
type=str,
|
|
|
|
help="Local dir to save training results to. Defaults to '{}'.".format(
|
|
|
|
DEFAULT_RESULTS_DIR
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2019-07-02 20:46:00 -07:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--upload-dir",
|
|
|
|
default="",
|
|
|
|
type=str,
|
|
|
|
help="Optional URI to sync training results to (e.g. s3://bucket).",
|
|
|
|
)
|
2021-05-18 13:18:12 +02:00
|
|
|
# This will override any framework setting found in a yaml file.
|
|
|
|
parser.add_argument(
|
|
|
|
"--framework",
|
|
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
|
|
default=None,
|
|
|
|
help="The DL framework specifier.",
|
|
|
|
)
|
2019-11-13 18:50:45 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"-v", action="store_true", help="Whether to use INFO level logging."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-vv", action="store_true", help="Whether to use DEBUG level logging."
|
|
|
|
)
|
2018-12-29 11:42:25 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--resume",
|
|
|
|
action="store_true",
|
|
|
|
help="Whether to attempt to resume previous Tune experiments.",
|
|
|
|
)
|
2019-09-17 04:44:20 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--trace",
|
|
|
|
action="store_true",
|
|
|
|
help="Whether to attempt to enable tracing for eager mode.",
|
|
|
|
)
|
2018-07-12 19:12:04 +02:00
|
|
|
parser.add_argument(
|
|
|
|
"--env", default=None, type=str, help="The gym environment to use."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
2018-07-19 15:30:36 -07:00
|
|
|
"-f",
|
|
|
|
"--config-file",
|
|
|
|
default=None,
|
|
|
|
type=str,
|
2018-07-12 19:12:04 +02:00
|
|
|
help="If specified, use config options from this file. Note that this "
|
|
|
|
"overrides any trial-specific options set via flags above.",
|
|
|
|
)
|
2021-05-18 13:18:12 +02:00
|
|
|
|
|
|
|
# Obsolete: Use --framework=torch|tf2|tfe instead!
|
|
|
|
parser.add_argument(
|
|
|
|
"--torch",
|
|
|
|
action="store_true",
|
|
|
|
help="Whether to use PyTorch (instead of tf) as the DL framework.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--eager",
|
|
|
|
action="store_true",
|
|
|
|
help="Whether to attempt to enable TF eager execution.",
|
|
|
|
)
|
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
return parser
|
2017-07-10 23:36:14 +00:00
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
|
|
|
|
def run(args, parser):
|
2017-10-13 16:18:16 -07:00
|
|
|
if args.config_file:
|
|
|
|
with open(args.config_file) as f:
|
2019-07-06 20:41:28 -07:00
|
|
|
experiments = yaml.safe_load(f)
|
2017-07-13 14:53:57 -07:00
|
|
|
else:
|
2017-11-12 12:05:18 -08:00
|
|
|
# Note: keep this in sync with tune/config_parser.py
|
2017-11-06 23:41:17 -08:00
|
|
|
experiments = {
|
2017-12-15 14:19:08 -08:00
|
|
|
args.experiment_name: { # i.e. log to ~/ray_results/default
|
2017-11-20 17:52:43 -08:00
|
|
|
"run": args.run,
|
2017-11-12 12:05:18 -08:00
|
|
|
"checkpoint_freq": args.checkpoint_freq,
|
2020-06-16 08:51:20 +02:00
|
|
|
"checkpoint_at_end": args.checkpoint_at_end,
|
2019-04-07 05:01:54 +02:00
|
|
|
"keep_checkpoints_num": args.keep_checkpoints_num,
|
|
|
|
"checkpoint_score_attr": args.checkpoint_score_attr,
|
2017-11-12 12:05:18 -08:00
|
|
|
"local_dir": args.local_dir,
|
2021-11-12 15:25:50 +01:00
|
|
|
"resources_per_trial": (
|
|
|
|
args.resources_per_trial
|
|
|
|
and resources_to_json(args.resources_per_trial)
|
|
|
|
),
|
2017-11-06 23:41:17 -08:00
|
|
|
"stop": args.stop,
|
2017-11-20 17:52:43 -08:00
|
|
|
"config": dict(args.config, env=args.env),
|
2017-11-06 23:41:17 -08:00
|
|
|
"restore": args.restore,
|
2018-08-24 15:05:24 -07:00
|
|
|
"num_samples": args.num_samples,
|
2021-11-12 15:25:50 +01:00
|
|
|
"sync_config": {
|
|
|
|
"upload_dir": args.upload_dir,
|
|
|
|
},
|
2017-11-06 23:41:17 -08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-09-20 15:31:57 +02:00
|
|
|
# Ray UI.
|
|
|
|
if args.no_ray_ui:
|
|
|
|
deprecation_warning(old="--no-ray-ui", new="--ray-ui", error=False)
|
|
|
|
args.ray_ui = False
|
|
|
|
|
2019-11-13 18:50:45 -08:00
|
|
|
verbose = 1
|
2017-11-06 23:41:17 -08:00
|
|
|
for exp in experiments.values():
|
2020-02-15 23:50:44 +01:00
|
|
|
# Bazel makes it hard to find files specified in `args` (and `data`).
|
|
|
|
# Look for them here.
|
2020-04-09 23:04:21 +02:00
|
|
|
# NOTE: Some of our yaml files don't have a `config` section.
|
2021-05-18 11:10:46 +02:00
|
|
|
input_ = exp.get("config", {}).get("input")
|
2021-09-24 08:41:33 -04:00
|
|
|
|
2021-05-18 11:10:46 +02:00
|
|
|
if input_ and input_ != "sampler":
|
2020-02-15 23:50:44 +01:00
|
|
|
# This script runs in the ray/rllib dir.
|
|
|
|
rllib_dir = Path(__file__).parent
|
2021-07-03 04:12:47 -04:00
|
|
|
|
|
|
|
def patch_path(path):
|
2021-09-24 08:41:33 -04:00
|
|
|
if isinstance(path, list):
|
|
|
|
return [patch_path(i) for i in path]
|
|
|
|
elif isinstance(path, dict):
|
|
|
|
return {patch_path(k): patch_path(v) for k, v in path.items()}
|
|
|
|
elif isinstance(path, str):
|
|
|
|
if os.path.exists(path):
|
|
|
|
return path
|
|
|
|
else:
|
|
|
|
abs_path = str(rllib_dir.absolute().joinpath(path))
|
|
|
|
return abs_path if os.path.exists(abs_path) else path
|
2021-07-03 04:12:47 -04:00
|
|
|
else:
|
2021-09-24 08:41:33 -04:00
|
|
|
return path
|
2021-05-18 11:10:46 +02:00
|
|
|
|
2021-09-24 08:41:33 -04:00
|
|
|
exp["config"]["input"] = patch_path(input_)
|
2020-02-15 23:50:44 +01:00
|
|
|
|
2017-11-20 17:52:43 -08:00
|
|
|
if not exp.get("run"):
|
|
|
|
parser.error("the following arguments are required: --run")
|
|
|
|
if not exp.get("env") and not exp.get("config", {}).get("env"):
|
2017-11-06 23:41:17 -08:00
|
|
|
parser.error("the following arguments are required: --env")
|
2020-07-21 22:02:24 +02:00
|
|
|
|
|
|
|
if args.torch:
|
2021-05-18 13:18:12 +02:00
|
|
|
deprecation_warning("--torch", "--framework=torch")
|
2020-05-27 16:19:13 +02:00
|
|
|
exp["config"]["framework"] = "torch"
|
2020-07-21 22:02:24 +02:00
|
|
|
elif args.eager:
|
2021-05-18 13:18:12 +02:00
|
|
|
deprecation_warning("--eager", "--framework=[tf2|tfe]")
|
2020-07-21 22:02:24 +02:00
|
|
|
exp["config"]["framework"] = "tfe"
|
2021-05-18 13:18:12 +02:00
|
|
|
elif args.framework is not None:
|
|
|
|
exp["config"]["framework"] = args.framework
|
2020-07-21 22:02:24 +02:00
|
|
|
|
|
|
|
if args.trace:
|
|
|
|
if exp["config"]["framework"] not in ["tf2", "tfe"]:
|
|
|
|
raise ValueError("Must enable --eager to enable tracing.")
|
|
|
|
exp["config"]["eager_tracing"] = True
|
|
|
|
|
2019-11-13 18:50:45 -08:00
|
|
|
if args.v:
|
|
|
|
exp["config"]["log_level"] = "INFO"
|
2020-12-04 22:56:26 +01:00
|
|
|
verbose = 3 # Print details on trial result
|
2019-11-13 18:50:45 -08:00
|
|
|
if args.vv:
|
|
|
|
exp["config"]["log_level"] = "DEBUG"
|
2020-12-04 22:56:26 +01:00
|
|
|
verbose = 3 # Print details on trial result
|
2017-11-06 23:41:17 -08:00
|
|
|
|
2019-01-07 12:44:49 -08:00
|
|
|
if args.ray_num_nodes:
|
2021-08-31 14:56:53 +02:00
|
|
|
# Import this only here so that train.py also works with
|
|
|
|
# older versions (and user doesn't use `--ray-num-nodes`).
|
|
|
|
from ray.cluster_utils import Cluster
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2018-11-08 00:04:20 -08:00
|
|
|
cluster = Cluster()
|
2019-01-07 12:44:49 -08:00
|
|
|
for _ in range(args.ray_num_nodes):
|
2018-11-08 00:04:20 -08:00
|
|
|
cluster.add_node(
|
2019-02-27 14:33:06 -08:00
|
|
|
num_cpus=args.ray_num_cpus or 1,
|
|
|
|
num_gpus=args.ray_num_gpus or 0,
|
2020-08-28 15:03:50 -07:00
|
|
|
object_store_memory=args.ray_object_store_memory,
|
|
|
|
)
|
2019-09-01 16:53:02 -07:00
|
|
|
ray.init(address=cluster.address)
|
2018-11-08 00:04:20 -08:00
|
|
|
else:
|
|
|
|
ray.init(
|
2021-09-20 15:31:57 +02:00
|
|
|
include_dashboard=args.ray_ui,
|
2019-08-10 00:18:41 -07:00
|
|
|
address=args.ray_address,
|
2018-11-08 00:04:20 -08:00
|
|
|
object_store_memory=args.ray_object_store_memory,
|
|
|
|
num_cpus=args.ray_num_cpus,
|
2020-06-05 15:40:30 +02:00
|
|
|
num_gpus=args.ray_num_gpus,
|
|
|
|
local_mode=args.local_mode,
|
|
|
|
)
|
2020-06-20 00:05:19 +02:00
|
|
|
|
2018-04-16 16:58:15 -07:00
|
|
|
run_experiments(
|
2018-07-19 15:30:36 -07:00
|
|
|
experiments,
|
2020-10-08 23:10:23 +00:00
|
|
|
scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
|
2019-11-13 12:22:55 -08:00
|
|
|
resume=args.resume,
|
2019-11-13 18:50:45 -08:00
|
|
|
verbose=verbose,
|
2019-11-13 12:22:55 -08:00
|
|
|
concurrent=True,
|
|
|
|
)
|
2018-07-12 19:12:04 +02:00
|
|
|
|
2020-06-20 00:05:19 +02:00
|
|
|
ray.shutdown()
|
|
|
|
|
2018-07-12 19:12:04 +02:00
|
|
|
|
2021-07-26 11:12:59 -04:00
|
|
|
def main():
|
2018-07-12 19:12:04 +02:00
|
|
|
parser = create_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
run(args, parser)
|
2021-07-26 11:12:59 -04:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|