ray/rllib/examples/tune/framework.py
xwjiang2010 fcf897ee72
[air] update rllib example to use Tuner API. (#26987)
update rllib example to use Tuner API.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
2022-07-27 12:12:59 +01:00

89 lines
2.3 KiB
Python

#!/usr/bin/env python3
""" Benchmarking TF against PyTorch on an example task using Ray Tune.
"""
import logging
from pprint import pformat
import ray
from ray import air, tune
from ray.tune import CLIReporter
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger("tune_framework")
def run(smoke_test=False):
stop = {"training_iteration": 1 if smoke_test else 50}
num_workers = 1 if smoke_test else 20
num_gpus = 0 if smoke_test else 1
config = {
"env": "PongNoFrameskip-v4",
"framework": tune.grid_search(["tf", "torch"]),
"num_gpus": num_gpus,
"rollout_fragment_length": 50,
"train_batch_size": 750,
"num_workers": num_workers,
"num_envs_per_worker": 1,
"clip_rewards": True,
"num_sgd_iter": 2,
"vf_loss_coeff": 1.0,
"clip_param": 0.3,
"grad_clip": 10,
"vtrace": True,
"use_kl_loss": False,
}
logger.info("Configuration: \n %s", pformat(config))
# Run the experiment.
# TODO(jungong) : maybe add checkpointing.
tune.Tuner(
"APPO",
param_space=config,
run_config=air.RunConfig(
stop=stop,
verbose=1,
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "iter",
"time_total_s": "time_total_s",
"timesteps_total": "ts",
"snapshots": "snapshots",
"episodes_this_iter": "train_episodes",
"episode_reward_mean": "reward_mean",
},
sort_by_metric=True,
max_report_frequency=30,
),
),
tune_config=tune.TuneConfig(
num_samples=1,
),
).fit()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Tune+RLlib Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
args = parser.parse_args()
if args.smoke_test:
ray.init(num_cpus=2)
else:
ray.init()
run(smoke_test=args.smoke_test)
ray.shutdown()