diff --git a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py index 042f4cd10..1bde9a541 100644 --- a/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py +++ b/python/ray/util/sgd/torch/examples/raysgd_torch_signatures.py @@ -8,7 +8,12 @@ in the documentation. # yapf: disable # __torch_operator_start__ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + from ray.util.sgd.torch import TrainingOperator +from ray.util.sgd.torch.examples.train_example import LinearDataset class MyTrainingOperator(TrainingOperator): def setup(self, config): @@ -44,10 +49,29 @@ class MyTrainingOperator(TrainingOperator): self.model, self.optimizer, self.criterion, self.scheduler = \ self.register(models=model, optimizers=optimizer, criterion=criterion, - scheduler=scheduler) + schedulers=scheduler) self.register_data(train_loader=train_loader, validation_loader=val_loader) # __torch_operator_end__ +# __torch_ray_start__ +import ray + +ray.init() +# or ray.init(address="auto") to connect to a running cluster. +# __torch_ray_end__ + +# __torch_trainer_start__ +from ray.util.sgd import TorchTrainer + +trainer = TorchTrainer( + training_operator_cls=MyTrainingOperator, + scheduler_step_freq="epoch", # if scheduler is used + config={"lr": 0.001, "batch_size": 64}) + +# __torch_trainer_end__ + +trainer.shutdown() + # __torch_model_start__ import torch.nn as nn @@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config): # __torch_scheduler_end__ -# __torch_ray_start__ -import ray - -ray.init() -# or ray.init(address="auto") to connect to a running cluster. -# __torch_ray_end__ - # __backwards_compat_start__ from ray.util.sgd import TorchTrainer @@ -167,15 +184,3 @@ trainer = TorchTrainer( # __backwards_compat_end__ trainer.shutdown() - -# __torch_trainer_start__ -from ray.util.sgd import TorchTrainer - -trainer = TorchTrainer( - training_operator_cls=MyTrainingOperator, - scheduler_step_freq="epoch", # if scheduler is used - config={"lr": 0.001, "batch_size": 64}) - -# __torch_trainer_end__ - -trainer.shutdown()