[tune] More PTL example cleanup (#11585)

This commit is contained in:
Richard Liaw 2020-10-26 12:26:14 -07:00 committed by GitHub
parent b02e61f672
commit 4ad8af9b0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -13,8 +13,7 @@ import os
# __import_tune_begin__ # __import_tune_begin__
import shutil import shutil
from functools import partial import tempfile
from tempfile import mkdtemp
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune from ray import tune
@ -178,7 +177,7 @@ def train_mnist_tune_checkpoint(config,
ckpt = pl_load( ckpt = pl_load(
os.path.join(checkpoint_dir, "checkpoint"), os.path.join(checkpoint_dir, "checkpoint"),
map_location=lambda storage, loc: storage) map_location=lambda storage, loc: storage)
model = LightningMNISTClassifier._load_model_state(ckpt, config=config) model = LightningMNISTClassifier._load_model_state(ckpt, config=config, data_dir=data_dir)
trainer.current_epoch = ckpt["epoch"] trainer.current_epoch = ckpt["epoch"]
else: else:
model = LightningMNISTClassifier(config=config, data_dir=data_dir) model = LightningMNISTClassifier(config=config, data_dir=data_dir)
@ -189,7 +188,7 @@ def train_mnist_tune_checkpoint(config,
# __tune_asha_begin__ # __tune_asha_begin__
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0): def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_") data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
LightningMNISTClassifier.download_data(data_dir) LightningMNISTClassifier.download_data(data_dir)
config = { config = {
@ -211,7 +210,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
metric_columns=["loss", "mean_accuracy", "training_iteration"]) metric_columns=["loss", "mean_accuracy", "training_iteration"])
tune.run( tune.run(
partial( tune.with_parameters(
train_mnist_tune, train_mnist_tune,
data_dir=data_dir, data_dir=data_dir,
num_epochs=num_epochs, num_epochs=num_epochs,
@ -232,7 +231,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
# __tune_pbt_begin__ # __tune_pbt_begin__
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0): def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_") data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
LightningMNISTClassifier.download_data(data_dir) LightningMNISTClassifier.download_data(data_dir)
config = { config = {
@ -257,7 +256,7 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
metric_columns=["loss", "mean_accuracy", "training_iteration"]) metric_columns=["loss", "mean_accuracy", "training_iteration"])
tune.run( tune.run(
partial( tune.with_parameters(
train_mnist_tune_checkpoint, train_mnist_tune_checkpoint,
data_dir=data_dir, data_dir=data_dir,
num_epochs=num_epochs, num_epochs=num_epochs,