[tune] Support user-defined trainable functions / classes / envs with a shared object registry (#1226)

This commit is contained in:
Eric Liang 2017-11-20 17:52:43 -08:00 committed by Richard Liaw
parent 9233e496cc
commit 316f9e2bb7
38 changed files with 739 additions and 299 deletions

View file

@ -19,10 +19,19 @@ import shlex
# These lines added to enable Sphinx to work without installing Ray.
import mock
MOCK_MODULES = ["gym",
"gym.spaces",
"scipy",
"scipy.signal",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.layers",
"tensorflow.contrib.slim",
"tensorflow.contrib.rnn",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"pyarrow",
"pyarrow.plasma",
"smart_open",

View file

@ -25,7 +25,7 @@ You can run the code with
.. code-block:: bash
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=A3C --config='{"num_workers": N}'
python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}'
Reinforcement Learning
----------------------

View file

@ -18,7 +18,7 @@ on the ``Humanoid-v1`` gym environment.
.. code-block:: bash
python/ray/rllib/train.py --env=Humanoid-v1 --alg=ES
python/ray/rllib/train.py --env=Humanoid-v1 --run=ES
To train a policy on a cluster (e.g., using 900 workers), run the following.
@ -26,7 +26,7 @@ To train a policy on a cluster (e.g., using 900 workers), run the following.
python ray/python/ray/rllib/train.py \
--env=Humanoid-v1 \
--alg=ES \
--run=ES \
--redis-address=<redis-address> \
--config='{"num_workers": 900, "episodes_per_batch": 10000, "timesteps_per_batch": 100000}'

View file

@ -16,7 +16,7 @@ Then you can run the example as follows.
.. code-block:: bash
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=PPO
python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO
This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also
try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment.

View file

@ -30,7 +30,7 @@ You can run training with
::
python ray/python/ray/rllib/train.py --env CartPole-v0 --alg PPO --config '{"timesteps_per_batch": 10000}'
python ray/python/ray/rllib/train.py --env CartPole-v0 --run PPO --config '{"timesteps_per_batch": 10000}'
By default, the results will be logged to a subdirectory of ``/tmp/ray``.
This subdirectory will contain a file ``config.json`` which contains the
@ -51,7 +51,7 @@ The ``train.py`` script has a number of options you can show by running
The most important options are for choosing the environment
with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--alg``
can be used) and for choosing the algorithm with ``-run``
(available options are ``PPO``, ``A3C``, ``ES`` and ``DQN``). Each algorithm
has specific hyperparameters that can be set with ``--config``, see the
``DEFAULT_CONFIG`` variable in

View file

@ -8,7 +8,7 @@ You can run training with
::
python train.py --env CartPole-v0 --alg PPO
python train.py --env CartPole-v0 --run PPO
The available algorithms are:

View file

@ -0,0 +1,19 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.tune.registry import register_trainable
from ray.rllib import ppo, es, dqn, a3c
from ray.rllib.agent import _MockAgent, _SigmoidFakeData
def _register_all():
register_trainable("PPO", ppo.PPOAgent)
register_trainable("ES", es.ESAgent)
register_trainable("DQN", dqn.DQNAgent)
register_trainable("A3C", a3c.A3CAgent)
register_trainable("__fake", _MockAgent)
register_trainable("__sigmoid_fake_data", _SigmoidFakeData)
_register_all()

View file

@ -17,13 +17,15 @@ import uuid
import tensorflow as tf
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import ENV_CREATOR
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Agent(object):
class Agent(Trainable):
"""All RLlib agents extend this base class.
Agent objects retain internal model state between calls to train(), so
@ -33,39 +35,40 @@ class Agent(object):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Object registry.
"""
_allow_unknown_configs = False
_default_logdir = "/tmp/ray"
def __init__(
self, env_creator, config, logger_creator=None):
self, config={}, env=None, registry=None, logger_creator=None):
"""Initialize an RLLib agent.
Args:
env_creator (str|func): Name of the OpenAI gym environment to train
against, or a function that creates such an env.
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, it will be assumed empty.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
if type(env_creator) is str:
import gym
env_name = env_creator
self.env_creator = lambda: gym.make(env_name)
env = env or config.get("env")
if env:
config["env"] = env
if registry and registry.contains(ENV_CREATOR, env):
self.env_creator = registry.get(ENV_CREATOR, env)
else:
if hasattr(env_creator, "env_name"):
env_name = env_creator.env_name
else:
env_name = "custom"
self.env_creator = env_creator
import gym
self.env_creator = lambda: gym.make(env)
self.config = self._default_config.copy()
self.registry = registry
if not self._allow_unknown_configs:
for k in config.keys():
if k not in self.config:
if k not in self.config and k != "env":
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
@ -76,8 +79,7 @@ class Agent(object):
self.logdir = self._result_logger.logdir
else:
logdir_suffix = "{}_{}_{}".format(
env_name,
self._agent_name,
env, self._agent_name,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if not os.path.exists(self._default_logdir):
os.makedirs(self._default_logdir)
@ -214,7 +216,14 @@ class Agent(object):
def stop(self):
"""Releases all resources used by this agent."""
self._result_logger.close()
if self._initialize_ok:
self._result_logger.close()
self._stop()
def _stop(self):
"""Subclasses should override this for custom stopping."""
pass
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
@ -336,5 +345,4 @@ def get_agent_class(alg):
return _SigmoidFakeData
else:
raise Exception(
("Unknown algorithm {}, check --alg argument. Valid choices " +
"are PPO, ES, DQN, and A3C.").format(alg))
("Unknown algorithm {}.").format(alg))

View file

@ -355,7 +355,7 @@ class DQNAgent(Agent):
_agent_name = "DQN"
_default_config = DEFAULT_CONFIG
def stop(self):
def _stop(self):
for w in self.workers:
w.stop.remote()

View file

@ -29,8 +29,8 @@ CONFIGS = {
def test(use_object_store, alg_name):
cls = get_agent_class(alg_name)
alg1 = cls("CartPole-v0", CONFIGS[name])
alg2 = cls("CartPole-v0", CONFIGS[name])
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
for _ in range(3):
res = alg1.train()

View file

@ -9,12 +9,12 @@ import sys
import yaml
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import make_scheduler, run_experiments
from ray.tune.tune import _make_scheduler, run_experiments
EXAMPLE_USAGE = """
Training example:
./train.py --alg DQN --env CartPole-v0
./train.py --run DQN --env CartPole-v0
Grid search example:
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
@ -29,16 +29,24 @@ parser = make_parser(
epilog=EXAMPLE_USAGE)
# See also the base parser definition in ray/tune/config_parser.py
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--experiment-name", default="default", type=str,
help="Name of experiment dir.")
parser.add_argument("-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file.")
parser.add_argument(
"--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument(
"--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument(
"--experiment-name", default="default", type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
"-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file. Note that this "
"overrides any trial-specific options set via flags above.")
if __name__ == "__main__":
@ -50,13 +58,12 @@ if __name__ == "__main__":
# Note: keep this in sync with tune/config_parser.py
experiments = {
args.experiment_name: { # i.e. log to /tmp/ray/default
"alg": args.alg,
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"local_dir": args.local_dir,
"env": args.env,
"resources": resources_to_json(args.resources),
"stop": args.stop,
"config": args.config,
"config": dict(args.config, env=args.env),
"restore": args.restore,
"repeat": args.repeat,
"upload_dir": args.upload_dir,
@ -64,12 +71,12 @@ if __name__ == "__main__":
}
for exp in experiments.values():
if not exp.get("alg"):
parser.error("the following arguments are required: --alg")
if not exp.get("env"):
if not exp.get("run"):
parser.error("the following arguments are required: --run")
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")
run_experiments(
experiments, scheduler=make_scheduler(args),
experiments, scheduler=_make_scheduler(args),
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)

View file

@ -1,6 +1,6 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
run: PPO
stop:
episode_reward_mean: 200
time_total_s: 180

View file

@ -1,6 +1,6 @@
hopper-ppo:
env: Hopper-v1
alg: PPO
run: PPO
resources:
cpu: 64
gpu: 4

View file

@ -1,6 +1,6 @@
humanoid-es:
env: Humanoid-v1
alg: ES
run: ES
resources:
cpu: 100
driver_cpu_limit: 4

View file

@ -1,6 +1,6 @@
humanoid-ppo-gae:
env: Humanoid-v1
alg: PPO
run: PPO
stop:
episode_reward_mean: 6000
resources:

View file

@ -1,6 +1,6 @@
humanoid-ppo:
env: Humanoid-v1
alg: PPO
run: PPO
stop:
episode_reward_mean: 6000
resources:

View file

@ -1,6 +1,6 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
run: PPO
repeat: 3
stop:
episode_reward_mean: 200

View file

@ -1,6 +1,6 @@
pong-a3c:
env: PongDeterministic-v4
alg: A3C
run: A3C
resources:
cpu: 16
driver_cpu_limit: 1

View file

@ -1,6 +1,6 @@
pong-deterministic-dqn:
env: PongDeterministic-v4
alg: DQN
run: DQN
resources:
cpu: 1
gpu: 1
@ -28,7 +28,7 @@ pong-deterministic-dqn:
]
pong-noframeskip-dqn:
env: PongNoFrameskip-v4
alg: DQN
run: DQN
resources:
cpu: 1
gpu: 1

View file

@ -1,6 +1,6 @@
walker2d-v1-ppo:
env: Walker2d-v1
alg: PPO
run: PPO
resources:
cpu: 64
gpu: 4

View file

@ -86,7 +86,7 @@ expression.
cartpole-ppo:
env: CartPole-v0
alg: PPO
run: PPO
repeat: 2
stop:
episode_reward_mean: 200
@ -119,7 +119,7 @@ When using the Python API, the above is equivalent to the following program:
spec = {
"env": "CartPole-v0",
"alg": "PPO",
"run": "PPO",
"repeat": 2,
"stop": {
"episode_reward_mean": 200,
@ -166,9 +166,9 @@ Using ray.tune with Ray RLlib
Another way to use ray.tune is through RLlib's ``python/ray/rllib/train.py``
script. This script allows you to select between different RL algorithms with
the ``--alg`` option. For example, to train pong with the A3C algorithm, run:
the ``--run`` option. For example, to train pong with the A3C algorithm, run:
- ``./train.py --env=PongDeterministic-v4 --alg=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'``
- ``./train.py --env=PongDeterministic-v4 --run=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'``
or

View file

@ -0,0 +1,24 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.tune.error import TuneError
from ray.tune.tune import run_experiments
from ray.tune.registry import register_env, register_trainable
from ray.tune.result import TrainingResult
from ray.tune.script_runner import ScriptRunner
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search
register_trainable("script", ScriptRunner)
__all__ = [
"Trainable",
"TrainingResult",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run_experiments",
]

View file

@ -6,12 +6,18 @@ from __future__ import print_function
import argparse
import json
from ray.tune import TuneError
from ray.tune.trial import Resources
def json_to_resources(data):
if type(data) is str:
data = json.loads(data)
for k in data:
if k not in Resources._fields:
raise TuneError(
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 0), data.get("gpu", 0),
data.get("driver_cpu_limit"), data.get("driver_gpu_limit"))
@ -32,34 +38,49 @@ def make_parser(**kwargs):
parser = argparse.ArgumentParser(**kwargs)
# Note: keep this in sync with rllib/train.py
parser.add_argument("--alg", default=None, type=str,
help="The learning algorithm to train.")
parser.add_argument("--stop", default="{}", type=json.loads,
help="The stopping criteria, specified in JSON.")
parser.add_argument("--config", default="{}", type=json.loads,
help="The config of the algorithm, specified in JSON.")
parser.add_argument("--resources", default='{"cpu": 1}',
type=json_to_resources,
help="Amount of resources to allocate per trial.")
parser.add_argument("--repeat", default=1, type=int,
help="Number of times to repeat each trial.")
parser.add_argument("--local-dir", default="/tmp/ray", type=str,
help="Local dir to save training results to.")
parser.add_argument("--upload-dir", default="", type=str,
help="URI to upload training results to.")
parser.add_argument("--checkpoint-freq", default=0, type=int,
help="How many iterations between checkpoints.")
parser.add_argument("--scheduler", default="FIFO", type=str,
help="FIFO, MedianStopping, or HyperBand")
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
parser.add_argument(
"--run", default=None, type=str,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
parser.add_argument(
"--stop", default="{}", type=json.loads,
help="The stopping criteria, specified in JSON. The keys may be any "
"field in TrainingResult, e.g. "
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
"after 600 seconds or 100k timesteps, whichever is reached first.")
parser.add_argument(
"--config", default="{}", type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--resources", default='{"cpu": 1}', type=json_to_resources,
help="Machine resources to allocate per trial, e.g. "
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
"unless you specify them here.")
parser.add_argument(
"--repeat", default=1, type=int,
help="Number of times to repeat each trial.")
parser.add_argument(
"--local-dir", default="/tmp/ray", type=str,
help="Local dir to save training results to. Defaults to '/tmp/ray'.")
parser.add_argument(
"--upload-dir", default="", type=str,
help="Optional URI to upload training results to.")
parser.add_argument(
"--checkpoint-freq", default=0, type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
help="FIFO (default), MedianStopping, or HyperBand.")
parser.add_argument(
"--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
# Note: this currently only makes sense when running a single trial
parser.add_argument("--restore", default=None, type=str,
help="If specified, restore from this checkpoint.")
# TODO(ekl) environments are RL specific
parser.add_argument("--env", default=None, type=str,
help="The gym environment to use.")
return parser

8
python/ray/tune/error.py Normal file
View file

@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class TuneError(Exception):
"""General error class raised by ray.tune."""
pass

View file

@ -31,12 +31,9 @@ from __future__ import print_function
import argparse
import sys
import tempfile
import os
import time
import ray
from ray.tune.result import TrainingResult
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import grid_search, generate_trials
from ray.tune import grid_search, run_experiments, register_trainable
from tensorflow.examples.tutorials.mnist import input_data
@ -135,7 +132,12 @@ def bias_variable(shape):
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
for _ in range(10):
try:
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
break
except Exception:
time.sleep(5)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
@ -174,9 +176,8 @@ def main(_):
# !!! Report status to ray.tune !!!
if status_reporter:
status_reporter.report(TrainingResult(
timesteps_total=i,
mean_accuracy=train_accuracy))
status_reporter(
timesteps_total=i, mean_accuracy=train_accuracy)
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(
@ -201,34 +202,24 @@ def train(config={'activation': 'relu'}, reporter=None):
# !!! Example of using the ray.tune Python API !!!
if __name__ == '__main__':
runner = TrialRunner()
parser = argparse.ArgumentParser()
parser.add_argument(
'--fast', action='store_true', help='Finish quickly for testing')
args, _ = parser.parse_known_args()
spec = {
register_trainable('train_mnist', train)
mnist_spec = {
'run': 'train_mnist',
'stop': {
'mean_accuracy': 0.99,
'time_total_s': 600,
},
'config': {
'script_file_path': os.path.abspath(__file__),
'script_min_iter_time_s': 1,
'activation': grid_search(['relu', 'elu', 'tanh']),
},
}
# These arguments are only for testing purposes.
parser = argparse.ArgumentParser()
parser.add_argument('--fast', action='store_true',
help='Run minimal iterations.')
args, _ = parser.parse_known_args()
if args.fast:
spec['stop']['training_iteration'] = 2
mnist_spec['stop']['training_iteration'] = 2
for trial in generate_trials(spec):
runner.add_trial(trial)
ray.init()
while not runner.is_finished():
runner.step()
print(runner.debug_string())
run_experiments({'tune_mnist_test': mnist_spec})

View file

@ -1,4 +1,5 @@
tune_mnist:
run: script
repeat: 2
resources:
cpu: 1

View file

@ -0,0 +1,87 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from types import FunctionType
import ray
from ray.tune import TuneError
from ray.local_scheduler import ObjectID
from ray.tune.trainable import Trainable, wrap_function
TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
KNOWN_CATEGORIES = [TRAINABLE_CLASS, ENV_CREATOR]
def register_trainable(name, trainable):
"""Register a trainable function or class.
Args:
name (str): Name to register.
trainable (obj): Function or tune.Trainable clsas. Functions must
take (config, status_reporter) as arguments and will be
automatically converted into a class during registration.
"""
if isinstance(trainable, FunctionType):
trainable = wrap_function(trainable)
if not issubclass(trainable, Trainable):
raise TypeError(
"Second argument must be convertable to Trainable", trainable)
_default_registry.register(TRAINABLE_CLASS, name, trainable)
def register_env(name, env_creator):
"""Register a custom environment for use with RLlib.
Args:
name (str): Name to register.
env_creator (obj): Function that creates an env.
"""
if not isinstance(env_creator, FunctionType):
raise TypeError(
"Second argument must be a function.", env_creator)
_default_registry.register(ENV_CREATOR, name, env_creator)
def get_registry():
"""Use this to access the registry. This requires ray to be initialized."""
_default_registry.flush_values_to_object_store()
# returns a registry copy that doesn't include the hard refs
return _Registry(_default_registry._all_objects)
class _Registry(object):
def __init__(self, objs={}):
self._all_objects = objs
self._refs = [] # hard refs that prevent eviction of objects
def register(self, category, key, value):
if category not in KNOWN_CATEGORIES:
raise TuneError("Unknown category {} not among {}".format(
category, KNOWN_CATEGORIES))
self._all_objects[(category, key)] = value
def contains(self, category, key):
return (category, key) in self._all_objects
def get(self, category, key):
value = self._all_objects[(category, key)]
if type(value) == ObjectID:
return ray.get(value)
else:
return value
def flush_values_to_object_store(self):
for k, v in self._all_objects.items():
if type(v) != ObjectID:
obj = ray.put(v)
self._all_objects[k] = obj
self._refs.append(ray.get(obj))
_default_registry = _Registry()

View file

@ -18,6 +18,9 @@ TrainingResult = namedtuple("TrainingResult", [
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
# (Optional) If training is finished.
"done",
# (Optional) Custom metadata to report for this iteration.
"info",

View file

@ -10,6 +10,8 @@ import threading
import traceback
from ray.rllib.agent import Agent
from ray.tune import TuneError
from ray.tune.result import TrainingResult
class StatusReporter(object):
@ -17,33 +19,30 @@ class StatusReporter(object):
def __init__(self):
self._latest_result = None
self._last_result = None
self._lock = threading.Lock()
self._error = None
self._done = False
def report(self, result):
def __call__(self, **kwargs):
"""Report updated training status.
Args:
result (TrainingResult): Latest training result status. You must
kwargs (TrainingResult): Latest training result status. You must
at least define `timesteps_total`, but probably want to report
some of the other metrics as well.
"""
with self._lock:
self._latest_result = result
def set_error(self, error):
"""Report an error.
Args:
error (obj): Error object or string.
"""
self._error = error
self._latest_result = self._last_result = TrainingResult(**kwargs)
def _get_and_clear_status(self):
if self._error:
raise Exception("Error running script: " + str(self._error))
raise TuneError("Error running trial: " + str(self._error))
if self._done and not self._latest_result:
if not self._last_result:
raise TuneError("Trial finished without reporting result!")
return self._last_result._replace(done=True)
with self._lock:
res = self._latest_result
self._latest_result = None
@ -61,7 +60,7 @@ DEFAULT_CONFIG = {
"script_entrypoint": "train",
# batch results to at least this granularity
"script_min_iter_time_s": 5,
"script_min_iter_time_s": 1,
}
@ -79,9 +78,35 @@ class _RunnerThread(threading.Thread):
try:
self._entrypoint(*self._entrypoint_args)
except Exception as e:
self._status_reporter.set_error(e)
self._status_reporter._error = e
print("Runner thread raised: {}".format(traceback.format_exc()))
raise e
finally:
self._status_reporter._done = True
def import_function(file_path, function_name):
# strong assumption here that we're in a new process
file_path = os.path.expanduser(file_path)
sys.path.insert(0, os.path.dirname(file_path))
if hasattr(importlib, "util"):
# Python 3.4+
spec = importlib.util.spec_from_file_location(
"external_file", file_path)
external_file = importlib.util.module_from_spec(spec)
spec.loader.exec_module(external_file)
elif hasattr(importlib, "machinery"):
# Python 3.3
from importlib.machinery import SourceFileLoader
external_file = SourceFileLoader(
"external_file", file_path).load_module()
else:
# Python 2.x
import imp
external_file = imp.load_source("external_file", file_path)
if not external_file:
raise TuneError("Unable to import file at {}".format(file_path))
return getattr(external_file, function_name)
class ScriptRunner(Agent):
@ -92,51 +117,41 @@ class ScriptRunner(Agent):
_allow_unknown_configs = True
def _init(self):
# strong assumption here that we're in a new process
file_path = os.path.expanduser(self.config["script_file_path"])
sys.path.insert(0, os.path.dirname(file_path))
if hasattr(importlib, "util"):
# Python 3.4+
spec = importlib.util.spec_from_file_location(
"external_file", file_path)
external_file = importlib.util.module_from_spec(spec)
spec.loader.exec_module(external_file)
elif hasattr(importlib, "machinery"):
# Python 3.3
from importlib.machinery import SourceFileLoader
external_file = SourceFileLoader(
"external_file", file_path).load_module()
else:
# Python 2.x
import imp
external_file = imp.load_source("external_file", file_path)
if not external_file:
raise Exception(
"Unable to import file at {}".format(
self.config["script_file_path"]))
entrypoint = getattr(external_file, self.config["script_entrypoint"])
entrypoint = self._trainable_func()
if not entrypoint:
entrypoint = import_function(
self.config["script_file_path"],
self.config["script_entrypoint"])
self._status_reporter = StatusReporter()
scrubbed_config = self.config.copy()
for k in self._default_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, self.config, self._status_reporter)
entrypoint, scrubbed_config, self._status_reporter)
self._start_time = time.time()
self._last_reported_time = self._start_time
self._last_reported_timestep = 0
self._runner.start()
# Subclasses can override this to set the trainable func
# TODO(ekl) this isn't a very clean layering, we should refactor it
def _trainable_func(self):
return None
def train(self):
if not self._initialize_ok:
raise ValueError(
"Agent initialization failed, see previous errors")
poll_start = time.time()
now = time.time()
time.sleep(self.config["script_min_iter_time_s"])
result = self._status_reporter._get_and_clear_status()
while result is None or \
time.time() - poll_start < \
self.config["script_min_iter_time_s"]:
while result is None:
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
now = time.time()
if result.timesteps_total is None:
raise TuneError("Must specify timesteps_total in result", result)
# Include the negative loss to use as a stopping condition
if result.mean_loss is not None:
@ -164,6 +179,5 @@ class ScriptRunner(Agent):
return result
def stop(self):
def _stop(self):
self._status_reporter._stop()
Agent.stop(self)

View file

@ -0,0 +1,57 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class Trainable(object):
"""Interface for trainable models, functions, etc.
Implementing this interface is required to use ray.tune's full
functionality, though you can also get away with supplying just a
`my_train(config, reporter)` function and calling:
register_trainable("my_func", train)
to register it for use with tune. The function will be automatically
converted to this interface (sans checkpoint functionality)."""
def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
"""
raise NotImplementedError
def save(self):
"""Saves the current model state to a checkpoint.
Returns:
Checkpoint path that may be passed to restore().
"""
raise NotImplementedError
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
"""
raise NotImplementedError
def stop(self):
"""Releases all resources used by this class."""
pass
def wrap_function(train_func):
from ray.tune.script_runner import ScriptRunner
class WrappedFunc(ScriptRunner):
def _trainable_func(self):
return train_func
return WrappedFunc

View file

@ -8,8 +8,10 @@ import ray
import os
from collections import namedtuple
from ray.rllib.agent import get_agent_class
from ray.tune import TuneError
from ray.tune.logger import NoopLogger, UnifiedLogger
from ray.tune.result import TrainingResult
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
class Resources(
@ -60,7 +62,7 @@ class Trial(object):
ERROR = "ERROR"
def __init__(
self, env_creator, alg, config={}, local_dir='/tmp/ray',
self, trainable_name, config={}, local_dir='/tmp/ray',
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
stopping_criterion={}, checkpoint_freq=0,
restore_path=None, upload_dir=None):
@ -70,16 +72,18 @@ class Trial(object):
in ray.tune.config_parser.
"""
if not _default_registry.contains(
TRAINABLE_CLASS, trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
for k in stopping_criterion:
if k not in TrainingResult._fields:
raise TuneError(
"Stopping condition key `{}` must be one of {}".format(
k, TrainingResult._fields))
# Immutable config
self.env_creator = env_creator
if type(env_creator) is str:
self.env_name = env_creator
else:
if hasattr(env_creator, "env_name"):
self.env_name = env_creator.env_name
else:
self.env_name = "custom"
self.alg = alg
self.trainable_name = trainable_name
self.config = config
self.local_dir = local_dir
self.experiment_tag = experiment_tag
@ -92,7 +96,7 @@ class Trial(object):
self.last_result = None
self._checkpoint_path = restore_path
self._checkpoint_obj = None
self.agent = None
self.runner = None
self.status = Trial.PENDING
self.location = None
self.logdir = None
@ -105,7 +109,7 @@ class Trial(object):
be thrown.
"""
self._setup_agent()
self._setup_runner()
if self._checkpoint_path:
self.restore_from_path(self._checkpoint_path)
elif self._checkpoint_obj:
@ -128,11 +132,11 @@ class Trial(object):
self.status = Trial.TERMINATED
try:
if self.agent:
if self.runner:
stop_tasks = []
stop_tasks.append(self.agent.stop.remote())
stop_tasks.append(self.agent.__ray_terminate__.remote(
self.agent._ray_actor_id.id()))
stop_tasks.append(self.runner.stop.remote())
stop_tasks.append(self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
@ -140,10 +144,10 @@ class Trial(object):
print(("Stopping %s Actor timed out, "
"but moving on...") % self)
except Exception:
print("Error stopping agent:", traceback.format_exc())
print("Error stopping runner:", traceback.format_exc())
self.status = Trial.ERROR
finally:
self.agent = None
self.runner = None
if stop_logger and self.result_logger:
self.result_logger.close()
@ -159,7 +163,7 @@ class Trial(object):
self.stop(stop_logger=False)
self.status = Trial.PAUSED
except Exception:
print("Error pausing agent:", traceback.format_exc())
print("Error pausing runner:", traceback.format_exc())
self.status = Trial.ERROR
def unpause(self):
@ -177,11 +181,14 @@ class Trial(object):
"""Returns Ray future for one iteration of training."""
assert self.status == Trial.RUNNING, self.status
return self.agent.train.remote()
return self.runner.train.remote()
def should_stop(self, result):
"""Whether the given result meets this trial's stopping criteria."""
if result.done:
return True
for criteria, stop_value in self.stopping_criterion.items():
if getattr(result, criteria) >= stop_value:
return True
@ -240,9 +247,9 @@ class Trial(object):
obj = None
path = None
if to_object_store:
obj = self.agent.save_to_object.remote()
obj = self.runner.save_to_object.remote()
else:
path = ray.get(self.agent.save.remote())
path = ray.get(self.runner.save.remote())
self._checkpoint_path = path
self._checkpoint_obj = obj
@ -250,39 +257,40 @@ class Trial(object):
return path or obj
def restore_from_path(self, path):
"""Restores agent state from specified path.
"""Restores runner state from specified path.
Args:
path (str): A path where state will be restored.
"""
if self.agent is None:
print("Unable to restore - no agent")
if self.runner is None:
print("Unable to restore - no runner")
else:
try:
ray.get(self.agent.restore.remote(path))
ray.get(self.runner.restore.remote(path))
except Exception:
print("Error restoring agent:", traceback.format_exc())
print("Error restoring runner:", traceback.format_exc())
self.status = Trial.ERROR
def restore_from_obj(self, obj):
"""Restores agent state from the specified object."""
"""Restores runner state from the specified object."""
if self.agent is None:
print("Unable to restore - no agent")
if self.runner is None:
print("Unable to restore - no runner")
else:
try:
ray.get(self.agent.restore_from_object.remote(obj))
ray.get(self.runner.restore_from_object.remote(obj))
except Exception:
print("Error restoring agent:", traceback.format_exc())
print("Error restoring runner:", traceback.format_exc())
self.status = Trial.ERROR
def _setup_agent(self):
def _setup_runner(self):
self.status = Trial.RUNNING
agent_cls = get_agent_class(self.alg)
trainable_cls = get_registry().get(
TRAINABLE_CLASS, self.trainable_name)
cls = ray.remote(
num_cpus=self.resources.driver_cpu_limit,
num_gpus=self.resources.driver_gpu_limit)(agent_cls)
num_gpus=self.resources.driver_gpu_limit)(trainable_cls)
if not self.result_logger:
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
@ -292,19 +300,18 @@ class Trial(object):
self.config, self.logdir, self.upload_dir)
remote_logdir = self.logdir
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote agent to use a noop-logger.
self.agent = cls.remote(
self.env_creator, self.config,
lambda config: NoopLogger(config, remote_logdir))
# configure the remote runner to use a noop-logger.
self.runner = cls.remote(
config=self.config,
registry=get_registry(),
logger_creator=lambda config: NoopLogger(config, remote_logdir))
def __str__(self):
identifier = '{}_{}'.format(self.alg, self.env_name)
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
else:
identifier = self.trainable_name
if self.experiment_tag:
identifier += '_' + self.experiment_tag
identifier += "_" + self.experiment_tag
return identifier
def __eq__(self, other):
return str(self) == str(other)
def __hash__(self):
return hash(str(self))

View file

@ -7,6 +7,7 @@ import ray
import time
import traceback
from ray.tune import TuneError
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
@ -27,7 +28,7 @@ class TrialRunner(object):
While Ray itself provides resource management for tasks and actors, this is
not sufficient when scheduling trials that may instantiate multiple actors.
This is because if insufficient resources are available, concurrent agents
This is because if insufficient resources are available, concurrent trials
could deadlock waiting for new resources to become available. Furthermore,
oversubscribing the cluster could degrade training performance, leading to
misleading benchmark results.
@ -77,13 +78,15 @@ class TrialRunner(object):
else:
for trial in self._trials:
if trial.status == Trial.PENDING:
assert self.has_resources(trial.resources), \
("Insufficient cluster resources to launch trial",
(trial.resources, self._avail_resources))
if not self.has_resources(trial.resources):
raise TuneError(
"Insufficient cluster resources to launch trial",
(trial.resources, self._avail_resources))
elif trial.status == Trial.PAUSED:
assert False, "There are paused trials, but no more "\
"pending trials with sufficient resources."
assert False, "Called step when all trials finished?"
raise TuneError(
"There are paused trials, but no more pending "
"trials with sufficient resources.")
raise TuneError("Called step when all trials finished?")
def get_trials(self):
"""Returns the list of trials managed by this TrialRunner.
@ -141,14 +144,14 @@ class TrialRunner(object):
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
print("Error starting agent, retrying:", traceback.format_exc())
print("Error starting runner, retrying:", traceback.format_exc())
time.sleep(2)
trial.stop(error=True)
try:
trial.start()
self._running[trial.train_remote()] = trial
except Exception:
print("Error starting agent, abort:", traceback.format_exc())
print("Error starting runner, abort:", traceback.format_exc())
trial.stop(error=True)
# note that we don't return the resources, since they may
# have been lost

View file

@ -7,9 +7,10 @@ from __future__ import print_function
import argparse
import json
import sys
import yaml
import ray
from ray.tune import TuneError
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial
@ -44,24 +45,25 @@ parser.add_argument("-f", "--config-file", required=True, type=str,
help="Read experiment options from this JSON/YAML file.")
SCHEDULERS = {
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
"HyperBand": HyperBandScheduler,
}
def make_scheduler(args):
if args.scheduler in SCHEDULERS:
return SCHEDULERS[args.scheduler](**args.scheduler_config)
def _make_scheduler(args):
if args.scheduler in _SCHEDULERS:
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
else:
assert False, "Unknown scheduler: {}, should be one of {}".format(
args.scheduler, SCHEDULERS.keys())
raise TuneError(
"Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments, scheduler=None, **ray_args):
if scheduler is None:
scheduler = make_scheduler(args)
scheduler = FIFOScheduler()
runner = TrialRunner(scheduler)
for name, spec in experiments.items():
@ -77,16 +79,16 @@ def run_experiments(experiments, scheduler=None, **ray_args):
for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
print("Exit 1")
sys.exit(1)
raise TuneError("Trial did not complete", trial)
print("Exit 0")
return runner.get_trials()
if __name__ == "__main__":
import yaml
args = parser.parse_args(sys.argv[1:])
with open(args.config_file) as f:
experiments = yaml.load(f)
run_experiments(
experiments, make_scheduler(args), redis_address=args.redis_address,
experiments, _make_scheduler(args), redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)

View file

@ -5,6 +5,7 @@ import os
import random
import types
from ray.tune import TuneError
from ray.tune.trial import Trial
from ray.tune.config_parser import make_parser, json_to_resources
@ -20,6 +21,9 @@ def generate_trials(unresolved_spec, output_path=''):
output_path (str): Path where to store experiment outputs.
"""
if "run" not in unresolved_spec:
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
def to_argv(config):
argv = []
for k, v in config.items():
@ -34,15 +38,23 @@ def generate_trials(unresolved_spec, output_path=''):
i = 0
for _ in range(unresolved_spec.get("repeat", 1)):
for resolved_vars, spec in generate_variants(unresolved_spec):
args = parser.parse_args(to_argv(spec))
try:
# Special case the `env` param for RLlib by automatically
# moving it into the `config` section.
if "env" in spec:
spec["config"] = spec.get("config", {})
spec["config"]["env"] = spec["env"]
del spec["env"]
args = parser.parse_args(to_argv(spec))
except SystemExit:
raise TuneError("Error parsing args, see above message", spec)
if resolved_vars:
experiment_tag = "{}_{}".format(i, resolved_vars)
else:
experiment_tag = str(i)
i += 1
yield Trial(
env_creator=spec.get("env", lambda: None),
alg=spec.get("alg", "script"),
trainable_name=spec["run"],
config=spec.get("config", {}),
local_dir=os.path.join(args.local_dir, output_path),
experiment_tag=experiment_tag,
@ -105,8 +117,8 @@ _MAX_RESOLUTION_PASSES = 20
def _format_vars(resolved_vars):
out = []
for path, value in sorted(resolved_vars.items()):
if path[0] in ["alg", "env", "resources"]:
continue # these settings aren't usually search parameters
if path[0] in ["run", "env", "resources"]:
continue # TrialRunner already has these in the experiment_tag
pieces = []
last_string = True
for k in path[::-1]:
@ -229,9 +241,10 @@ def _try_resolve(v):
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
# Grid search values
grid_values = v["grid_search"]
assert isinstance(grid_values, list), \
"Grid search expected list of values, got: {}".format(
grid_values)
if not isinstance(grid_values, list):
raise TuneError(
"Grid search expected list of values, got: {}".format(
grid_values))
return False, grid_values
return True, v

View file

@ -35,7 +35,7 @@ def _parse_results(res_path):
pass
res_dict = _flatten_dict(json.loads(line.strip()))
except Exception as e:
print("Importing %s failed...Perhaps empty?" % res_path)
print("Importing %s failed...Perhaps empty?" % res_path, e)
return res_dict
@ -60,7 +60,7 @@ def _resolve(directory, result_fname):
def load_results_to_df(directory, result_name="result.json"):
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
for f in files if f == result_name]
data = [_resolve(directory, result_name) for directory in exp_directories]
data = [_resolve(d, result_name) for d in exp_directories]
return pd.DataFrame(data)

View file

@ -59,83 +59,83 @@ python $ROOT_DIR/multi_node_docker_test.py \
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v0 \
--alg A3C \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 16}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v1 \
--alg PPO \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"free_log_std": true}}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v1 \
--alg PPO \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "use_gae": false}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env Pendulum-v0 \
--alg ES \
--run ES \
--stop '{"training_iteration": 2}' \
--config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env Pong-v0 \
--alg ES \
--run ES \
--stop '{"training_iteration": 2}' \
--config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--alg A3C \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"use_lstm": false}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--alg DQN \
--run DQN \
--stop '{"training_iteration": 2}' \
--config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env FrozenLake-v0 \
--alg DQN \
--run DQN \
--stop '{"training_iteration": 2}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env FrozenLake-v0 \
--alg PPO \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \
--alg DQN \
--run DQN \
--stop '{"training_iteration": 2}' \
--config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env MontezumaRevenge-v0 \
--alg PPO \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env PongDeterministic-v4 \
--alg A3C \
--run A3C \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'

View file

@ -2,38 +2,201 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import os
import time
import unittest
import ray
from ray.rllib import _register_all
from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.variant_generator import generate_trials, grid_search, \
RecursiveDependencyError
class TrainableFunctionApiTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
def testRegisterTrainable(self):
def train(config, reporter):
pass
class A(object):
pass
class B(Trainable):
pass
register_trainable("foo", train)
register_trainable("foo", B)
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"env": "CartPole-v0",
}})
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"config": {"a": "b"},
}})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments({"foo": {
"bah": "this param is not allowed",
}})
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments({"foo": {
"run": grid_search("invalid grid search"),
}})
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments({"foo": {
"run": "asdf",
}})
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"stop": {"asdf": 1}
}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"resources": {"asdf": 1}
}})
self.assertRaises(TuneError, f)
def testBadReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertRaises(TuneError, f)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testAbruptReturn(self):
def train(config, reporter):
reporter(timesteps_total=100)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
class VariantGeneratorTest(unittest.TestCase):
def testParseToTrials(self):
trials = generate_trials({
"env": "Pong-v0",
"alg": "PPO",
"run": "PPO",
"repeat": 2,
"config": {
"env": "Pong-v0",
"foo": "bar"
},
}, "tune-pong")
trials = list(trials)
self.assertEqual(len(trials), 2)
self.assertEqual(trials[0].env_name, "Pong-v0")
self.assertEqual(trials[0].config, {"foo": "bar"})
self.assertEqual(trials[0].alg, "PPO")
self.assertEqual(str(trials[0]), "PPO_Pong-v0_0")
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
self.assertEqual(trials[0].trainable_name, "PPO")
self.assertEqual(trials[0].experiment_tag, "0")
self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong")
self.assertEqual(trials[1].experiment_tag, "1")
def testEval(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"foo": {
"eval": "2 + 2"
@ -48,7 +211,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testGridSearch(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"bar": {
"grid_search": [True, False]
@ -71,7 +234,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testGridSearchAndEval(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"qux": lambda spec: 2 + 2,
"bar": grid_search([True, False]),
@ -85,7 +248,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testConditionResolution(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"x": 1,
"y": lambda spec: spec.config.x + 1,
@ -98,7 +261,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testDependentLambda(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"x": grid_search([1, 2]),
"y": lambda spec: spec.config.x * 100,
@ -111,7 +274,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testDependentGridSearch(self):
trials = generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"x": grid_search([
lambda spec: spec.config.y * 100,
@ -128,7 +291,7 @@ class VariantGeneratorTest(unittest.TestCase):
def testRecursiveDep(self):
try:
list(generate_trials({
"env": "Pong-v0",
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
@ -142,10 +305,11 @@ class VariantGeneratorTest(unittest.TestCase):
class TrialRunnerTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testTrialStatus(self):
ray.init()
trial = Trial("CartPole-v0", "__fake")
trial = Trial("__fake")
self.assertEqual(trial.status, Trial.PENDING)
trial.start()
self.assertEqual(trial.status, Trial.RUNNING)
@ -156,11 +320,12 @@ class TrialRunnerTest(unittest.TestCase):
def testTrialErrorOnStart(self):
ray.init()
trial = Trial("CartPole-v0", "asdf")
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf")
try:
trial.start()
except Exception as e:
self.assertIn("Unknown algorithm", str(e))
self.assertIn("a class", str(e))
def testResourceScheduler(self):
ray.init(num_cpus=4, num_gpus=1)
@ -170,8 +335,8 @@ class TrialRunnerTest(unittest.TestCase):
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("CartPole-v0", "__fake", **kwargs),
Trial("CartPole-v0", "__fake", **kwargs)]
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -199,8 +364,8 @@ class TrialRunnerTest(unittest.TestCase):
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("CartPole-v0", "__fake", **kwargs),
Trial("CartPole-v0", "__fake", **kwargs)]
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -227,9 +392,10 @@ class TrialRunnerTest(unittest.TestCase):
"stopping_criterion": {"training_iteration": 1},
"resources": Resources(cpu=1, gpu=1),
}
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trials = [
Trial("CartPole-v0", "asdf", **kwargs),
Trial("CartPole-v0", "__fake", **kwargs)]
Trial("asdf", **kwargs),
Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@ -248,17 +414,17 @@ class TrialRunnerTest(unittest.TestCase):
"stopping_criterion": {"training_iteration": 1},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
path = trials[0].checkpoint()
kwargs["restore_path"] = path
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
@ -268,7 +434,7 @@ class TrialRunnerTest(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)
self.assertEqual(trials[1].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[1].agent.get_info.remote()), 1)
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
self.addCleanup(os.remove, path)
def testPauseThenResume(self):
@ -278,14 +444,14 @@ class TrialRunnerTest(unittest.TestCase):
"stopping_criterion": {"training_iteration": 2},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), None)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None)
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
trials[0].pause()
self.assertEqual(trials[0].status, Trial.PAUSED)
@ -295,7 +461,7 @@ class TrialRunnerTest(unittest.TestCase):
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), 1)
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1)
runner.step()
self.assertEqual(trials[0].status, Trial.TERMINATED)

View file

@ -20,8 +20,8 @@ def result(t, rew):
class EarlyStoppingSuite(unittest.TestCase):
def basicSetup(self, rule):
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result(i, i * 100)),
@ -62,7 +62,7 @@ class EarlyStoppingSuite(unittest.TestCase):
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("t3", "PPO")
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 10)),
TrialScheduler.CONTINUE)
@ -77,7 +77,7 @@ class EarlyStoppingSuite(unittest.TestCase):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
t3 = Trial("t3", "PPO")
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.CONTINUE)
@ -91,7 +91,7 @@ class EarlyStoppingSuite(unittest.TestCase):
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("t3", "PPO")
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.CONTINUE)
@ -105,7 +105,7 @@ class EarlyStoppingSuite(unittest.TestCase):
t1, t2 = self.basicSetup(rule)
rule.on_trial_complete(None, t1, result(10, 1000))
rule.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("t3", "PPO")
t3 = Trial("PPO")
self.assertEqual(
rule.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.CONTINUE)
@ -120,8 +120,8 @@ class EarlyStoppingSuite(unittest.TestCase):
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1,
time_attr='training_iteration', reward_attr='neg_mean_loss')
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
self.assertEqual(
rule.on_trial_result(None, t1, result2(i, i * 100)),
@ -166,7 +166,7 @@ class HyperbandSuite(unittest.TestCase):
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
sched = HyperBandScheduler()
for i in range(num_trials):
t = Trial("t%d" % i, "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()
return sched, runner
@ -211,7 +211,7 @@ class HyperbandSuite(unittest.TestCase):
def advancedSetup(self):
sched = self.basicSetup()
for i in range(4):
t = Trial("t%d" % (i + 20), "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
self.assertEqual(sched._cur_band_filled(), False)
@ -232,7 +232,7 @@ class HyperbandSuite(unittest.TestCase):
sched = HyperBandScheduler()
i = 0
while not sched._cur_band_filled():
t = Trial("t%d" % (i), "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
@ -244,7 +244,7 @@ class HyperbandSuite(unittest.TestCase):
sched = HyperBandScheduler(max_t=810)
i = 0
while not sched._cur_band_filled():
t = Trial("t%d" % (i), "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
@ -257,7 +257,7 @@ class HyperbandSuite(unittest.TestCase):
sched = HyperBandScheduler(max_t=1)
i = 0
while len(sched._hyperbands) < 2:
t = Trial("t%d" % (i), "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
i += 1
self.assertEqual(len(sched._hyperbands[0]), 5)
@ -415,7 +415,7 @@ class HyperbandSuite(unittest.TestCase):
status = sched.on_trial_result(
mock_runner, t, result(init_units, i))
self.assertEqual(status, TrialScheduler.CONTINUE)
t = Trial("t%d" % 100, "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
mock_runner._launch_trial(t)
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
@ -440,7 +440,7 @@ class HyperbandSuite(unittest.TestCase):
stats = self.default_statistics()
for i in range(stats["max_trials"]):
t = Trial("t%d" % i, "__fake")
t = Trial("__fake")
sched.on_trial_add(None, t)
runner = _MockTrialRunner()