mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
fix example (#10964)
This commit is contained in:
parent
a260e66016
commit
7dbd0ff824
1 changed files with 25 additions and 20 deletions
|
@ -8,7 +8,12 @@ in the documentation.
|
|||
# yapf: disable
|
||||
|
||||
# __torch_operator_start__
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
class MyTrainingOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
|
@ -44,10 +49,29 @@ class MyTrainingOperator(TrainingOperator):
|
|||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(models=model, optimizers=optimizer,
|
||||
criterion=criterion,
|
||||
scheduler=scheduler)
|
||||
schedulers=scheduler)
|
||||
self.register_data(train_loader=train_loader, validation_loader=val_loader)
|
||||
# __torch_operator_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler is used
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
# __torch_model_start__
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -144,13 +168,6 @@ def scheduler_creator(optimizer, config):
|
|||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __backwards_compat_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
|
@ -167,15 +184,3 @@ trainer = TorchTrainer(
|
|||
# __backwards_compat_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
scheduler_step_freq="epoch", # if scheduler is used
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
|
Loading…
Add table
Reference in a new issue