ray/release/ml_user_tests/ray-lightning/simple_example.py
Jimmy Yao 1c1cca2736
[release/ray-lightning] adjust the release test of ray lightning master
First of all, sorry i messed up with the previous pr when sync with the master (#27374). This PR is the duplicate of previous pr until we update the changes (change: adding the version check for the ray_lightning for the compatibility). Also, apology for the massive review requests on the previous PR.
2022-08-03 16:01:32 +01:00

99 lines
2.9 KiB
Python

# This file is duplicated in ray/tests/ray_lightning
import argparse
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from importlib_metadata import version
from packaging.version import parse as v_parse
rlt_use_master = v_parse(version("ray_lightning")) >= v_parse("0.3.0")
if rlt_use_master:
# ray_lightning >= 0.3.0
from ray_lightning import RayStrategy
else:
# ray_lightning < 0.3.0
from ray_lightning import RayPlugin
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
if rlt_use_master:
trainer = pl.Trainer(
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
max_steps=max_steps,
)
else:
trainer = pl.Trainer(
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Ray Lightning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="Number of workers to use for training.",
)
parser.add_argument(
"--max-steps",
type=int,
default=10,
help="Maximum number of steps to run for training.",
)
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Whether to enable GPU training.",
)
args = parser.parse_args()
main(num_workers=args.num_workers, max_steps=args.max_steps, use_gpu=args.use_gpu)