mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[SGD] v2 Class API (#18571)
* wip * wip * add horovod example * add example * lint * fix * address comments * updates * lint * update example * address comment * address comment * update * fix * Update python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> * address comments * add back name mangling * fix tests * Update python/ray/util/sgd/v2/trainer.py * fix * lint * fix * fix docstring * Update python/ray/util/sgd/v2/tests/test_trainer.py Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> * update Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
This commit is contained in:
parent
eeaae5aa08
commit
de050e8187
6 changed files with 433 additions and 61 deletions
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, TypeVar, List, Optional, Dict, Union
|
||||
from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple
|
||||
|
||||
import ray
|
||||
from ray import cloudpickle
|
||||
|
@ -257,11 +257,21 @@ class BackendExecutor:
|
|||
|
||||
self.checkpoint_manager.on_init()
|
||||
|
||||
def start(self, initialization_hook: Optional[Callable[[], None]] = None):
|
||||
def start(self,
|
||||
initialization_hook: Optional[Callable[[], None]] = None,
|
||||
train_cls: Optional[Type] = None,
|
||||
train_cls_args: Optional[Tuple] = None,
|
||||
train_cls_kwargs: Optional[Dict] = None):
|
||||
"""Starts the worker group."""
|
||||
self.worker_group = WorkerGroup(
|
||||
self._num_workers, self._num_cpus_per_worker,
|
||||
self._num_gpus_per_worker, self._additional_resources_per_worker)
|
||||
num_workers=self._num_workers,
|
||||
num_cpus_per_worker=self._num_cpus_per_worker,
|
||||
num_gpus_per_worker=self._num_gpus_per_worker,
|
||||
additional_resources_per_worker=self.
|
||||
_additional_resources_per_worker,
|
||||
actor_cls=train_cls,
|
||||
actor_cls_args=train_cls_args,
|
||||
actor_cls_kwargs=train_cls_kwargs)
|
||||
try:
|
||||
if initialization_hook:
|
||||
self._initialization_hook = initialization_hook
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import horovod.torch as hvd
|
||||
import ray
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data.distributed
|
||||
from filelock import FileLock
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
|
||||
def metric_average(val, name):
|
||||
tensor = torch.tensor(val)
|
||||
avg_tensor = hvd.allreduce(tensor, name=name)
|
||||
return avg_tensor.item()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
class TrainClass:
|
||||
def __init__(self, config):
|
||||
self.model = Net()
|
||||
self.use_cuda = config.get("use_cuda", False)
|
||||
|
||||
data_dir = config.get("data_dir", None)
|
||||
seed = config.get("seed", 42)
|
||||
batch_size = config.get("batch_size", 64)
|
||||
use_adasum = config.get("use_adasum", False)
|
||||
lr = config.get("lr", 0.01)
|
||||
momentum = config.get("momentum", 0.5)
|
||||
|
||||
# Horovod: initialize library.
|
||||
hvd.init()
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if self.use_cuda:
|
||||
# Horovod: pin GPU to local rank.
|
||||
torch.cuda.set_device(hvd.local_rank())
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Horovod: limit # of CPU threads to be used per worker.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
kwargs = {"num_workers": 1, "pin_memory": True} if self.use_cuda \
|
||||
else {}
|
||||
data_dir = data_dir or "~/data"
|
||||
with FileLock(os.path.expanduser("~/.horovod_lock")):
|
||||
train_dataset = \
|
||||
datasets.MNIST(data_dir, train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
]))
|
||||
# Horovod: use DistributedSampler to partition the training data.
|
||||
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=self.train_sampler,
|
||||
**kwargs)
|
||||
|
||||
model = Net()
|
||||
|
||||
# By default, Adasum doesn't need scaling up learning rate.
|
||||
lr_scaler = hvd.size() if not use_adasum else 1
|
||||
|
||||
if self.use_cuda:
|
||||
# Move model to GPU.
|
||||
model.cuda()
|
||||
# If using GPU Adasum allreduce, scale learning rate by local_size.
|
||||
if use_adasum and hvd.nccl_built():
|
||||
lr_scaler = hvd.local_size()
|
||||
|
||||
# Horovod: scale learning rate by lr_scaler.
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=lr * lr_scaler, momentum=momentum)
|
||||
|
||||
# Horovod: wrap optimizer with DistributedOptimizer.
|
||||
self.optimizer = hvd.DistributedOptimizer(
|
||||
optimizer,
|
||||
named_parameters=model.named_parameters(),
|
||||
op=hvd.Adasum if use_adasum else hvd.Average)
|
||||
|
||||
def train(self, epoch):
|
||||
self.model.train()
|
||||
# Horovod: set epoch to sampler for shuffling.
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
num_batches = len(self.train_loader)
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if self.use_cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
if batch_idx == num_batches - 1:
|
||||
return loss.item()
|
||||
|
||||
|
||||
def main(num_workers, use_gpu, num_epochs, config):
|
||||
trainer = Trainer("horovod", use_gpu=use_gpu, num_workers=num_workers)
|
||||
trainer.start()
|
||||
workers = trainer.to_worker_group(TrainClass, config)
|
||||
results = []
|
||||
for epoch in range(num_epochs):
|
||||
loss = ray.get([w.train.remote(epoch=epoch) for w in workers])
|
||||
results.append(loss)
|
||||
trainer.shutdown()
|
||||
print(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch MNIST Example",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)")
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 5)")
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 0.01)")
|
||||
parser.add_argument(
|
||||
"--momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="enables CUDA training")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
metavar="S",
|
||||
help="random seed (default: 42)")
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status")
|
||||
parser.add_argument(
|
||||
"--use-adasum",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use adasum algorithm to do reduction")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of Ray workers to use for training.")
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
help="location of the training dataset in the local filesystem ("
|
||||
"will be downloaded if needed)")
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of Ray cluster.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.address:
|
||||
ray.init(args.address)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
use_cuda = args.use_gpu if args.use_gpu is not None else False
|
||||
|
||||
kwargs = {
|
||||
"data_dir": args.data_dir,
|
||||
"seed": args.seed,
|
||||
"use_cuda": use_cuda,
|
||||
"batch_size": args.batch_size,
|
||||
"use_adasum": args.use_adasum if args.use_adasum else False,
|
||||
"lr": args.lr,
|
||||
"momentum": args.momentum,
|
||||
"log_interval": args.log_interval
|
||||
}
|
||||
|
||||
main(
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=use_cuda,
|
||||
num_epochs=args.num_epochs,
|
||||
config=kwargs)
|
|
@ -54,9 +54,9 @@ def ray_2_node_4_gpu():
|
|||
def gen_execute_special(special_f):
|
||||
def execute_async_special(self, f):
|
||||
"""Runs f on worker 0, special_f on other workers."""
|
||||
futures = [self.workers[0].execute.remote(f)]
|
||||
futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)]
|
||||
for worker in self.workers[1:]:
|
||||
futures.append(worker.execute.remote(special_f))
|
||||
futures.append(worker._BaseWorkerMixin__execute.remote(special_f))
|
||||
return futures
|
||||
|
||||
return execute_async_special
|
||||
|
|
|
@ -17,8 +17,11 @@ from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \
|
|||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \
|
||||
tensorflow_mnist_train_func
|
||||
from ray.util.sgd.v2.examples.horovod.horovod_stateful_example import \
|
||||
TrainClass as HorovodTrainClass
|
||||
from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \
|
||||
as fashion_mnist_train_func
|
||||
as \
|
||||
fashion_mnist_train_func
|
||||
from ray.util.sgd.v2.examples.train_linear_example import train_func as \
|
||||
linear_train_func
|
||||
|
||||
|
@ -87,7 +90,8 @@ def gen_execute_single_async_special(special_f):
|
|||
assert len(self.workers) == 2
|
||||
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
|
||||
kwargs["train_func"] = special_f
|
||||
return self.workers[i].execute.remote(f, *args, **kwargs)
|
||||
return self.workers[i]._BaseWorkerMixin__execute.remote(
|
||||
f, *args, **kwargs)
|
||||
|
||||
return execute_single_async_special
|
||||
|
||||
|
@ -630,6 +634,25 @@ def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus):
|
|||
assert worker_result[num_epochs - 1] < worker_result[0]
|
||||
|
||||
|
||||
def test_horovod_torch_mnist_stateful(ray_start_2_cpus):
|
||||
num_workers = 2
|
||||
num_epochs = 2
|
||||
trainer = Trainer("horovod", num_workers)
|
||||
workers = trainer.to_worker_group(
|
||||
HorovodTrainClass, config={
|
||||
"num_epochs": num_epochs,
|
||||
"lr": 1e-3
|
||||
})
|
||||
results = []
|
||||
for epoch in range(num_epochs):
|
||||
results.append(ray.get([w.train.remote(epoch=epoch) for w in workers]))
|
||||
trainer.shutdown()
|
||||
|
||||
assert len(results) == num_epochs
|
||||
for i in range(num_workers):
|
||||
assert results[num_epochs - 1][i] < results[0][i]
|
||||
|
||||
|
||||
def test_init_failure(ray_start_2_cpus):
|
||||
with pytest.raises(TypeError):
|
||||
Trainer(5)
|
||||
|
@ -1001,6 +1024,33 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra):
|
|||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_to_worker_group(ray_start_2_cpus):
|
||||
config = TestConfig()
|
||||
trainer = Trainer(config, num_workers=2)
|
||||
|
||||
class Incrementer:
|
||||
def __init__(self, starting=0):
|
||||
self.count = starting
|
||||
|
||||
def increment(self):
|
||||
self.count += 1
|
||||
|
||||
def get_count(self):
|
||||
return self.count
|
||||
|
||||
workers = trainer.to_worker_group(Incrementer, starting=2)
|
||||
assert ray.get([w.get_count.remote() for w in workers]) == [2, 2]
|
||||
|
||||
ray.get([w.increment.remote() for w in workers])
|
||||
assert ray.get([w.get_count.remote() for w in workers]) == [3, 3]
|
||||
|
||||
ray.get(workers[0].increment.remote())
|
||||
assert ray.get([w.get_count.remote() for w in workers]) == [4, 3]
|
||||
|
||||
ray.get(workers[1].increment.remote())
|
||||
assert ray.get([w.get_count.remote() for w in workers]) == [4, 4]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -4,8 +4,9 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
|
||||
Type, Iterator
|
||||
Type
|
||||
|
||||
from ray.actor import ActorHandle
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \
|
||||
InactiveWorkerGroupError, SGDBackendError, TrainingWorkerError
|
||||
from ray.util.sgd.v2.backends.horovod import HorovodConfig
|
||||
|
@ -18,6 +19,7 @@ from ray.util.sgd.v2.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \
|
|||
|
||||
# Ray SGD should be usable even if Tune is not installed.
|
||||
from ray.util.sgd.v2.utils import construct_path
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
@ -174,19 +176,12 @@ class Trainer:
|
|||
else:
|
||||
raise TypeError(f"Invalid type for backend: {type(backend)}.")
|
||||
|
||||
def start(self,
|
||||
initialization_hook: Optional[Callable[[], None]] = None,
|
||||
train_cls: Optional[S] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
def start(self, initialization_hook: Optional[Callable[[], None]] = None):
|
||||
"""Starts the training execution service.
|
||||
|
||||
Args:
|
||||
initialization_hook (Optional[Callable]): The function to call on
|
||||
each worker when it is instantiated.
|
||||
train_cls (Optional[cls]): The training class that each worker
|
||||
should be instantiated as.
|
||||
args, kwargs: The arguments to pass into ``train_cls.__init__``.
|
||||
"""
|
||||
self._executor.start(initialization_hook)
|
||||
|
||||
|
@ -258,7 +253,7 @@ class Trainer:
|
|||
config: Optional[Dict[str, Any]] = None,
|
||||
checkpoint: Optional[Union[Dict, str, Path]] = None,
|
||||
checkpoint_strategy: Optional[CheckpointStrategy] = None
|
||||
) -> Iterator[List[Dict]]:
|
||||
) -> "SGDIterator":
|
||||
"""Same as ``run`` except returns an iterator over the results.
|
||||
|
||||
This is useful if you want to have more customization of what to do
|
||||
|
@ -343,36 +338,6 @@ class Trainer:
|
|||
else: # num_params == 0
|
||||
return train_func
|
||||
|
||||
def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
|
||||
"""Executes a function for all instances of ``self.train_cls``.
|
||||
|
||||
Args:
|
||||
func (Callable): The function that should be executed.
|
||||
The first argument should be an instance of
|
||||
``self.train_cls``.
|
||||
args, kwargs: The arguments to pass into ``func``.
|
||||
|
||||
Returns:
|
||||
A list of results from ``func``. Each value in the
|
||||
list corresponds to the output of ``func`` from
|
||||
each worker.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_single(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Executes a function on a single instance of ``self.train_cls``.
|
||||
|
||||
Args:
|
||||
func (Callable): The function that should be executed.
|
||||
The first argument should be an instance of
|
||||
``self.train_cls``.
|
||||
args, kwargs: The arguments to pass into ``func``.
|
||||
|
||||
Returns:
|
||||
The output of ``func`` from a single worker.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def latest_run_dir(self) -> Optional[Path]:
|
||||
"""Path to the log directory for the latest call to ``run()``.
|
||||
|
@ -444,6 +409,91 @@ class Trainer:
|
|||
self._num_workers, self._use_gpu,
|
||||
self._resources_per_worker)
|
||||
|
||||
def to_worker_group(self, train_cls: Type, *args,
|
||||
**kwargs) -> "SGDWorkerGroup":
|
||||
"""Returns Ray actors with the provided class and the backend started.
|
||||
|
||||
This is useful if you want to provide your own class for training
|
||||
and have more control over execution, but still want to use Ray SGD
|
||||
to setup the appropriate backend configurations (torch, tf, etc.).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def train_epoch(self):
|
||||
...
|
||||
return 1
|
||||
|
||||
config = {"lr": 0.1}
|
||||
trainer = Trainer(num_workers=2, backend="torch")
|
||||
workers = trainer.to_worker_group(train_cls=Trainer, config=config)
|
||||
futures = [w.train_epoch.remote() for w in workers]
|
||||
assert ray.get(futures) == [1, 1]
|
||||
assert ray.get(workers[0].train_epoch.remote()) == 1
|
||||
workers.shutdown()
|
||||
|
||||
Args:
|
||||
train_cls (Type): The class definition to use for the Ray
|
||||
actors/workers.
|
||||
args, kwargs: Arguments to pass into the ``__init__`` of the
|
||||
provided ``train_cls``.
|
||||
"""
|
||||
if self._executor.is_started:
|
||||
raise RuntimeError("The Trainer must not be active to use "
|
||||
"`to_worker_group`. Either shutdown the "
|
||||
"Trainer or don't start it in the first place.")
|
||||
self._executor.start(
|
||||
train_cls=train_cls, train_cls_args=args, train_cls_kwargs=kwargs)
|
||||
return SGDWorkerGroup(self._executor.worker_group)
|
||||
|
||||
|
||||
class SGDWorkerGroup:
|
||||
"""A container for a group of Ray actors.
|
||||
|
||||
You should not instantiate this directly and only use this as the output
|
||||
of ``Trainer.to_worker_group``. You can index or iterate this object like
|
||||
you would a List.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def train_epoch(self):
|
||||
...
|
||||
return 1
|
||||
|
||||
config = {"lr": 0.1}
|
||||
trainer = Trainer(num_workers=2, backend="torch")
|
||||
workers = trainer.to_worker_group(train_cls=Trainer, config=config)
|
||||
futures = [w.train_epoch.remote() for w in workers]
|
||||
assert ray.get(futures) == [1, 1]
|
||||
assert ray.get(workers[0].train_epoch.remote()) == 1
|
||||
workers.shutdown()`
|
||||
"""
|
||||
|
||||
def __init__(self, worker_group: WorkerGroup):
|
||||
self._worker_group = worker_group
|
||||
|
||||
def __getitem__(self, item) -> ActorHandle:
|
||||
return self._worker_group.workers[item]
|
||||
|
||||
def shutdown(self, patience_s: float = 5):
|
||||
"""Shutdown all the workers.
|
||||
|
||||
Args:
|
||||
patience_s (float): Attempt a graceful shutdown
|
||||
of the workers for this many seconds. Fallback to force kill
|
||||
if graceful shutdown is not complete after this time. If
|
||||
this is less than or equal to 0, immediately force kill all
|
||||
workers.
|
||||
"""
|
||||
self._worker_group.shutdown(patience_s=patience_s)
|
||||
|
||||
|
||||
class SGDIterator:
|
||||
"""An iterator over SGD results. Returned by ``trainer.run_iterator``."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Callable, List, TypeVar, Optional, Dict
|
||||
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
|
||||
|
||||
import ray
|
||||
from ray.types import ObjectRef
|
||||
|
@ -9,10 +9,10 @@ T = TypeVar("T")
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseWorker:
|
||||
class BaseWorkerMixin:
|
||||
"""A class to execute arbitrary functions. Does not hold any state."""
|
||||
|
||||
def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
def __execute(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Executes the input function and returns the output.
|
||||
|
||||
Args:
|
||||
|
@ -22,6 +22,21 @@ class BaseWorker:
|
|||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
||||
"""Create the executable class to use as the Ray actors."""
|
||||
if not executable_cls:
|
||||
return BaseWorkerMixin
|
||||
elif issubclass(executable_cls, BaseWorkerMixin):
|
||||
return executable_cls
|
||||
else:
|
||||
|
||||
class _WrappedExecutable(executable_cls, BaseWorkerMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
return _WrappedExecutable
|
||||
|
||||
|
||||
class WorkerGroup:
|
||||
"""Group of Ray Actors that can execute arbitrary functions.
|
||||
|
||||
|
@ -43,6 +58,10 @@ class WorkerGroup:
|
|||
Dictionary specifying the extra resources that will be
|
||||
requested for each worker in addition to ``num_cpus_per_worker``
|
||||
and ``num_gpus_per_worker``.
|
||||
actor_cls (Optional[Type]): If specified use this class as the
|
||||
remote actors.
|
||||
remote_cls_args, remote_cls_kwargs: If ``remote_cls`` is provided,
|
||||
these args will be used for the worker initialization.
|
||||
|
||||
|
||||
Example:
|
||||
|
@ -55,12 +74,15 @@ class WorkerGroup:
|
|||
assert all(o == 1 for o in output)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_workers: int = 1,
|
||||
num_cpus_per_worker: float = 1,
|
||||
num_gpus_per_worker: float = 0,
|
||||
additional_resources_per_worker: Optional[Dict[
|
||||
str, float]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
num_workers: int = 1,
|
||||
num_cpus_per_worker: float = 1,
|
||||
num_gpus_per_worker: float = 0,
|
||||
additional_resources_per_worker: Optional[Dict[str, float]] = None,
|
||||
actor_cls: Type = None,
|
||||
actor_cls_args: Optional[Tuple] = None,
|
||||
actor_cls_kwargs: Optional[Dict] = None):
|
||||
|
||||
if num_workers <= 0:
|
||||
raise ValueError("The provided `num_workers` must be greater "
|
||||
|
@ -72,21 +94,32 @@ class WorkerGroup:
|
|||
f"num_cpus_per_worker={num_cpus_per_worker} and "
|
||||
f"num_gpus_per_worker={num_gpus_per_worker}.")
|
||||
|
||||
if (actor_cls_args or actor_cls_kwargs) and not actor_cls:
|
||||
raise ValueError("`actor_cls_args` or `actor_class_kwargs` are "
|
||||
"passed in but no `actor_cls` is passed in.")
|
||||
|
||||
self.num_workers = num_workers
|
||||
self.num_cpus_per_worker = num_cpus_per_worker
|
||||
self.num_gpus_per_worker = num_gpus_per_worker
|
||||
self.additional_resources_per_worker = additional_resources_per_worker
|
||||
self.workers = []
|
||||
self._base_cls = create_executable_class(actor_cls)
|
||||
assert issubclass(self._base_cls, BaseWorkerMixin)
|
||||
|
||||
self._actor_cls_args = actor_cls_args or []
|
||||
self._actor_cls_kwargs = actor_cls_kwargs or {}
|
||||
|
||||
# TODO(matt): Validate resources. Fast-fail if it is impossible to
|
||||
# handle the request, rather than hang indefinitely.
|
||||
self._remote_cls = ray.remote(
|
||||
num_cpus=self.num_cpus_per_worker,
|
||||
num_gpus=self.num_gpus_per_worker,
|
||||
resources=self.additional_resources_per_worker)(BaseWorker)
|
||||
resources=self.additional_resources_per_worker)(self._base_cls)
|
||||
self.start()
|
||||
|
||||
def _create_worker(self):
|
||||
return self._remote_cls.remote()
|
||||
return self._remote_cls.remote(*self._actor_cls_args,
|
||||
**self._actor_cls_kwargs)
|
||||
|
||||
def start(self):
|
||||
"""Starts all the workers in this worker group."""
|
||||
|
@ -146,7 +179,10 @@ class WorkerGroup:
|
|||
"group has most likely been shut down. Please"
|
||||
"create a new WorkerGroup or restart this one.")
|
||||
|
||||
return [w.execute.remote(func, *args, **kwargs) for w in self.workers]
|
||||
return [
|
||||
w._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
|
||||
for w in self.workers
|
||||
]
|
||||
|
||||
def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]:
|
||||
"""Execute ``func`` on each worker and return the outputs of ``func``.
|
||||
|
@ -178,7 +214,8 @@ class WorkerGroup:
|
|||
if worker_index >= len(self.workers):
|
||||
raise ValueError(f"The provided worker_index {worker_index} is "
|
||||
f"not valid for {self.num_workers} workers.")
|
||||
return self.workers[worker_index].execute.remote(func, *args, **kwargs)
|
||||
return self.workers[worker_index]._BaseWorkerMixin__execute.remote(
|
||||
func, *args, **kwargs)
|
||||
|
||||
def execute_single(self, worker_index: int, func: Callable[..., T], *args,
|
||||
**kwargs) -> T:
|
||||
|
|
Loading…
Add table
Reference in a new issue