[tune] [rllib] Automatically determine RLlib resources and add queueing mechanism for autoscaling (#1848)

This commit is contained in:
Eric Liang 2018-04-16 16:58:15 -07:00 committed by Richard Liaw
parent 2e25972d4d
commit 7ab890f4a1
39 changed files with 286 additions and 122 deletions

View file

@ -287,7 +287,8 @@ Here is an example of using the command-line interface with RLlib:
python ray/python/ray/rllib/train.py -f tuned_examples/cartpole-grid-search-example.yaml
Here is an example using the Python API. The same config passed to ``Agents`` may be placed
in the ``config`` section of the experiments.
in the ``config`` section of the experiments. RLlib agents automatically declare their
resources requirements (e.g., based on ``num_workers``) to Tune, so you don't have to.
.. code-block:: python
@ -300,10 +301,6 @@ in the ``config`` section of the experiments.
'cartpole-ppo': {
'run': 'PPO',
'env': 'CartPole-v0',
'trial_resources': {
'cpu': 1,
'extra_cpu': 2, # for workers
},
'stop': {
'episode_reward_mean': 200,
'time_total_s': 180

View file

@ -13,6 +13,7 @@ from ray.rllib.utils import FilterManager
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator, \
GPURemoteA3CEvaluator
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
DEFAULT_CONFIG = {
@ -68,6 +69,14 @@ class A3CAgent(Agent):
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["model", "optimizer", "env_config"]
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1, gpu=0,
extra_cpu=cf["num_workers"],
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)
def _init(self):
self.local_evaluator = A3CEvaluator(
self.registry, self.env_creator, self.config, self.logdir,

View file

@ -4,6 +4,7 @@ from __future__ import print_function
import logging
import numpy as np
import json
import os
import pickle
@ -62,6 +63,14 @@ class Agent(Trainable):
_allow_unknown_configs = False
_allow_unknown_subkeys = []
@classmethod
def resource_help(cls, config):
return (
"\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers` and other configs. See the "
"DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: " + json.dumps(config))
def __init__(
self, config=None, env=None, registry=None,
logger_creator=None):

View file

@ -8,16 +8,19 @@ from ray.rllib.bc.bc_evaluator import BCEvaluator, GPURemoteBCEvaluator, \
RemoteBCEvaluator
from ray.rllib.optimizers import AsyncOptimizer
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
DEFAULT_CONFIG = {
# Number of workers (excluding master)
"num_workers": 4,
"num_workers": 1,
# Size of rollout batch
"batch_size": 100,
# Max global norm for each gradient calculated by worker
"grad_clip": 40.0,
# Learning rate
"lr": 0.0001,
# Whether to use a GPU for local optimization.
"gpu": False,
# Whether to place workers on GPUs
"use_gpu_for_workers": False,
# Model and preprocessor options
@ -46,6 +49,18 @@ class BCAgent(Agent):
_default_config = DEFAULT_CONFIG
_allow_unknown_configs = True
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
if cf["use_gpu_for_workers"]:
num_gpus_per_worker = 1
else:
num_gpus_per_worker = 0
return Resources(
cpu=1, gpu=cf["gpu"] and 1 or 0,
extra_cpu=cf["num_workers"],
extra_gpu=num_gpus_per_worker * cf["num_workers"])
def _init(self):
self.local_evaluator = BCEvaluator(
self.registry, self.env_creator, self.config, self.logdir)

View file

@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
from ray.tune.trial import Resources
APEX_DEFAULT_CONFIG = dict(DQN_CONFIG, **dict(
optimizer_class="ApexOptimizer",
@ -12,6 +13,7 @@ APEX_DEFAULT_CONFIG = dict(DQN_CONFIG, **dict(
debug=False,
)),
n_step=3,
gpu=True,
num_workers=32,
buffer_size=2000000,
learning_starts=50000,
@ -35,6 +37,15 @@ class ApexAgent(DQNAgent):
_agent_name = "APEX"
_default_config = APEX_DEFAULT_CONFIG
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1 + cf["optimizer_config"]["num_replay_buffer_shards"],
gpu=cf["gpu"] and 1 or 0,
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \

View file

@ -13,6 +13,7 @@ from ray.rllib import optimizers
from ray.rllib.dqn.dqn_evaluator import DQNEvaluator
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
OPTIMIZER_SHARED_CONFIGS = [
@ -100,14 +101,16 @@ DEFAULT_CONFIG = dict(
},
# === Parallelism ===
# Whether to use a GPU for local optimization.
gpu=False,
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
num_workers=0,
# Whether to allocate GPUs for workers (if > 0).
num_gpus_per_worker=0,
# Whether to reserve CPUs for workers (if not None).
num_cpus_per_worker=None,
# Whether to allocate CPUs for workers (if > 0).
num_cpus_per_worker=1,
# Optimizer class to use.
optimizer_class="LocalSyncReplayOptimizer",
# Config to pass to the optimizer.
@ -124,6 +127,14 @@ class DQNAgent(Agent):
"model", "optimizer", "tf_session_args", "env_config"]
_default_config = DEFAULT_CONFIG
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1, gpu=cf["gpu"] and 1 or 0,
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
def _init(self):
self.local_evaluator = DQNEvaluator(
self.registry, self.env_creator, self.config, self.logdir, 0)

View file

@ -13,6 +13,7 @@ import time
import ray
from ray.rllib import agent
from ray.tune.trial import Resources
from ray.rllib.es import optimizers
from ray.rllib.es import policies
@ -138,6 +139,11 @@ class ESAgent(agent.Agent):
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["env_config"]
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):
policy_params = {
"action_noise_std": 0.01

View file

@ -18,7 +18,7 @@ import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.optimizers.sample_batch import SampleBatch
from ray.rllib.utils.actors import TaskPool
from ray.rllib.utils.actors import TaskPool, create_colocated
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.window_stat import WindowStat
@ -163,15 +163,12 @@ class ApexOptimizer(PolicyOptimizer):
self.learner = LearnerThread(self.local_evaluator)
self.learner.start()
# TODO(ekl) use create_colocated() for these actors once
# https://github.com/ray-project/ray/issues/1734 is fixed
self.replay_actors = [
ReplayActor.remote(
num_replay_buffer_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps, clip_rewards)
for _ in range(num_replay_buffer_shards)
]
self.replay_actors = create_colocated(
ReplayActor,
[num_replay_buffer_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps, clip_rewards],
num_replay_buffer_shards)
assert len(self.remote_evaluators) > 0
# Stats

View file

@ -9,6 +9,8 @@ from ray.rllib.optimizers import LocalSyncOptimizer
from ray.rllib.pg.pg_evaluator import PGEvaluator
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
DEFAULT_CONFIG = {
# Number of workers (excluding master)
@ -41,6 +43,11 @@ class PGAgent(Agent):
_agent_name = "PG"
_default_config = DEFAULT_CONFIG
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):
self.optimizer = LocalSyncOptimizer.make(
evaluator_cls=PGEvaluator,

View file

@ -12,6 +12,7 @@ from tensorflow.python import debug as tf_debug
import ray
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
from ray.rllib.agent import Agent
from ray.rllib.utils import FilterManager
from ray.rllib.ppo.ppo_evaluator import PPOEvaluator
@ -69,8 +70,10 @@ DEFAULT_CONFIG = {
"min_steps_per_task": 200,
# Number of actors used to collect the rollouts
"num_workers": 5,
# Resource requirements for remote actors
"worker_resources": {"num_cpus": None},
# Whether to allocate GPUs for workers (if > 0).
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
"num_cpus_per_worker": 1,
# Dump TensorFlow timeline after this many SGD minibatches
"full_trace_nth_sgd_batch": -1,
# Whether to profile data loading
@ -89,17 +92,26 @@ DEFAULT_CONFIG = {
class PPOAgent(Agent):
_agent_name = "PPO"
_allow_unknown_subkeys = ["model", "tf_session_args", "env_config",
"worker_resources"]
_allow_unknown_subkeys = ["model", "tf_session_args", "env_config"]
_default_config = DEFAULT_CONFIG
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
return Resources(
cpu=1,
gpu=len([d for d in cf["devices"] if "gpu" in d.lower()]),
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
def _init(self):
self.global_step = 0
self.kl_coeff = self.config["kl_coeff"]
self.local_evaluator = PPOEvaluator(
self.registry, self.env_creator, self.config, self.logdir, False)
RemotePPOEvaluator = ray.remote(
**self.config["worker_resources"])(PPOEvaluator)
num_cpus=self.config["num_cpus_per_worker"],
num_gpus=self.config["num_gpus_per_worker"])(PPOEvaluator)
self.remote_evaluators = [
RemotePPOEvaluator.remote(
self.registry, self.env_creator, self.config, self.logdir,

View file

@ -34,16 +34,22 @@ parser.add_argument(
"--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
"--ray-num-cpus", default=None, type=int,
help="--num-cpus to pass to Ray. This only has an affect in local mode.")
parser.add_argument(
"--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
"--ray-num-gpus", default=None, type=int,
help="--num-gpus to pass to Ray. This only has an affect in local mode.")
parser.add_argument(
"--experiment-name", default="default", type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
"--queue-trials", action='store_true',
help=(
"Whether to queue trials when the cluster does not currently have "
"enough resources to launch one. This should be set to True when "
"running on an autoscaling cluster to enable automatic scale-up."))
parser.add_argument(
"-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file. Note that this "
@ -62,7 +68,9 @@ if __name__ == "__main__":
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"local_dir": args.local_dir,
"trial_resources": resources_to_json(args.trial_resources),
"trial_resources": (
args.trial_resources and
resources_to_json(args.trial_resources)),
"stop": args.stop,
"config": dict(args.config, env=args.env),
"restore": args.restore,
@ -79,5 +87,7 @@ if __name__ == "__main__":
ray.init(
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
run_experiments(experiments, scheduler=_make_scheduler(args))
num_cpus=args.ray_num_cpus, num_gpus=args.ray_num_gpus)
run_experiments(
experiments, scheduler=_make_scheduler(args),
queue_trials=args.queue_trials)

View file

@ -4,9 +4,6 @@ cartpole-ppo:
stop:
episode_reward_mean: 200
time_total_s: 180
trial_resources:
cpu: 1
extra_cpu: 1
config:
num_workers: 2
num_sgd_iter:

View file

@ -1,8 +1,4 @@
hopper-ppo:
env: Hopper-v1
run: PPO
trial_resources:
cpu: 1
gpu: 4
extra_cpu: 64
config: {"gamma": 0.995, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": .0001, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}

View file

@ -1,9 +1,6 @@
humanoid-es:
env: Humanoid-v1
run: ES
trial_resources:
cpu: 1
extra_cpu: 100
stop:
episode_reward_mean: 6000
config:

View file

@ -3,9 +3,5 @@ humanoid-ppo-gae:
run: PPO
stop:
episode_reward_mean: 6000
trial_resources:
cpu: 1
gpu: 4
extra_cpu: 64
config: {"lambda": 0.95, "clip_param": 0.2, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": .0001, "sgd_batchsize": 32768, "horizon": 5000, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "write_logs": false}

View file

@ -3,8 +3,4 @@ humanoid-ppo:
run: PPO
stop:
episode_reward_mean: 6000
trial_resources:
cpu: 1
gpu: 4
extra_cpu: 64
config: {"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": .0001, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64, "model": {"free_log_std": true}, "use_gae": false}

View file

@ -5,9 +5,6 @@ cartpole-ppo:
stop:
episode_reward_mean: 200
time_total_s: 180
trial_resources:
cpu: 1
extra_cpu: 1
config:
num_workers: 1
num_sgd_iter:

View file

@ -2,9 +2,6 @@
pendulum-ppo:
env: Pendulum-v0
run: PPO
trial_resources:
cpu: 1
extra_cpu: 4
config:
timesteps_per_batch: 2048
num_workers: 4

View file

@ -1,9 +1,6 @@
pong-a3c-pytorch-cnn:
env: PongDeterministic-v4
run: A3C
trial_resources:
cpu: 1
extra_cpu: 16
config:
num_workers: 16
batch_size: 20

View file

@ -1,9 +1,6 @@
pong-a3c:
env: PongDeterministic-v4
run: A3C
trial_resources:
cpu: 1
extra_cpu: 16
config:
num_workers: 16
batch_size: 20

View file

@ -4,11 +4,6 @@
pong-apex:
env: PongNoFrameskip-v4
run: APEX
trial_resources:
cpu: 1
gpu: 1
extra_cpu:
eval: 4 + spec.config.num_workers
config:
target_network_update_freq: 50000
num_workers: 32

View file

@ -8,10 +8,6 @@
pong-deterministic-ppo:
env: PongDeterministic-v4
run: PPO
trial_resources:
cpu: 1
gpu: 1
extra_cpu: 4
stop:
episode_reward_mean: 21
config:

View file

@ -4,8 +4,6 @@ cartpole-a3c:
stop:
episode_reward_mean: 200
time_total_s: 600
trial_resources:
cpu: 2
config:
num_workers: 4
gamma: 0.95

View file

@ -4,8 +4,6 @@ cartpole-dqn:
stop:
episode_reward_mean: 200
time_total_s: 600
trial_resources:
cpu: 1
config:
n_step: 3
gamma: 0.95

View file

@ -4,8 +4,6 @@ cartpole-es:
stop:
episode_reward_mean: 200
time_total_s: 300
trial_resources:
cpu: 2
config:
num_workers: 2
noise_size: 25000000

View file

@ -4,7 +4,5 @@ cartpole-ppo:
stop:
episode_reward_mean: 200
time_total_s: 300
trial_resources:
cpu: 1
config:
num_workers: 1

View file

@ -1,8 +1,4 @@
walker2d-v1-ppo:
env: Walker2d-v1
run: PPO
trial_resources:
cpu: 1
gpu: 4
extra_cpu: 64
config: {"kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": .0001, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 320000, "num_workers": 64}

View file

@ -11,6 +11,8 @@ from ray.tune.trial import Resources
def json_to_resources(data):
if data is None or data == "null":
return None
if type(data) is str:
data = json.loads(data)
for k in data:
@ -29,7 +31,7 @@ def json_to_resources(data):
def resources_to_json(resources):
if resources is None:
resources = Resources(cpu=1, gpu=0)
return None
return {
"cpu": resources.cpu,
"gpu": resources.gpu,
@ -70,19 +72,14 @@ def make_parser(**kwargs):
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--resources",
help="Deprecated, use --trial-resources.",
type=lambda v: _tune_error("The `resources` argument is no longer "
"supported. Use `trial_resources` or "
"--trial-resources instead."))
parser.add_argument(
"--trial-resources",
default='{"cpu": 1}',
default=None,
type=json_to_resources,
help="Machine resources to allocate per trial, e.g. "
help="Override the machine resources to allocate per trial, e.g. "
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
"unless you specify them here.")
"unless you specify them here. For RLlib, you probably want to "
"leave this alone and use RLlib configs to control parallelism.")
parser.add_argument(
"--repeat",
default=1,
@ -115,7 +112,7 @@ def make_parser(**kwargs):
"--scheduler",
default="FIFO",
type=str,
help="FIFO (default), MedianStopping, AsyncHyperBand,"
help="FIFO (default), MedianStopping, AsyncHyperBand, "
"HyperBand, or HyperOpt.")
parser.add_argument(
"--scheduler-config",

View file

@ -86,10 +86,6 @@ if __name__ == "__main__":
"training_iteration": 2 if args.smoke_test else 99999
},
"repeat": 10,
"trial_resources": {
"cpu": 1,
"gpu": 0
},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,

View file

@ -51,10 +51,6 @@ if __name__ == "__main__":
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"trial_resources": {
"cpu": 4,
"gpu": 1
},
"config": {
"kl_coeff":
1.0,

View file

@ -12,7 +12,7 @@ from ray.rllib import _register_all
from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial, Resources
@ -23,7 +23,7 @@ from ray.tune.variant_generator import generate_trials, grid_search, \
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init()
ray.init(num_cpus=4, num_gpus=0)
def tearDown(self):
ray.worker.cleanup()
@ -76,6 +76,46 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
def testBuiltInTrainableResources(self):
class B(Trainable):
@classmethod
def default_resource_request(cls, config):
return Resources(cpu=config["cpu"], gpu=config["gpu"])
def _train(self):
return TrainingResult(timesteps_this_iter=1, done=True)
register_trainable("B", B)
def f(cpus, gpus, queue_trials):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
queue_trials=queue_trials)[0]
# Should all succeed
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
# Infeasible even with queueing enabled (no gpus)
self.assertRaises(TuneError, lambda: f(1, 1, True))
# Too large resource request
self.assertRaises(TuneError, lambda: f(100, 100, False))
self.assertRaises(TuneError, lambda: f(0, 100, False))
self.assertRaises(TuneError, lambda: f(100, 0, False))
# TODO(ekl) how can we test this is queued (hangs)?
# f(100, 0, True)
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
@ -357,6 +397,13 @@ class RunExperimentTest(unittest.TestCase):
class VariantGeneratorTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def testParseToTrials(self):
trials = generate_trials({
"run": "PPO",
@ -531,7 +578,7 @@ class TrialRunnerTest(unittest.TestCase):
def testTrialErrorOnStart(self):
ray.init()
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf")
trial = Trial("asdf", resources=Resources(1, 0))
try:
trial.start()
except Exception as e:

View file

@ -6,6 +6,7 @@ import random
import unittest
import numpy as np
import ray
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.async_hyperband import AsyncHyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining, explore
@ -24,6 +25,13 @@ def result(t, rew):
class EarlyStoppingSuite(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def basicSetup(self, rule):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
@ -184,6 +192,13 @@ class _MockTrialRunner():
class HyperbandSuite(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def schedulerSetup(self, num_trials):
"""Setup a scheduler and Runner with max Iter = 9
@ -538,6 +553,13 @@ class _MockTrial(Trial):
class PopulationBasedTestingSuite(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def basicSetup(self, resample_prob=0.0, explore=None):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
@ -751,6 +773,13 @@ class PopulationBasedTestingSuite(unittest.TestCase):
class AsyncHyperBandSuite(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
def basicSetup(self, scheduler):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5

View file

@ -17,6 +17,7 @@ import ray
from ray.tune import TuneError
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.trial import Resources
class Trainable(object):
@ -90,6 +91,22 @@ class Trainable(object):
self._initialize_ok = True
self._local_ip = ray.services.get_node_ip_address()
@classmethod
def default_resource_request(cls, config):
"""Returns the resource requirement for the given configuration.
This can be overriden by sub-classes to set the correct trial resource
allocation, so the user does not need to.
"""
return Resources(cpu=1, gpu=0)
@classmethod
def resource_help(cls, config):
"""Returns a help string for configuring this trainable's resources."""
return ""
def train(self):
"""Runs one logical iteration of training.

View file

@ -82,7 +82,7 @@ class Trial(object):
config=None,
local_dir=DEFAULT_RESULTS_DIR,
experiment_tag="",
resources=Resources(cpu=1, gpu=0),
resources=None,
stopping_criterion=None,
checkpoint_freq=0,
restore_path=None,
@ -112,7 +112,9 @@ class Trial(object):
self.config = config or {}
self.local_dir = local_dir
self.experiment_tag = experiment_tag
self.resources = resources
self.resources = (
resources
or self._get_trainable_cls().default_resource_request(self.config))
self.stopping_criterion = stopping_criterion or {}
self.checkpoint_freq = checkpoint_freq
self.upload_dir = upload_dir
@ -350,11 +352,9 @@ class Trial(object):
def _setup_runner(self):
self.status = Trial.RUNNING
trainable_cls = ray.tune.registry.get_registry().get(
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
cls = ray.remote(
num_cpus=self.resources.cpu,
num_gpus=self.resources.gpu)(trainable_cls)
num_gpus=self.resources.gpu)(self._get_trainable_cls())
if not self.result_logger:
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
@ -380,6 +380,10 @@ class Trial(object):
registry=ray.tune.registry.get_registry(),
logger_creator=logger_creator)
def _get_trainable_cls(self):
return ray.tune.registry.get_registry().get(
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
def set_verbose(self, verbose):
self.verbose = verbose

View file

@ -42,7 +42,8 @@ class TrialRunner(object):
scheduler=None,
launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
verbose=True,
queue_trials=False):
"""Initializes a new TrialRunner.
Args:
@ -51,6 +52,10 @@ class TrialRunner(object):
server_port (int): Port number for launching TuneServer
verbose (bool): Flag for verbosity. If False, trial results
will not be output.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
"""
self._scheduler_alg = scheduler or FIFOScheduler()
@ -70,6 +75,7 @@ class TrialRunner(object):
self._server = TuneServer(self, server_port)
self._stop_queue = []
self._verbose = verbose
self._queue_trials = queue_trials
def is_finished(self):
"""Returns whether all trials have finished running."""
@ -102,9 +108,14 @@ class TrialRunner(object):
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster only has {} "
"available.").format(
"available. Pass `queue_trials=True` in "
"ray.tune.run_experiments() or on the command "
"line to queue trials until the cluster scales "
"up. {}").format(
trial.resources.summary_string(),
self._avail_resources.summary_string()))
self._avail_resources.summary_string(),
trial._get_trainable_cls().resource_help(
trial.config)))
elif trial.status == Trial.PAUSED:
raise TuneError(
"There are paused trials, but no more pending "
@ -177,9 +188,10 @@ class TrialRunner(object):
messages = ["== Status =="]
messages.append(self._scheduler_alg.debug_string())
if self._resources_initialized:
messages.append("Resources used: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu, self._avail_resources.cpu,
self._committed_resources.gpu, self._avail_resources.gpu))
messages.append(
"Resources requested: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu, self._avail_resources.cpu,
self._committed_resources.gpu, self._avail_resources.gpu))
return messages
def has_resources(self, resources):
@ -187,8 +199,29 @@ class TrialRunner(object):
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
return (resources.cpu_total() <= cpu_avail
and resources.gpu_total() <= gpu_avail)
have_space = (resources.cpu_total() <= cpu_avail
and resources.gpu_total() <= gpu_avail)
if have_space:
return True
can_overcommit = self._queue_trials
if ((resources.cpu_total() > 0 and cpu_avail <= 0)
or (resources.gpu_total() > 0 and gpu_avail <= 0)):
can_overcommit = False # requested resource is already saturated
if can_overcommit:
print("WARNING:tune:allowing trial to start even though the "
"cluster does not have enough free resources. Trial actors "
"may appear to hang until enough resources are added to the "
"cluster (e.g., via autoscaling). You can disable this "
"behavior by specifying `queue_trials=False` in "
"ray.tune.run_experiments().")
return True
return False
def _get_next_trial(self):
self._update_avail_resources()

View file

@ -37,7 +37,8 @@ def run_experiments(experiments,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
verbose=True,
queue_trials=False):
"""Tunes experiments.
Args:
@ -49,6 +50,10 @@ def run_experiments(experiments,
using the Client API.
server_port (int): Port number for launching TuneServer.
verbose (bool): How much output should be printed for each trial.
queue_trials (bool): Whether to queue trials when the cluster does
not currently have enough resources to launch one. This should
be set to True when running on an autoscaling cluster to enable
automatic scale-up.
"""
if scheduler is None:
@ -58,7 +63,8 @@ def run_experiments(experiments,
scheduler,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose)
verbose=verbose,
queue_trials=queue_trials)
exp_list = experiments
if isinstance(experiments, Experiment):
exp_list = [experiments]

View file

@ -53,12 +53,16 @@ def generate_trials(unresolved_spec, output_path=''):
else:
experiment_tag = str(i)
i += 1
if "trial_resources" in spec:
resources = json_to_resources(spec["trial_resources"])
else:
resources = None
yield Trial(
trainable_name=spec["run"],
config=spec.get("config", {}),
local_dir=os.path.join(args.local_dir, output_path),
experiment_tag=experiment_tag,
resources=json_to_resources(spec.get("trial_resources", {})),
resources=resources,
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
restore_path=spec.get("restore"),

View file

@ -876,10 +876,12 @@ void spillback_tasks_handler(LocalSchedulerState *state) {
<< TaskSpec_actor_creation_id(spec) << " is taking a "
<< "while to be created. It is possible that the "
<< "cluster does not have enough resources to place this "
<< "actor. Try reducing the number of actors created or "
<< "actor (this may be normal while an autoscaling "
<< "is scaling up). Consider reducing the number of "
<< "actors created, or "
<< "increasing the number of slots available by using "
<< "the --num-cpus, --num-gpus, and --resources flags. "
<< " The actor creation task is requesting ";
<< "The actor creation task is requesting ";
for (auto const &resource_pair :
TaskSpec_get_required_resources(spec)) {
error_message << resource_pair.second << " " << resource_pair.first

View file

@ -117,7 +117,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--env CartPole-v0 \
--run APEX \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "timesteps_per_iteration": 1000}'
--config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \