mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] tf2.0 mnist example (#5898)
* tfmnistexample * tfmnist * add_to_ci * format * exampledownlaod * fix
This commit is contained in:
parent
6843a01a7f
commit
9f23620412
5 changed files with 139 additions and 0 deletions
|
@ -64,6 +64,9 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
bash -c 'pip install tensorflow==1.15.0rc1 && python /ray/python/ray/tune/examples/async_hyperband_example.py --smoke-test'
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/tf_mnist_example.py --smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/lightgbm_example.py
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ Tensorflow/Keras Examples
|
|||
|
||||
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||
- `Tensorflow 2 Example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tf_mnist_example.py>`__: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||
|
||||
|
||||
PyTorch Examples
|
||||
|
|
|
@ -30,6 +30,7 @@ Tensorflow/Keras Examples
|
|||
|
||||
- `tune_mnist_keras <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_mnist_keras.py>`__: Converts the Keras MNIST example to use Tune with the function-based API and a Keras callback. Also shows how to easily convert something relying on argparse to use Tune.
|
||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||
- `Tensorflow 2 Example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tf_mnist_example.py>`__: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||
|
||||
|
||||
PyTorch Examples
|
||||
|
|
128
python/ray/tune/examples/tf_mnist_example.py
Normal file
128
python/ray/tune/examples/tf_mnist_example.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
#
|
||||
# This example showcases how to use TF2.0 APIs with Tune.
|
||||
# Original code: https://www.tensorflow.org/tutorials/quickstart/advanced
|
||||
#
|
||||
# As of 10/12/2019: One caveat of using TF2.0 is that TF AutoGraph
|
||||
# functionality does not interact nicely with Ray actors. One way to get around
|
||||
# this is to `import tensorflow` inside the Tune Trainable.
|
||||
#
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
from tensorflow.keras.layers import Dense, Flatten, Conv2D
|
||||
from tensorflow.keras import Model
|
||||
from tensorflow.keras.datasets.mnist import load_data
|
||||
|
||||
from ray import tune
|
||||
|
||||
MAX_TRAIN_BATCH = 10
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
class MyModel(Model):
|
||||
def __init__(self, hiddens=128):
|
||||
super(MyModel, self).__init__()
|
||||
self.conv1 = Conv2D(32, 3, activation="relu")
|
||||
self.flatten = Flatten()
|
||||
self.d1 = Dense(hiddens, activation="relu")
|
||||
self.d2 = Dense(10, activation="softmax")
|
||||
|
||||
def call(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.flatten(x)
|
||||
x = self.d1(x)
|
||||
return self.d2(x)
|
||||
|
||||
|
||||
class MNISTTrainable(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
# IMPORTANT: See the above note.
|
||||
import tensorflow as tf
|
||||
(x_train, y_train), (x_test, y_test) = load_data()
|
||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||
|
||||
# Add a channels dimension
|
||||
x_train = x_train[..., tf.newaxis]
|
||||
x_test = x_test[..., tf.newaxis]
|
||||
self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
self.train_ds = self.train_ds.shuffle(10000).batch(
|
||||
config.get("batch", 32))
|
||||
|
||||
self.test_ds = tf.data.Dataset.from_tensor_slices((x_test,
|
||||
y_test)).batch(32)
|
||||
|
||||
self.model = MyModel(hiddens=config.get("hiddens", 128))
|
||||
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.Adam()
|
||||
self.train_loss = tf.keras.metrics.Mean(name="train_loss")
|
||||
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
name="train_accuracy")
|
||||
|
||||
self.test_loss = tf.keras.metrics.Mean(name="test_loss")
|
||||
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
name="test_accuracy")
|
||||
|
||||
@tf.function
|
||||
def train_step(images, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = self.model(images)
|
||||
loss = self.loss_object(labels, predictions)
|
||||
gradients = tape.gradient(loss, self.model.trainable_variables)
|
||||
self.optimizer.apply_gradients(
|
||||
zip(gradients, self.model.trainable_variables))
|
||||
|
||||
self.train_loss(loss)
|
||||
self.train_accuracy(labels, predictions)
|
||||
|
||||
@tf.function
|
||||
def test_step(images, labels):
|
||||
predictions = self.model(images)
|
||||
t_loss = self.loss_object(labels, predictions)
|
||||
|
||||
self.test_loss(t_loss)
|
||||
self.test_accuracy(labels, predictions)
|
||||
|
||||
self.tf_train_step = train_step
|
||||
self.tf_test_step = test_step
|
||||
|
||||
def _train(self):
|
||||
self.train_loss.reset_states()
|
||||
self.train_accuracy.reset_states()
|
||||
self.test_loss.reset_states()
|
||||
self.test_accuracy.reset_states()
|
||||
|
||||
for idx, (images, labels) in enumerate(self.train_ds):
|
||||
if idx > MAX_TRAIN_BATCH: # This is optional and can be removed.
|
||||
break
|
||||
self.tf_train_step(images, labels)
|
||||
|
||||
for test_images, test_labels in self.test_ds:
|
||||
self.tf_test_step(test_images, test_labels)
|
||||
|
||||
# It is important to return tf.Tensors as numpy objects.
|
||||
return {
|
||||
"epoch": self.iteration,
|
||||
"loss": self.train_loss.result().numpy(),
|
||||
"accuracy": self.train_accuracy.result().numpy() * 100,
|
||||
"test_loss": self.test_loss.result().numpy(),
|
||||
"mean_accuracy": self.test_accuracy.result().numpy() * 100
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_data() # we download data on the driver to avoid race conditions.
|
||||
tune.run(
|
||||
MNISTTrainable,
|
||||
stop={"training_iteration": 5 if args.smoke_test else 50},
|
||||
verbose=1,
|
||||
config={"hiddens": tune.grid_search([32, 64, 128])})
|
|
@ -112,6 +112,12 @@ class PopulationBasedTraining(FIFOScheduler):
|
|||
they will be time-multiplexed as to balance training progress across the
|
||||
population. To run multiple trials, use `tune.run(num_samples=<int>)`.
|
||||
|
||||
In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
|
||||
`pbt_global.txt` and individual policy perturbations are recorded
|
||||
in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag,
|
||||
target trial iteration, clone trial iteration, old config, new config]
|
||||
on each perturbation step.
|
||||
|
||||
Args:
|
||||
time_attr (str): The training result attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
|
|
Loading…
Add table
Reference in a new issue