mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
154 lines
No EOL
5 KiB
Text
154 lines
No EOL
5 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3b05af3b",
|
|
"metadata": {},
|
|
"source": [
|
|
"(tune-mnist-keras)=\n",
|
|
"\n",
|
|
"# Using Keras & TensorFlow with Tune\n",
|
|
"\n",
|
|
"```{image} /images/tf_keras_logo.jpeg\n",
|
|
":align: center\n",
|
|
":alt: Keras & TensorFlow Logo\n",
|
|
":height: 120px\n",
|
|
":target: https://keras.io\n",
|
|
"```\n",
|
|
"\n",
|
|
"```{contents}\n",
|
|
":backlinks: none\n",
|
|
":local: true\n",
|
|
"```\n",
|
|
"\n",
|
|
"## Example"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "19e3c389",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import argparse\n",
|
|
"import os\n",
|
|
"\n",
|
|
"from filelock import FileLock\n",
|
|
"from tensorflow.keras.datasets import mnist\n",
|
|
"\n",
|
|
"import ray\n",
|
|
"from ray import tune\n",
|
|
"from ray.tune.schedulers import AsyncHyperBandScheduler\n",
|
|
"from ray.tune.integration.keras import TuneReportCallback\n",
|
|
"\n",
|
|
"\n",
|
|
"def train_mnist(config):\n",
|
|
" # https://github.com/tensorflow/tensorflow/issues/32159\n",
|
|
" import tensorflow as tf\n",
|
|
"\n",
|
|
" batch_size = 128\n",
|
|
" num_classes = 10\n",
|
|
" epochs = 12\n",
|
|
"\n",
|
|
" with FileLock(os.path.expanduser(\"~/.data.lock\")):\n",
|
|
" (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
|
|
" x_train, x_test = x_train / 255.0, x_test / 255.0\n",
|
|
" model = tf.keras.models.Sequential(\n",
|
|
" [\n",
|
|
" tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
|
|
" tf.keras.layers.Dense(config[\"hidden\"], activation=\"relu\"),\n",
|
|
" tf.keras.layers.Dropout(0.2),\n",
|
|
" tf.keras.layers.Dense(num_classes, activation=\"softmax\"),\n",
|
|
" ]\n",
|
|
" )\n",
|
|
"\n",
|
|
" model.compile(\n",
|
|
" loss=\"sparse_categorical_crossentropy\",\n",
|
|
" optimizer=tf.keras.optimizers.SGD(lr=config[\"lr\"], momentum=config[\"momentum\"]),\n",
|
|
" metrics=[\"accuracy\"],\n",
|
|
" )\n",
|
|
"\n",
|
|
" model.fit(\n",
|
|
" x_train,\n",
|
|
" y_train,\n",
|
|
" batch_size=batch_size,\n",
|
|
" epochs=epochs,\n",
|
|
" verbose=0,\n",
|
|
" validation_data=(x_test, y_test),\n",
|
|
" callbacks=[TuneReportCallback({\"mean_accuracy\": \"accuracy\"})],\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"def tune_mnist(num_training_iterations):\n",
|
|
" sched = AsyncHyperBandScheduler(\n",
|
|
" time_attr=\"training_iteration\", max_t=400, grace_period=20\n",
|
|
" )\n",
|
|
"\n",
|
|
" analysis = tune.run(\n",
|
|
" train_mnist,\n",
|
|
" name=\"exp\",\n",
|
|
" scheduler=sched,\n",
|
|
" metric=\"mean_accuracy\",\n",
|
|
" mode=\"max\",\n",
|
|
" stop={\"mean_accuracy\": 0.99, \"training_iteration\": num_training_iterations},\n",
|
|
" num_samples=10,\n",
|
|
" resources_per_trial={\"cpu\": 2, \"gpu\": 0},\n",
|
|
" config={\n",
|
|
" \"threads\": 2,\n",
|
|
" \"lr\": tune.uniform(0.001, 0.1),\n",
|
|
" \"momentum\": tune.uniform(0.1, 0.9),\n",
|
|
" \"hidden\": tune.randint(32, 512),\n",
|
|
" },\n",
|
|
" )\n",
|
|
" print(\"Best hyperparameters found were: \", analysis.best_config)\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == \"__main__\":\n",
|
|
" parser = argparse.ArgumentParser()\n",
|
|
" parser.add_argument(\n",
|
|
" \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
|
|
" )\n",
|
|
" parser.add_argument(\n",
|
|
" \"--server-address\",\n",
|
|
" type=str,\n",
|
|
" default=None,\n",
|
|
" required=False,\n",
|
|
" help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
|
|
" )\n",
|
|
" args, _ = parser.parse_known_args()\n",
|
|
" if args.smoke_test:\n",
|
|
" ray.init(num_cpus=4)\n",
|
|
" elif args.server_address:\n",
|
|
" ray.init(f\"ray://{args.server_address}\")\n",
|
|
"\n",
|
|
" tune_mnist(num_training_iterations=5 if args.smoke_test else 300)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d7e46189",
|
|
"metadata": {},
|
|
"source": [
|
|
"## More Keras and TensorFlow Examples\n",
|
|
"\n",
|
|
"- {doc}`/tune/examples/includes/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT.\n",
|
|
"- {doc}`/tune/examples/includes/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune\n",
|
|
" with the Trainable. This uses `tf.function`.\n",
|
|
" Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced\n",
|
|
"- {doc}`/tune/examples/includes/pbt_tune_cifar10_with_keras`:\n",
|
|
" A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"orphan": true
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
} |