2017-07-10 23:36:14 +00:00
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
2017-09-12 14:28:16 -07:00
import pprint
2017-08-07 19:05:48 -07:00
import sys
2017-07-10 23:36:14 +00:00
import ray
2017-08-29 16:56:42 -07:00
import ray.rllib.ppo as ppo
import ray.rllib.es as es
2017-07-10 23:36:14 +00:00
import ray.rllib.dqn as dqn
import ray.rllib.a3c as a3c
parser = argparse.ArgumentParser(
description=("Train a reinforcement learning agent."))
2017-08-07 19:05:48 -07:00
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--env", required=True, type=str,
help="The gym environment to use.")
parser.add_argument("--alg", required=True, type=str,
help="The reinforcement learning algorithm to use.")
parser.add_argument("--num-iterations", default=sys.maxsize, type=int,
help="The number of training iterations to run.")
parser.add_argument("--config", default="{}", type=str,
help="The configuration options of the algorithm.")
parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str,
help="Where the traces are stored.")
2017-08-27 18:56:52 -07:00
parser.add_argument("--checkpoint-freq", default=sys.maxsize, type=int,
help="How many iterations between checkpoints.")
parser.add_argument("--restore", default="", type=str,
help="If specified, restores state from this checkpoint.")
2017-07-10 23:36:14 +00:00
if __name__ == "__main__":
2017-07-13 14:53:57 -07:00
args = parser.parse_args()
2017-08-07 19:05:48 -07:00
json_config = json.loads(args.config)
2017-07-13 14:53:57 -07:00
2017-08-07 19:05:48 -07:00
2017-07-13 14:53:57 -07:00
2017-09-02 17:20:56 -07:00
def _check_and_update(config, json):
for k in json.keys():
if k not in config:
raise Exception(
"Unknown model config `{}`, all model configs: {}".format(
k, config.keys()))
2017-07-13 14:53:57 -07:00
env_name = args.env
2017-08-29 16:56:42 -07:00
if args.alg == "PPO":
config = ppo.DEFAULT_CONFIG.copy()
2017-09-02 17:20:56 -07:00
_check_and_update(config, json_config)
2017-08-29 16:56:42 -07:00
alg = ppo.PPOAgent(
2017-07-19 23:45:05 +00:00
env_name, config, upload_dir=args.upload_dir)
2017-08-29 16:56:42 -07:00
elif args.alg == "ES":
2017-07-19 23:45:05 +00:00
config = es.DEFAULT_CONFIG.copy()
2017-09-02 17:20:56 -07:00
_check_and_update(config, json_config)
2017-08-29 16:56:42 -07:00
alg = es.ESAgent(
2017-07-19 23:45:05 +00:00
env_name, config, upload_dir=args.upload_dir)
2017-07-13 14:53:57 -07:00
elif args.alg == "DQN":
2017-07-19 23:45:05 +00:00
config = dqn.DEFAULT_CONFIG.copy()
2017-09-02 17:20:56 -07:00
_check_and_update(config, json_config)
2017-08-29 16:56:42 -07:00
alg = dqn.DQNAgent(
2017-07-19 23:45:05 +00:00
env_name, config, upload_dir=args.upload_dir)
2017-07-13 14:53:57 -07:00
elif args.alg == "A3C":
2017-07-19 23:45:05 +00:00
config = a3c.DEFAULT_CONFIG.copy()
2017-09-02 17:20:56 -07:00
_check_and_update(config, json_config)
2017-08-29 16:56:42 -07:00
alg = a3c.A3CAgent(
2017-07-19 23:45:05 +00:00
env_name, config, upload_dir=args.upload_dir)
2017-07-13 14:53:57 -07:00
assert False, ("Unknown algorithm, check --alg argument. Valid "
2017-08-29 16:56:42 -07:00
"choices are PPO, ES, DQN and A3C.")
2017-07-13 14:53:57 -07:00
result_logger = ray.rllib.common.RLLibLogger(
os.path.join(alg.logdir, "result.json"))
2017-08-27 18:56:52 -07:00
if args.restore:
2017-08-07 19:05:48 -07:00
for i in range(args.num_iterations):
2017-07-13 14:53:57 -07:00
result = alg.train()
# We need to use a custom json serializer class so that NaNs get
# encoded as null as required by Athena.
json.dump(result._asdict(), result_logger,
2017-08-07 19:05:48 -07:00
2017-09-12 14:28:16 -07:00
print("== Iteration {} ==".format(alg.iteration))
2017-08-27 18:56:52 -07:00
if (i + 1) % args.checkpoint_freq == 0:
print("checkpoint path: {}".format(alg.save()))