mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[release/ray-lightning] adjust the release test of ray lightning master
First of all, sorry i messed up with the previous pr when sync with the master (#27374). This PR is the duplicate of previous pr until we update the changes (change: adding the version check for the ray_lightning for the compatibility). Also, apology for the massive review requests on the previous PR.
This commit is contained in:
parent
20119c7022
commit
1c1cca2736
1 changed files with 21 additions and 5 deletions
|
@ -9,7 +9,17 @@ from torch.utils.data import DataLoader, random_split
|
|||
from torchvision import transforms
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from ray_lightning import RayPlugin
|
||||
from importlib_metadata import version
|
||||
from packaging.version import parse as v_parse
|
||||
|
||||
rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")
|
||||
|
||||
if rlt_use_master:
|
||||
# ray_lightning >= 0.3.0
|
||||
from ray_lightning import RayStrategy
|
||||
else:
|
||||
# ray_lightning < 0.3.0
|
||||
from ray_lightning import RayPlugin
|
||||
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
|
@ -47,10 +57,16 @@ def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
|
|||
train, val = random_split(dataset, [55000, 5000])
|
||||
|
||||
autoencoder = LitAutoEncoder()
|
||||
trainer = pl.Trainer(
|
||||
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
|
||||
max_steps=max_steps,
|
||||
)
|
||||
if rlt_use_master:
|
||||
trainer = pl.Trainer(
|
||||
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
|
||||
max_steps=max_steps,
|
||||
)
|
||||
else:
|
||||
trainer = pl.Trainer(
|
||||
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
|
||||
max_steps=max_steps,
|
||||
)
|
||||
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue