ray/doc/source/ray-overview/doc_test/ray_train.py
Max Pumperla d6bff736f3
[docs] test ray.io snippets (#22822)
Tests all snippets we have on ray.io. There were some minor issues, which I'll fix upstream.

Signed-off-by: Max Pumperla <max.pumperla@googlemail.com>
2022-03-08 15:50:57 +00:00

37 lines
938 B
Python

import ray.train as train
from ray.train import Trainer
import torch
def train_func():
# Setup model.
model = torch.nn.Linear(1, 1)
model = train.torch.prepare_model(model)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
# Setup data.
input = torch.randn(1000, 1)
labels = input * 2
dataset = torch.utils.data.TensorDataset(input, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
dataloader = train.torch.prepare_data_loader(dataloader)
# Train.
for _ in range(5):
for X, y in dataloader:
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model.state_dict()
trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
results = trainer.run(train_func)
trainer.shutdown()
print(results)