mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
199 lines
6.6 KiB
Python
199 lines
6.6 KiB
Python
"""
|
|
This script can be used to find learning- or performance regressions in RLlib.
|
|
|
|
If you think something broke after(!) some good commit C, do the following
|
|
while checked out in the current bad commit D (where D is newer than C):
|
|
|
|
$ cd ray
|
|
$ git bisect start
|
|
$ git bisect bad
|
|
$ git bisect good [the hash code of C]
|
|
$ git bisect run python debug_learning_failure_git_bisect.py [... options]
|
|
|
|
Produces an error if the given reward is not reached within
|
|
the stopping criteria (training iters or timesteps) OR if some number
|
|
of env timesteps are not reached within some wall time or iterations,
|
|
and thus allowing git bisect to properly analyze and find the faulty commit.
|
|
|
|
Run as follows using a simple command line config
|
|
(must run 1M timesteps in 2min):
|
|
$ python debug_learning_failure_git_bisect.py --config '{...}'
|
|
--env CartPole-v0 --run PPO --stop-time=120 --stop-timesteps=1000000
|
|
|
|
With a yaml file (must reach 180.0 reward in 100 training iterations):
|
|
$ python debug_learning_failure_git_bisect.py -f [yaml file] --stop-reward=180
|
|
--stop-iters=100
|
|
"""
|
|
import argparse
|
|
import importlib
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import yaml
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--run",
|
|
type=str,
|
|
default=None,
|
|
help="The RLlib-registered algorithm to use, even if -f (yaml file) given "
|
|
"(will override yaml run setting).")
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default=None,
|
|
help="The DL framework specifier.")
|
|
parser.add_argument(
|
|
"--skip-install-ray",
|
|
action="store_true",
|
|
help="If set, do not attempt to re-build ray from source.")
|
|
parser.add_argument(
|
|
"--stop-iters",
|
|
type=int,
|
|
default=None,
|
|
help="Number of iterations to train. Skip if this criterium is not "
|
|
"important.")
|
|
parser.add_argument(
|
|
"--stop-timesteps",
|
|
type=int,
|
|
default=None,
|
|
help="Number of env timesteps to train. Can be used in combination with "
|
|
"--stop-time to assertain we reach a certain (env) "
|
|
"timesteps per (wall) time interval. Skip if this "
|
|
"criterium is not important.")
|
|
parser.add_argument(
|
|
"--stop-time",
|
|
type=int,
|
|
default=None,
|
|
help="Time in seconds, when to stop the run. Can be used in combination "
|
|
"with --stop-timesteps to assertain we reach a certain (env) "
|
|
"timesteps per (wall) time interval. Skip if this criterium is "
|
|
"not important.")
|
|
parser.add_argument(
|
|
"--stop-reward",
|
|
type=float,
|
|
default=None,
|
|
help="The minimum reward that must be reached within the given "
|
|
"time/timesteps/iters. Skip if this criterium is not important.")
|
|
parser.add_argument(
|
|
"-f",
|
|
type=str,
|
|
default=None,
|
|
help="The yaml file to use as config. Alternatively, use --run, "
|
|
"--config, and --env.")
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default=None,
|
|
help="If no -f (yaml file) given, use this config instead.")
|
|
parser.add_argument(
|
|
"--env",
|
|
type=str,
|
|
default=None,
|
|
help="Sets the env to use, even if -f (yaml file) given "
|
|
"(will override yaml env setting).")
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
|
|
|
run = None
|
|
|
|
# Explicit yaml config file.
|
|
if args.f:
|
|
with open(args.f, "r") as fp:
|
|
experiment_config = yaml.load(fp)
|
|
experiment_config = experiment_config[next(
|
|
iter(experiment_config))]
|
|
config = experiment_config.get("config", {})
|
|
config["env"] = experiment_config.get("env")
|
|
run = experiment_config.pop("run")
|
|
# JSON string on command line.
|
|
else:
|
|
config = json.loads(args.config)
|
|
assert args.env
|
|
config["env"] = args.env
|
|
|
|
# Explicit run.
|
|
if args.run:
|
|
run = args.run
|
|
|
|
# Set --framework, if provided.
|
|
if args.framework:
|
|
config["framework"] = args.framework
|
|
|
|
# Define stopping criteria.
|
|
stop = {}
|
|
if args.stop_iters:
|
|
stop["training_iteration"] = args.stop_iters
|
|
if args.stop_timesteps:
|
|
stop["timesteps_total"] = args.stop_timesteps
|
|
if args.stop_reward:
|
|
stop["episode_reward_mean"] = args.stop_reward
|
|
if args.stop_time:
|
|
stop["time_total_s"] = args.stop_time
|
|
|
|
# - Stop ray.
|
|
# - Uninstall and re-install ray (from source) if required.
|
|
# - Start ray.
|
|
try:
|
|
subprocess.run("ray stop".split(" "))
|
|
subprocess.run("ray stop".split(" "))
|
|
except Exception:
|
|
pass
|
|
|
|
# Install ray from the checked out repo.
|
|
if not args.skip_install_ray:
|
|
subprocess.run("sudo apt-get update".split(" "))
|
|
subprocess.run("sudo apt-get install -y build-essential curl unzip "
|
|
"psmisc".split(" "))
|
|
subprocess.run("pip install cython==0.29.0 pytest".split(" "))
|
|
# Assume we are in the ray (git clone) directory.
|
|
try:
|
|
subprocess.run("pip uninstall -y ray".split(" "))
|
|
except Exception:
|
|
pass
|
|
subprocess.run("ci/travis/install-bazel.sh".split(" "))
|
|
os.chdir("python")
|
|
subprocess.run("pip install -e . --verbose".split(" "))
|
|
os.chdir("../")
|
|
|
|
try:
|
|
subprocess.run("ray start --head --include-dashboard false".split(" "))
|
|
except Exception:
|
|
subprocess.run("ray stop".split(" "))
|
|
subprocess.run("ray start --head --include-dashboard false".split(" "))
|
|
|
|
# Run the training experiment.
|
|
importlib.invalidate_caches()
|
|
import ray
|
|
from ray import tune
|
|
|
|
ray.init()
|
|
|
|
results = tune.run(run, stop=stop, config=config)
|
|
|
|
# Criterium is to have reached some min reward.
|
|
if args.stop_reward:
|
|
last_result = results.trials[0].last_result
|
|
avg_reward = last_result["episode_reward_mean"]
|
|
if avg_reward < args.stop_reward:
|
|
raise ValueError("`stop-reward` of {} not reached!".format(
|
|
args.stop_reward))
|
|
# Criterium is to have run through n env timesteps in some wall time m.
|
|
elif args.stop_timesteps and args.stop_time:
|
|
last_result = results.trials[0].last_result
|
|
total_timesteps = last_result["timesteps_total"]
|
|
# We stopped because we reached the time limit ->
|
|
# Means throughput is too slow (time steps not reached).
|
|
if total_timesteps - 100 < args.stop_timesteps:
|
|
raise ValueError(
|
|
"`stop-timesteps` of {} not reached in {}sec!".format(
|
|
args.stop_timesteps, args.stop_time))
|
|
else:
|
|
raise ValueError("Invalid pass criterium! Must use either "
|
|
"(--stop-reward + optionally any other) OR "
|
|
"(--stop-timesteps + --stop-time).")
|
|
|
|
print("ok")
|
|
ray.shutdown()
|