mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] Update Lightning examples to support PTL 1.5 (#20562)
To helps resolve the issues users are facing with running Lightning examples with Ray Tune PyTorchLightning/pytorch-lightning#10407 Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
e8e35169c6
commit
8515fdd6db
6 changed files with 13 additions and 12 deletions
|
@ -3,6 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torchmetrics import Accuracy
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
|
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
|
||||||
import os
|
import os
|
||||||
|
@ -24,7 +25,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||||
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
|
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
|
||||||
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
|
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
|
||||||
self.layer_3 = torch.nn.Linear(layer_2, 10)
|
self.layer_3 = torch.nn.Linear(layer_2, 10)
|
||||||
self.accuracy = pl.metrics.Accuracy()
|
self.accuracy = Accuracy()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
batch_size, channels, width, height = x.size()
|
batch_size, channels, width, height = x.size()
|
||||||
|
@ -75,7 +76,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0):
|
||||||
max_epochs=num_epochs,
|
max_epochs=num_epochs,
|
||||||
# If fractional GPUs passed in, convert to int.
|
# If fractional GPUs passed in, convert to int.
|
||||||
gpus=math.ceil(num_gpus),
|
gpus=math.ceil(num_gpus),
|
||||||
progress_bar_refresh_rate=0,
|
enable_progress_bar=False,
|
||||||
callbacks=[TuneReportCallback(metrics, on="validation_end")],
|
callbacks=[TuneReportCallback(metrics, on="validation_end")],
|
||||||
)
|
)
|
||||||
trainer.fit(model, dm)
|
trainer.fit(model, dm)
|
||||||
|
|
|
@ -121,7 +121,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||||
|
|
||||||
def train_mnist(config):
|
def train_mnist(config):
|
||||||
model = LightningMNISTClassifier(config)
|
model = LightningMNISTClassifier(config)
|
||||||
trainer = pl.Trainer(max_epochs=10, show_progress_bar=False)
|
trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False)
|
||||||
|
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
# __lightning_end__
|
# __lightning_end__
|
||||||
|
@ -148,7 +148,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
|
||||||
gpus=math.ceil(num_gpus),
|
gpus=math.ceil(num_gpus),
|
||||||
logger=TensorBoardLogger(
|
logger=TensorBoardLogger(
|
||||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||||
progress_bar_refresh_rate=0,
|
enable_progress_bar=False,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
TuneReportCallback(
|
TuneReportCallback(
|
||||||
{
|
{
|
||||||
|
@ -174,7 +174,7 @@ def train_mnist_tune_checkpoint(config,
|
||||||
"gpus": math.ceil(num_gpus),
|
"gpus": math.ceil(num_gpus),
|
||||||
"logger": TensorBoardLogger(
|
"logger": TensorBoardLogger(
|
||||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||||
"progress_bar_refresh_rate": 0,
|
"enable_progress_bar": False,
|
||||||
"callbacks": [
|
"callbacks": [
|
||||||
TuneReportCheckpointCallback(
|
TuneReportCheckpointCallback(
|
||||||
metrics={
|
metrics={
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torchvision import transforms
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
from ray.util.ray_lightning import RayPlugin
|
from ray.util.ray_lightning import RayPlugin
|
||||||
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_ddp_resources
|
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_resources
|
||||||
|
|
||||||
num_cpus_per_actor = 1
|
num_cpus_per_actor = 1
|
||||||
num_workers = 1
|
num_workers = 1
|
||||||
|
@ -70,7 +70,7 @@ def main():
|
||||||
num_samples=1,
|
num_samples=1,
|
||||||
metric="loss",
|
metric="loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
resources_per_trial=get_tune_ddp_resources(
|
resources_per_trial=get_tune_resources(
|
||||||
num_workers=num_workers, cpus_per_worker=num_cpus_per_actor
|
num_workers=num_workers, cpus_per_worker=num_cpus_per_actor
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,13 +4,13 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TuneReportCallback = None
|
TuneReportCallback = None
|
||||||
TuneReportCheckpointCallback = None
|
TuneReportCheckpointCallback = None
|
||||||
get_tune_ddp_resources = None
|
get_tune_resources = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ray_lightning.tune import (
|
from ray_lightning.tune import (
|
||||||
TuneReportCallback,
|
TuneReportCallback,
|
||||||
TuneReportCheckpointCallback,
|
TuneReportCheckpointCallback,
|
||||||
get_tune_ddp_resources,
|
get_tune_resources,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -22,5 +22,5 @@ except ImportError:
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TuneReportCallback",
|
"TuneReportCallback",
|
||||||
"TuneReportCheckpointCallback",
|
"TuneReportCheckpointCallback",
|
||||||
"get_tune_ddp_resources",
|
"get_tune_resources",
|
||||||
]
|
]
|
||||||
|
|
|
@ -29,7 +29,7 @@ nevergrad==0.4.3.post7
|
||||||
optuna==2.9.1
|
optuna==2.9.1
|
||||||
pytest-remotedata==0.3.2
|
pytest-remotedata==0.3.2
|
||||||
lightning-bolts==0.4.0
|
lightning-bolts==0.4.0
|
||||||
pytorch-lightning==1.4.9
|
pytorch-lightning==1.5.10
|
||||||
shortuuid==1.0.1
|
shortuuid==1.0.1
|
||||||
scikit-learn==0.24.2
|
scikit-learn==0.24.2
|
||||||
scikit-optimize==0.8.1
|
scikit-optimize==0.8.1
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# Because they depend on Ray, we can't pin the subdependencies.
|
# Because they depend on Ray, we can't pin the subdependencies.
|
||||||
# So we separate its own requirements file.
|
# So we separate its own requirements file.
|
||||||
|
|
||||||
ray_lightning==0.1.1
|
ray_lightning==0.2.0
|
||||||
tune-sklearn==0.4.1
|
tune-sklearn==0.4.1
|
||||||
xgboost_ray==0.1.4
|
xgboost_ray==0.1.4
|
||||||
lightgbm_ray==0.0.2
|
lightgbm_ray==0.0.2
|
||||||
|
|
Loading…
Add table
Reference in a new issue