mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[horovod] Horovod+Ray Pytorch Lightning Accelerator (#13458)
This commit is contained in:
parent
25e1b78eed
commit
01d74af89d
7 changed files with 547 additions and 1 deletions
|
@ -420,6 +420,7 @@ matrix:
|
|||
script:
|
||||
- ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=py37 python/ray/tune/...
|
||||
- ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only python/ray/util/xgboost/...
|
||||
- ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only python/ray/util/lightning_accelerators/...
|
||||
# There are no python 3.7 tests for RaySGD at the moment
|
||||
# - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=py37 python/ray/util/sgd/...
|
||||
# - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=py37 doc/...
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pl_bolts.datamodules import MNISTDataModule
|
||||
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
|
||||
import os
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
|
||||
|
@ -16,6 +16,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
|||
self.data_dir = data_dir or os.getcwd()
|
||||
self.lr = config["lr"]
|
||||
layer_1, layer_2 = config["layer_1"], config["layer_2"]
|
||||
self.batch_size = config["batch_size"]
|
||||
|
||||
# mnist images are (1, 28, 28) (channels, width, height)
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
|
||||
|
|
33
python/ray/util/lightning_accelerators/BUILD
Normal file
33
python/ray/util/lightning_accelerators/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/lightning_accelerators/tests directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_horovod_ray_accelerator",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_horovod_ray_accelerator.py"],
|
||||
tags = ["exclusive", "pytorch-lightning", "pytorch", "horovod"],
|
||||
deps = [":accelerator_lib"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/lightning_accelerators/examples directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "ptl_horovod_ray_example",
|
||||
size = "medium",
|
||||
srcs = ["examples/ptl_horovod_ray_example.py"],
|
||||
tags = ["exclusive", "example", "pytorch-lightning", "pytorch", "horovod"],
|
||||
deps = [":accelerator_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
# # This is a dummy test dependency that causes the above tests to be
|
||||
# # re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "accelerator_lib",
|
||||
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
|
||||
)
|
4
python/ray/util/lightning_accelerators/__init__.py
Normal file
4
python/ray/util/lightning_accelerators/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ray.util.lightning_accelerators.horovod_ray_accelerator import \
|
||||
HorovodRayAccelerator
|
||||
|
||||
__all__ = ["HorovodRayAccelerator"]
|
|
@ -0,0 +1,195 @@
|
|||
"""Example using Pytorch Lightning with a Horovod on Ray Accelerator."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
from ray.util.lightning_accelerators import HorovodRayAccelerator
|
||||
|
||||
|
||||
class MNISTClassifier(LightningMNISTClassifier):
|
||||
def prepare_data(self):
|
||||
self.dataset = MNIST(
|
||||
self.data_dir,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.ToTensor())
|
||||
|
||||
def train_dataloader(self):
|
||||
dataset = self.dataset
|
||||
train_length = len(dataset)
|
||||
dataset_train, _ = random_split(
|
||||
dataset, [train_length - 5000, 5000],
|
||||
generator=torch.Generator().manual_seed(0))
|
||||
loader = DataLoader(
|
||||
dataset_train,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
def val_dataloader(self):
|
||||
dataset = self.dataset
|
||||
train_length = len(dataset)
|
||||
_, dataset_val = random_split(
|
||||
dataset, [train_length - 5000, 5000],
|
||||
generator=torch.Generator().manual_seed(0))
|
||||
loader = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def train_mnist(config,
|
||||
data_dir=None,
|
||||
num_epochs=10,
|
||||
num_hosts=1,
|
||||
num_slots=4,
|
||||
use_gpu=False,
|
||||
callbacks=None):
|
||||
model = MNISTClassifier(config, data_dir)
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=num_epochs,
|
||||
gpus=int(use_gpu),
|
||||
callbacks=callbacks,
|
||||
accelerator=HorovodRayAccelerator(
|
||||
num_hosts=num_hosts, num_slots=num_slots, use_gpu=use_gpu))
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def tune_mnist(data_dir,
|
||||
num_samples=10,
|
||||
num_epochs=10,
|
||||
num_hosts=1,
|
||||
num_slots=4,
|
||||
use_gpu=False):
|
||||
config = {
|
||||
"layer_1": tune.choice([32, 64, 128]),
|
||||
"layer_2": tune.choice([64, 128, 256]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
}
|
||||
|
||||
# Add Tune callback.
|
||||
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
|
||||
callbacks = [TuneReportCallback(metrics, on="validation_end")]
|
||||
trainable = tune.with_parameters(
|
||||
train_mnist,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_hosts=num_hosts,
|
||||
num_slots=num_slots,
|
||||
use_gpu=use_gpu,
|
||||
callbacks=callbacks)
|
||||
analysis = tune.run(
|
||||
trainable,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
# Assume 1 cpu per slot.
|
||||
"extra_cpu": num_hosts * num_slots,
|
||||
# Assume 1 gpu per slot.
|
||||
"extra_gpu": num_hosts * num_slots * int(use_gpu)
|
||||
},
|
||||
name="tune_mnist")
|
||||
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-hosts",
|
||||
type=int,
|
||||
help="Number of machines to train on. If using Tune, then each "
|
||||
"trial will use this many machines.",
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--num-slots",
|
||||
type=int,
|
||||
help="Number of workers to "
|
||||
"place on each "
|
||||
"machine. If using "
|
||||
"Tune, then each trial will use this many slots per machine.",
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--use-gpu", action="store_true", help="Use GPU for "
|
||||
"training.")
|
||||
parser.add_argument(
|
||||
"--tune",
|
||||
action="store_true",
|
||||
help="Use Ray Tune "
|
||||
"for "
|
||||
"hyperparameter "
|
||||
"tuning.")
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number "
|
||||
"of "
|
||||
"samples to tune.")
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number "
|
||||
"of "
|
||||
"epochs "
|
||||
"to train for.")
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
num_epochs = 1 if args.smoke_test else args.num_epochs
|
||||
num_hosts = 1 if args.smoke_test else args.num_hosts
|
||||
num_slots = 1 if args.smoke_test else args.num_slots
|
||||
use_gpu = False if args.smoke_test else args.use_gpu
|
||||
num_samples = 1 if args.smoke_test else args.num_samples
|
||||
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=2)
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
|
||||
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||
|
||||
if args.tune:
|
||||
raise NotImplementedError("Using Tune + Pytorch Lightning with "
|
||||
"distributed training is currently not "
|
||||
"supported.")
|
||||
tune_mnist(data_dir, num_samples, num_epochs, num_hosts, num_slots,
|
||||
use_gpu)
|
||||
else:
|
||||
config = {"layer_1": 32, "layer_2": 64, "lr": 1e-1, "batch_size": 32}
|
||||
train_mnist(config, data_dir, num_epochs, num_hosts, num_slots,
|
||||
use_gpu)
|
|
@ -0,0 +1,121 @@
|
|||
import ray
|
||||
from pytorch_lightning.accelerators.horovod_accelerator import \
|
||||
HorovodAccelerator
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
from horovod.ray import RayExecutor
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
HOROVOD_AVAILABLE = False
|
||||
else:
|
||||
HOROVOD_AVAILABLE = True
|
||||
|
||||
|
||||
def get_executable_cls():
|
||||
# Only used for testing purposes, currently.
|
||||
# We need to override this in tests to ensure test path is set correctly.
|
||||
return None
|
||||
|
||||
|
||||
class HorovodRayAccelerator(HorovodAccelerator):
|
||||
"""Pytorch Lightning Accelerator for Horovod training on a Ray cluster.
|
||||
|
||||
This accelerator is used to manage distributed training on a Ray cluster
|
||||
via the Horovod training framework. Internally, the specified number of
|
||||
Ray actors are launched in the cluster and are configured as part of the
|
||||
Horovod ring. The Pytorch Lightning trainer is instantiated on the
|
||||
driver and sent to each of these training workers where training is
|
||||
executed. The distributed training protocol is handled by Horovod.
|
||||
|
||||
Each training worker is configured to reserve 1 CPU and if 1 GPU if
|
||||
``use_gpu`` is set to ``True``.
|
||||
|
||||
If using this accelerator, you should run your code like a normal Python
|
||||
script: ``python train.py``, and not with ``horovodrun``.
|
||||
|
||||
Args:
|
||||
num_hosts (int): The number of nodes/machines to execute the job on.
|
||||
num_slots (int): Number of workers to be placed on each machine.
|
||||
use_gpu (bool): Whether to use GPU for allocation. For GPU to be
|
||||
used, you must also set the ``gpus`` arg in your Pytorch Lightning
|
||||
Trainer to a value > 0.
|
||||
|
||||
Example:
|
||||
|
||||
.. code_block:: python
|
||||
|
||||
import pytorch_lightning as ptl
|
||||
from ray.util.lightning_accelerators import HorovodRayAccelerator
|
||||
|
||||
ptl_model = MNISTClassifier(...)
|
||||
# 2 nodes, 4 workers per node, each using 1 CPU and 1 GPU.
|
||||
accelerator = HorovodRayAccelerator(num_hosts=2, num_slots=4,
|
||||
use_gpu=True).
|
||||
|
||||
# If using GPUs, set the ``gpus`` arg to a value > 0.
|
||||
# The actual number of GPUs is determined by ``num_slots``.
|
||||
trainer = pl.Trainer(..., gpus=1, accelerator=accelerator).
|
||||
trainer.fit(ptl_model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
num_hosts=1,
|
||||
num_slots=1,
|
||||
use_gpu=False,
|
||||
**kwargs):
|
||||
super().__init__(*args, trainer=None, **kwargs)
|
||||
self.nickname = "horovod_ray"
|
||||
self.num_hosts = num_hosts
|
||||
self.num_slots = num_slots
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
def setup(self, model):
|
||||
self.trainer.use_horovod = True
|
||||
settings = RayExecutor.create_settings(timeout_s=30)
|
||||
self.executor = RayExecutor(
|
||||
settings,
|
||||
num_hosts=self.num_hosts,
|
||||
num_slots=self.num_slots,
|
||||
use_gpu=self.use_gpu)
|
||||
self.trainer.model = model
|
||||
self.executor.start(executable_cls=get_executable_cls())
|
||||
|
||||
def train(self):
|
||||
trainer = self.trainer
|
||||
trainer_ref = ray.put(self.trainer)
|
||||
self.trainer = None
|
||||
results = self.executor.run(self.train_remote, args=[trainer_ref])
|
||||
results, state_dict, best_path = results[0]
|
||||
|
||||
self.trainer = trainer
|
||||
self.trainer.model.load_state_dict(state_dict)
|
||||
if self.trainer.checkpoint_callback:
|
||||
self.trainer.checkpoint_callback.best_model_path = best_path
|
||||
|
||||
return results
|
||||
|
||||
def train_remote(self, trainer_ref):
|
||||
self.trainer = ray.get(trainer_ref)
|
||||
hvd.init()
|
||||
if self.trainer.on_gpu:
|
||||
# Horovod assigns one local GPU per process.
|
||||
self.trainer.root_gpu = hvd.local_rank()
|
||||
|
||||
# TODO: Make changes in PTL to clean this up.
|
||||
super(HorovodRayAccelerator, self).setup(self.trainer.model)
|
||||
results = super(HorovodRayAccelerator, self).train()
|
||||
if hvd.rank() != 0:
|
||||
# Only want results from the first worker.
|
||||
return None
|
||||
|
||||
best_model_path = None
|
||||
if self.trainer.checkpoint_callback is not None:
|
||||
best_model_path = self.trainer.checkpoint_callback.best_model_path
|
||||
|
||||
model = self.trainer.model
|
||||
return results, model.state_dict(), best_model_path
|
||||
|
||||
def teardown(self):
|
||||
self.executor.shutdown()
|
|
@ -0,0 +1,191 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import ray
|
||||
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
|
||||
from ray.util.sgd.tests.test_ptl import PTL_Module
|
||||
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
|
||||
from ray.util.lightning_accelerators import HorovodRayAccelerator
|
||||
import pytorch_lightning as pl
|
||||
|
||||
try:
|
||||
import horovod # noqa: F401
|
||||
from horovod.common.util import nccl_built
|
||||
except ImportError:
|
||||
HOROVOD_AVAILABLE = False
|
||||
else:
|
||||
HOROVOD_AVAILABLE = True
|
||||
|
||||
|
||||
def _nccl_available():
|
||||
if not HOROVOD_AVAILABLE:
|
||||
return False
|
||||
try:
|
||||
return nccl_built()
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_gpus():
|
||||
address_info = ray.init(num_cpus=2, num_gpus=2)
|
||||
yield address_info
|
||||
ray.shutdown()
|
||||
# This env var is set by Pytorch Lightning.
|
||||
# Make sure to reset it after each test.
|
||||
# TODO: Upstream to PTL to not set this env var if using Ray.
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed():
|
||||
pl.seed_everything(0)
|
||||
|
||||
|
||||
def get_model(lr=1e-2, hidden_size=1, data_size=10, val_size=10, batch_size=2):
|
||||
config = {
|
||||
"lr": lr,
|
||||
"hidden_size": hidden_size,
|
||||
"data_size": data_size,
|
||||
"val_size": val_size,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
return PTL_Module(config)
|
||||
|
||||
|
||||
def get_trainer(dir,
|
||||
num_slots=2,
|
||||
use_gpu=False,
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
progress_bar_refresh_rate=0):
|
||||
accelerator = HorovodRayAccelerator(num_slots=num_slots, use_gpu=use_gpu)
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=dir,
|
||||
gpus=1 if use_gpu else 0,
|
||||
max_epochs=max_epochs,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=limit_val_batches,
|
||||
progress_bar_refresh_rate=progress_bar_refresh_rate,
|
||||
checkpoint_callback=True,
|
||||
accelerator=accelerator)
|
||||
return trainer
|
||||
|
||||
|
||||
def train_test(trainer, model):
|
||||
initial_values = torch.tensor(
|
||||
[torch.sum(torch.abs(x)) for x in model.parameters()])
|
||||
result = trainer.fit(model)
|
||||
post_train_values = torch.tensor(
|
||||
[torch.sum(torch.abs(x)) for x in model.parameters()])
|
||||
assert result == 1, "trainer failed"
|
||||
# Check that the model is actually changed post-training.
|
||||
assert torch.norm(initial_values - post_train_values) > 0.1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_train(tmpdir, ray_start_2_cpus, seed, num_slots):
|
||||
model = get_model()
|
||||
|
||||
trainer = get_trainer(tmpdir, num_slots=num_slots)
|
||||
train_test(trainer, model)
|
||||
|
||||
|
||||
def load_test(trainer, model):
|
||||
trainer.fit(model)
|
||||
trained_model = PTL_Module.load_from_checkpoint(
|
||||
trainer.checkpoint_callback.best_model_path, config=model.config)
|
||||
assert trained_model is not None, "loading model failed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_load(tmpdir, ray_start_2_cpus, seed, num_slots):
|
||||
model = get_model()
|
||||
trainer = get_trainer(tmpdir, num_slots=num_slots)
|
||||
load_test(trainer, model)
|
||||
|
||||
|
||||
def predict_test(trainer, model, dm):
|
||||
trainer.fit(model, dm)
|
||||
test_loader = dm.test_dataloader()
|
||||
acc = pl.metrics.Accuracy()
|
||||
for batch in test_loader:
|
||||
x, y = batch
|
||||
with torch.no_grad():
|
||||
y_hat = model(x)
|
||||
y_hat = y_hat.cpu()
|
||||
acc.update(y_hat, y)
|
||||
average_acc = acc.compute()
|
||||
assert average_acc >= 0.5, f"This model is expected to get > {0.5} in " \
|
||||
f"test set (it got {average_acc})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_predict(tmpdir, ray_start_2_cpus, seed, num_slots):
|
||||
config = {
|
||||
"layer_1": 32,
|
||||
"layer_2": 32,
|
||||
"lr": 1e-2,
|
||||
"batch_size": 32,
|
||||
}
|
||||
model = LightningMNISTClassifier(config, tmpdir)
|
||||
dm = MNISTDataModule(
|
||||
data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
|
||||
trainer = get_trainer(
|
||||
tmpdir, limit_train_batches=10, max_epochs=1, num_slots=num_slots)
|
||||
predict_test(trainer, model, dm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _nccl_available(), reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_train_gpu(tmpdir, ray_start_2_gpus, seed, num_slots):
|
||||
model = get_model()
|
||||
trainer = get_trainer(tmpdir, num_slots=num_slots, use_gpu=True)
|
||||
train_test(trainer, model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _nccl_available(), reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_load_gpu(tmpdir, ray_start_2_gpus, seed, num_slots):
|
||||
model = get_model()
|
||||
trainer = get_trainer(tmpdir, num_slots=num_slots, use_gpu=True)
|
||||
load_test(trainer, model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _nccl_available(), reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.parametrize("num_slots", [1, 2])
|
||||
def test_predict_gpu(tmpdir, ray_start_2_gpus, seed, num_slots):
|
||||
config = {
|
||||
"layer_1": 32,
|
||||
"layer_2": 32,
|
||||
"lr": 1e-2,
|
||||
"batch_size": 32,
|
||||
}
|
||||
model = LightningMNISTClassifier(config, tmpdir)
|
||||
dm = MNISTDataModule(
|
||||
data_dir=tmpdir, num_workers=1, batch_size=config["batch_size"])
|
||||
trainer = get_trainer(
|
||||
tmpdir,
|
||||
limit_train_batches=10,
|
||||
max_epochs=1,
|
||||
num_slots=num_slots,
|
||||
use_gpu=True)
|
||||
predict_test(trainer, model, dm)
|
Loading…
Add table
Reference in a new issue