mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[ML/Train] TensorflowTrainer
implementation (#23250)
Implements `TensorflowTrainer`. Depends on https://github.com/ray-project/ray/pull/23211 (review only files with `tensorflow` in the name). Co-authored-by: Eric Liang <ekhliang@gmail.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com> Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
parent
8c9e3f6c2e
commit
1211c452d4
7 changed files with 503 additions and 1 deletions
|
@ -19,6 +19,26 @@ py_test (
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensorflow_linear_dataset_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tensorflow_linear_dataset_example.py",
|
||||
srcs = ["examples/tensorflow/tensorflow_linear_dataset_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensorflow_mnist_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tensorflow/tensorflow_mnist_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "torch_fashion_mnist_example",
|
||||
size = "medium",
|
||||
|
@ -49,6 +69,16 @@ py_test(
|
|||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_tensorflow_mnist_example",
|
||||
size = "medium",
|
||||
main = "examples/tensorflow/tune_tensorflow_mnist_example.py",
|
||||
srcs = ["examples/tensorflow/tune_tensorflow_mnist_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_torch_linear_dataset_example.py",
|
||||
size = "medium",
|
||||
|
@ -129,6 +159,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_tensorflow_trainer",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_tensorflow_trainer.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_predictor",
|
||||
size = "small",
|
||||
|
|
0
python/ray/ml/examples/tensorflow/__init__.py
Normal file
0
python/ray/ml/examples/tensorflow/__init__.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
import ray
|
||||
import ray.train as train
|
||||
from ray.data import Dataset
|
||||
from ray.train.tensorflow import prepare_dataset_shard
|
||||
from ray.ml.checkpoint import Checkpoint
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
|
||||
from ray.ml.result import Result
|
||||
|
||||
|
||||
class TrainCheckpointReportCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
train.save_checkpoint(**{"model": self.model.get_weights()})
|
||||
train.report(**logs)
|
||||
|
||||
|
||||
def get_dataset(a=5, b=10, size=1000) -> Dataset:
|
||||
items = [i / size for i in range(size)]
|
||||
dataset = ray.data.from_items([{"x": x, "y": a * x + b} for x in items])
|
||||
return dataset
|
||||
|
||||
|
||||
def build_model() -> tf.keras.Model:
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.InputLayer(input_shape=(1,)),
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(1),
|
||||
]
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def train_func(config: dict):
|
||||
batch_size = config.get("batch_size", 64)
|
||||
epochs = config.get("epochs", 3)
|
||||
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
# Model building/compiling need to be within `strategy.scope()`.
|
||||
multi_worker_model = build_model()
|
||||
multi_worker_model.compile(
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
|
||||
loss=tf.keras.losses.mean_squared_error,
|
||||
metrics=[tf.keras.metrics.mean_squared_error],
|
||||
)
|
||||
|
||||
dataset = train.get_dataset_shard("train")
|
||||
|
||||
results = []
|
||||
for _ in range(epochs):
|
||||
tf_dataset = prepare_dataset_shard(
|
||||
dataset.to_tf(
|
||||
label_column="y",
|
||||
output_signature=(
|
||||
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
|
||||
tf.TensorSpec(shape=(None), dtype=tf.float32),
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
history = multi_worker_model.fit(
|
||||
tf_dataset, callbacks=[TrainCheckpointReportCallback()]
|
||||
)
|
||||
results.append(history.history)
|
||||
return results
|
||||
|
||||
|
||||
def train_tensorflow_linear(num_workers: int = 2, use_gpu: bool = False) -> Result:
|
||||
dataset_pipeline = get_dataset()
|
||||
config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}
|
||||
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
train_loop_config=config,
|
||||
scaling_config=scaling_config,
|
||||
datasets={"train": dataset_pipeline},
|
||||
)
|
||||
results = trainer.fit()
|
||||
print(results.metrics)
|
||||
return results
|
||||
|
||||
|
||||
def predict_linear(result: Result) -> Dataset:
|
||||
items = [{"x": np.random.uniform(0, 1)}] * 10
|
||||
prediction_dataset = ray.data.from_items(items)
|
||||
|
||||
checkpoint_object_ref = result.checkpoint.to_object_ref()
|
||||
|
||||
class TFScorer:
|
||||
def __init__(self):
|
||||
self.predictor = TensorflowPredictor.from_checkpoint(
|
||||
Checkpoint.from_object_ref(checkpoint_object_ref),
|
||||
model_definition=build_model,
|
||||
)
|
||||
|
||||
def __call__(self, batch) -> pd.DataFrame:
|
||||
return self.predictor.predict(batch, dtype=tf.float32)
|
||||
|
||||
predictions = prediction_dataset.map_batches(
|
||||
TFScorer, compute="actors", batch_format="pandas"
|
||||
)
|
||||
|
||||
pandas_predictions = predictions.to_pandas(float("inf"))
|
||||
|
||||
print(f"PREDICTIONS\n{pandas_predictions}")
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
# 2 workers, 1 for trainer, 1 for datasets
|
||||
num_gpus = args.num_workers if args.use_gpu else 0
|
||||
ray.init(num_cpus=4, num_gpus=num_gpus)
|
||||
result = train_tensorflow_linear(num_workers=2, use_gpu=args.use_gpu)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
result = train_tensorflow_linear(
|
||||
num_workers=args.num_workers, use_gpu=args.use_gpu
|
||||
)
|
||||
predict_linear(result)
|
137
python/ray/ml/examples/tensorflow/tensorflow_mnist_example.py
Normal file
137
python/ray/ml/examples/tensorflow/tensorflow_mnist_example.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
# This example showcases how to use Tensorflow with Ray Train.
|
||||
# Original code:
|
||||
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from ray.ml.result import Result
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
import ray.train as train
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
|
||||
|
||||
class TrainCheckpointReportCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
train.save_checkpoint(**{"model": self.model.get_weights()})
|
||||
train.report(**logs)
|
||||
|
||||
|
||||
def mnist_dataset(batch_size: int) -> tf.data.Dataset:
|
||||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
|
||||
# The `x` arrays are in uint8 and have values in the [0, 255] range.
|
||||
# You need to convert them to float32 with values in the [0, 1] range.
|
||||
x_train = x_train / np.float32(255)
|
||||
y_train = y_train.astype(np.int64)
|
||||
train_dataset = (
|
||||
tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
.shuffle(60000)
|
||||
.repeat()
|
||||
.batch(batch_size)
|
||||
)
|
||||
return train_dataset
|
||||
|
||||
|
||||
def build_cnn_model() -> tf.keras.Model:
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.Input(shape=(28, 28)),
|
||||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(32, 3, activation="relu"),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation="relu"),
|
||||
tf.keras.layers.Dense(10),
|
||||
]
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def train_func(config: dict):
|
||||
per_worker_batch_size = config.get("batch_size", 64)
|
||||
epochs = config.get("epochs", 3)
|
||||
steps_per_epoch = config.get("steps_per_epoch", 70)
|
||||
|
||||
tf_config = json.loads(os.environ["TF_CONFIG"])
|
||||
num_workers = len(tf_config["cluster"]["worker"])
|
||||
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
|
||||
global_batch_size = per_worker_batch_size * num_workers
|
||||
multi_worker_dataset = mnist_dataset(global_batch_size)
|
||||
|
||||
with strategy.scope():
|
||||
# Model building/compiling need to be within `strategy.scope()`.
|
||||
multi_worker_model = build_cnn_model()
|
||||
learning_rate = config.get("lr", 0.001)
|
||||
multi_worker_model.compile(
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate),
|
||||
metrics=["accuracy"],
|
||||
)
|
||||
|
||||
history = multi_worker_model.fit(
|
||||
multi_worker_dataset,
|
||||
epochs=epochs,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
callbacks=[TrainCheckpointReportCallback()],
|
||||
)
|
||||
results = history.history
|
||||
return results
|
||||
|
||||
|
||||
def train_tensorflow_mnist(
|
||||
num_workers: int = 2, use_gpu: bool = False, epochs: int = 4
|
||||
) -> Result:
|
||||
config = {"lr": 1e-3, "batch_size": 64, "epochs": epochs}
|
||||
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
train_loop_config=config,
|
||||
scaling_config=scaling_config,
|
||||
)
|
||||
results = trainer.fit()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=3, help="Number of epochs to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.",
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
|
||||
if args.smoke_test:
|
||||
# 2 workers, 1 for trainer, 1 for datasets
|
||||
num_gpus = args.num_workers if args.use_gpu else 0
|
||||
ray.init(num_cpus=4, num_gpus=num_gpus)
|
||||
train_tensorflow_mnist(num_workers=2, use_gpu=args.use_gpu)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
train_tensorflow_mnist(
|
||||
num_workers=args.num_workers, use_gpu=args.use_gpu, epochs=args.epochs
|
||||
)
|
|
@ -0,0 +1,77 @@
|
|||
import argparse
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
|
||||
from ray.ml.examples.tensorflow.tensorflow_mnist_example import train_func
|
||||
|
||||
|
||||
def tune_tensorflow_mnist(
|
||||
num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
|
||||
):
|
||||
scaling_config = dict(num_workers=num_workers, use_gpu=use_gpu)
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
scaling_config=scaling_config,
|
||||
)
|
||||
Trainable = trainer.as_trainable()
|
||||
analysis = tune.run(
|
||||
Trainable,
|
||||
num_samples=num_samples,
|
||||
config={
|
||||
"train_loop_config": {
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"epochs": 3,
|
||||
}
|
||||
},
|
||||
)
|
||||
best_loss = analysis.get_best_config(metric="loss", mode="min")
|
||||
best_accuracy = analysis.get_best_config(metric="accuracy", mode="max")
|
||||
print(f"Best loss config: {best_loss}")
|
||||
print(f"Best accuracy config: {best_accuracy}")
|
||||
return analysis
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--address", required=False, type=str, help="the address to use for Ray"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of workers for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Sets number of samples for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.smoke_test:
|
||||
num_gpus = args.num_workers if args.use_gpu else 0
|
||||
ray.init(num_cpus=8, num_gpus=num_gpus)
|
||||
tune_tensorflow_mnist(num_workers=2, num_samples=2, use_gpu=args.use_gpu)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
tune_tensorflow_mnist(
|
||||
num_workers=args.num_workers,
|
||||
num_samples=args.num_samples,
|
||||
use_gpu=args.use_gpu,
|
||||
)
|
85
python/ray/ml/tests/test_tensorflow_trainer.py
Normal file
85
python/ray/ml/tests/test_tensorflow_trainer.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import train
|
||||
from ray.ml.train.integrations.tensorflow import TensorflowTrainer
|
||||
from ray.ml.examples.tensorflow.tensorflow_linear_dataset_example import (
|
||||
train_func as tensorflow_linear_train_func,
|
||||
get_dataset,
|
||||
)
|
||||
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
|
||||
from ray.ml.constants import MODEL_KEY, TRAIN_DATASET_KEY
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def build_model():
|
||||
import tensorflow as tf
|
||||
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.InputLayer(input_shape=(1,)),
|
||||
tf.keras.layers.Dense(1),
|
||||
]
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_tensorflow_linear(ray_start_4_cpus, num_workers):
|
||||
def train_func(config):
|
||||
result = tensorflow_linear_train_func(config)
|
||||
assert len(result) == epochs
|
||||
assert result[-1]["loss"] < result[0]["loss"]
|
||||
|
||||
num_workers = num_workers
|
||||
epochs = 3
|
||||
scaling_config = {"num_workers": num_workers}
|
||||
config = {
|
||||
"lr": 1e-3,
|
||||
"batch_size": 32,
|
||||
"epochs": epochs,
|
||||
}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
train_loop_config=config,
|
||||
scaling_config=scaling_config,
|
||||
datasets={TRAIN_DATASET_KEY: get_dataset()},
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
def test_tensorflow_e2e(ray_start_4_cpus):
|
||||
def train_func():
|
||||
model = build_model().get_weights()
|
||||
train.save_checkpoint(**{MODEL_KEY: model})
|
||||
|
||||
scaling_config = {"num_workers": 2}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scaling_config
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
predictor = TensorflowPredictor.from_checkpoint(result.checkpoint, build_model)
|
||||
|
||||
predict_dataset = ray.data.range(3)
|
||||
predictions = predict_dataset.map_batches(
|
||||
lambda batch: predictor.predict(batch, dtype=np.float),
|
||||
batch_format="pandas",
|
||||
)
|
||||
assert predictions.count() == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
|
@ -163,4 +163,16 @@ class TensorflowTrainer(DataParallelTrainer):
|
|||
preprocessor: Optional[Preprocessor] = None,
|
||||
resume_from_checkpoint: Optional[Checkpoint] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
if not tensorflow_config:
|
||||
tensorflow_config = TensorflowConfig()
|
||||
|
||||
super(TensorflowTrainer, self).__init__(
|
||||
train_loop_per_worker=train_loop_per_worker,
|
||||
train_loop_config=train_loop_config,
|
||||
backend_config=tensorflow_config,
|
||||
scaling_config=scaling_config,
|
||||
run_config=run_config,
|
||||
datasets=datasets,
|
||||
preprocessor=preprocessor,
|
||||
resume_from_checkpoint=resume_from_checkpoint,
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue