mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] tf2.0 mnist example (#5898)
* tfmnistexample * tfmnist * add_to_ci * format * exampledownlaod * fix
This commit is contained in:
parent
6843a01a7f
commit
9f23620412
5 changed files with 139 additions and 0 deletions
|
@ -64,6 +64,9 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
||||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
bash -c 'pip install tensorflow==1.15.0rc1 && python /ray/python/ray/tune/examples/async_hyperband_example.py --smoke-test'
|
bash -c 'pip install tensorflow==1.15.0rc1 && python /ray/python/ray/tune/examples/async_hyperband_example.py --smoke-test'
|
||||||
|
|
||||||
|
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
|
python /ray/python/ray/tune/examples/tf_mnist_example.py --smoke-test
|
||||||
|
|
||||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
python /ray/python/ray/tune/examples/lightgbm_example.py
|
python /ray/python/ray/tune/examples/lightgbm_example.py
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ Tensorflow/Keras Examples
|
||||||
|
|
||||||
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
||||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||||
|
- `Tensorflow 2 Example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tf_mnist_example.py>`__: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||||
|
|
||||||
|
|
||||||
PyTorch Examples
|
PyTorch Examples
|
||||||
|
|
|
@ -30,6 +30,7 @@ Tensorflow/Keras Examples
|
||||||
|
|
||||||
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
||||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||||
|
- `Tensorflow 2 Example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tf_mnist_example.py>`__: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||||
|
|
||||||
|
|
||||||
PyTorch Examples
|
PyTorch Examples
|
||||||
|
|
128
python/ray/tune/examples/tf_mnist_example.py
Normal file
128
python/ray/tune/examples/tf_mnist_example.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
#
|
||||||
|
# This example showcases how to use TF2.0 APIs with Tune.
|
||||||
|
# Original code: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||||
|
#
|
||||||
|
# As of 10/12/2019: One caveat of using TF2.0 is that TF AutoGraph
|
||||||
|
# functionality does not interact nicely with Ray actors. One way to get around
|
||||||
|
# this is to `import tensorflow` inside the Tune Trainable.
|
||||||
|
#
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from tensorflow.keras.layers import Dense, Flatten, Conv2D
|
||||||
|
from tensorflow.keras import Model
|
||||||
|
from tensorflow.keras.datasets.mnist import load_data
|
||||||
|
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
MAX_TRAIN_BATCH = 10
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
class MyModel(Model):
|
||||||
|
def __init__(self, hiddens=128):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.conv1 = Conv2D(32, 3, activation="relu")
|
||||||
|
self.flatten = Flatten()
|
||||||
|
self.d1 = Dense(hiddens, activation="relu")
|
||||||
|
self.d2 = Dense(10, activation="softmax")
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.d1(x)
|
||||||
|
return self.d2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTTrainable(tune.Trainable):
|
||||||
|
def _setup(self, config):
|
||||||
|
# IMPORTANT: See the above note.
|
||||||
|
import tensorflow as tf
|
||||||
|
(x_train, y_train), (x_test, y_test) = load_data()
|
||||||
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
|
||||||
|
# Add a channels dimension
|
||||||
|
x_train = x_train[..., tf.newaxis]
|
||||||
|
x_test = x_test[..., tf.newaxis]
|
||||||
|
self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||||
|
self.train_ds = self.train_ds.shuffle(10000).batch(
|
||||||
|
config.get("batch", 32))
|
||||||
|
|
||||||
|
self.test_ds = tf.data.Dataset.from_tensor_slices((x_test,
|
||||||
|
y_test)).batch(32)
|
||||||
|
|
||||||
|
self.model = MyModel(hiddens=config.get("hiddens", 128))
|
||||||
|
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
self.optimizer = tf.keras.optimizers.Adam()
|
||||||
|
self.train_loss = tf.keras.metrics.Mean(name="train_loss")
|
||||||
|
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
name="train_accuracy")
|
||||||
|
|
||||||
|
self.test_loss = tf.keras.metrics.Mean(name="test_loss")
|
||||||
|
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
|
name="test_accuracy")
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def train_step(images, labels):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = self.model(images)
|
||||||
|
loss = self.loss_object(labels, predictions)
|
||||||
|
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||||
|
self.optimizer.apply_gradients(
|
||||||
|
zip(gradients, self.model.trainable_variables))
|
||||||
|
|
||||||
|
self.train_loss(loss)
|
||||||
|
self.train_accuracy(labels, predictions)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def test_step(images, labels):
|
||||||
|
predictions = self.model(images)
|
||||||
|
t_loss = self.loss_object(labels, predictions)
|
||||||
|
|
||||||
|
self.test_loss(t_loss)
|
||||||
|
self.test_accuracy(labels, predictions)
|
||||||
|
|
||||||
|
self.tf_train_step = train_step
|
||||||
|
self.tf_test_step = test_step
|
||||||
|
|
||||||
|
def _train(self):
|
||||||
|
self.train_loss.reset_states()
|
||||||
|
self.train_accuracy.reset_states()
|
||||||
|
self.test_loss.reset_states()
|
||||||
|
self.test_accuracy.reset_states()
|
||||||
|
|
||||||
|
for idx, (images, labels) in enumerate(self.train_ds):
|
||||||
|
if idx > MAX_TRAIN_BATCH: # This is optional and can be removed.
|
||||||
|
break
|
||||||
|
self.tf_train_step(images, labels)
|
||||||
|
|
||||||
|
for test_images, test_labels in self.test_ds:
|
||||||
|
self.tf_test_step(test_images, test_labels)
|
||||||
|
|
||||||
|
# It is important to return tf.Tensors as numpy objects.
|
||||||
|
return {
|
||||||
|
"epoch": self.iteration,
|
||||||
|
"loss": self.train_loss.result().numpy(),
|
||||||
|
"accuracy": self.train_accuracy.result().numpy() * 100,
|
||||||
|
"test_loss": self.test_loss.result().numpy(),
|
||||||
|
"mean_accuracy": self.test_accuracy.result().numpy() * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_data() # we download data on the driver to avoid race conditions.
|
||||||
|
tune.run(
|
||||||
|
MNISTTrainable,
|
||||||
|
stop={"training_iteration": 5 if args.smoke_test else 50},
|
||||||
|
verbose=1,
|
||||||
|
config={"hiddens": tune.grid_search([32, 64, 128])})
|
|
@ -112,6 +112,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||||
they will be time-multiplexed as to balance training progress across the
|
they will be time-multiplexed as to balance training progress across the
|
||||||
population. To run multiple trials, use `tune.run(num_samples=<int>)`.
|
population. To run multiple trials, use `tune.run(num_samples=<int>)`.
|
||||||
|
|
||||||
|
In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
|
||||||
|
`pbt_global.txt` and individual policy perturbations are recorded
|
||||||
|
in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag,
|
||||||
|
target trial iteration, clone trial iteration, old config, new config]
|
||||||
|
on each perturbation step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time_attr (str): The training result attr to use for comparing time.
|
time_attr (str): The training result attr to use for comparing time.
|
||||||
Note that you can pass in something non-temporal such as
|
Note that you can pass in something non-temporal such as
|
||||||
|
|
Loading…
Add table
Reference in a new issue