mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Support user-defined trainable functions / classes / envs with a shared object registry (#1226)
This commit is contained in:
parent
9233e496cc
commit
316f9e2bb7
38 changed files with 739 additions and 299 deletions
|
@ -19,10 +19,19 @@ import shlex
|
|||
# These lines added to enable Sphinx to work without installing Ray.
|
||||
import mock
|
||||
MOCK_MODULES = ["gym",
|
||||
"gym.spaces",
|
||||
"scipy",
|
||||
"scipy.signal",
|
||||
"tensorflow",
|
||||
"tensorflow.contrib",
|
||||
"tensorflow.contrib.layers",
|
||||
"tensorflow.contrib.slim",
|
||||
"tensorflow.contrib.rnn",
|
||||
"tensorflow.core",
|
||||
"tensorflow.core.util",
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.client",
|
||||
"tensorflow.python.util",
|
||||
"pyarrow",
|
||||
"pyarrow.plasma",
|
||||
"smart_open",
|
||||
|
|
|
@ -25,7 +25,7 @@ You can run the code with
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=A3C --config='{"num_workers": N}'
|
||||
python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}'
|
||||
|
||||
Reinforcement Learning
|
||||
----------------------
|
||||
|
|
|
@ -18,7 +18,7 @@ on the ``Humanoid-v1`` gym environment.
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
python/ray/rllib/train.py --env=Humanoid-v1 --alg=ES
|
||||
python/ray/rllib/train.py --env=Humanoid-v1 --run=ES
|
||||
|
||||
To train a policy on a cluster (e.g., using 900 workers), run the following.
|
||||
|
||||
|
@ -26,7 +26,7 @@ To train a policy on a cluster (e.g., using 900 workers), run the following.
|
|||
|
||||
python ray/python/ray/rllib/train.py \
|
||||
--env=Humanoid-v1 \
|
||||
--alg=ES \
|
||||
--run=ES \
|
||||
--redis-address=<redis-address> \
|
||||
--config='{"num_workers": 900, "episodes_per_batch": 10000, "timesteps_per_batch": 100000}'
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Then you can run the example as follows.
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=PPO
|
||||
python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO
|
||||
|
||||
This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also
|
||||
try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment.
|
||||
|
|
|
@ -30,7 +30,7 @@ You can run training with
|
|||
|
||||
::
|
||||
|
||||
python ray/python/ray/rllib/train.py --env CartPole-v0 --alg PPO --config '{"timesteps_per_batch": 10000}'
|
||||
python ray/python/ray/rllib/train.py --env CartPole-v0 --run PPO --config '{"timesteps_per_batch": 10000}'
|
||||
|
||||
By default, the results will be logged to a subdirectory of ``/tmp/ray``.
|
||||
This subdirectory will contain a file ``config.json`` which contains the
|
||||
|
@ -51,7 +51,7 @@ The ``train.py`` script has a number of options you can show by running
|
|||
|
||||
The most important options are for choosing the environment
|
||||
with ``--env`` (any OpenAI gym environment including ones registered by the user
|
||||
can be used) and for choosing the algorithm with ``--alg``
|
||||
can be used) and for choosing the algorithm with ``-run``
|
||||
(available options are ``PPO``, ``A3C``, ``ES`` and ``DQN``). Each algorithm
|
||||
has specific hyperparameters that can be set with ``--config``, see the
|
||||
``DEFAULT_CONFIG`` variable in
|
||||
|
|
|
@ -8,7 +8,7 @@ You can run training with
|
|||
|
||||
::
|
||||
|
||||
python train.py --env CartPole-v0 --alg PPO
|
||||
python train.py --env CartPole-v0 --run PPO
|
||||
|
||||
The available algorithms are:
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.rllib import ppo, es, dqn, a3c
|
||||
from ray.rllib.agent import _MockAgent, _SigmoidFakeData
|
||||
|
||||
|
||||
def _register_all():
|
||||
register_trainable("PPO", ppo.PPOAgent)
|
||||
register_trainable("ES", es.ESAgent)
|
||||
register_trainable("DQN", dqn.DQNAgent)
|
||||
register_trainable("A3C", a3c.A3CAgent)
|
||||
register_trainable("__fake", _MockAgent)
|
||||
register_trainable("__sigmoid_fake_data", _SigmoidFakeData)
|
||||
|
||||
|
||||
_register_all()
|
|
@ -17,13 +17,15 @@ import uuid
|
|||
|
||||
import tensorflow as tf
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Agent(object):
|
||||
class Agent(Trainable):
|
||||
"""All RLlib agents extend this base class.
|
||||
|
||||
Agent objects retain internal model state between calls to train(), so
|
||||
|
@ -33,39 +35,40 @@ class Agent(object):
|
|||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Object registry.
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
_default_logdir = "/tmp/ray"
|
||||
|
||||
def __init__(
|
||||
self, env_creator, config, logger_creator=None):
|
||||
self, config={}, env=None, registry=None, logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
env_creator (str|func): Name of the OpenAI gym environment to train
|
||||
against, or a function that creates such an env.
|
||||
config (dict): Algorithm-specific configuration data.
|
||||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, it will be assumed empty.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
if type(env_creator) is str:
|
||||
import gym
|
||||
env_name = env_creator
|
||||
self.env_creator = lambda: gym.make(env_name)
|
||||
env = env or config.get("env")
|
||||
if env:
|
||||
config["env"] = env
|
||||
if registry and registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
if hasattr(env_creator, "env_name"):
|
||||
env_name = env_creator.env_name
|
||||
else:
|
||||
env_name = "custom"
|
||||
self.env_creator = env_creator
|
||||
|
||||
import gym
|
||||
self.env_creator = lambda: gym.make(env)
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
if not self._allow_unknown_configs:
|
||||
for k in config.keys():
|
||||
if k not in self.config:
|
||||
if k not in self.config and k != "env":
|
||||
raise Exception(
|
||||
"Unknown agent config `{}`, "
|
||||
"all agent configs: {}".format(k, self.config.keys()))
|
||||
|
@ -76,8 +79,7 @@ class Agent(object):
|
|||
self.logdir = self._result_logger.logdir
|
||||
else:
|
||||
logdir_suffix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self._agent_name,
|
||||
env, self._agent_name,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if not os.path.exists(self._default_logdir):
|
||||
os.makedirs(self._default_logdir)
|
||||
|
@ -214,7 +216,14 @@ class Agent(object):
|
|||
def stop(self):
|
||||
"""Releases all resources used by this agent."""
|
||||
|
||||
self._result_logger.close()
|
||||
if self._initialize_ok:
|
||||
self._result_logger.close()
|
||||
self._stop()
|
||||
|
||||
def _stop(self):
|
||||
"""Subclasses should override this for custom stopping."""
|
||||
|
||||
pass
|
||||
|
||||
def compute_action(self, observation):
|
||||
"""Computes an action using the current trained policy."""
|
||||
|
@ -336,5 +345,4 @@ def get_agent_class(alg):
|
|||
return _SigmoidFakeData
|
||||
else:
|
||||
raise Exception(
|
||||
("Unknown algorithm {}, check --alg argument. Valid choices " +
|
||||
"are PPO, ES, DQN, and A3C.").format(alg))
|
||||
("Unknown algorithm {}.").format(alg))
|
||||
|
|
|
@ -355,7 +355,7 @@ class DQNAgent(Agent):
|
|||
_agent_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
|
||||
def stop(self):
|
||||
def _stop(self):
|
||||
for w in self.workers:
|
||||
w.stop.remote()
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ CONFIGS = {
|
|||
|
||||
def test(use_object_store, alg_name):
|
||||
cls = get_agent_class(alg_name)
|
||||
alg1 = cls("CartPole-v0", CONFIGS[name])
|
||||
alg2 = cls("CartPole-v0", CONFIGS[name])
|
||||
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")
|
||||
|
||||
for _ in range(3):
|
||||
res = alg1.train()
|
||||
|
|
|
@ -9,12 +9,12 @@ import sys
|
|||
import yaml
|
||||
|
||||
from ray.tune.config_parser import make_parser, resources_to_json
|
||||
from ray.tune.tune import make_scheduler, run_experiments
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Training example:
|
||||
./train.py --alg DQN --env CartPole-v0
|
||||
./train.py --run DQN --env CartPole-v0
|
||||
|
||||
Grid search example:
|
||||
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
|
@ -29,16 +29,24 @@ parser = make_parser(
|
|||
epilog=EXAMPLE_USAGE)
|
||||
|
||||
# See also the base parser definition in ray/tune/config_parser.py
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-cpus", default=None, type=int,
|
||||
help="Number of CPUs to allocate to Ray.")
|
||||
parser.add_argument("--num-gpus", default=None, type=int,
|
||||
help="Number of GPUs to allocate to Ray.")
|
||||
parser.add_argument("--experiment-name", default="default", type=str,
|
||||
help="Name of experiment dir.")
|
||||
parser.add_argument("-f", "--config-file", default=None, type=str,
|
||||
help="If specified, use config options from this file.")
|
||||
parser.add_argument(
|
||||
"--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument(
|
||||
"--num-cpus", default=None, type=int,
|
||||
help="Number of CPUs to allocate to Ray.")
|
||||
parser.add_argument(
|
||||
"--num-gpus", default=None, type=int,
|
||||
help="Number of GPUs to allocate to Ray.")
|
||||
parser.add_argument(
|
||||
"--experiment-name", default="default", type=str,
|
||||
help="Name of the subdirectory under `local_dir` to put results in.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"-f", "--config-file", default=None, type=str,
|
||||
help="If specified, use config options from this file. Note that this "
|
||||
"overrides any trial-specific options set via flags above.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -50,13 +58,12 @@ if __name__ == "__main__":
|
|||
# Note: keep this in sync with tune/config_parser.py
|
||||
experiments = {
|
||||
args.experiment_name: { # i.e. log to /tmp/ray/default
|
||||
"alg": args.alg,
|
||||
"run": args.run,
|
||||
"checkpoint_freq": args.checkpoint_freq,
|
||||
"local_dir": args.local_dir,
|
||||
"env": args.env,
|
||||
"resources": resources_to_json(args.resources),
|
||||
"stop": args.stop,
|
||||
"config": args.config,
|
||||
"config": dict(args.config, env=args.env),
|
||||
"restore": args.restore,
|
||||
"repeat": args.repeat,
|
||||
"upload_dir": args.upload_dir,
|
||||
|
@ -64,12 +71,12 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
for exp in experiments.values():
|
||||
if not exp.get("alg"):
|
||||
parser.error("the following arguments are required: --alg")
|
||||
if not exp.get("env"):
|
||||
if not exp.get("run"):
|
||||
parser.error("the following arguments are required: --run")
|
||||
if not exp.get("env") and not exp.get("config", {}).get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
|
||||
run_experiments(
|
||||
experiments, scheduler=make_scheduler(args),
|
||||
experiments, scheduler=_make_scheduler(args),
|
||||
redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
run: PPO
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 180
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
hopper-ppo:
|
||||
env: Hopper-v1
|
||||
alg: PPO
|
||||
run: PPO
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
humanoid-es:
|
||||
env: Humanoid-v1
|
||||
alg: ES
|
||||
run: ES
|
||||
resources:
|
||||
cpu: 100
|
||||
driver_cpu_limit: 4
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
humanoid-ppo-gae:
|
||||
env: Humanoid-v1
|
||||
alg: PPO
|
||||
run: PPO
|
||||
stop:
|
||||
episode_reward_mean: 6000
|
||||
resources:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
humanoid-ppo:
|
||||
env: Humanoid-v1
|
||||
alg: PPO
|
||||
run: PPO
|
||||
stop:
|
||||
episode_reward_mean: 6000
|
||||
resources:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
run: PPO
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
pong-a3c:
|
||||
env: PongDeterministic-v4
|
||||
alg: A3C
|
||||
run: A3C
|
||||
resources:
|
||||
cpu: 16
|
||||
driver_cpu_limit: 1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
pong-deterministic-dqn:
|
||||
env: PongDeterministic-v4
|
||||
alg: DQN
|
||||
run: DQN
|
||||
resources:
|
||||
cpu: 1
|
||||
gpu: 1
|
||||
|
@ -28,7 +28,7 @@ pong-deterministic-dqn:
|
|||
]
|
||||
pong-noframeskip-dqn:
|
||||
env: PongNoFrameskip-v4
|
||||
alg: DQN
|
||||
run: DQN
|
||||
resources:
|
||||
cpu: 1
|
||||
gpu: 1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
walker2d-v1-ppo:
|
||||
env: Walker2d-v1
|
||||
alg: PPO
|
||||
run: PPO
|
||||
resources:
|
||||
cpu: 64
|
||||
gpu: 4
|
||||
|
|
|
@ -86,7 +86,7 @@ expression.
|
|||
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
alg: PPO
|
||||
run: PPO
|
||||
repeat: 2
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
|
@ -119,7 +119,7 @@ When using the Python API, the above is equivalent to the following program:
|
|||
|
||||
spec = {
|
||||
"env": "CartPole-v0",
|
||||
"alg": "PPO",
|
||||
"run": "PPO",
|
||||
"repeat": 2,
|
||||
"stop": {
|
||||
"episode_reward_mean": 200,
|
||||
|
@ -166,9 +166,9 @@ Using ray.tune with Ray RLlib
|
|||
|
||||
Another way to use ray.tune is through RLlib's ``python/ray/rllib/train.py``
|
||||
script. This script allows you to select between different RL algorithms with
|
||||
the ``--alg`` option. For example, to train pong with the A3C algorithm, run:
|
||||
the ``--run`` option. For example, to train pong with the A3C algorithm, run:
|
||||
|
||||
- ``./train.py --env=PongDeterministic-v4 --alg=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'``
|
||||
- ``./train.py --env=PongDeterministic-v4 --run=A3C --stop '{"time_total_s": 3200}' --resources '{"cpu": 8}' --config '{"num_workers": 8}'``
|
||||
|
||||
or
|
||||
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.tune import run_experiments
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
||||
register_trainable("script", ScriptRunner)
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TrainingResult",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run_experiments",
|
||||
]
|
|
@ -6,12 +6,18 @@ from __future__ import print_function
|
|||
import argparse
|
||||
import json
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if type(data) is str:
|
||||
data = json.loads(data)
|
||||
for k in data:
|
||||
if k not in Resources._fields:
|
||||
raise TuneError(
|
||||
"Unknown resource type {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 0), data.get("gpu", 0),
|
||||
data.get("driver_cpu_limit"), data.get("driver_gpu_limit"))
|
||||
|
@ -32,34 +38,49 @@ def make_parser(**kwargs):
|
|||
parser = argparse.ArgumentParser(**kwargs)
|
||||
|
||||
# Note: keep this in sync with rllib/train.py
|
||||
parser.add_argument("--alg", default=None, type=str,
|
||||
help="The learning algorithm to train.")
|
||||
parser.add_argument("--stop", default="{}", type=json.loads,
|
||||
help="The stopping criteria, specified in JSON.")
|
||||
parser.add_argument("--config", default="{}", type=json.loads,
|
||||
help="The config of the algorithm, specified in JSON.")
|
||||
parser.add_argument("--resources", default='{"cpu": 1}',
|
||||
type=json_to_resources,
|
||||
help="Amount of resources to allocate per trial.")
|
||||
parser.add_argument("--repeat", default=1, type=int,
|
||||
help="Number of times to repeat each trial.")
|
||||
parser.add_argument("--local-dir", default="/tmp/ray", type=str,
|
||||
help="Local dir to save training results to.")
|
||||
parser.add_argument("--upload-dir", default="", type=str,
|
||||
help="URI to upload training results to.")
|
||||
parser.add_argument("--checkpoint-freq", default=0, type=int,
|
||||
help="How many iterations between checkpoints.")
|
||||
parser.add_argument("--scheduler", default="FIFO", type=str,
|
||||
help="FIFO, MedianStopping, or HyperBand")
|
||||
parser.add_argument("--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
parser.add_argument(
|
||||
"--run", default=None, type=str,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
parser.add_argument(
|
||||
"--stop", default="{}", type=json.loads,
|
||||
help="The stopping criteria, specified in JSON. The keys may be any "
|
||||
"field in TrainingResult, e.g. "
|
||||
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
|
||||
"after 600 seconds or 100k timesteps, whichever is reached first.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams), "
|
||||
"specified in JSON.")
|
||||
parser.add_argument(
|
||||
"--resources", default='{"cpu": 1}', type=json_to_resources,
|
||||
help="Machine resources to allocate per trial, e.g. "
|
||||
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
|
||||
"unless you specify them here.")
|
||||
parser.add_argument(
|
||||
"--repeat", default=1, type=int,
|
||||
help="Number of times to repeat each trial.")
|
||||
parser.add_argument(
|
||||
"--local-dir", default="/tmp/ray", type=str,
|
||||
help="Local dir to save training results to. Defaults to '/tmp/ray'.")
|
||||
parser.add_argument(
|
||||
"--upload-dir", default="", type=str,
|
||||
help="Optional URI to upload training results to.")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq", default=0, type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
help="FIFO (default), MedianStopping, or HyperBand.")
|
||||
parser.add_argument(
|
||||
"--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
|
||||
# Note: this currently only makes sense when running a single trial
|
||||
parser.add_argument("--restore", default=None, type=str,
|
||||
help="If specified, restore from this checkpoint.")
|
||||
|
||||
# TODO(ekl) environments are RL specific
|
||||
parser.add_argument("--env", default=None, type=str,
|
||||
help="The gym environment to use.")
|
||||
|
||||
return parser
|
||||
|
|
8
python/ray/tune/error.py
Normal file
8
python/ray/tune/error.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class TuneError(Exception):
|
||||
"""General error class raised by ray.tune."""
|
||||
pass
|
|
@ -31,12 +31,9 @@ from __future__ import print_function
|
|||
import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import grid_search, generate_trials
|
||||
from ray.tune import grid_search, run_experiments, register_trainable
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
|
@ -135,7 +132,12 @@ def bias_variable(shape):
|
|||
|
||||
def main(_):
|
||||
# Import data
|
||||
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
|
||||
for _ in range(10):
|
||||
try:
|
||||
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(5)
|
||||
|
||||
# Create the model
|
||||
x = tf.placeholder(tf.float32, [None, 784])
|
||||
|
@ -174,9 +176,8 @@ def main(_):
|
|||
|
||||
# !!! Report status to ray.tune !!!
|
||||
if status_reporter:
|
||||
status_reporter.report(TrainingResult(
|
||||
timesteps_total=i,
|
||||
mean_accuracy=train_accuracy))
|
||||
status_reporter(
|
||||
timesteps_total=i, mean_accuracy=train_accuracy)
|
||||
|
||||
print('step %d, training accuracy %g' % (i, train_accuracy))
|
||||
train_step.run(
|
||||
|
@ -201,34 +202,24 @@ def train(config={'activation': 'relu'}, reporter=None):
|
|||
|
||||
# !!! Example of using the ray.tune Python API !!!
|
||||
if __name__ == '__main__':
|
||||
runner = TrialRunner()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--fast', action='store_true', help='Finish quickly for testing')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
spec = {
|
||||
register_trainable('train_mnist', train)
|
||||
mnist_spec = {
|
||||
'run': 'train_mnist',
|
||||
'stop': {
|
||||
'mean_accuracy': 0.99,
|
||||
'time_total_s': 600,
|
||||
},
|
||||
'config': {
|
||||
'script_file_path': os.path.abspath(__file__),
|
||||
'script_min_iter_time_s': 1,
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
},
|
||||
}
|
||||
|
||||
# These arguments are only for testing purposes.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--fast', action='store_true',
|
||||
help='Run minimal iterations.')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.fast:
|
||||
spec['stop']['training_iteration'] = 2
|
||||
mnist_spec['stop']['training_iteration'] = 2
|
||||
|
||||
for trial in generate_trials(spec):
|
||||
runner.add_trial(trial)
|
||||
|
||||
ray.init()
|
||||
|
||||
while not runner.is_finished():
|
||||
runner.step()
|
||||
print(runner.debug_string())
|
||||
run_experiments({'tune_mnist_test': mnist_spec})
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
tune_mnist:
|
||||
run: script
|
||||
repeat: 2
|
||||
resources:
|
||||
cpu: 1
|
||||
|
|
87
python/ray/tune/registry.py
Normal file
87
python/ray/tune/registry.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.local_scheduler import ObjectID
|
||||
from ray.tune.trainable import Trainable, wrap_function
|
||||
|
||||
TRAINABLE_CLASS = "trainable_class"
|
||||
ENV_CREATOR = "env_creator"
|
||||
KNOWN_CATEGORIES = [TRAINABLE_CLASS, ENV_CREATOR]
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
"""Register a trainable function or class.
|
||||
|
||||
Args:
|
||||
name (str): Name to register.
|
||||
trainable (obj): Function or tune.Trainable clsas. Functions must
|
||||
take (config, status_reporter) as arguments and will be
|
||||
automatically converted into a class during registration.
|
||||
"""
|
||||
|
||||
if isinstance(trainable, FunctionType):
|
||||
trainable = wrap_function(trainable)
|
||||
if not issubclass(trainable, Trainable):
|
||||
raise TypeError(
|
||||
"Second argument must be convertable to Trainable", trainable)
|
||||
_default_registry.register(TRAINABLE_CLASS, name, trainable)
|
||||
|
||||
|
||||
def register_env(name, env_creator):
|
||||
"""Register a custom environment for use with RLlib.
|
||||
|
||||
Args:
|
||||
name (str): Name to register.
|
||||
env_creator (obj): Function that creates an env.
|
||||
"""
|
||||
|
||||
if not isinstance(env_creator, FunctionType):
|
||||
raise TypeError(
|
||||
"Second argument must be a function.", env_creator)
|
||||
_default_registry.register(ENV_CREATOR, name, env_creator)
|
||||
|
||||
|
||||
def get_registry():
|
||||
"""Use this to access the registry. This requires ray to be initialized."""
|
||||
|
||||
_default_registry.flush_values_to_object_store()
|
||||
|
||||
# returns a registry copy that doesn't include the hard refs
|
||||
return _Registry(_default_registry._all_objects)
|
||||
|
||||
|
||||
class _Registry(object):
|
||||
def __init__(self, objs={}):
|
||||
self._all_objects = objs
|
||||
self._refs = [] # hard refs that prevent eviction of objects
|
||||
|
||||
def register(self, category, key, value):
|
||||
if category not in KNOWN_CATEGORIES:
|
||||
raise TuneError("Unknown category {} not among {}".format(
|
||||
category, KNOWN_CATEGORIES))
|
||||
self._all_objects[(category, key)] = value
|
||||
|
||||
def contains(self, category, key):
|
||||
return (category, key) in self._all_objects
|
||||
|
||||
def get(self, category, key):
|
||||
value = self._all_objects[(category, key)]
|
||||
if type(value) == ObjectID:
|
||||
return ray.get(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def flush_values_to_object_store(self):
|
||||
for k, v in self._all_objects.items():
|
||||
if type(v) != ObjectID:
|
||||
obj = ray.put(v)
|
||||
self._all_objects[k] = obj
|
||||
self._refs.append(ray.get(obj))
|
||||
|
||||
|
||||
_default_registry = _Registry()
|
|
@ -18,6 +18,9 @@ TrainingResult = namedtuple("TrainingResult", [
|
|||
# (Required) Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
|
||||
# (Optional) If training is finished.
|
||||
"done",
|
||||
|
||||
# (Optional) Custom metadata to report for this iteration.
|
||||
"info",
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ import threading
|
|||
import traceback
|
||||
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
class StatusReporter(object):
|
||||
|
@ -17,33 +19,30 @@ class StatusReporter(object):
|
|||
|
||||
def __init__(self):
|
||||
self._latest_result = None
|
||||
self._last_result = None
|
||||
self._lock = threading.Lock()
|
||||
self._error = None
|
||||
self._done = False
|
||||
|
||||
def report(self, result):
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
||||
Args:
|
||||
result (TrainingResult): Latest training result status. You must
|
||||
kwargs (TrainingResult): Latest training result status. You must
|
||||
at least define `timesteps_total`, but probably want to report
|
||||
some of the other metrics as well.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
self._latest_result = result
|
||||
|
||||
def set_error(self, error):
|
||||
"""Report an error.
|
||||
|
||||
Args:
|
||||
error (obj): Error object or string.
|
||||
"""
|
||||
|
||||
self._error = error
|
||||
self._latest_result = self._last_result = TrainingResult(**kwargs)
|
||||
|
||||
def _get_and_clear_status(self):
|
||||
if self._error:
|
||||
raise Exception("Error running script: " + str(self._error))
|
||||
raise TuneError("Error running trial: " + str(self._error))
|
||||
if self._done and not self._latest_result:
|
||||
if not self._last_result:
|
||||
raise TuneError("Trial finished without reporting result!")
|
||||
return self._last_result._replace(done=True)
|
||||
with self._lock:
|
||||
res = self._latest_result
|
||||
self._latest_result = None
|
||||
|
@ -61,7 +60,7 @@ DEFAULT_CONFIG = {
|
|||
"script_entrypoint": "train",
|
||||
|
||||
# batch results to at least this granularity
|
||||
"script_min_iter_time_s": 5,
|
||||
"script_min_iter_time_s": 1,
|
||||
}
|
||||
|
||||
|
||||
|
@ -79,9 +78,35 @@ class _RunnerThread(threading.Thread):
|
|||
try:
|
||||
self._entrypoint(*self._entrypoint_args)
|
||||
except Exception as e:
|
||||
self._status_reporter.set_error(e)
|
||||
self._status_reporter._error = e
|
||||
print("Runner thread raised: {}".format(traceback.format_exc()))
|
||||
raise e
|
||||
finally:
|
||||
self._status_reporter._done = True
|
||||
|
||||
|
||||
def import_function(file_path, function_name):
|
||||
# strong assumption here that we're in a new process
|
||||
file_path = os.path.expanduser(file_path)
|
||||
sys.path.insert(0, os.path.dirname(file_path))
|
||||
if hasattr(importlib, "util"):
|
||||
# Python 3.4+
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"external_file", file_path)
|
||||
external_file = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(external_file)
|
||||
elif hasattr(importlib, "machinery"):
|
||||
# Python 3.3
|
||||
from importlib.machinery import SourceFileLoader
|
||||
external_file = SourceFileLoader(
|
||||
"external_file", file_path).load_module()
|
||||
else:
|
||||
# Python 2.x
|
||||
import imp
|
||||
external_file = imp.load_source("external_file", file_path)
|
||||
if not external_file:
|
||||
raise TuneError("Unable to import file at {}".format(file_path))
|
||||
return getattr(external_file, function_name)
|
||||
|
||||
|
||||
class ScriptRunner(Agent):
|
||||
|
@ -92,51 +117,41 @@ class ScriptRunner(Agent):
|
|||
_allow_unknown_configs = True
|
||||
|
||||
def _init(self):
|
||||
# strong assumption here that we're in a new process
|
||||
file_path = os.path.expanduser(self.config["script_file_path"])
|
||||
sys.path.insert(0, os.path.dirname(file_path))
|
||||
if hasattr(importlib, "util"):
|
||||
# Python 3.4+
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"external_file", file_path)
|
||||
external_file = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(external_file)
|
||||
elif hasattr(importlib, "machinery"):
|
||||
# Python 3.3
|
||||
from importlib.machinery import SourceFileLoader
|
||||
external_file = SourceFileLoader(
|
||||
"external_file", file_path).load_module()
|
||||
else:
|
||||
# Python 2.x
|
||||
import imp
|
||||
external_file = imp.load_source("external_file", file_path)
|
||||
if not external_file:
|
||||
raise Exception(
|
||||
"Unable to import file at {}".format(
|
||||
self.config["script_file_path"]))
|
||||
entrypoint = getattr(external_file, self.config["script_entrypoint"])
|
||||
entrypoint = self._trainable_func()
|
||||
if not entrypoint:
|
||||
entrypoint = import_function(
|
||||
self.config["script_file_path"],
|
||||
self.config["script_entrypoint"])
|
||||
self._status_reporter = StatusReporter()
|
||||
scrubbed_config = self.config.copy()
|
||||
for k in self._default_config:
|
||||
del scrubbed_config[k]
|
||||
self._runner = _RunnerThread(
|
||||
entrypoint, self.config, self._status_reporter)
|
||||
entrypoint, scrubbed_config, self._status_reporter)
|
||||
self._start_time = time.time()
|
||||
self._last_reported_time = self._start_time
|
||||
self._last_reported_timestep = 0
|
||||
self._runner.start()
|
||||
|
||||
# Subclasses can override this to set the trainable func
|
||||
# TODO(ekl) this isn't a very clean layering, we should refactor it
|
||||
def _trainable_func(self):
|
||||
return None
|
||||
|
||||
def train(self):
|
||||
if not self._initialize_ok:
|
||||
raise ValueError(
|
||||
"Agent initialization failed, see previous errors")
|
||||
|
||||
poll_start = time.time()
|
||||
now = time.time()
|
||||
time.sleep(self.config["script_min_iter_time_s"])
|
||||
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None or \
|
||||
time.time() - poll_start < \
|
||||
self.config["script_min_iter_time_s"]:
|
||||
while result is None:
|
||||
time.sleep(1)
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
|
||||
now = time.time()
|
||||
if result.timesteps_total is None:
|
||||
raise TuneError("Must specify timesteps_total in result", result)
|
||||
|
||||
# Include the negative loss to use as a stopping condition
|
||||
if result.mean_loss is not None:
|
||||
|
@ -164,6 +179,5 @@ class ScriptRunner(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def stop(self):
|
||||
def _stop(self):
|
||||
self._status_reporter._stop()
|
||||
Agent.stop(self)
|
||||
|
|
57
python/ray/tune/trainable.py
Normal file
57
python/ray/tune/trainable.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class Trainable(object):
|
||||
"""Interface for trainable models, functions, etc.
|
||||
|
||||
Implementing this interface is required to use ray.tune's full
|
||||
functionality, though you can also get away with supplying just a
|
||||
`my_train(config, reporter)` function and calling:
|
||||
|
||||
register_trainable("my_func", train)
|
||||
|
||||
to register it for use with tune. The function will be automatically
|
||||
converted to this interface (sans checkpoint functionality)."""
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
"""Saves the current model state to a checkpoint.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restores training state from a given model checkpoint.
|
||||
|
||||
These checkpoints are returned from calls to save().
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
"""Releases all resources used by this class."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
from ray.tune.script_runner import ScriptRunner
|
||||
|
||||
class WrappedFunc(ScriptRunner):
|
||||
def _trainable_func(self):
|
||||
return train_func
|
||||
|
||||
return WrappedFunc
|
|
@ -8,8 +8,10 @@ import ray
|
|||
import os
|
||||
|
||||
from collections import namedtuple
|
||||
from ray.rllib.agent import get_agent_class
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import NoopLogger, UnifiedLogger
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS
|
||||
|
||||
|
||||
class Resources(
|
||||
|
@ -60,7 +62,7 @@ class Trial(object):
|
|||
ERROR = "ERROR"
|
||||
|
||||
def __init__(
|
||||
self, env_creator, alg, config={}, local_dir='/tmp/ray',
|
||||
self, trainable_name, config={}, local_dir='/tmp/ray',
|
||||
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion={}, checkpoint_freq=0,
|
||||
restore_path=None, upload_dir=None):
|
||||
|
@ -70,16 +72,18 @@ class Trial(object):
|
|||
in ray.tune.config_parser.
|
||||
"""
|
||||
|
||||
if not _default_registry.contains(
|
||||
TRAINABLE_CLASS, trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
for k in stopping_criterion:
|
||||
if k not in TrainingResult._fields:
|
||||
raise TuneError(
|
||||
"Stopping condition key `{}` must be one of {}".format(
|
||||
k, TrainingResult._fields))
|
||||
|
||||
# Immutable config
|
||||
self.env_creator = env_creator
|
||||
if type(env_creator) is str:
|
||||
self.env_name = env_creator
|
||||
else:
|
||||
if hasattr(env_creator, "env_name"):
|
||||
self.env_name = env_creator.env_name
|
||||
else:
|
||||
self.env_name = "custom"
|
||||
self.alg = alg
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config
|
||||
self.local_dir = local_dir
|
||||
self.experiment_tag = experiment_tag
|
||||
|
@ -92,7 +96,7 @@ class Trial(object):
|
|||
self.last_result = None
|
||||
self._checkpoint_path = restore_path
|
||||
self._checkpoint_obj = None
|
||||
self.agent = None
|
||||
self.runner = None
|
||||
self.status = Trial.PENDING
|
||||
self.location = None
|
||||
self.logdir = None
|
||||
|
@ -105,7 +109,7 @@ class Trial(object):
|
|||
be thrown.
|
||||
"""
|
||||
|
||||
self._setup_agent()
|
||||
self._setup_runner()
|
||||
if self._checkpoint_path:
|
||||
self.restore_from_path(self._checkpoint_path)
|
||||
elif self._checkpoint_obj:
|
||||
|
@ -128,11 +132,11 @@ class Trial(object):
|
|||
self.status = Trial.TERMINATED
|
||||
|
||||
try:
|
||||
if self.agent:
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.agent.stop.remote())
|
||||
stop_tasks.append(self.agent.__ray_terminate__.remote(
|
||||
self.agent._ray_actor_id.id()))
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
stop_tasks.append(self.runner.__ray_terminate__.remote(
|
||||
self.runner._ray_actor_id.id()))
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
|
@ -140,10 +144,10 @@ class Trial(object):
|
|||
print(("Stopping %s Actor timed out, "
|
||||
"but moving on...") % self)
|
||||
except Exception:
|
||||
print("Error stopping agent:", traceback.format_exc())
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
finally:
|
||||
self.agent = None
|
||||
self.runner = None
|
||||
|
||||
if stop_logger and self.result_logger:
|
||||
self.result_logger.close()
|
||||
|
@ -159,7 +163,7 @@ class Trial(object):
|
|||
self.stop(stop_logger=False)
|
||||
self.status = Trial.PAUSED
|
||||
except Exception:
|
||||
print("Error pausing agent:", traceback.format_exc())
|
||||
print("Error pausing runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def unpause(self):
|
||||
|
@ -177,11 +181,14 @@ class Trial(object):
|
|||
"""Returns Ray future for one iteration of training."""
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
return self.agent.train.remote()
|
||||
return self.runner.train.remote()
|
||||
|
||||
def should_stop(self, result):
|
||||
"""Whether the given result meets this trial's stopping criteria."""
|
||||
|
||||
if result.done:
|
||||
return True
|
||||
|
||||
for criteria, stop_value in self.stopping_criterion.items():
|
||||
if getattr(result, criteria) >= stop_value:
|
||||
return True
|
||||
|
@ -240,9 +247,9 @@ class Trial(object):
|
|||
obj = None
|
||||
path = None
|
||||
if to_object_store:
|
||||
obj = self.agent.save_to_object.remote()
|
||||
obj = self.runner.save_to_object.remote()
|
||||
else:
|
||||
path = ray.get(self.agent.save.remote())
|
||||
path = ray.get(self.runner.save.remote())
|
||||
self._checkpoint_path = path
|
||||
self._checkpoint_obj = obj
|
||||
|
||||
|
@ -250,39 +257,40 @@ class Trial(object):
|
|||
return path or obj
|
||||
|
||||
def restore_from_path(self, path):
|
||||
"""Restores agent state from specified path.
|
||||
"""Restores runner state from specified path.
|
||||
|
||||
Args:
|
||||
path (str): A path where state will be restored.
|
||||
"""
|
||||
|
||||
if self.agent is None:
|
||||
print("Unable to restore - no agent")
|
||||
if self.runner is None:
|
||||
print("Unable to restore - no runner")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.agent.restore.remote(path))
|
||||
ray.get(self.runner.restore.remote(path))
|
||||
except Exception:
|
||||
print("Error restoring agent:", traceback.format_exc())
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def restore_from_obj(self, obj):
|
||||
"""Restores agent state from the specified object."""
|
||||
"""Restores runner state from the specified object."""
|
||||
|
||||
if self.agent is None:
|
||||
print("Unable to restore - no agent")
|
||||
if self.runner is None:
|
||||
print("Unable to restore - no runner")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.agent.restore_from_object.remote(obj))
|
||||
ray.get(self.runner.restore_from_object.remote(obj))
|
||||
except Exception:
|
||||
print("Error restoring agent:", traceback.format_exc())
|
||||
print("Error restoring runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def _setup_agent(self):
|
||||
def _setup_runner(self):
|
||||
self.status = Trial.RUNNING
|
||||
agent_cls = get_agent_class(self.alg)
|
||||
trainable_cls = get_registry().get(
|
||||
TRAINABLE_CLASS, self.trainable_name)
|
||||
cls = ray.remote(
|
||||
num_cpus=self.resources.driver_cpu_limit,
|
||||
num_gpus=self.resources.driver_gpu_limit)(agent_cls)
|
||||
num_gpus=self.resources.driver_gpu_limit)(trainable_cls)
|
||||
if not self.result_logger:
|
||||
if not os.path.exists(self.local_dir):
|
||||
os.makedirs(self.local_dir)
|
||||
|
@ -292,19 +300,18 @@ class Trial(object):
|
|||
self.config, self.logdir, self.upload_dir)
|
||||
remote_logdir = self.logdir
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote agent to use a noop-logger.
|
||||
self.agent = cls.remote(
|
||||
self.env_creator, self.config,
|
||||
lambda config: NoopLogger(config, remote_logdir))
|
||||
# configure the remote runner to use a noop-logger.
|
||||
self.runner = cls.remote(
|
||||
config=self.config,
|
||||
registry=get_registry(),
|
||||
logger_creator=lambda config: NoopLogger(config, remote_logdir))
|
||||
|
||||
def __str__(self):
|
||||
identifier = '{}_{}'.format(self.alg, self.env_name)
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += '_' + self.experiment_tag
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
|
|
@ -7,6 +7,7 @@ import ray
|
|||
import time
|
||||
import traceback
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
@ -27,7 +28,7 @@ class TrialRunner(object):
|
|||
|
||||
While Ray itself provides resource management for tasks and actors, this is
|
||||
not sufficient when scheduling trials that may instantiate multiple actors.
|
||||
This is because if insufficient resources are available, concurrent agents
|
||||
This is because if insufficient resources are available, concurrent trials
|
||||
could deadlock waiting for new resources to become available. Furthermore,
|
||||
oversubscribing the cluster could degrade training performance, leading to
|
||||
misleading benchmark results.
|
||||
|
@ -77,13 +78,15 @@ class TrialRunner(object):
|
|||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
assert self.has_resources(trial.resources), \
|
||||
("Insufficient cluster resources to launch trial",
|
||||
(trial.resources, self._avail_resources))
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError(
|
||||
"Insufficient cluster resources to launch trial",
|
||||
(trial.resources, self._avail_resources))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
assert False, "There are paused trials, but no more "\
|
||||
"pending trials with sufficient resources."
|
||||
assert False, "Called step when all trials finished?"
|
||||
raise TuneError(
|
||||
"There are paused trials, but no more pending "
|
||||
"trials with sufficient resources.")
|
||||
raise TuneError("Called step when all trials finished?")
|
||||
|
||||
def get_trials(self):
|
||||
"""Returns the list of trials managed by this TrialRunner.
|
||||
|
@ -141,14 +144,14 @@ class TrialRunner(object):
|
|||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
print("Error starting agent, retrying:", traceback.format_exc())
|
||||
print("Error starting runner, retrying:", traceback.format_exc())
|
||||
time.sleep(2)
|
||||
trial.stop(error=True)
|
||||
try:
|
||||
trial.start()
|
||||
self._running[trial.train_remote()] = trial
|
||||
except Exception:
|
||||
print("Error starting agent, abort:", traceback.format_exc())
|
||||
print("Error starting runner, abort:", traceback.format_exc())
|
||||
trial.stop(error=True)
|
||||
# note that we don't return the resources, since they may
|
||||
# have been lost
|
||||
|
|
|
@ -7,9 +7,10 @@ from __future__ import print_function
|
|||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial
|
||||
|
@ -44,24 +45,25 @@ parser.add_argument("-f", "--config-file", required=True, type=str,
|
|||
help="Read experiment options from this JSON/YAML file.")
|
||||
|
||||
|
||||
SCHEDULERS = {
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
"HyperBand": HyperBandScheduler,
|
||||
}
|
||||
|
||||
|
||||
def make_scheduler(args):
|
||||
if args.scheduler in SCHEDULERS:
|
||||
return SCHEDULERS[args.scheduler](**args.scheduler_config)
|
||||
def _make_scheduler(args):
|
||||
if args.scheduler in _SCHEDULERS:
|
||||
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
|
||||
else:
|
||||
assert False, "Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, SCHEDULERS.keys())
|
||||
raise TuneError(
|
||||
"Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, **ray_args):
|
||||
if scheduler is None:
|
||||
scheduler = make_scheduler(args)
|
||||
scheduler = FIFOScheduler()
|
||||
runner = TrialRunner(scheduler)
|
||||
|
||||
for name, spec in experiments.items():
|
||||
|
@ -77,16 +79,16 @@ def run_experiments(experiments, scheduler=None, **ray_args):
|
|||
|
||||
for trial in runner.get_trials():
|
||||
if trial.status != Trial.TERMINATED:
|
||||
print("Exit 1")
|
||||
sys.exit(1)
|
||||
raise TuneError("Trial did not complete", trial)
|
||||
|
||||
print("Exit 0")
|
||||
return runner.get_trials()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import yaml
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
with open(args.config_file) as f:
|
||||
experiments = yaml.load(f)
|
||||
run_experiments(
|
||||
experiments, make_scheduler(args), redis_address=args.redis_address,
|
||||
experiments, _make_scheduler(args), redis_address=args.redis_address,
|
||||
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
import random
|
||||
import types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.config_parser import make_parser, json_to_resources
|
||||
|
||||
|
@ -20,6 +21,9 @@ def generate_trials(unresolved_spec, output_path=''):
|
|||
output_path (str): Path where to store experiment outputs.
|
||||
"""
|
||||
|
||||
if "run" not in unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
|
||||
def to_argv(config):
|
||||
argv = []
|
||||
for k, v in config.items():
|
||||
|
@ -34,15 +38,23 @@ def generate_trials(unresolved_spec, output_path=''):
|
|||
i = 0
|
||||
for _ in range(unresolved_spec.get("repeat", 1)):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
try:
|
||||
# Special case the `env` param for RLlib by automatically
|
||||
# moving it into the `config` section.
|
||||
if "env" in spec:
|
||||
spec["config"] = spec.get("config", {})
|
||||
spec["config"]["env"] = spec["env"]
|
||||
del spec["env"]
|
||||
args = parser.parse_args(to_argv(spec))
|
||||
except SystemExit:
|
||||
raise TuneError("Error parsing args, see above message", spec)
|
||||
if resolved_vars:
|
||||
experiment_tag = "{}_{}".format(i, resolved_vars)
|
||||
else:
|
||||
experiment_tag = str(i)
|
||||
i += 1
|
||||
yield Trial(
|
||||
env_creator=spec.get("env", lambda: None),
|
||||
alg=spec.get("alg", "script"),
|
||||
trainable_name=spec["run"],
|
||||
config=spec.get("config", {}),
|
||||
local_dir=os.path.join(args.local_dir, output_path),
|
||||
experiment_tag=experiment_tag,
|
||||
|
@ -105,8 +117,8 @@ _MAX_RESOLUTION_PASSES = 20
|
|||
def _format_vars(resolved_vars):
|
||||
out = []
|
||||
for path, value in sorted(resolved_vars.items()):
|
||||
if path[0] in ["alg", "env", "resources"]:
|
||||
continue # these settings aren't usually search parameters
|
||||
if path[0] in ["run", "env", "resources"]:
|
||||
continue # TrialRunner already has these in the experiment_tag
|
||||
pieces = []
|
||||
last_string = True
|
||||
for k in path[::-1]:
|
||||
|
@ -229,9 +241,10 @@ def _try_resolve(v):
|
|||
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
|
||||
# Grid search values
|
||||
grid_values = v["grid_search"]
|
||||
assert isinstance(grid_values, list), \
|
||||
"Grid search expected list of values, got: {}".format(
|
||||
grid_values)
|
||||
if not isinstance(grid_values, list):
|
||||
raise TuneError(
|
||||
"Grid search expected list of values, got: {}".format(
|
||||
grid_values))
|
||||
return False, grid_values
|
||||
return True, v
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ def _parse_results(res_path):
|
|||
pass
|
||||
res_dict = _flatten_dict(json.loads(line.strip()))
|
||||
except Exception as e:
|
||||
print("Importing %s failed...Perhaps empty?" % res_path)
|
||||
print("Importing %s failed...Perhaps empty?" % res_path, e)
|
||||
return res_dict
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ def _resolve(directory, result_fname):
|
|||
def load_results_to_df(directory, result_name="result.json"):
|
||||
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
|
||||
for f in files if f == result_name]
|
||||
data = [_resolve(directory, result_name) for directory in exp_directories]
|
||||
data = [_resolve(d, result_name) for d in exp_directories]
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
|
|
|
@ -59,83 +59,83 @@ python $ROOT_DIR/multi_node_docker_test.py \
|
|||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v0 \
|
||||
--alg A3C \
|
||||
--run A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 16}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v1 \
|
||||
--alg PPO \
|
||||
--run PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"free_log_std": true}}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v1 \
|
||||
--alg PPO \
|
||||
--run PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "use_gae": false}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env Pendulum-v0 \
|
||||
--alg ES \
|
||||
--run ES \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env Pong-v0 \
|
||||
--alg ES \
|
||||
--run ES \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"stepsize": 0.01, "episodes_per_batch": 20, "timesteps_per_batch": 100}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--alg A3C \
|
||||
--run A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"use_lstm": false}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--alg DQN \
|
||||
--run DQN \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env FrozenLake-v0 \
|
||||
--alg DQN \
|
||||
--run DQN \
|
||||
--stop '{"training_iteration": 2}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env FrozenLake-v0 \
|
||||
--alg PPO \
|
||||
--run PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_sgd_iter": 10, "sgd_batchsize": 64, "timesteps_per_batch": 1000, "num_workers": 1}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v4 \
|
||||
--alg DQN \
|
||||
--run DQN \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env MontezumaRevenge-v0 \
|
||||
--alg PPO \
|
||||
--run PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "sgd_stepsize": 1e-4, "sgd_batchsize": 64, "timesteps_per_batch": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}, "extra_frameskip": 4}'
|
||||
|
||||
docker run --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env PongDeterministic-v4 \
|
||||
--alg A3C \
|
||||
--run A3C \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "use_lstm": false, "use_pytorch": true, "model": {"grayscale": true, "zero_mean": false, "dim": 80, "channel_major": true}}'
|
||||
|
||||
|
|
|
@ -2,38 +2,201 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.variant_generator import generate_trials, grid_search, \
|
||||
RecursiveDependencyError
|
||||
|
||||
|
||||
class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testRegisterEnv(self):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
||||
def testRegisterTrainable(self):
|
||||
def train(config, reporter):
|
||||
pass
|
||||
|
||||
class A(object):
|
||||
pass
|
||||
|
||||
class B(Trainable):
|
||||
pass
|
||||
|
||||
register_trainable("foo", train)
|
||||
register_trainable("foo", B)
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
|
||||
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
|
||||
|
||||
def testRewriteEnv(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"env": "CartPole-v0",
|
||||
}})
|
||||
self.assertEqual(trial.config["env"], "CartPole-v0")
|
||||
|
||||
def testConfigPurity(self):
|
||||
def train(config, reporter):
|
||||
assert config == {"a": "b"}, config
|
||||
reporter(timesteps_total=1)
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
|
||||
def testBadParams(self):
|
||||
def f():
|
||||
run_experiments({"foo": {}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams2(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"bah": "this param is not allowed",
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams3(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": grid_search("invalid grid search"),
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams4(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "asdf",
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams5(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"stop": {"asdf": 1}
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams6(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"resources": {"asdf": 1}
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter()
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
time.sleep(99999)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testAbruptReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class VariantGeneratorTest(unittest.TestCase):
|
||||
def testParseToTrials(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"alg": "PPO",
|
||||
"run": "PPO",
|
||||
"repeat": 2,
|
||||
"config": {
|
||||
"env": "Pong-v0",
|
||||
"foo": "bar"
|
||||
},
|
||||
}, "tune-pong")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].env_name, "Pong-v0")
|
||||
self.assertEqual(trials[0].config, {"foo": "bar"})
|
||||
self.assertEqual(trials[0].alg, "PPO")
|
||||
self.assertEqual(str(trials[0]), "PPO_Pong-v0_0")
|
||||
self.assertEqual(trials[0].config, {"foo": "bar", "env": "Pong-v0"})
|
||||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong")
|
||||
self.assertEqual(trials[1].experiment_tag, "1")
|
||||
|
||||
def testEval(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": {
|
||||
"eval": "2 + 2"
|
||||
|
@ -48,7 +211,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testGridSearch(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"bar": {
|
||||
"grid_search": [True, False]
|
||||
|
@ -71,7 +234,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testGridSearchAndEval(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"qux": lambda spec: 2 + 2,
|
||||
"bar": grid_search([True, False]),
|
||||
|
@ -85,7 +248,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testConditionResolution(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": 1,
|
||||
"y": lambda spec: spec.config.x + 1,
|
||||
|
@ -98,7 +261,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testDependentLambda(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([1, 2]),
|
||||
"y": lambda spec: spec.config.x * 100,
|
||||
|
@ -111,7 +274,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
|
||||
def testDependentGridSearch(self):
|
||||
trials = generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([
|
||||
lambda spec: spec.config.y * 100,
|
||||
|
@ -128,7 +291,7 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(generate_trials({
|
||||
"env": "Pong-v0",
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
},
|
||||
|
@ -142,10 +305,11 @@ class VariantGeneratorTest(unittest.TestCase):
|
|||
class TrialRunnerTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testTrialStatus(self):
|
||||
ray.init()
|
||||
trial = Trial("CartPole-v0", "__fake")
|
||||
trial = Trial("__fake")
|
||||
self.assertEqual(trial.status, Trial.PENDING)
|
||||
trial.start()
|
||||
self.assertEqual(trial.status, Trial.RUNNING)
|
||||
|
@ -156,11 +320,12 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
trial = Trial("CartPole-v0", "asdf")
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf")
|
||||
try:
|
||||
trial.start()
|
||||
except Exception as e:
|
||||
self.assertIn("Unknown algorithm", str(e))
|
||||
self.assertIn("a class", str(e))
|
||||
|
||||
def testResourceScheduler(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
|
@ -170,8 +335,8 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("CartPole-v0", "__fake", **kwargs),
|
||||
Trial("CartPole-v0", "__fake", **kwargs)]
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -199,8 +364,8 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("CartPole-v0", "__fake", **kwargs),
|
||||
Trial("CartPole-v0", "__fake", **kwargs)]
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -227,9 +392,10 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
"stopping_criterion": {"training_iteration": 1},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trials = [
|
||||
Trial("CartPole-v0", "asdf", **kwargs),
|
||||
Trial("CartPole-v0", "__fake", **kwargs)]
|
||||
Trial("asdf", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
|
@ -248,17 +414,17 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
"stopping_criterion": {"training_iteration": 1},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
path = trials[0].checkpoint()
|
||||
kwargs["restore_path"] = path
|
||||
|
||||
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
|
@ -268,7 +434,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[1].agent.get_info.remote()), 1)
|
||||
self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1)
|
||||
self.addCleanup(os.remove, path)
|
||||
|
||||
def testPauseThenResume(self):
|
||||
|
@ -278,14 +444,14 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
"stopping_criterion": {"training_iteration": 2},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("CartPole-v0", "__fake", **kwargs))
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), None)
|
||||
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), None)
|
||||
|
||||
self.assertEqual(ray.get(trials[0].agent.set_info.remote(1)), 1)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
|
||||
trials[0].pause()
|
||||
self.assertEqual(trials[0].status, Trial.PAUSED)
|
||||
|
@ -295,7 +461,7 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].agent.get_info.remote()), 1)
|
||||
self.assertEqual(ray.get(trials[0].runner.get_info.remote()), 1)
|
||||
|
||||
runner.step()
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
|
|
@ -20,8 +20,8 @@ def result(t, rew):
|
|||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
def basicSetup(self, rule):
|
||||
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result(i, i * 100)),
|
||||
|
@ -62,7 +62,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
@ -77,7 +77,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
@ -91,7 +91,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
@ -105,7 +105,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
t1, t2 = self.basicSetup(rule)
|
||||
rule.on_trial_complete(None, t1, result(10, 1000))
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("t3", "PPO")
|
||||
t3 = Trial("PPO")
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
@ -120,8 +120,8 @@ class EarlyStoppingSuite(unittest.TestCase):
|
|||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1,
|
||||
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||
t1 = Trial("t1", "PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("t2", "PPO") # mean is 450, max 450, t_max=5
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t1, result2(i, i * 100)),
|
||||
|
@ -166,7 +166,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
(81, 1) -> (27, 3) -> (9, 9) -> (3, 27) -> (1, 81);"""
|
||||
sched = HyperBandScheduler()
|
||||
for i in range(num_trials):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
return sched, runner
|
||||
|
@ -211,7 +211,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
def advancedSetup(self):
|
||||
sched = self.basicSetup()
|
||||
for i in range(4):
|
||||
t = Trial("t%d" % (i + 20), "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
|
||||
self.assertEqual(sched._cur_band_filled(), False)
|
||||
|
@ -232,7 +232,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
sched = HyperBandScheduler()
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
|
@ -244,7 +244,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
sched = HyperBandScheduler(max_t=810)
|
||||
i = 0
|
||||
while not sched._cur_band_filled():
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
|
@ -257,7 +257,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
sched = HyperBandScheduler(max_t=1)
|
||||
i = 0
|
||||
while len(sched._hyperbands) < 2:
|
||||
t = Trial("t%d" % (i), "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
i += 1
|
||||
self.assertEqual(len(sched._hyperbands[0]), 5)
|
||||
|
@ -415,7 +415,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
status = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
self.assertEqual(status, TrialScheduler.CONTINUE)
|
||||
t = Trial("t%d" % 100, "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
mock_runner._launch_trial(t)
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
|
||||
|
@ -440,7 +440,7 @@ class HyperbandSuite(unittest.TestCase):
|
|||
stats = self.default_statistics()
|
||||
|
||||
for i in range(stats["max_trials"]):
|
||||
t = Trial("t%d" % i, "__fake")
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
runner = _MockTrialRunner()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue