[release/ray-lightning] adjust the release test of ray lightning master

First of all, sorry i messed up with the previous pr when sync with the master (#27374). This PR is the duplicate of previous pr until we update the changes (change: adding the version check for the ray_lightning for the compatibility). Also, apology for the massive review requests on the previous PR.
This commit is contained in:
Jimmy Yao 2022-08-03 08:01:32 -07:00 committed by GitHub
parent 20119c7022
commit 1c1cca2736
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -9,7 +9,17 @@ from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from ray_lightning import RayPlugin
from importlib_metadata import version
from packaging.version import parse as v_parse
rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")
if rlt_use_master:
# ray_lightning >= 0.3.0
from ray_lightning import RayStrategy
else:
# ray_lightning < 0.3.0
from ray_lightning import RayPlugin
class LitAutoEncoder(pl.LightningModule):
@ -47,10 +57,16 @@ def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer(
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
max_steps=max_steps,
)
if rlt_use_master:
trainer = pl.Trainer(
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
max_steps=max_steps,
)
else:
trainer = pl.Trainer(
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))