mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] Update Lightning examples to support PTL 1.5 (#20562)
To helps resolve the issues users are facing with running Lightning examples with Ray Tune PyTorchLightning/pytorch-lightning#10407 Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
e8e35169c6
commit
8515fdd6db
6 changed files with 13 additions and 12 deletions
|
@ -3,6 +3,7 @@ import math
|
|||
import torch
|
||||
from filelock import FileLock
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics import Accuracy
|
||||
import pytorch_lightning as pl
|
||||
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
|
||||
import os
|
||||
|
@ -24,7 +25,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
|||
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
|
||||
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
|
||||
self.layer_3 = torch.nn.Linear(layer_2, 10)
|
||||
self.accuracy = pl.metrics.Accuracy()
|
||||
self.accuracy = Accuracy()
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
|
@ -75,7 +76,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0):
|
|||
max_epochs=num_epochs,
|
||||
# If fractional GPUs passed in, convert to int.
|
||||
gpus=math.ceil(num_gpus),
|
||||
progress_bar_refresh_rate=0,
|
||||
enable_progress_bar=False,
|
||||
callbacks=[TuneReportCallback(metrics, on="validation_end")],
|
||||
)
|
||||
trainer.fit(model, dm)
|
||||
|
|
|
@ -121,7 +121,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
|||
|
||||
def train_mnist(config):
|
||||
model = LightningMNISTClassifier(config)
|
||||
trainer = pl.Trainer(max_epochs=10, show_progress_bar=False)
|
||||
trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False)
|
||||
|
||||
trainer.fit(model)
|
||||
# __lightning_end__
|
||||
|
@ -148,7 +148,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
|
|||
gpus=math.ceil(num_gpus),
|
||||
logger=TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
progress_bar_refresh_rate=0,
|
||||
enable_progress_bar=False,
|
||||
callbacks=[
|
||||
TuneReportCallback(
|
||||
{
|
||||
|
@ -174,7 +174,7 @@ def train_mnist_tune_checkpoint(config,
|
|||
"gpus": math.ceil(num_gpus),
|
||||
"logger": TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
"progress_bar_refresh_rate": 0,
|
||||
"enable_progress_bar": False,
|
||||
"callbacks": [
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={
|
||||
|
|
|
@ -8,7 +8,7 @@ from torchvision import transforms
|
|||
import pytorch_lightning as pl
|
||||
|
||||
from ray.util.ray_lightning import RayPlugin
|
||||
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_ddp_resources
|
||||
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_resources
|
||||
|
||||
num_cpus_per_actor = 1
|
||||
num_workers = 1
|
||||
|
@ -70,7 +70,7 @@ def main():
|
|||
num_samples=1,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
resources_per_trial=get_tune_ddp_resources(
|
||||
resources_per_trial=get_tune_resources(
|
||||
num_workers=num_workers, cpus_per_worker=num_cpus_per_actor
|
||||
),
|
||||
)
|
||||
|
|
|
@ -4,13 +4,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
TuneReportCallback = None
|
||||
TuneReportCheckpointCallback = None
|
||||
get_tune_ddp_resources = None
|
||||
get_tune_resources = None
|
||||
|
||||
try:
|
||||
from ray_lightning.tune import (
|
||||
TuneReportCallback,
|
||||
TuneReportCheckpointCallback,
|
||||
get_tune_ddp_resources,
|
||||
get_tune_resources,
|
||||
)
|
||||
except ImportError:
|
||||
logger.info(
|
||||
|
@ -22,5 +22,5 @@ except ImportError:
|
|||
__all__ = [
|
||||
"TuneReportCallback",
|
||||
"TuneReportCheckpointCallback",
|
||||
"get_tune_ddp_resources",
|
||||
"get_tune_resources",
|
||||
]
|
||||
|
|
|
@ -29,7 +29,7 @@ nevergrad==0.4.3.post7
|
|||
optuna==2.9.1
|
||||
pytest-remotedata==0.3.2
|
||||
lightning-bolts==0.4.0
|
||||
pytorch-lightning==1.4.9
|
||||
pytorch-lightning==1.5.10
|
||||
shortuuid==1.0.1
|
||||
scikit-learn==0.24.2
|
||||
scikit-optimize==0.8.1
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Because they depend on Ray, we can't pin the subdependencies.
|
||||
# So we separate its own requirements file.
|
||||
|
||||
ray_lightning==0.1.1
|
||||
ray_lightning==0.2.0
|
||||
tune-sklearn==0.4.1
|
||||
xgboost_ray==0.1.4
|
||||
lightgbm_ray==0.0.2
|
||||
|
|
Loading…
Add table
Reference in a new issue