[ML/Train] TensorflowTrainer implementation (#23250)

Implements `TensorflowTrainer`. Depends on https://github.com/ray-project/ray/pull/23211 (review only files with `tensorflow` in the name).

Co-authored-by: Eric Liang <ekhliang@gmail.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Antoni Baum 2022-03-17 19:34:47 +01:00 committed by GitHub
parent 8c9e3f6c2e
commit 1211c452d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 503 additions and 1 deletions

View file

@ -19,6 +19,26 @@ py_test (
deps = [":ml_lib"]
)
py_test(
name = "tensorflow_linear_dataset_example",
size = "medium",
main = "examples/tensorflow/tensorflow_linear_dataset_example.py",
srcs = ["examples/tensorflow/tensorflow_linear_dataset_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]
)
py_test(
name = "tensorflow_mnist_example",
size = "medium",
main = "examples/tensorflow/tensorflow_mnist_example.py",
srcs = ["examples/tensorflow/tensorflow_mnist_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]
)
py_test(
name = "torch_fashion_mnist_example",
size = "medium",
@ -49,6 +69,16 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "tune_tensorflow_mnist_example",
size = "medium",
main = "examples/tensorflow/tune_tensorflow_mnist_example.py",
srcs = ["examples/tensorflow/tune_tensorflow_mnist_example.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"],
args = ["--smoke-test"]
)
py_test(
name = "tune_torch_linear_dataset_example.py",
size = "medium",
@ -129,6 +159,14 @@ py_test(
deps = [":ml_lib"]
)
py_test(
name = "test_tensorflow_trainer",
size = "medium",
srcs = ["tests/test_tensorflow_trainer.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)
py_test(
name = "test_torch_predictor",
size = "small",

View file

@ -0,0 +1,153 @@
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import ray
import ray.train as train
from ray.data import Dataset
from ray.train.tensorflow import prepare_dataset_shard
from ray.ml.checkpoint import Checkpoint
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.result import Result
class TrainCheckpointReportCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
train.save_checkpoint(**{"model": self.model.get_weights()})
train.report(**logs)
def get_dataset(a=5, b=10, size=1000) -> Dataset:
items = [i / size for i in range(size)]
dataset = ray.data.from_items([{"x": x, "y": a * x + b} for x in items])
return dataset
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(1),
]
)
return model
def train_func(config: dict):
batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 3)
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_model()
multi_worker_model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_squared_error],
)
dataset = train.get_dataset_shard("train")
results = []
for _ in range(epochs):
tf_dataset = prepare_dataset_shard(
dataset.to_tf(
label_column="y",
output_signature=(
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
tf.TensorSpec(shape=(None), dtype=tf.float32),
),
batch_size=batch_size,
)
)
history = multi_worker_model.fit(
tf_dataset, callbacks=[TrainCheckpointReportCallback()]
)
results.append(history.history)
return results
def train_tensorflow_linear(num_workers: int = 2, use_gpu: bool = False) -> Result:
dataset_pipeline = get_dataset()
config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
datasets={"train": dataset_pipeline},
)
results = trainer.fit()
print(results.metrics)
return results
def predict_linear(result: Result) -> Dataset:
items = [{"x": np.random.uniform(0, 1)}] * 10
prediction_dataset = ray.data.from_items(items)
checkpoint_object_ref = result.checkpoint.to_object_ref()
class TFScorer:
def __init__(self):
self.predictor = TensorflowPredictor.from_checkpoint(
Checkpoint.from_object_ref(checkpoint_object_ref),
model_definition=build_model,
)
def __call__(self, batch) -> pd.DataFrame:
return self.predictor.predict(batch, dtype=tf.float32)
predictions = prediction_dataset.map_batches(
TFScorer, compute="actors", batch_format="pandas"
)
pandas_predictions = predictions.to_pandas(float("inf"))
print(f"PREDICTIONS\n{pandas_predictions}")
return predictions
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
args, _ = parser.parse_known_args()
if args.smoke_test:
# 2 workers, 1 for trainer, 1 for datasets
num_gpus = args.num_workers if args.use_gpu else 0
ray.init(num_cpus=4, num_gpus=num_gpus)
result = train_tensorflow_linear(num_workers=2, use_gpu=args.use_gpu)
else:
ray.init(address=args.address)
result = train_tensorflow_linear(
num_workers=args.num_workers, use_gpu=args.use_gpu
)
predict_linear(result)

View file

@ -0,0 +1,137 @@
# This example showcases how to use Tensorflow with Ray Train.
# Original code:
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
import argparse
import json
import os
import numpy as np
from ray.ml.result import Result
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import ray.train as train
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
class TrainCheckpointReportCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
train.save_checkpoint(**{"model": self.model.get_weights()})
train.report(**logs)
def mnist_dataset(batch_size: int) -> tf.data.Dataset:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(60000)
.repeat()
.batch(batch_size)
)
return train_dataset
def build_cnn_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
return model
def train_func(config: dict):
per_worker_batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 3)
steps_per_epoch = config.get("steps_per_epoch", 70)
tf_config = json.loads(os.environ["TF_CONFIG"])
num_workers = len(tf_config["cluster"]["worker"])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_cnn_model()
learning_rate = config.get("lr", 0.001)
multi_worker_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate),
metrics=["accuracy"],
)
history = multi_worker_model.fit(
multi_worker_dataset,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[TrainCheckpointReportCallback()],
)
results = history.history
return results
def train_tensorflow_mnist(
num_workers: int = 2, use_gpu: bool = False, epochs: int = 4
) -> Result:
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
)
results = trainer.fit()
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--epochs", type=int, default=3, help="Number of epochs to train for."
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
args, _ = parser.parse_known_args()
import ray
if args.smoke_test:
# 2 workers, 1 for trainer, 1 for datasets
num_gpus = args.num_workers if args.use_gpu else 0
ray.init(num_cpus=4, num_gpus=num_gpus)
train_tensorflow_mnist(num_workers=2, use_gpu=args.use_gpu)
else:
ray.init(address=args.address)
train_tensorflow_mnist(
num_workers=args.num_workers, use_gpu=args.use_gpu, epochs=args.epochs
)

View file

@ -0,0 +1,77 @@
import argparse
import ray
from ray import tune
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
from ray.ml.examples.tensorflow.tensorflow_mnist_example import train_func
def tune_tensorflow_mnist(
num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
):
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
)
Trainable = trainer.as_trainable()
analysis = tune.run(
Trainable,
num_samples=num_samples,
config={
"train_loop_config": {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"epochs": 3,
}
},
)
best_loss = analysis.get_best_config(metric="loss", mode="min")
best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
print(f"Best loss config: {best_loss}")
print(f"Best accuracy config: {best_accuracy}")
return analysis
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--num-samples",
type=int,
default=2,
help="Sets number of samples for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
args = parser.parse_args()
if args.smoke_test:
num_gpus = args.num_workers if args.use_gpu else 0
ray.init(num_cpus=8, num_gpus=num_gpus)
tune_tensorflow_mnist(num_workers=2, num_samples=2, use_gpu=args.use_gpu)
else:
ray.init(address=args.address)
tune_tensorflow_mnist(
num_workers=args.num_workers,
num_samples=args.num_samples,
use_gpu=args.use_gpu,
)

View file

@ -0,0 +1,85 @@
import pytest
import numpy as np
import ray
from ray import train
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
from ray.ml.examples.tensorflow.tensorflow_linear_dataset_example import (
train_func as tensorflow_linear_train_func,
get_dataset,
)
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
from ray.ml.constants import MODEL_KEY, TRAIN_DATASET_KEY
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
def build_model():
import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,)),
tf.keras.layers.Dense(1),
]
)
return model
@pytest.mark.parametrize("num_workers", [1, 2])
def test_tensorflow_linear(ray_start_4_cpus, num_workers):
def train_func(config):
result = tensorflow_linear_train_func(config)
assert len(result) == epochs
assert result[-1]["loss"] < result[0]["loss"]
num_workers = num_workers
epochs = 3
scaling_config = {"num_workers": num_workers}
config = {
"lr": 1e-3,
"batch_size": 32,
"epochs": epochs,
}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
datasets={TRAIN_DATASET_KEY: get_dataset()},
)
trainer.fit()
def test_tensorflow_e2e(ray_start_4_cpus):
def train_func():
model = build_model().get_weights()
train.save_checkpoint(**{MODEL_KEY: model})
scaling_config = {"num_workers": 2}
trainer = TensorflowTrainer(
train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()
predictor = TensorflowPredictor.from_checkpoint(result.checkpoint, build_model)
predict_dataset = ray.data.range(3)
predictions = predict_dataset.map_batches(
lambda batch: predictor.predict(batch, dtype=np.float),
batch_format="pandas",
)
assert predictions.count() == 3
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))

View file

@ -163,4 +163,16 @@ class TensorflowTrainer(DataParallelTrainer):
preprocessor: Optional[Preprocessor] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
raise NotImplementedError
if not tensorflow_config:
tensorflow_config = TensorflowConfig()
super(TensorflowTrainer, self).__init__(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
backend_config=tensorflow_config,
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)