ray/doc/examples/doc_code/raysgd_torch_signatures.py
Maksim Smolin 3a134c7224
[RaySGD] Rename PyTorch API endpoints to start with Torch (#7425)
* Start renaming pytorch to torch

* Rename PyTorchTrainer to TorchTrainer

* Rename PyTorch runners to Torch runners

* Finish renaming API

* Rename to torch in tests

* Finish renaming docs + tests

* Run format + fix DeprecationWarning

* fix

* move tests up

* rename

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
2020-03-03 16:44:42 -08:00

121 lines
3.1 KiB
Python

# flake8: noqa
"""
This file holds code for the torch Trainer creator signatures.
It ignores yapf because yapf doesn't allow comments right after code blocks,
but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# __torch_model_start__
import torch.nn as nn
def model_creator(config):
"""Constructor function for the model(s) to be optimized.
You will also need to provide a custom training
function to specify the optimization procedure for multiple models.
Args:
config (dict): Configuration dictionary passed into ``TorchTrainer``.
Returns:
One or more torch.nn.Module objects.
"""
return nn.Linear(1, 1)
# __torch_model_end__
# __torch_optimizer_start__
import torch
def optimizer_creator(model, config):
"""Constructor of one or more Torch optimizers.
Args:
models: The return values from ``model_creator``. This can be one
or more torch nn modules.
config (dict): Configuration dictionary passed into ``TorchTrainer``.
Returns:
One or more Torch optimizer objects.
"""
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
# __torch_optimizer_end__
# __torch_data_start__
from ray.util.sgd.torch.examples.train_example import LinearDataset
def data_creator(config):
"""Constructs torch.utils.data.Dataset objects.
Note that even though two Dataset objects can be returned,
only one dataset will be used for training.
Args:
config: Configuration dictionary passed into ``TorchTrainer``
Returns:
One or Two Dataset objects. If only one Dataset object is provided,
``trainer.validate()`` will throw a ValueError.
"""
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
# __torch_data_end__
# __torch_loss_start__
import torch
def loss_creator(config):
"""Constructs the Torch Loss object.
Note that optionally, you can pass in a Torch Loss constructor directly
into the TorchTrainer (i.e., ``TorchTrainer(loss_creator=nn.BCELoss, ...)``).
Args:
config: Configuration dictionary passed into ``TorchTrainer``
Returns:
Torch Loss object.
"""
return torch.nn.BCELoss()
# __torch_loss_end__
# __torch_scheduler_start__
import torch
def scheduler_creator(optimizer, config):
"""Constructor of one or more Torch optimizer schedulers.
Args:
optimizers: The return values from ``optimizer_creator``.
This can be one or more torch optimizer objects.
config: Configuration dictionary passed into ``TorchTrainer``
Returns:
One or more Torch scheduler objects.
"""
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
# __torch_scheduler_end__
# __torch_ray_start__
import ray
ray.init()
# or ray.init(address="auto") to connect to a running cluster.
# __torch_ray_end__
# __torch_trainer_start__
from ray.util.sgd import TorchTrainer
trainer = TorchTrainer(
model_creator,
data_creator,
optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator,
config={"lr": 0.001})
# __torch_trainer_end__