2017-07-10 23:36:14 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
import os
|
2017-09-12 14:28:16 -07:00
|
|
|
import pprint
|
2017-08-07 19:05:48 -07:00
|
|
|
import sys
|
2017-07-10 23:36:14 +00:00
|
|
|
|
|
|
|
import ray
|
2017-08-29 16:56:42 -07:00
|
|
|
import ray.rllib.ppo as ppo
|
|
|
|
import ray.rllib.es as es
|
2017-07-10 23:36:14 +00:00
|
|
|
import ray.rllib.dqn as dqn
|
|
|
|
import ray.rllib.a3c as a3c
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description=("Train a reinforcement learning agent."))
|
2017-08-07 19:05:48 -07:00
|
|
|
parser.add_argument("--redis-address", default=None, type=str,
|
|
|
|
help="The Redis address of the cluster.")
|
|
|
|
parser.add_argument("--env", required=True, type=str,
|
|
|
|
help="The gym environment to use.")
|
|
|
|
parser.add_argument("--alg", required=True, type=str,
|
|
|
|
help="The reinforcement learning algorithm to use.")
|
|
|
|
parser.add_argument("--num-iterations", default=sys.maxsize, type=int,
|
|
|
|
help="The number of training iterations to run.")
|
|
|
|
parser.add_argument("--config", default="{}", type=str,
|
|
|
|
help="The configuration options of the algorithm.")
|
|
|
|
parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str,
|
|
|
|
help="Where the traces are stored.")
|
2017-08-27 18:56:52 -07:00
|
|
|
parser.add_argument("--checkpoint-freq", default=sys.maxsize, type=int,
|
|
|
|
help="How many iterations between checkpoints.")
|
|
|
|
parser.add_argument("--restore", default="", type=str,
|
|
|
|
help="If specified, restores state from this checkpoint.")
|
2017-07-10 23:36:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2017-07-13 14:53:57 -07:00
|
|
|
args = parser.parse_args()
|
2017-08-07 19:05:48 -07:00
|
|
|
json_config = json.loads(args.config)
|
2017-07-13 14:53:57 -07:00
|
|
|
|
2017-08-07 19:05:48 -07:00
|
|
|
ray.init(redis_address=args.redis_address)
|
2017-07-13 14:53:57 -07:00
|
|
|
|
2017-09-02 17:20:56 -07:00
|
|
|
def _check_and_update(config, json):
|
|
|
|
for k in json.keys():
|
|
|
|
if k not in config:
|
|
|
|
raise Exception(
|
|
|
|
"Unknown model config `{}`, all model configs: {}".format(
|
|
|
|
k, config.keys()))
|
|
|
|
config.update(json)
|
|
|
|
|
2017-07-13 14:53:57 -07:00
|
|
|
env_name = args.env
|
2017-08-29 16:56:42 -07:00
|
|
|
if args.alg == "PPO":
|
|
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
2017-09-02 17:20:56 -07:00
|
|
|
_check_and_update(config, json_config)
|
2017-08-29 16:56:42 -07:00
|
|
|
alg = ppo.PPOAgent(
|
2017-07-19 23:45:05 +00:00
|
|
|
env_name, config, upload_dir=args.upload_dir)
|
2017-08-29 16:56:42 -07:00
|
|
|
elif args.alg == "ES":
|
2017-07-19 23:45:05 +00:00
|
|
|
config = es.DEFAULT_CONFIG.copy()
|
2017-09-02 17:20:56 -07:00
|
|
|
_check_and_update(config, json_config)
|
2017-08-29 16:56:42 -07:00
|
|
|
alg = es.ESAgent(
|
2017-07-19 23:45:05 +00:00
|
|
|
env_name, config, upload_dir=args.upload_dir)
|
2017-07-13 14:53:57 -07:00
|
|
|
elif args.alg == "DQN":
|
2017-07-19 23:45:05 +00:00
|
|
|
config = dqn.DEFAULT_CONFIG.copy()
|
2017-09-02 17:20:56 -07:00
|
|
|
_check_and_update(config, json_config)
|
2017-08-29 16:56:42 -07:00
|
|
|
alg = dqn.DQNAgent(
|
2017-07-19 23:45:05 +00:00
|
|
|
env_name, config, upload_dir=args.upload_dir)
|
2017-07-13 14:53:57 -07:00
|
|
|
elif args.alg == "A3C":
|
2017-07-19 23:45:05 +00:00
|
|
|
config = a3c.DEFAULT_CONFIG.copy()
|
2017-09-02 17:20:56 -07:00
|
|
|
_check_and_update(config, json_config)
|
2017-08-29 16:56:42 -07:00
|
|
|
alg = a3c.A3CAgent(
|
2017-07-19 23:45:05 +00:00
|
|
|
env_name, config, upload_dir=args.upload_dir)
|
2017-07-13 14:53:57 -07:00
|
|
|
else:
|
|
|
|
assert False, ("Unknown algorithm, check --alg argument. Valid "
|
2017-08-29 16:56:42 -07:00
|
|
|
"choices are PPO, ES, DQN and A3C.")
|
2017-07-13 14:53:57 -07:00
|
|
|
|
|
|
|
result_logger = ray.rllib.common.RLLibLogger(
|
|
|
|
os.path.join(alg.logdir, "result.json"))
|
|
|
|
|
2017-08-27 18:56:52 -07:00
|
|
|
if args.restore:
|
|
|
|
alg.restore(args.restore)
|
|
|
|
|
2017-08-07 19:05:48 -07:00
|
|
|
for i in range(args.num_iterations):
|
2017-07-13 14:53:57 -07:00
|
|
|
result = alg.train()
|
|
|
|
|
|
|
|
# We need to use a custom json serializer class so that NaNs get
|
|
|
|
# encoded as null as required by Athena.
|
|
|
|
json.dump(result._asdict(), result_logger,
|
|
|
|
cls=ray.rllib.common.RLLibEncoder)
|
|
|
|
result_logger.write("\n")
|
2017-08-07 19:05:48 -07:00
|
|
|
|
2017-09-12 14:28:16 -07:00
|
|
|
print("== Iteration {} ==".format(alg.iteration))
|
|
|
|
pprint.pprint(result._asdict())
|
2017-08-27 18:56:52 -07:00
|
|
|
|
|
|
|
if (i + 1) % args.checkpoint_freq == 0:
|
|
|
|
print("checkpoint path: {}".format(alg.save()))
|