[tune] Fix Categorical Space + Add Keras Example (#2401)

Previously did not properly resolve categorical variables for HyperOpt.
This commit is contained in:
Richard Liaw 2018-07-17 23:52:52 +02:00 committed by GitHub
parent e3badb9b09
commit 8e8c733696
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 230 additions and 10 deletions

View file

@ -72,7 +72,7 @@ This PyTorch script runs a small grid search over the ``train_func`` function us
In order to report incremental progress, ``train_func`` periodically calls the ``reporter`` function passed in by Ray Tune to return the current timestep and other metrics as defined in `ray.tune.result.TrainingResult <https://github.com/ray-project/ray/blob/master/python/ray/tune/result.py>`__. Incremental results will be synced to local disk on the head node of the cluster.
``tune.run_experiments`` returns a list of Trial objects which you can inspect results of via ``trial.last_result``.
`tune.run_experiments <tune.html#ray.tune.run_experiments>`__ returns a list of Trial objects which you can inspect results of via ``trial.last_result``.
Learn more `about specifying experiments <tune-config.html>`__.
@ -160,6 +160,11 @@ In order to use this scheduler, you will need to install HyperOpt via the follow
An example of this can be found in `hyperopt_example.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/hyperopt_example.py>`__.
.. note::
The HyperOptScheduler takes an *increasing* metric in the reward attribute. If trying to
minimize a loss, be sure to specify *mean_loss* in the function/class reporting and *reward_attr=neg_mean_loss* in the HyperOptScheduler initializer.
.. autoclass:: ray.tune.hpo_scheduler.HyperOptScheduler
@ -274,4 +279,4 @@ For an example notebook for using the Client API, see the `Client API Example <h
Examples
--------
You can find a list of examples `using Ray Tune and its various features here <https://github.com/ray-project/ray/tree/master/python/ray/tune/examples>`__.
You can find a list of examples `using Ray Tune and its various features here <https://github.com/ray-project/ray/tree/master/python/ray/tune/examples>`__, including examples using Keras, TensorFlow, and Population-Based Training.

View file

@ -5,6 +5,6 @@ FROM ray-project/deploy
# This updates numpy to 1.14 and mutes errors from other libraries
RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
RUN conda install pytorch-cpu torchvision-cpu -c pytorch

View file

@ -1,3 +1,4 @@
# Ray Tune Examples
Ray Tune Examples
=================
Code examples for various schedulers and Ray Tune features.

View file

@ -10,10 +10,10 @@ from ray.tune.hpo_scheduler import HyperOptScheduler
def easy_objective(config, reporter):
import time
time.sleep(0.2)
assert type(config["activation"]) == str
reporter(
timesteps_total=1,
episode_reward_mean=-(
(config["height"] - 14)**2 + abs(config["width"] - 3)))
mean_loss=((config["height"] - 14)**2 + abs(config["width"] - 3)))
time.sleep(0.2)
@ -32,6 +32,7 @@ if __name__ == '__main__':
space = {
'width': hp.uniform('width', 0, 20),
'height': hp.uniform('height', -100, 100),
'activation': hp.choice("activation", ["relu", "tanh"])
}
config = {
@ -46,6 +47,6 @@ if __name__ == '__main__':
}
}
}
hpo_sched = HyperOptScheduler()
hpo_sched = HyperOptScheduler(reward_attr="neg_mean_loss")
run_experiments(config, verbose=False, scheduler=hpo_sched)

View file

@ -0,0 +1,197 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import argparse
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import ray
from ray import tune
from ray.tune.async_hyperband import AsyncHyperBandScheduler
class TuneCallback(keras.callbacks.Callback):
def __init__(self, reporter, logs={}):
self.reporter = reporter
self.iteration = 0
def on_train_end(self, epoch, logs={}):
self.reporter(
timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"])
def on_batch_end(self, batch, logs={}):
self.iteration += 1
self.reporter(
timesteps_total=self.iteration, mean_accuracy=logs["acc"])
def train_mnist(args, cfg, reporter):
# We set threads here to avoid contention, as Keras
# is heavily parallelized across multiple cores.
K.set_session(
K.tf.Session(
config=K.tf.ConfigProto(
intra_op_parallelism_threads=args.threads,
inter_op_parallelism_threads=args.threads)))
vars(args).update(cfg)
batch_size = 128
num_classes = 10
epochs = 12
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = Sequential()
model.add(
Conv2D(
32,
kernel_size=(args.kernel1, args.kernel1),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (args.kernel2, args.kernel2), activation='relu'))
model.add(MaxPooling2D(pool_size=(args.poolsize, args.poolsize)))
model.add(Dropout(args.dropout1))
model.add(Flatten())
model.add(Dense(args.hidden, activation='relu'))
model.add(Dropout(args.dropout2))
model.add(Dense(num_classes, activation='softmax'))
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum),
metrics=['accuracy'])
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneCallback(reporter)])
def create_parser():
parser = argparse.ArgumentParser(description='Keras MNIST Example')
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
'--jobs',
type=int,
default=1,
help='number of jobs to run concurrently (default: 1)')
parser.add_argument(
'--threads',
type=int,
default=None,
help='threads used in operations (default: all)')
parser.add_argument(
'--steps',
type=float,
default=0.01,
metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument(
'--lr',
type=float,
default=0.01,
metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument(
'--momentum',
type=float,
default=0.5,
metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument(
'--kernel1',
type=int,
default=3,
help='Size of first kernel (default: 3)')
parser.add_argument(
'--kernel2',
type=int,
default=3,
help='Size of second kernel (default: 3)')
parser.add_argument(
'--poolsize', type=int, default=2, help='Size of Pooling (default: 2)')
parser.add_argument(
'--dropout1',
type=float,
default=0.25,
help='Size of first kernel (default: 0.25)')
parser.add_argument(
'--hidden',
type=int,
default=128,
help='Size of Hidden Layer (default: 128)')
parser.add_argument(
'--dropout2',
type=float,
default=0.5,
help='Size of first kernel (default: 0.5)')
return parser
if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
mnist.load_data() # we do this because it's not threadsafe
ray.init()
sched = AsyncHyperBandScheduler(
time_attr="timesteps_total",
reward_attr="mean_accuracy",
max_t=400,
grace_period=20)
tune.register_trainable("train_mnist",
lambda cfg, rprtr: train_mnist(args, cfg, rprtr))
tune.run_experiments(
{
"exp": {
"stop": {
"mean_accuracy": 0.99,
"timesteps_total": 10 if args.smoke_test else 300
},
"run": "train_mnist",
"repeat": 1 if args.smoke_test else 10,
"config": {
"lr": lambda spec: np.random.uniform(0.001, 0.1),
"momentum": lambda spec: np.random.uniform(0.1, 0.9),
"hidden": lambda spec: np.random.randint(32, 512),
"dropout1": lambda spec: np.random.uniform(0.2, 0.8),
}
}
},
verbose=0,
scheduler=sched)

View file

@ -32,8 +32,8 @@ class HyperOptScheduler(FIFOScheduler):
are available.
reward_attr (str): The TrainingResult objective value attribute.
This refers to an increasing value, which is internally negated
when interacting with HyperOpt. Suggestion procedures
will use this attribute.
when interacting with HyperOpt so that HyperOpt can "maximize"
this value.
Examples:
>>> space = {'param': hp.uniform('param', 0, 20)}
@ -108,7 +108,19 @@ class HyperOptScheduler(FIFOScheduler):
self._hpopt_trials.refresh()
new_trial = new_trials[0]
new_trial_id = new_trial["tid"]
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
# Taken from HyperOpt.base.evaluate
config = hpo.base.spec_from_misc(new_trial["misc"])
ctrl = hpo.base.Ctrl(self._hpopt_trials, current_trial=new_trial)
memo = self.domain.memo_from_config(config)
hpo.utils.use_obj_for_literal_in_memo(self.domain.expr, ctrl,
hpo.base.Ctrl, memo)
suggested_config = hpo.pyll.rec_eval(
self.domain.expr,
memo=memo,
print_node_on_error=self.domain.rec_eval_print_node_on_error)
new_cfg.update(suggested_config)
kv_str = "_".join([

View file

@ -240,6 +240,10 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/hyperopt_example.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/tune/examples/tune_mnist_keras.py \
--smoke-test
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py