mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Tune trial + checkpoint selection example. (#14209)
This commit is contained in:
parent
de8d9d3e44
commit
3d20d58c90
3 changed files with 89 additions and 1 deletions
|
@ -20,6 +20,7 @@ from ray.tune.result import DEFAULT_METRIC, EXPR_PROGRESS_FILE, \
|
|||
EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.utils.util import unflattened_lookup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -238,7 +239,10 @@ class Analysis:
|
|||
return path_metric_df[["chkpt_path", metric]].values.tolist()
|
||||
elif isinstance(trial, Trial):
|
||||
checkpoints = trial.checkpoint_manager.best_checkpoints()
|
||||
return [(c.value, c.result[metric]) for c in checkpoints]
|
||||
# Support metrics given as paths, e.g.
|
||||
# "info/learner/default_policy/policy_loss".
|
||||
return [(c.value, unflattened_lookup(metric, c.result))
|
||||
for c in checkpoints]
|
||||
else:
|
||||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
|
|
|
@ -1707,6 +1707,15 @@ py_test(
|
|||
args = ["--as-test", "--torch", "--stop-reward=6.0"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/checkpoint_by_custom_criteria",
|
||||
main = "examples/checkpoint_by_custom_criteria.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "medium",
|
||||
srcs = ["examples/checkpoint_by_custom_criteria.py"],
|
||||
args = ["--stop-iters=3 --num-cpus=3"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/complex_struct_space_tf", main = "examples/complex_struct_space.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
|
|
75
rllib/examples/checkpoint_by_custom_criteria.py
Normal file
75
rllib/examples/checkpoint_by_custom_criteria.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="PPO")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")
|
||||
parser.add_argument("--stop-iters", type=int, default=200)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--stop-reward", type=float, default=150.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
||||
# Simple PPO config.
|
||||
config = {
|
||||
"env": "CartPole-v0",
|
||||
# Run 3 trials.
|
||||
"lr": tune.grid_search([0.01, 0.001, 0.0001]),
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"framework": args.framework,
|
||||
# Run with tracing enabled for tfe/tf2.
|
||||
"eager_tracing": args.framework in ["tfe", "tf2"],
|
||||
}
|
||||
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
# Run tune for some iterations and generate checkpoints.
|
||||
results = tune.run(args.run, config=config, stop=stop, checkpoint_freq=1)
|
||||
|
||||
# Get the best of the 3 trials by using some metric.
|
||||
# NOTE: Choosing the min `episodes_this_iter` automatically picks the trial
|
||||
# with the best performance (over the entire run (scope="all")):
|
||||
# The fewer episodes, the longer each episode lasted, the more reward we
|
||||
# got each episode.
|
||||
# Setting scope to "last", "last-5-avg", or "last-10-avg" will only compare
|
||||
# (using `mode=min|max`) the average values of the last 1, 5, or 10
|
||||
# iterations with each other, respectively.
|
||||
# Setting scope to "avg" will compare (using `mode`=min|max) the average
|
||||
# values over the entire run.
|
||||
metric = "episodes_this_iter"
|
||||
best_trial = results.get_best_trial(metric=metric, mode="min", scope="all")
|
||||
value_best_metric = best_trial.metric_analysis[metric]["min"]
|
||||
print("Best trial's lowest episode length (over all "
|
||||
"iterations): {}".format(value_best_metric))
|
||||
|
||||
# Confirm, we picked the right trial.
|
||||
assert all(value_best_metric <= results.results[t][metric]
|
||||
for t in results.results.keys())
|
||||
|
||||
# Get the best checkpoints from the trial, based on different metrics.
|
||||
# Checkpoint with the lowest policy loss value:
|
||||
ckpt = results.get_best_checkpoint(
|
||||
best_trial,
|
||||
metric="info/learner/default_policy/policy_loss",
|
||||
mode="min")
|
||||
print("Lowest pol-loss: {}".format(ckpt))
|
||||
|
||||
# Checkpoint with the highest value-function loss:
|
||||
ckpt = results.get_best_checkpoint(
|
||||
best_trial, metric="info/learner/default_policy/vf_loss", mode="max")
|
||||
print("Highest vf-loss: {}".format(ckpt))
|
||||
|
||||
ray.shutdown()
|
Loading…
Add table
Reference in a new issue