mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] More PTL example cleanup (#11585)
This commit is contained in:
parent
b02e61f672
commit
4ad8af9b0d
1 changed files with 6 additions and 7 deletions
|
@ -13,8 +13,7 @@ import os
|
|||
|
||||
# __import_tune_begin__
|
||||
import shutil
|
||||
from functools import partial
|
||||
from tempfile import mkdtemp
|
||||
import tempfile
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from ray import tune
|
||||
|
@ -178,7 +177,7 @@ def train_mnist_tune_checkpoint(config,
|
|||
ckpt = pl_load(
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
map_location=lambda storage, loc: storage)
|
||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
|
||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config, data_dir=data_dir)
|
||||
trainer.current_epoch = ckpt["epoch"]
|
||||
else:
|
||||
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
|
||||
|
@ -189,7 +188,7 @@ def train_mnist_tune_checkpoint(config,
|
|||
|
||||
# __tune_asha_begin__
|
||||
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
|
||||
config = {
|
||||
|
@ -211,7 +210,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
|||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
|
||||
tune.run(
|
||||
partial(
|
||||
tune.with_parameters(
|
||||
train_mnist_tune,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
|
@ -232,7 +231,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
|||
|
||||
# __tune_pbt_begin__
|
||||
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
|
||||
config = {
|
||||
|
@ -257,7 +256,7 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
|||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
|
||||
tune.run(
|
||||
partial(
|
||||
tune.with_parameters(
|
||||
train_mnist_tune_checkpoint,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
|
|
Loading…
Add table
Reference in a new issue