2017-10-13 16:18:16 -07:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import numpy as np
import os
import random
from ray.tune.trial import Trial, Resources
def _resource_json(data):
values = json.loads(data)
2017-10-28 22:16:05 -07:00
return Resources(
values.get('cpu', 0), values.get('gpu', 0),
values.get('driver_cpu_limit'), values.get('driver_gpu_limit'))
2017-10-13 16:18:16 -07:00
def make_parser(description):
"""Returns a base argument parser for the ray.tune tool."""
parser = argparse.ArgumentParser(description=(description))
parser.add_argument("--alg", default="PPO", type=str,
help="The learning algorithm to train.")
parser.add_argument("--stop", default="{}", type=json.loads,
help="The stopping criteria, specified in JSON.")
parser.add_argument("--config", default="{}", type=json.loads,
help="The config of the algorithm, specified in JSON.")
parser.add_argument("--resources", default='{"cpu": 1}',
help="Amount of resources to allocate per trial.")
2017-10-18 11:49:28 -07:00
parser.add_argument("--num-trials", default=1, type=int,
2017-10-13 16:18:16 -07:00
help="Number of trials to evaluate.")
2017-10-18 11:49:28 -07:00
parser.add_argument("--local-dir", default="/tmp/ray", type=str,
2017-10-13 16:18:16 -07:00
help="Local dir to save training results to.")
2017-10-18 11:49:28 -07:00
parser.add_argument("--upload-dir", default=None, type=str,
2017-10-13 16:18:16 -07:00
help="URI to upload training results to.")
2017-10-18 11:49:28 -07:00
parser.add_argument("--checkpoint-freq", default=None, type=int,
2017-10-13 16:18:16 -07:00
help="How many iterations between checkpoints.")
# TODO(ekl) environments are RL specific
parser.add_argument("--env", default=None, type=str,
help="The gym environment to use.")
return parser
def parse_to_trials(config):
"""Parses a json config to the number of trials specified by the config.
The input config is a mapping from experiment names to an argument
dictionary describing a set of trials. These args include the parser args
documented in make_parser().
def resolve(agent_cfg, resolved_vars, i):
assert type(agent_cfg) == dict
cfg = agent_cfg.copy()
for p, val in cfg.items():
if type(val) == dict and "eval" in val:
cfg[p] = eval(val["eval"], {
"random": random,
"np": np,
}, {
"_i": i,
resolved_vars[p] = True
return cfg, resolved_vars
def to_argv(config):
argv = []
for k, v in config.items():
2017-10-18 11:49:28 -07:00
argv.append("--{}".format(k.replace("_", "-")))
2017-10-13 16:18:16 -07:00
if type(v) is str:
return argv
def param_str(config, resolved_vars):
return "_".join(
[k + "=" + str(v) for k, v in sorted(config.items())
if resolved_vars.get(k)])
parser = make_parser("Ray hyperparameter tuning tool")
trials = []
for experiment_name, exp_cfg in config.items():
args = parser.parse_args(to_argv(exp_cfg))
grid_search = _GridSearchGenerator(args.config)
for i in range(args.num_trials):
next_cfg, resolved_vars = grid_search.next()
resolved, resolved_vars = resolve(next_cfg, resolved_vars, i)
if resolved_vars:
2017-10-22 18:44:18 -07:00
experiment_tag = "{}_{}".format(
2017-10-13 16:18:16 -07:00
i, param_str(resolved, resolved_vars))
2017-10-22 18:44:18 -07:00
experiment_tag = str(i)
2017-10-13 16:18:16 -07:00
args.env, args.alg, resolved,
2017-10-22 18:44:18 -07:00
os.path.join(args.local_dir, experiment_name), experiment_tag,
2017-10-13 16:18:16 -07:00
args.resources, args.stop, args.checkpoint_freq, None,
return trials
class _GridSearchGenerator(object):
"""Generator that implements grid search over a set of value lists."""
def __init__(self, agent_cfg):
self.cfg = agent_cfg
self.grid_values = []
for p, val in sorted(agent_cfg.items()):
if type(val) == dict and "grid_search" in val:
assert type(val["grid_search"] == list)
self.grid_values.append((p, val["grid_search"]))
self.value_indices = [0] * len(self.grid_values)
def next(self):
cfg = self.cfg.copy()
resolved_vars = {}
for i, (k, values) in enumerate(self.grid_values):
idx = self.value_indices[i]
cfg[k] = values[idx]
resolved_vars[k] = True
if self.grid_values:
return cfg, resolved_vars
def _increment(self, i):
self.value_indices[i] += 1
if self.value_indices[i] >= len(self.grid_values[i][1]):
self.value_indices[i] = 0
if i + 1 < len(self.value_indices):
self._increment(i + 1)