mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[sgd] v2 documentation draft (#17253)
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: Matthew Deng <matthew.j.deng@gmail.com> Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
e812691909
commit
ecc7cf4c5e
12 changed files with 528 additions and 22 deletions
|
@ -68,4 +68,5 @@ You can start a ``TorchTrainer`` with the following:
|
||||||
trainer1.shutdown()
|
trainer1.shutdown()
|
||||||
print("success!")
|
print("success!")
|
||||||
|
|
||||||
.. tip:: Get in touch with us if you're using or considering using `RaySGD <https://forms.gle/26EMwdahdgm7Lscy9>`_!
|
|
||||||
|
.. tip:: We are rolling out a lighter-weight version of RaySGD in a future version of Ray. See the documentation :ref:`here <sgd-v2-docs>`.
|
||||||
|
|
16
doc/source/raysgd/v2/api.rst
Normal file
16
doc/source/raysgd/v2/api.rst
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _sgd-api:
|
||||||
|
|
||||||
|
RaySGD API
|
||||||
|
----------
|
||||||
|
|
||||||
|
|
||||||
|
.. autoclass:: ray.util.sgd.v2.Trainer
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: ray.util.sgd.v2.BackendConfig
|
||||||
|
|
||||||
|
.. autoclass:: ray.util.sgd.v2.TorchConfig
|
||||||
|
|
||||||
|
.. autoclass:: ray.util.sgd.v2.SGDCallback
|
47
doc/source/raysgd/v2/architecture.rst
Normal file
47
doc/source/raysgd/v2/architecture.rst
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _sgd-arch:
|
||||||
|
|
||||||
|
Architecture
|
||||||
|
============
|
||||||
|
|
||||||
|
A diagram of the RaySGD architecture is provided below.
|
||||||
|
|
||||||
|
.. image:: sgd-arch.svg
|
||||||
|
:width: 70%
|
||||||
|
:align: center
|
||||||
|
|
||||||
|
|
||||||
|
Trainer
|
||||||
|
-------
|
||||||
|
|
||||||
|
The Trainer is the main class that is exposed in the RaySGD API that users will interact with.
|
||||||
|
|
||||||
|
|
||||||
|
* The user will pass in a *function* which defines the training logic.
|
||||||
|
* The Trainer will create an :ref:`Executor <sgd-arch-executor>` to run the distributed training.
|
||||||
|
* The Trainer will handle callbacks based on the results from the BackendExecutor.
|
||||||
|
|
||||||
|
.. _sgd-arch-executor:
|
||||||
|
|
||||||
|
Executor
|
||||||
|
--------
|
||||||
|
|
||||||
|
The executor is an interface which handles execution of distributed training.
|
||||||
|
|
||||||
|
* The executor will handle the creation of an actor group and will be initialized in conjunction with a backend.
|
||||||
|
* Worker resources, number of workers, and placement strategy will be passed to the Worker Group.
|
||||||
|
|
||||||
|
|
||||||
|
Backend
|
||||||
|
-------
|
||||||
|
|
||||||
|
A backend is used in conjunction with the executor to initialize and manage framework-specific communication protocols.
|
||||||
|
Each communication library (Torch, Horovod, TensorFlow, etc.) will have a separate backend and will take a specific configuration value.
|
||||||
|
|
||||||
|
WorkerGroup
|
||||||
|
-----------
|
||||||
|
|
||||||
|
The WorkerGroup is a generic utility class for managing a group of Ray Actors.
|
||||||
|
|
||||||
|
* This is similar in concept to Fiber's `Ring <https://uber.github.io/fiber/experimental/ring/>`_.
|
28
doc/source/raysgd/v2/examples.rst
Normal file
28
doc/source/raysgd/v2/examples.rst
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _sgd-v2-examples:
|
||||||
|
|
||||||
|
RaySGD Examples
|
||||||
|
===============
|
||||||
|
|
||||||
|
Below are examples for using RaySGD with a variety of models, frameworks, and use cases.
|
||||||
|
|
||||||
|
|
||||||
|
* Simple example for Pytorch.
|
||||||
|
* End-to-end example for Pytorch.
|
||||||
|
* End-to-end example for HuggingFace Transformers (Pytorch).
|
||||||
|
* Simple example for Tensorflow
|
||||||
|
* End-to-end example for Tensorflow
|
||||||
|
* Simple example for Horovod (with Tensorflow)
|
||||||
|
* End-to-end example for Horovod (with Tensorflow)
|
||||||
|
|
||||||
|
Features
|
||||||
|
--------
|
||||||
|
|
||||||
|
* Example for using a custom callback
|
||||||
|
* End-to-end example for running on an elastic cluster (elastic training)
|
||||||
|
|
||||||
|
Models
|
||||||
|
------
|
||||||
|
|
||||||
|
* Example training on Vision model.
|
89
doc/source/raysgd/v2/raysgd.rst
Normal file
89
doc/source/raysgd/v2/raysgd.rst
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _sgd-v2-docs:
|
||||||
|
|
||||||
|
RaySGD: Distributed Training Wrappers
|
||||||
|
=====================================
|
||||||
|
|
||||||
|
.. _`issue on GitHub`: https://github.com/ray-project/ray/issues
|
||||||
|
|
||||||
|
RaySGD is a lightweight library for distributed deep learning, allowing you to scale up and speed up training for your deep learning models.
|
||||||
|
|
||||||
|
The main features are:
|
||||||
|
|
||||||
|
- **Ease of use**: Scale your single process training code to a cluster in just a couple lines of code.
|
||||||
|
- **Composability**: RaySGD interoperates with :ref:`Ray Tune <tune-main>` to tune your distributed model and :ref:`Ray Datasets <datasets>` to train on large amounts of data.
|
||||||
|
- **Interactivity**: RaySGD fits in your workflow with support to run from any environment, including seamless Jupyter notebook support.
|
||||||
|
|
||||||
|
|
||||||
|
Intro to RaySGD
|
||||||
|
---------------
|
||||||
|
|
||||||
|
RaySGD is a library that aims to simplify distributed deep learning.
|
||||||
|
|
||||||
|
**Frameworks**: RaySGD is built to abstract away the coordination/configuration setup of distributed deep learning frameworks such as Pytorch Distributed and Tensorflow Distributed, allowing users to only focus on implementing training logic.
|
||||||
|
|
||||||
|
* For Pytorch, RaySGD automatically handles the construction of the distributed process group.
|
||||||
|
* For Tensorflow, RaySGD automatically handles the coordination of the ``TF_CONFIG``. The current implementation assumes that the user will use a MultiWorkerMirroredStrategy, but this will change in the near future.
|
||||||
|
* For Horovod, RaySGD automatically handles the construction of the Horovod runtime and Rendezvous server.
|
||||||
|
|
||||||
|
**Built for data scientists/ML practitioners**: RaySGD has support for standard ML tools and features that practitioners love:
|
||||||
|
|
||||||
|
* Callbacks for early stopping
|
||||||
|
* Checkpointing
|
||||||
|
* Integration with Tensorboard, Weights/Biases, and MLflow
|
||||||
|
* Jupyter notebooks
|
||||||
|
|
||||||
|
**Integration with Ray Ecosystem**: Distributed deep learning often comes with a lot of complexity.
|
||||||
|
|
||||||
|
|
||||||
|
* Use :ref:`Ray Datasets <datasets>` with RaySGD to handle and train on large amounts of data.
|
||||||
|
* Use :ref:`Ray Tune <tune-main>` with RaySGD to leverage cutting edge hyperparameter techniques and distribute both your training and tuning.
|
||||||
|
* You can leverage the :ref:`Ray cluster launcher <cluster-cloud>` to launch autoscaling or spot instance clusters to train your model at scale on any cloud.
|
||||||
|
|
||||||
|
|
||||||
|
Quickstart
|
||||||
|
----------
|
||||||
|
|
||||||
|
You can run the following on your local machine:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def train_func(config=None):
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
train_loader, test_loader = get_data_loaders()
|
||||||
|
model = ConvNet().to(device)
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||||
|
model = DistributedDataParallel(model)
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for epoch in range(40):
|
||||||
|
train(model, optimizer, train_loader, device)
|
||||||
|
acc = test(model, test_loader, device)
|
||||||
|
all_results.append(acc)
|
||||||
|
|
||||||
|
return model._module, all_results
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
num_workers=8,
|
||||||
|
use_gpu=True,
|
||||||
|
backend=TorchConfig())
|
||||||
|
|
||||||
|
print(trainer)
|
||||||
|
# prints a table of resource usage
|
||||||
|
|
||||||
|
model = trainer.run(train_func) # scale out here!
|
||||||
|
|
||||||
|
Links
|
||||||
|
-----
|
||||||
|
|
||||||
|
* :ref:`API reference <sgd-api>`
|
||||||
|
* :ref:`User guide <sgd-user-guide>`
|
||||||
|
* :ref:`Architecture <sgd-arch>`
|
||||||
|
* :ref:`Examples <sgd-v2-examples>`
|
||||||
|
|
||||||
|
|
||||||
|
**Next steps:** Check out the :ref:`user guide here <sgd-user-guide>`
|
1
doc/source/raysgd/v2/sgd-arch.svg
Normal file
1
doc/source/raysgd/v2/sgd-arch.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 32 KiB |
307
doc/source/raysgd/v2/user_guide.rst
Normal file
307
doc/source/raysgd/v2/user_guide.rst
Normal file
|
@ -0,0 +1,307 @@
|
||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _sgd-user-guide:
|
||||||
|
|
||||||
|
RaySGD User Guide
|
||||||
|
=================
|
||||||
|
|
||||||
|
In this guide, we cover examples for the following use cases:
|
||||||
|
|
||||||
|
* How do I port my code to using RaySGD?
|
||||||
|
* How do I use RaySGD to train with a large dataset?
|
||||||
|
* How do I tune my RaySGD model?
|
||||||
|
* How do I run my training on pre-emptible instances (fault tolerance)?
|
||||||
|
* How do I monitor my training?
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Quick Start
|
||||||
|
-----------
|
||||||
|
|
||||||
|
RaySGD abstracts away the complexity of setting up a distributed training system. Let's take this simple example function:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.fc1 = nn.Linear(1, 128)
|
||||||
|
self.fc2 = nn.Linear(128, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def train_func():
|
||||||
|
model = Net()
|
||||||
|
for x in data:
|
||||||
|
results = model(x)
|
||||||
|
return results
|
||||||
|
|
||||||
|
To convert this to RaySGD, we add a `config` parameter to `train_func()`:
|
||||||
|
|
||||||
|
.. code-block:: diff
|
||||||
|
|
||||||
|
-def train_func():
|
||||||
|
+def train_func(config):
|
||||||
|
|
||||||
|
Then, we can construct the trainer function:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from ray.util.sgd import Trainer
|
||||||
|
|
||||||
|
trainer = Trainer(num_workers=2)
|
||||||
|
|
||||||
|
Then, we can pass the function to the trainer. This will cause the trainer to start the necessary processes and execute the training function:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
results = trainer.run(train_func, config=None)
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
Now, let's leverage Pytorch's Distributed Data Parallel. With RaySGD, you just pass in your distributed data parallel code as as you would normally run it with `torch.distributed.launch`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
def train_simple(config: Dict):
|
||||||
|
|
||||||
|
# N is batch size; D_in is input dimension;
|
||||||
|
# H is hidden dimension; D_out is output dimension.
|
||||||
|
N, D_in, H, D_out = 8, 5, 5, 5
|
||||||
|
|
||||||
|
# Create random Tensors to hold inputs and outputs
|
||||||
|
x = torch.randn(N, D_in)
|
||||||
|
y = torch.randn(N, D_out)
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
|
||||||
|
# Use the nn package to define our model and loss function.
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(D_in, H),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(H, D_out),
|
||||||
|
)
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||||
|
|
||||||
|
model = DistributedDataParallel(model)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for epoch in range(config.get("epochs", 10)):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(x)
|
||||||
|
loss = loss_fn(output, y)
|
||||||
|
loss.backward()
|
||||||
|
results.append(loss.item())
|
||||||
|
optimizer.step()
|
||||||
|
return results
|
||||||
|
|
||||||
|
Running this with RaySGD is as simple as the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
all_results = trainer.run(train_simple)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Porting code to RaySGD
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
.. tabs::
|
||||||
|
|
||||||
|
.. group-tab:: pytorch
|
||||||
|
|
||||||
|
TODO. Write about how to convert standard pytorch code to distributed.
|
||||||
|
|
||||||
|
.. group-tab:: tensorflow
|
||||||
|
|
||||||
|
TODO. Write about how to convert standard tf code to distributed.
|
||||||
|
|
||||||
|
.. group-tab:: horovod
|
||||||
|
|
||||||
|
TODO. Write about how to convert code to use horovod.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Training on a large dataset
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
SGD provides native support for :ref:`Ray Datasets <datasets>`. You can pass in a Dataset to RaySGD via ``Trainer.run``.
|
||||||
|
Underneath the hood, RaySGD will automatically shard the given dataset.
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def train_func(config):
|
||||||
|
batch_size = config["worker_batch_size"]
|
||||||
|
data_shard = ray.sgd.get_data_shard()
|
||||||
|
dataloader = data_shard.to_torch(batch_size=batch_size)
|
||||||
|
|
||||||
|
for x, y in dataloader:
|
||||||
|
output = model(x)
|
||||||
|
...
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
trainer = Trainer(num_workers=8, backend="torch")
|
||||||
|
dataset = ray.data.read_csv("...").filter().pipeline(length=50)
|
||||||
|
|
||||||
|
result = trainer.run(
|
||||||
|
train_func,
|
||||||
|
config={"worker_batch_size": 64},
|
||||||
|
dataset=dataset)
|
||||||
|
|
||||||
|
|
||||||
|
.. note:: This feature currently does not work with elastic training.
|
||||||
|
|
||||||
|
|
||||||
|
Monitoring training
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
You may want to plug in your training code with your favorite experiment management framework.
|
||||||
|
RaySGD provides an interface to fetch intermediate results and callbacks to process/log your intermediate results.
|
||||||
|
|
||||||
|
You can plug all of these into RaySGD with the following interface:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def train_func(config):
|
||||||
|
# do something
|
||||||
|
for x, y in dataset:
|
||||||
|
result = process(x)
|
||||||
|
ray.sgd.report(**result)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Where do we pass in the logging folder?
|
||||||
|
result = trainer.run(
|
||||||
|
train_func,
|
||||||
|
config={"worker_batch_size": 64},
|
||||||
|
callbacks=[sgd.MlflowCallback()]
|
||||||
|
dataset=dataset)
|
||||||
|
|
||||||
|
.. Here is a list of callbacks that is supported by RaySGD:
|
||||||
|
|
||||||
|
.. * WandbCallback
|
||||||
|
.. * MlflowCallback
|
||||||
|
.. * TensorboardCallback
|
||||||
|
.. * JsonCallback (Automatically logs given parameters)
|
||||||
|
.. * CSVCallback
|
||||||
|
|
||||||
|
|
||||||
|
.. note:: When using RayTune, these callbacks will not be used.
|
||||||
|
|
||||||
|
Checkpointing
|
||||||
|
-------------
|
||||||
|
|
||||||
|
RaySGD provides a way to save state during the training process. This will be useful for:
|
||||||
|
|
||||||
|
1. :ref:`Integration with Ray Tune <tune-sgd>` to use certain Ray Tune schedulers
|
||||||
|
2. Running a long-running training job on a cluster of pre-emptible machines/pods.
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import ray
|
||||||
|
|
||||||
|
def train_func(config):
|
||||||
|
|
||||||
|
state = ray.sgd.load_checkpoint()
|
||||||
|
# eventually, optional:
|
||||||
|
for _ in config["num_epochs"]:
|
||||||
|
train(...)
|
||||||
|
ray.sgd.save_checkpoint((model, optimizer, etc))
|
||||||
|
return model
|
||||||
|
|
||||||
|
trainer = Trainer(backend="torch", num_workers=4)
|
||||||
|
trainer.run(train_func)
|
||||||
|
state = trainer.get_last_checkpoint()
|
||||||
|
|
||||||
|
.. Running on the cloud
|
||||||
|
.. --------------------
|
||||||
|
|
||||||
|
.. Use RaySGD with the Ray cluster launcher by changing the following:
|
||||||
|
|
||||||
|
.. .. code-block:: bash
|
||||||
|
|
||||||
|
.. ray up cluster.yaml
|
||||||
|
|
||||||
|
.. TODO.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. Running on pre-emptible machines
|
||||||
|
.. --------------------------------
|
||||||
|
|
||||||
|
.. You may want to
|
||||||
|
|
||||||
|
.. TODO.
|
||||||
|
|
||||||
|
|
||||||
|
.. _tune-sgd:
|
||||||
|
|
||||||
|
Hyperparameter tuning
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
Hyperparameter tuning with Ray Tune is natively supported with RaySGD. Specifically, you can take an existing training function and follow these steps:
|
||||||
|
|
||||||
|
1. Call ``trainer.to_tune_trainable``, which will produce an object ("Trainable") that will be passed to Ray Tune.
|
||||||
|
2. Call ``tune.run(trainable)`` instead of ``trainer.run``. This will invoke the hyperparameter tuning, starting multiple "trials" each with the resource amount specified by the Trainer.
|
||||||
|
|
||||||
|
A couple caveats:
|
||||||
|
|
||||||
|
* Tune won't handle the ``training_func`` return value correctly. To save your best trained model, you'll need to use the checkpointing API.
|
||||||
|
* You should **not** call ``tune.report`` or ``tune.checkpoint_dir`` in your training function.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
def training_func(config):
|
||||||
|
dataloader = ray.sgd.get_dataset()\
|
||||||
|
.get_shard(torch.rank())\
|
||||||
|
.to_torch(batch_size=config["batch_size"])
|
||||||
|
|
||||||
|
for i in config["epochs"]:
|
||||||
|
ray.sgd.report(...) # use same intermediate reporting API
|
||||||
|
|
||||||
|
# Declare the specification for training.
|
||||||
|
trainer = Trainer(backend="torch", num_workers=12, use_gpu=True)
|
||||||
|
dataset = ray.dataset.pipeline()
|
||||||
|
|
||||||
|
# Convert this to a trainable.
|
||||||
|
trainable = trainer.to_tune_trainable(training_func, dataset=dataset)
|
||||||
|
|
||||||
|
analysis = tune.run(trainable, config={
|
||||||
|
"lr": tune.uniform(), "batch_size": tune.randint(1, 2, 3)}, num_samples=12)
|
||||||
|
|
||||||
|
|
||||||
|
Distributed metrics (for Pytorch)
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
In real applications, you may want to calcluate optimization metrics besides accuracy and loss: recall, precision, Fbeta, etc.
|
||||||
|
|
||||||
|
RaySGD natively supports `TorchMetrics <https://torchmetrics.readthedocs.io/en/latest/>`_, which provides a collection of machine learning metrics for distributed, scalable Pytorch models.
|
||||||
|
|
||||||
|
Here is an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
import ray
|
||||||
|
|
||||||
|
def train_func(config):
|
||||||
|
preds = torch.randn(10, 5).softmax(dim=-1)
|
||||||
|
target = torch.randint(5, (10,))
|
||||||
|
|
||||||
|
acc = torchmetrics.functional.accuracy(preds, target)
|
||||||
|
ray.sgd.report(accuracy=acc)
|
||||||
|
|
||||||
|
trainer = Trainer(num_workers=2)
|
||||||
|
trainer.run(train_func, config=None)
|
|
@ -1,4 +1,5 @@
|
||||||
from ray.util.sgd.v2.backends.torch import TorchConfig
|
from ray.util.sgd.v2.backends import BackendConfig, TorchConfig
|
||||||
|
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||||
from ray.util.sgd.v2.trainer import Trainer
|
from ray.util.sgd.v2.trainer import Trainer
|
||||||
|
|
||||||
__all__ = ["TorchConfig", "Trainer"]
|
__all__ = ["BackendConfig", "SGDCallback", "TorchConfig", "Trainer"]
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from ray.util.sgd.v2.backends.backend import BackendConfig
|
||||||
|
from ray.util.sgd.v2.backends.torch import TorchConfig
|
||||||
|
|
||||||
|
__all__ = ["TorchConfig", "BackendConfig"]
|
|
@ -0,0 +1,3 @@
|
||||||
|
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||||
|
|
||||||
|
__all__ = ["SGDCallback"]
|
|
@ -1,2 +1,2 @@
|
||||||
class Callback:
|
class SGDCallback:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,32 +1,33 @@
|
||||||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict
|
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict
|
||||||
|
|
||||||
from ray.util.sgd.v2.backends.backend import BackendConfig
|
from ray.util.sgd.v2.backends.backend import BackendConfig
|
||||||
from ray.util.sgd.v2.callbacks.callback import Callback
|
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
S = TypeVar("S")
|
S = TypeVar("S")
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
"""A class for enabling seamless distributed deep learning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (Union[str, BackendConfig]): The backend used for
|
||||||
|
distributed communication. If configurations are needed,
|
||||||
|
a subclass of ``BackendConfig`` can be passed in.
|
||||||
|
Supported ``str`` values: {"torch"}.
|
||||||
|
num_workers (int): The number of workers (Ray actors) to launch.
|
||||||
|
Defaults to 1. Each worker will reserve 1 CPU by default.
|
||||||
|
use_gpu (bool): If True, training will be done on GPUs (1 per
|
||||||
|
worker). Defaults to False.
|
||||||
|
resources_per_worker (Optional[Dict]): If specified, the resources
|
||||||
|
defined in this Dict will be reserved for each worker.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backend: Union[str, BackendConfig],
|
backend: Union[str, BackendConfig],
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
use_gpu: bool = False,
|
use_gpu: bool = False,
|
||||||
resources_per_worker: Optional[Dict[str, float]] = None):
|
resources_per_worker: Optional[Dict[str, float]] = None):
|
||||||
"""A class for distributed training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backend (Union[str, BackendConfig]): The backend used for
|
|
||||||
distributed communication. If configurations are needed,
|
|
||||||
a subclass of ``BackendConfig`` can be passed in.
|
|
||||||
Supported ``str`` values: {"torch"}.
|
|
||||||
num_workers (int): The number of workers (Ray actors) to launch.
|
|
||||||
Defaults to 1. Each worker will reserve 1 CPU by default.
|
|
||||||
use_gpu (bool): If True, training will be done on GPUs (1 per
|
|
||||||
worker). Defaults to False.
|
|
||||||
resources_per_worker (Optional[Dict]): If specified, the resources
|
|
||||||
defined in this Dict will be reserved for each worker.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def start(self,
|
def start(self,
|
||||||
|
@ -48,14 +49,14 @@ class Trainer:
|
||||||
def run(self,
|
def run(self,
|
||||||
train_func: Callable[[Dict[str, Any]], T],
|
train_func: Callable[[Dict[str, Any]], T],
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
callbacks: Optional[List[Callback]] = None) -> List[T]:
|
callbacks: Optional[List[SGDCallback]] = None) -> List[T]:
|
||||||
"""Runs a training function in a distributed manner.
|
"""Runs a training function in a distributed manner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_func (Callable): The training function to execute.
|
train_func (Callable): The training function to execute.
|
||||||
config (Optional[Dict]): Configurations to pass into
|
config (Optional[Dict]): Configurations to pass into
|
||||||
``train_func``. If None then an empty Dict will be created.
|
``train_func``. If None then an empty Dict will be created.
|
||||||
callbacks (Optional[List[Callback]]): A list of Callbacks which
|
callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
|
||||||
will be executed during training. If this is not set,
|
will be executed during training. If this is not set,
|
||||||
currently there are NO default Callbacks.
|
currently there are NO default Callbacks.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -101,7 +102,15 @@ class Trainer:
|
||||||
|
|
||||||
def to_tune_trainable(self, train_func: Callable[[Dict[str, Any]], T]
|
def to_tune_trainable(self, train_func: Callable[[Dict[str, Any]], T]
|
||||||
) -> Callable[[Dict[str, Any]], List[T]]:
|
) -> Callable[[Dict[str, Any]], List[T]]:
|
||||||
"""Creates a Tune trainable function."""
|
"""Creates a Tune trainable function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable): The function that should be executed on each
|
||||||
|
training worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:py:class:`ray.tune.Trainable`
|
||||||
|
"""
|
||||||
|
|
||||||
def trainable(config: Dict[str, Any]) -> List[T]:
|
def trainable(config: Dict[str, Any]) -> List[T]:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue