(tune-mnist-keras)=

# Using Keras & TensorFlow with Tune

```{image} /images/tf_keras_logo.jpeg
:align: center
:alt: Keras & TensorFlow Logo
:height: 120px
:target: https://keras.io
```

```{contents}
:backlinks: none
:local: true
```

## Example

In [1]:
import argparse
import os

from filelock import FileLock
from tensorflow.keras.datasets import mnist

import ray
from ray import air, tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.keras import TuneReportCallback


def train_mnist(config):
    # https://github.com/tensorflow/tensorflow/issues/32159
    import tensorflow as tf

    batch_size = 128
    num_classes = 10
    epochs = 12

    with FileLock(os.path.expanduser("~/.data.lock")):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(config["hidden"], activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=tf.keras.optimizers.SGD(lr=config["lr"], momentum=config["momentum"]),
        metrics=["accuracy"],
    )

    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        verbose=0,
        validation_data=(x_test, y_test),
        callbacks=[TuneReportCallback({"mean_accuracy": "accuracy"})],
    )


def tune_mnist(num_training_iterations):
    sched = AsyncHyperBandScheduler(
        time_attr="training_iteration", max_t=400, grace_period=20
    )
    
    tuner = tune.Tuner(
        tune.with_resources(
            train_mnist,
            resources={"cpu": 2, "gpu": 0}
        ),
        tune_config=tune.TuneConfig(
            metric="mean_accuracy",
            mode="max",
            scheduler=sched,
            num_samples=10,
        ),
        run_config=air.RunConfig(
            name="exp",
            stop={"mean_accuracy": 0.99, "training_iteration": num_training_iterations},
        ),
        param_space={
            "threads": 2,
            "lr": tune.uniform(0.001, 0.1),
            "momentum": tune.uniform(0.1, 0.9),
            "hidden": tune.randint(32, 512),
        },
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    parser.add_argument(
        "--server-address",
        type=str,
        default=None,
        required=False,
        help="The address of server to connect to if using " "Ray Client.",
    )
    args, _ = parser.parse_known_args()
    if args.smoke_test:
        ray.init(num_cpus=4)
    elif args.server_address:
        ray.init(f"ray://{args.server_address}")

    tune_mnist(num_training_iterations=5 if args.smoke_test else 300)


2022-07-22 16:16:58,114	INFO services.py:1483 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8269[39m[22m


Trial name,status,loc,hidden,lr,momentum,acc,iter,total time (s)
train_mnist_55a9b_00000,TERMINATED,127.0.0.1:51968,276,0.0406397,0.817788,0.98455,12,78.3252
train_mnist_55a9b_00001,TERMINATED,127.0.0.1:51977,380,0.0873557,0.524634,0.983717,12,74.9888
train_mnist_55a9b_00002,TERMINATED,127.0.0.1:51984,258,0.0951813,0.825499,0.990417,11,64.1272
train_mnist_55a9b_00003,TERMINATED,127.0.0.1:51991,255,0.0971683,0.23161,0.977633,12,60.8475
train_mnist_55a9b_00004,TERMINATED,127.0.0.1:52000,303,0.00440117,0.325439,0.90775,12,55.5722
train_mnist_55a9b_00005,TERMINATED,127.0.0.1:52007,92,0.0651919,0.710183,0.974867,12,44.8092
train_mnist_55a9b_00006,TERMINATED,127.0.0.1:52016,211,0.0731116,0.127751,0.97025,12,42.1217
train_mnist_55a9b_00007,TERMINATED,127.0.0.1:52021,181,0.0362389,0.790345,0.979967,12,41.7632
train_mnist_55a9b_00008,TERMINATED,127.0.0.1:52007,142,0.0323741,0.660418,0.969367,12,14.1527
train_mnist_55a9b_00009,TERMINATED,127.0.0.1:51984,97,0.0244971,0.175045,0.9407,12,12.6405


2022-07-22 16:17:01,834	INFO plugin_schema_manager.py:52 -- Loading the default runtime env schemas: ['/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/working_dir_schema.json', '/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/pip_schema.json'].
[2m[36m(train_mnist pid=51968)[0m 2022-07-22 16:17:08.627419: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(train_mnist pid=51968)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(train_mnist pid=51968)[0m   "The `lr` argument is deprecated, use `learning_rate` instead.")
[2m[36m(train_mnist pid=51968)[0m 2022-07-22 16:17:08.947939: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are e

Result for train_mnist_55a9b_00000:
  date: 2022-07-22_16-17-10
  done: false
  experiment_id: 3659349c38c746cfb71b4db5eb9302a0
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  mean_accuracy: 0.8903833627700806
  node_ip: 127.0.0.1
  pid: 51968
  time_since_restore: 2.439258098602295
  time_this_iter_s: 2.439258098602295
  time_total_s: 2.439258098602295
  timestamp: 1658503030
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 55a9b_00000
  warmup_time: 0.003445863723754883
  
Result for train_mnist_55a9b_00004:
  date: 2022-07-22_16-17-33
  done: false
  experiment_id: 6eb62b7cb38f442a867a9094f0664701
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  mean_accuracy: 0.6376166939735413
  node_ip: 127.0.0.1
  pid: 52000
  time_since_restore: 2.4364511966705322
  time_this_iter_s: 2.4364511966705322
  time_total_s: 2.4364511966705322
  timestamp: 1658503053
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 55a9b_00004
  warm

[2m[36m(train_mnist pid=52021)[0m 2022-07-22 16:17:51.567914: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
[2m[36m(train_mnist pid=52021)[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2m[36m(train_mnist pid=52021)[0m   "The `lr` argument is deprecated, use `learning_rate` instead.")
[2m[36m(train_mnist pid=52021)[0m 2022-07-22 16:17:52.977183: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Result for train_mnist_55a9b_00005:
  date: 2022-07-22_16-17-54
  done: false
  experiment_id: 8dbd22e6caed4fe39351dffa3ef14eac
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 3
  mean_accuracy: 0.9490833282470703
  node_ip: 127.0.0.1
  pid: 52007
  time_since_restore: 17.22033405303955
  time_this_iter_s: 2.672102928161621
  time_total_s: 17.22033405303955
  timestamp: 1658503074
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: 55a9b_00005
  warmup_time: 0.005449056625366211
  
Result for train_mnist_55a9b_00006:
  date: 2022-07-22_16-17-54
  done: false
  experiment_id: 9594405e38084311a891b48addd13f75
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 3
  mean_accuracy: 0.9327999949455261
  node_ip: 127.0.0.1
  pid: 52016
  time_since_restore: 11.758372068405151
  time_this_iter_s: 3.0426323413848877
  time_total_s: 11.758372068405151
  timestamp: 1658503074
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: 55a9b_00006
  warm

Result for train_mnist_55a9b_00007:
  date: 2022-07-22_16-18-07
  done: false
  experiment_id: d9469b1fc58b41db88da5446dc2a3b23
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 3
  mean_accuracy: 0.951033353805542
  node_ip: 127.0.0.1
  pid: 52021
  time_since_restore: 18.96213722229004
  time_this_iter_s: 3.252371311187744
  time_total_s: 18.96213722229004
  timestamp: 1658503087
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: 55a9b_00007
  warmup_time: 0.0028028488159179688
  
Result for train_mnist_55a9b_00006:
  date: 2022-07-22_16-18-08
  done: false
  experiment_id: 9594405e38084311a891b48addd13f75
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 7
  mean_accuracy: 0.9584500193595886
  node_ip: 127.0.0.1
  pid: 52016
  time_since_restore: 25.336583852767944
  time_this_iter_s: 3.311979055404663
  time_total_s: 25.336583852767944
  timestamp: 1658503088
  timesteps_since_restore: 0
  training_iteration: 7
  trial_id: 55a9b_00006
  warmu

Result for train_mnist_55a9b_00005:
  date: 2022-07-22_16-18-21
  done: true
  experiment_id: 8dbd22e6caed4fe39351dffa3ef14eac
  experiment_tag: 5_hidden=92,lr=0.0652,momentum=0.7102
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 12
  mean_accuracy: 0.9748666882514954
  node_ip: 127.0.0.1
  pid: 52007
  time_since_restore: 44.80922222137451
  time_this_iter_s: 3.3184430599212646
  time_total_s: 44.80922222137451
  timestamp: 1658503101
  timesteps_since_restore: 0
  training_iteration: 12
  trial_id: 55a9b_00005
  warmup_time: 0.005449056625366211
  
Result for train_mnist_55a9b_00006:
  date: 2022-07-22_16-18-22
  done: false
  experiment_id: 9594405e38084311a891b48addd13f75
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 11
  mean_accuracy: 0.9679166674613953
  node_ip: 127.0.0.1
  pid: 52016
  time_since_restore: 39.08963179588318
  time_this_iter_s: 3.4860758781433105
  time_total_s: 39.08963179588318
  timestamp: 1658503102
  timesteps_since_restore:

Result for train_mnist_55a9b_00008:
  date: 2022-07-22_16-18-33
  done: false
  experiment_id: 8dbd22e6caed4fe39351dffa3ef14eac
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 8
  mean_accuracy: 0.9599000215530396
  node_ip: 127.0.0.1
  pid: 52007
  time_since_restore: 11.612935304641724
  time_this_iter_s: 0.6818761825561523
  time_total_s: 11.612935304641724
  timestamp: 1658503113
  timesteps_since_restore: 0
  training_iteration: 8
  trial_id: 55a9b_00008
  warmup_time: 0.005449056625366211
  
Result for train_mnist_55a9b_00009:
  date: 2022-07-22_16-18-34
  done: false
  experiment_id: c4f803baf65f4d4e9fd6abc85b2fd00c
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 9
  mean_accuracy: 0.9319833517074585
  node_ip: 127.0.0.1
  pid: 51984
  time_since_restore: 10.803268194198608
  time_this_iter_s: 0.606992244720459
  time_total_s: 10.803268194198608
  timestamp: 1658503114
  timesteps_since_restore: 0
  training_iteration: 9
  trial_id: 55a9b_00009
  wa

2022-07-22 16:18:36,803	INFO tune.py:738 -- Total run time: 95.98 seconds (95.03 seconds for the tuning loop).


Best hyperparameters found were:  {'threads': 2, 'lr': 0.09518133271957563, 'momentum': 0.8254987643140009, 'hidden': 258}


## More Keras and TensorFlow Examples

- {doc}`/tune/examples/includes/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT.
- {doc}`/tune/examples/includes/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune
  with the Trainable. This uses `tf.function`.
  Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced
- {doc}`/tune/examples/includes/pbt_tune_cifar10_with_keras`:
  A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.
