fix example (#10964)

This commit is contained in:
Amog Kamsetty 2020-09-23 10:33:19 -07:00 committed by GitHub
parent a260e66016
commit 7dbd0ff824
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -8,7 +8,12 @@ in the documentation.
# yapf: disable
# __torch_operator_start__
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd.torch.examples.train_example import LinearDataset
class MyTrainingOperator(TrainingOperator):
def setup(self, config):
@ -44,10 +49,29 @@ class MyTrainingOperator(TrainingOperator):
self.model, self.optimizer, self.criterion, self.scheduler = \
self.register(models=model, optimizers=optimizer,
criterion=criterion,
scheduler=scheduler)
schedulers=scheduler)
self.register_data(train_loader=train_loader, validation_loader=val_loader)
# __torch_operator_end__
# __torch_ray_start__
import ray
ray.init()
# or ray.init(address="auto") to connect to a running cluster.
# __torch_ray_end__
# __torch_trainer_start__
from ray.util.sgd import TorchTrainer
trainer = TorchTrainer(
training_operator_cls=MyTrainingOperator,
scheduler_step_freq="epoch", # if scheduler is used
config={"lr": 0.001, "batch_size": 64})
# __torch_trainer_end__
trainer.shutdown()
# __torch_model_start__
import torch.nn as nn
@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config):
# __torch_scheduler_end__
# __torch_ray_start__
import ray
ray.init()
# or ray.init(address="auto") to connect to a running cluster.
# __torch_ray_end__
# __backwards_compat_start__
from ray.util.sgd import TorchTrainer
@ -167,15 +184,3 @@ trainer = TorchTrainer(
# __backwards_compat_end__
trainer.shutdown()
# __torch_trainer_start__
from ray.util.sgd import TorchTrainer
trainer = TorchTrainer(
training_operator_cls=MyTrainingOperator,
scheduler_step_freq="epoch", # if scheduler is used
config={"lr": 0.001, "batch_size": 64})
# __torch_trainer_end__
trainer.shutdown()