mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] More PTL example cleanup (#11585)
This commit is contained in:
parent
b02e61f672
commit
4ad8af9b0d
1 changed files with 6 additions and 7 deletions
|
@ -13,8 +13,7 @@ import os
|
||||||
|
|
||||||
# __import_tune_begin__
|
# __import_tune_begin__
|
||||||
import shutil
|
import shutil
|
||||||
from functools import partial
|
import tempfile
|
||||||
from tempfile import mkdtemp
|
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -178,7 +177,7 @@ def train_mnist_tune_checkpoint(config,
|
||||||
ckpt = pl_load(
|
ckpt = pl_load(
|
||||||
os.path.join(checkpoint_dir, "checkpoint"),
|
os.path.join(checkpoint_dir, "checkpoint"),
|
||||||
map_location=lambda storage, loc: storage)
|
map_location=lambda storage, loc: storage)
|
||||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
|
model = LightningMNISTClassifier._load_model_state(ckpt, config=config, data_dir=data_dir)
|
||||||
trainer.current_epoch = ckpt["epoch"]
|
trainer.current_epoch = ckpt["epoch"]
|
||||||
else:
|
else:
|
||||||
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
|
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
|
||||||
|
@ -189,7 +188,7 @@ def train_mnist_tune_checkpoint(config,
|
||||||
|
|
||||||
# __tune_asha_begin__
|
# __tune_asha_begin__
|
||||||
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||||
data_dir = mkdtemp(prefix="mnist_data_")
|
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||||
LightningMNISTClassifier.download_data(data_dir)
|
LightningMNISTClassifier.download_data(data_dir)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
@ -211,7 +210,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||||
|
|
||||||
tune.run(
|
tune.run(
|
||||||
partial(
|
tune.with_parameters(
|
||||||
train_mnist_tune,
|
train_mnist_tune,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
|
@ -232,7 +231,7 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||||
|
|
||||||
# __tune_pbt_begin__
|
# __tune_pbt_begin__
|
||||||
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||||
data_dir = mkdtemp(prefix="mnist_data_")
|
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||||
LightningMNISTClassifier.download_data(data_dir)
|
LightningMNISTClassifier.download_data(data_dir)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
@ -257,7 +256,7 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||||
|
|
||||||
tune.run(
|
tune.run(
|
||||||
partial(
|
tune.with_parameters(
|
||||||
train_mnist_tune_checkpoint,
|
train_mnist_tune_checkpoint,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
|
|
Loading…
Add table
Reference in a new issue