mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[sgd] v2 documentation draft (#17253)
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Co-authored-by: Matthew Deng <matthew.j.deng@gmail.com> Co-authored-by: Amog Kamsetty <amogkamsetty@yahoo.com>
This commit is contained in:
parent
e812691909
commit
ecc7cf4c5e
12 changed files with 528 additions and 22 deletions
|
@ -68,4 +68,5 @@ You can start a ``TorchTrainer`` with the following:
|
|||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
.. tip:: Get in touch with us if you're using or considering using `RaySGD <https://forms.gle/26EMwdahdgm7Lscy9>`_!
|
||||
|
||||
.. tip:: We are rolling out a lighter-weight version of RaySGD in a future version of Ray. See the documentation :ref:`here <sgd-v2-docs>`.
|
||||
|
|
16
doc/source/raysgd/v2/api.rst
Normal file
16
doc/source/raysgd/v2/api.rst
Normal file
|
@ -0,0 +1,16 @@
|
|||
:orphan:
|
||||
|
||||
.. _sgd-api:
|
||||
|
||||
RaySGD API
|
||||
----------
|
||||
|
||||
|
||||
.. autoclass:: ray.util.sgd.v2.Trainer
|
||||
:members:
|
||||
|
||||
.. autoclass:: ray.util.sgd.v2.BackendConfig
|
||||
|
||||
.. autoclass:: ray.util.sgd.v2.TorchConfig
|
||||
|
||||
.. autoclass:: ray.util.sgd.v2.SGDCallback
|
47
doc/source/raysgd/v2/architecture.rst
Normal file
47
doc/source/raysgd/v2/architecture.rst
Normal file
|
@ -0,0 +1,47 @@
|
|||
:orphan:
|
||||
|
||||
.. _sgd-arch:
|
||||
|
||||
Architecture
|
||||
============
|
||||
|
||||
A diagram of the RaySGD architecture is provided below.
|
||||
|
||||
.. image:: sgd-arch.svg
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
|
||||
Trainer
|
||||
-------
|
||||
|
||||
The Trainer is the main class that is exposed in the RaySGD API that users will interact with.
|
||||
|
||||
|
||||
* The user will pass in a *function* which defines the training logic.
|
||||
* The Trainer will create an :ref:`Executor <sgd-arch-executor>` to run the distributed training.
|
||||
* The Trainer will handle callbacks based on the results from the BackendExecutor.
|
||||
|
||||
.. _sgd-arch-executor:
|
||||
|
||||
Executor
|
||||
--------
|
||||
|
||||
The executor is an interface which handles execution of distributed training.
|
||||
|
||||
* The executor will handle the creation of an actor group and will be initialized in conjunction with a backend.
|
||||
* Worker resources, number of workers, and placement strategy will be passed to the Worker Group.
|
||||
|
||||
|
||||
Backend
|
||||
-------
|
||||
|
||||
A backend is used in conjunction with the executor to initialize and manage framework-specific communication protocols.
|
||||
Each communication library (Torch, Horovod, TensorFlow, etc.) will have a separate backend and will take a specific configuration value.
|
||||
|
||||
WorkerGroup
|
||||
-----------
|
||||
|
||||
The WorkerGroup is a generic utility class for managing a group of Ray Actors.
|
||||
|
||||
* This is similar in concept to Fiber's `Ring <https://uber.github.io/fiber/experimental/ring/>`_.
|
28
doc/source/raysgd/v2/examples.rst
Normal file
28
doc/source/raysgd/v2/examples.rst
Normal file
|
@ -0,0 +1,28 @@
|
|||
:orphan:
|
||||
|
||||
.. _sgd-v2-examples:
|
||||
|
||||
RaySGD Examples
|
||||
===============
|
||||
|
||||
Below are examples for using RaySGD with a variety of models, frameworks, and use cases.
|
||||
|
||||
|
||||
* Simple example for Pytorch.
|
||||
* End-to-end example for Pytorch.
|
||||
* End-to-end example for HuggingFace Transformers (Pytorch).
|
||||
* Simple example for Tensorflow
|
||||
* End-to-end example for Tensorflow
|
||||
* Simple example for Horovod (with Tensorflow)
|
||||
* End-to-end example for Horovod (with Tensorflow)
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
* Example for using a custom callback
|
||||
* End-to-end example for running on an elastic cluster (elastic training)
|
||||
|
||||
Models
|
||||
------
|
||||
|
||||
* Example training on Vision model.
|
89
doc/source/raysgd/v2/raysgd.rst
Normal file
89
doc/source/raysgd/v2/raysgd.rst
Normal file
|
@ -0,0 +1,89 @@
|
|||
:orphan:
|
||||
|
||||
.. _sgd-v2-docs:
|
||||
|
||||
RaySGD: Distributed Training Wrappers
|
||||
=====================================
|
||||
|
||||
.. _`issue on GitHub`: https://github.com/ray-project/ray/issues
|
||||
|
||||
RaySGD is a lightweight library for distributed deep learning, allowing you to scale up and speed up training for your deep learning models.
|
||||
|
||||
The main features are:
|
||||
|
||||
- **Ease of use**: Scale your single process training code to a cluster in just a couple lines of code.
|
||||
- **Composability**: RaySGD interoperates with :ref:`Ray Tune <tune-main>` to tune your distributed model and :ref:`Ray Datasets <datasets>` to train on large amounts of data.
|
||||
- **Interactivity**: RaySGD fits in your workflow with support to run from any environment, including seamless Jupyter notebook support.
|
||||
|
||||
|
||||
Intro to RaySGD
|
||||
---------------
|
||||
|
||||
RaySGD is a library that aims to simplify distributed deep learning.
|
||||
|
||||
**Frameworks**: RaySGD is built to abstract away the coordination/configuration setup of distributed deep learning frameworks such as Pytorch Distributed and Tensorflow Distributed, allowing users to only focus on implementing training logic.
|
||||
|
||||
* For Pytorch, RaySGD automatically handles the construction of the distributed process group.
|
||||
* For Tensorflow, RaySGD automatically handles the coordination of the ``TF_CONFIG``. The current implementation assumes that the user will use a MultiWorkerMirroredStrategy, but this will change in the near future.
|
||||
* For Horovod, RaySGD automatically handles the construction of the Horovod runtime and Rendezvous server.
|
||||
|
||||
**Built for data scientists/ML practitioners**: RaySGD has support for standard ML tools and features that practitioners love:
|
||||
|
||||
* Callbacks for early stopping
|
||||
* Checkpointing
|
||||
* Integration with Tensorboard, Weights/Biases, and MLflow
|
||||
* Jupyter notebooks
|
||||
|
||||
**Integration with Ray Ecosystem**: Distributed deep learning often comes with a lot of complexity.
|
||||
|
||||
|
||||
* Use :ref:`Ray Datasets <datasets>` with RaySGD to handle and train on large amounts of data.
|
||||
* Use :ref:`Ray Tune <tune-main>` with RaySGD to leverage cutting edge hyperparameter techniques and distribute both your training and tuning.
|
||||
* You can leverage the :ref:`Ray cluster launcher <cluster-cloud>` to launch autoscaling or spot instance clusters to train your model at scale on any cloud.
|
||||
|
||||
|
||||
Quickstart
|
||||
----------
|
||||
|
||||
You can run the following on your local machine:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
|
||||
def train_func(config=None):
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = ConvNet().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
model = DistributedDataParallel(model)
|
||||
all_results = []
|
||||
|
||||
for epoch in range(40):
|
||||
train(model, optimizer, train_loader, device)
|
||||
acc = test(model, test_loader, device)
|
||||
all_results.append(acc)
|
||||
|
||||
return model._module, all_results
|
||||
|
||||
trainer = Trainer(
|
||||
num_workers=8,
|
||||
use_gpu=True,
|
||||
backend=TorchConfig())
|
||||
|
||||
print(trainer)
|
||||
# prints a table of resource usage
|
||||
|
||||
model = trainer.run(train_func) # scale out here!
|
||||
|
||||
Links
|
||||
-----
|
||||
|
||||
* :ref:`API reference <sgd-api>`
|
||||
* :ref:`User guide <sgd-user-guide>`
|
||||
* :ref:`Architecture <sgd-arch>`
|
||||
* :ref:`Examples <sgd-v2-examples>`
|
||||
|
||||
|
||||
**Next steps:** Check out the :ref:`user guide here <sgd-user-guide>`
|
1
doc/source/raysgd/v2/sgd-arch.svg
Normal file
1
doc/source/raysgd/v2/sgd-arch.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 32 KiB |
307
doc/source/raysgd/v2/user_guide.rst
Normal file
307
doc/source/raysgd/v2/user_guide.rst
Normal file
|
@ -0,0 +1,307 @@
|
|||
:orphan:
|
||||
|
||||
.. _sgd-user-guide:
|
||||
|
||||
RaySGD User Guide
|
||||
=================
|
||||
|
||||
In this guide, we cover examples for the following use cases:
|
||||
|
||||
* How do I port my code to using RaySGD?
|
||||
* How do I use RaySGD to train with a large dataset?
|
||||
* How do I tune my RaySGD model?
|
||||
* How do I run my training on pre-emptible instances (fault tolerance)?
|
||||
* How do I monitor my training?
|
||||
|
||||
|
||||
|
||||
Quick Start
|
||||
-----------
|
||||
|
||||
RaySGD abstracts away the complexity of setting up a distributed training system. Let's take this simple example function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(1, 128)
|
||||
self.fc2 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def train_func():
|
||||
model = Net()
|
||||
for x in data:
|
||||
results = model(x)
|
||||
return results
|
||||
|
||||
To convert this to RaySGD, we add a `config` parameter to `train_func()`:
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
-def train_func():
|
||||
+def train_func(config):
|
||||
|
||||
Then, we can construct the trainer function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.util.sgd import Trainer
|
||||
|
||||
trainer = Trainer(num_workers=2)
|
||||
|
||||
Then, we can pass the function to the trainer. This will cause the trainer to start the necessary processes and execute the training function:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
results = trainer.run(train_func, config=None)
|
||||
print(results)
|
||||
|
||||
Now, let's leverage Pytorch's Distributed Data Parallel. With RaySGD, you just pass in your distributed data parallel code as as you would normally run it with `torch.distributed.launch`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
import torch.optim as optim
|
||||
|
||||
def train_simple(config: Dict):
|
||||
|
||||
# N is batch size; D_in is input dimension;
|
||||
# H is hidden dimension; D_out is output dimension.
|
||||
N, D_in, H, D_out = 8, 5, 5, 5
|
||||
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(N, D_in)
|
||||
y = torch.randn(N, D_out)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(D_in, H),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(H, D_out),
|
||||
)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
model = DistributedDataParallel(model)
|
||||
results = []
|
||||
|
||||
for epoch in range(config.get("epochs", 10)):
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = loss_fn(output, y)
|
||||
loss.backward()
|
||||
results.append(loss.item())
|
||||
optimizer.step()
|
||||
return results
|
||||
|
||||
Running this with RaySGD is as simple as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
all_results = trainer.run(train_simple)
|
||||
|
||||
|
||||
|
||||
Porting code to RaySGD
|
||||
----------------------
|
||||
|
||||
.. tabs::
|
||||
|
||||
.. group-tab:: pytorch
|
||||
|
||||
TODO. Write about how to convert standard pytorch code to distributed.
|
||||
|
||||
.. group-tab:: tensorflow
|
||||
|
||||
TODO. Write about how to convert standard tf code to distributed.
|
||||
|
||||
.. group-tab:: horovod
|
||||
|
||||
TODO. Write about how to convert code to use horovod.
|
||||
|
||||
|
||||
|
||||
Training on a large dataset
|
||||
---------------------------
|
||||
|
||||
SGD provides native support for :ref:`Ray Datasets <datasets>`. You can pass in a Dataset to RaySGD via ``Trainer.run``.
|
||||
Underneath the hood, RaySGD will automatically shard the given dataset.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def train_func(config):
|
||||
batch_size = config["worker_batch_size"]
|
||||
data_shard = ray.sgd.get_data_shard()
|
||||
dataloader = data_shard.to_torch(batch_size=batch_size)
|
||||
|
||||
for x, y in dataloader:
|
||||
output = model(x)
|
||||
...
|
||||
|
||||
return model
|
||||
|
||||
trainer = Trainer(num_workers=8, backend="torch")
|
||||
dataset = ray.data.read_csv("...").filter().pipeline(length=50)
|
||||
|
||||
result = trainer.run(
|
||||
train_func,
|
||||
config={"worker_batch_size": 64},
|
||||
dataset=dataset)
|
||||
|
||||
|
||||
.. note:: This feature currently does not work with elastic training.
|
||||
|
||||
|
||||
Monitoring training
|
||||
-------------------
|
||||
|
||||
You may want to plug in your training code with your favorite experiment management framework.
|
||||
RaySGD provides an interface to fetch intermediate results and callbacks to process/log your intermediate results.
|
||||
|
||||
You can plug all of these into RaySGD with the following interface:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def train_func(config):
|
||||
# do something
|
||||
for x, y in dataset:
|
||||
result = process(x)
|
||||
ray.sgd.report(**result)
|
||||
|
||||
|
||||
# TODO: Where do we pass in the logging folder?
|
||||
result = trainer.run(
|
||||
train_func,
|
||||
config={"worker_batch_size": 64},
|
||||
callbacks=[sgd.MlflowCallback()]
|
||||
dataset=dataset)
|
||||
|
||||
.. Here is a list of callbacks that is supported by RaySGD:
|
||||
|
||||
.. * WandbCallback
|
||||
.. * MlflowCallback
|
||||
.. * TensorboardCallback
|
||||
.. * JsonCallback (Automatically logs given parameters)
|
||||
.. * CSVCallback
|
||||
|
||||
|
||||
.. note:: When using RayTune, these callbacks will not be used.
|
||||
|
||||
Checkpointing
|
||||
-------------
|
||||
|
||||
RaySGD provides a way to save state during the training process. This will be useful for:
|
||||
|
||||
1. :ref:`Integration with Ray Tune <tune-sgd>` to use certain Ray Tune schedulers
|
||||
2. Running a long-running training job on a cluster of pre-emptible machines/pods.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
|
||||
def train_func(config):
|
||||
|
||||
state = ray.sgd.load_checkpoint()
|
||||
# eventually, optional:
|
||||
for _ in config["num_epochs"]:
|
||||
train(...)
|
||||
ray.sgd.save_checkpoint((model, optimizer, etc))
|
||||
return model
|
||||
|
||||
trainer = Trainer(backend="torch", num_workers=4)
|
||||
trainer.run(train_func)
|
||||
state = trainer.get_last_checkpoint()
|
||||
|
||||
.. Running on the cloud
|
||||
.. --------------------
|
||||
|
||||
.. Use RaySGD with the Ray cluster launcher by changing the following:
|
||||
|
||||
.. .. code-block:: bash
|
||||
|
||||
.. ray up cluster.yaml
|
||||
|
||||
.. TODO.
|
||||
|
||||
|
||||
|
||||
.. Running on pre-emptible machines
|
||||
.. --------------------------------
|
||||
|
||||
.. You may want to
|
||||
|
||||
.. TODO.
|
||||
|
||||
|
||||
.. _tune-sgd:
|
||||
|
||||
Hyperparameter tuning
|
||||
---------------------
|
||||
|
||||
Hyperparameter tuning with Ray Tune is natively supported with RaySGD. Specifically, you can take an existing training function and follow these steps:
|
||||
|
||||
1. Call ``trainer.to_tune_trainable``, which will produce an object ("Trainable") that will be passed to Ray Tune.
|
||||
2. Call ``tune.run(trainable)`` instead of ``trainer.run``. This will invoke the hyperparameter tuning, starting multiple "trials" each with the resource amount specified by the Trainer.
|
||||
|
||||
A couple caveats:
|
||||
|
||||
* Tune won't handle the ``training_func`` return value correctly. To save your best trained model, you'll need to use the checkpointing API.
|
||||
* You should **not** call ``tune.report`` or ``tune.checkpoint_dir`` in your training function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
def training_func(config):
|
||||
dataloader = ray.sgd.get_dataset()\
|
||||
.get_shard(torch.rank())\
|
||||
.to_torch(batch_size=config["batch_size"])
|
||||
|
||||
for i in config["epochs"]:
|
||||
ray.sgd.report(...) # use same intermediate reporting API
|
||||
|
||||
# Declare the specification for training.
|
||||
trainer = Trainer(backend="torch", num_workers=12, use_gpu=True)
|
||||
dataset = ray.dataset.pipeline()
|
||||
|
||||
# Convert this to a trainable.
|
||||
trainable = trainer.to_tune_trainable(training_func, dataset=dataset)
|
||||
|
||||
analysis = tune.run(trainable, config={
|
||||
"lr": tune.uniform(), "batch_size": tune.randint(1, 2, 3)}, num_samples=12)
|
||||
|
||||
|
||||
Distributed metrics (for Pytorch)
|
||||
---------------------------------
|
||||
|
||||
In real applications, you may want to calcluate optimization metrics besides accuracy and loss: recall, precision, Fbeta, etc.
|
||||
|
||||
RaySGD natively supports `TorchMetrics <https://torchmetrics.readthedocs.io/en/latest/>`_, which provides a collection of machine learning metrics for distributed, scalable Pytorch models.
|
||||
|
||||
Here is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import torchmetrics
|
||||
import ray
|
||||
|
||||
def train_func(config):
|
||||
preds = torch.randn(10, 5).softmax(dim=-1)
|
||||
target = torch.randint(5, (10,))
|
||||
|
||||
acc = torchmetrics.functional.accuracy(preds, target)
|
||||
ray.sgd.report(accuracy=acc)
|
||||
|
||||
trainer = Trainer(num_workers=2)
|
||||
trainer.run(train_func, config=None)
|
|
@ -1,4 +1,5 @@
|
|||
from ray.util.sgd.v2.backends.torch import TorchConfig
|
||||
from ray.util.sgd.v2.backends import BackendConfig, TorchConfig
|
||||
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||
from ray.util.sgd.v2.trainer import Trainer
|
||||
|
||||
__all__ = ["TorchConfig", "Trainer"]
|
||||
__all__ = ["BackendConfig", "SGDCallback", "TorchConfig", "Trainer"]
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from ray.util.sgd.v2.backends.backend import BackendConfig
|
||||
from ray.util.sgd.v2.backends.torch import TorchConfig
|
||||
|
||||
__all__ = ["TorchConfig", "BackendConfig"]
|
|
@ -0,0 +1,3 @@
|
|||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
|
||||
__all__ = ["SGDCallback"]
|
|
@ -1,2 +1,2 @@
|
|||
class Callback:
|
||||
class SGDCallback:
|
||||
pass
|
||||
|
|
|
@ -1,32 +1,33 @@
|
|||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict
|
||||
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig
|
||||
from ray.util.sgd.v2.callbacks.callback import Callback
|
||||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""A class for enabling seamless distributed deep learning.
|
||||
|
||||
Args:
|
||||
backend (Union[str, BackendConfig]): The backend used for
|
||||
distributed communication. If configurations are needed,
|
||||
a subclass of ``BackendConfig`` can be passed in.
|
||||
Supported ``str`` values: {"torch"}.
|
||||
num_workers (int): The number of workers (Ray actors) to launch.
|
||||
Defaults to 1. Each worker will reserve 1 CPU by default.
|
||||
use_gpu (bool): If True, training will be done on GPUs (1 per
|
||||
worker). Defaults to False.
|
||||
resources_per_worker (Optional[Dict]): If specified, the resources
|
||||
defined in this Dict will be reserved for each worker.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backend: Union[str, BackendConfig],
|
||||
num_workers: int = 1,
|
||||
use_gpu: bool = False,
|
||||
resources_per_worker: Optional[Dict[str, float]] = None):
|
||||
"""A class for distributed training.
|
||||
|
||||
Args:
|
||||
backend (Union[str, BackendConfig]): The backend used for
|
||||
distributed communication. If configurations are needed,
|
||||
a subclass of ``BackendConfig`` can be passed in.
|
||||
Supported ``str`` values: {"torch"}.
|
||||
num_workers (int): The number of workers (Ray actors) to launch.
|
||||
Defaults to 1. Each worker will reserve 1 CPU by default.
|
||||
use_gpu (bool): If True, training will be done on GPUs (1 per
|
||||
worker). Defaults to False.
|
||||
resources_per_worker (Optional[Dict]): If specified, the resources
|
||||
defined in this Dict will be reserved for each worker.
|
||||
"""
|
||||
pass
|
||||
|
||||
def start(self,
|
||||
|
@ -48,14 +49,14 @@ class Trainer:
|
|||
def run(self,
|
||||
train_func: Callable[[Dict[str, Any]], T],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
callbacks: Optional[List[Callback]] = None) -> List[T]:
|
||||
callbacks: Optional[List[SGDCallback]] = None) -> List[T]:
|
||||
"""Runs a training function in a distributed manner.
|
||||
|
||||
Args:
|
||||
train_func (Callable): The training function to execute.
|
||||
config (Optional[Dict]): Configurations to pass into
|
||||
``train_func``. If None then an empty Dict will be created.
|
||||
callbacks (Optional[List[Callback]]): A list of Callbacks which
|
||||
callbacks (Optional[List[SGDCallback]]): A list of Callbacks which
|
||||
will be executed during training. If this is not set,
|
||||
currently there are NO default Callbacks.
|
||||
Returns:
|
||||
|
@ -101,7 +102,15 @@ class Trainer:
|
|||
|
||||
def to_tune_trainable(self, train_func: Callable[[Dict[str, Any]], T]
|
||||
) -> Callable[[Dict[str, Any]], List[T]]:
|
||||
"""Creates a Tune trainable function."""
|
||||
"""Creates a Tune trainable function.
|
||||
|
||||
Args:
|
||||
func (Callable): The function that should be executed on each
|
||||
training worker.
|
||||
|
||||
Returns:
|
||||
:py:class:`ray.tune.Trainable`
|
||||
"""
|
||||
|
||||
def trainable(config: Dict[str, Any]) -> List[T]:
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue