From de050e81874035a4279f007a913a7ebd32317209 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Thu, 16 Sep 2021 12:33:38 -0700 Subject: [PATCH] [SGD] v2 Class API (#18571) * wip * wip * add horovod example * add example * lint * fix * address comments * updates * lint * update example * address comment * address comment * update * fix * Update python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py Co-authored-by: matthewdeng * address comments * add back name mangling * fix tests * Update python/ray/util/sgd/v2/trainer.py * fix * lint * fix * fix docstring * Update python/ray/util/sgd/v2/tests/test_trainer.py Co-authored-by: matthewdeng * update Co-authored-by: matthewdeng --- python/ray/util/sgd/v2/backends/backend.py | 18 +- .../horovod/horovod_stateful_example.py | 225 ++++++++++++++++++ python/ray/util/sgd/v2/tests/test_backend.py | 4 +- python/ray/util/sgd/v2/tests/test_trainer.py | 54 ++++- python/ray/util/sgd/v2/trainer.py | 130 ++++++---- python/ray/util/sgd/v2/worker_group.py | 63 ++++- 6 files changed, 433 insertions(+), 61 deletions(-) create mode 100644 python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index b870c78b1..98c7ee7c5 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -3,7 +3,7 @@ import logging import os from collections import defaultdict from pathlib import Path -from typing import Callable, TypeVar, List, Optional, Dict, Union +from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple import ray from ray import cloudpickle @@ -257,11 +257,21 @@ class BackendExecutor: self.checkpoint_manager.on_init() - def start(self, initialization_hook: Optional[Callable[[], None]] = None): + def start(self, + initialization_hook: Optional[Callable[[], None]] = None, + train_cls: Optional[Type] = None, + train_cls_args: Optional[Tuple] = None, + train_cls_kwargs: Optional[Dict] = None): """Starts the worker group.""" self.worker_group = WorkerGroup( - self._num_workers, self._num_cpus_per_worker, - self._num_gpus_per_worker, self._additional_resources_per_worker) + num_workers=self._num_workers, + num_cpus_per_worker=self._num_cpus_per_worker, + num_gpus_per_worker=self._num_gpus_per_worker, + additional_resources_per_worker=self. + _additional_resources_per_worker, + actor_cls=train_cls, + actor_cls_args=train_cls_args, + actor_cls_kwargs=train_cls_kwargs) try: if initialization_hook: self._initialization_hook = initialization_hook diff --git a/python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py b/python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py new file mode 100644 index 000000000..20bcc9716 --- /dev/null +++ b/python/ray/util/sgd/v2/examples/horovod/horovod_stateful_example.py @@ -0,0 +1,225 @@ +import argparse +import os + +import horovod.torch as hvd +import ray +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data.distributed +from filelock import FileLock +from ray.util.sgd.v2 import Trainer +from torchvision import datasets, transforms + + +def metric_average(val, name): + tensor = torch.tensor(val) + avg_tensor = hvd.allreduce(tensor, name=name) + return avg_tensor.item() + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +class TrainClass: + def __init__(self, config): + self.model = Net() + self.use_cuda = config.get("use_cuda", False) + + data_dir = config.get("data_dir", None) + seed = config.get("seed", 42) + batch_size = config.get("batch_size", 64) + use_adasum = config.get("use_adasum", False) + lr = config.get("lr", 0.01) + momentum = config.get("momentum", 0.5) + + # Horovod: initialize library. + hvd.init() + torch.manual_seed(seed) + + if self.use_cuda: + # Horovod: pin GPU to local rank. + torch.cuda.set_device(hvd.local_rank()) + torch.cuda.manual_seed(seed) + + # Horovod: limit # of CPU threads to be used per worker. + torch.set_num_threads(1) + + kwargs = {"num_workers": 1, "pin_memory": True} if self.use_cuda \ + else {} + data_dir = data_dir or "~/data" + with FileLock(os.path.expanduser("~/.horovod_lock")): + train_dataset = \ + datasets.MNIST(data_dir, train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) + # Horovod: use DistributedSampler to partition the training data. + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) + self.train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + sampler=self.train_sampler, + **kwargs) + + model = Net() + + # By default, Adasum doesn't need scaling up learning rate. + lr_scaler = hvd.size() if not use_adasum else 1 + + if self.use_cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + + # Horovod: scale learning rate by lr_scaler. + optimizer = optim.SGD( + model.parameters(), lr=lr * lr_scaler, momentum=momentum) + + # Horovod: wrap optimizer with DistributedOptimizer. + self.optimizer = hvd.DistributedOptimizer( + optimizer, + named_parameters=model.named_parameters(), + op=hvd.Adasum if use_adasum else hvd.Average) + + def train(self, epoch): + self.model.train() + # Horovod: set epoch to sampler for shuffling. + self.train_sampler.set_epoch(epoch) + num_batches = len(self.train_loader) + for batch_idx, (data, target) in enumerate(self.train_loader): + if self.use_cuda: + data, target = data.cuda(), target.cuda() + self.optimizer.zero_grad() + output = self.model(data) + loss = F.nll_loss(output, target) + loss.backward() + self.optimizer.step() + if batch_idx == num_batches - 1: + return loss.item() + + +def main(num_workers, use_gpu, num_epochs, config): + trainer = Trainer("horovod", use_gpu=use_gpu, num_workers=num_workers) + trainer.start() + workers = trainer.to_worker_group(TrainClass, config) + results = [] + for epoch in range(num_epochs): + loss = ray.get([w.train.remote(epoch=epoch) for w in workers]) + results.append(loss) + trainer.shutdown() + print(results) + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser( + description="PyTorch MNIST Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)") + parser.add_argument( + "--num-epochs", + type=int, + default=5, + metavar="N", + help="number of epochs to train (default: 5)") + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)") + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="enables CUDA training") + parser.add_argument( + "--seed", + type=int, + default=42, + metavar="S", + help="random seed (default: 42)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status") + parser.add_argument( + "--use-adasum", + action="store_true", + default=False, + help="use adasum algorithm to do reduction") + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of Ray workers to use for training.") + parser.add_argument( + "--data-dir", + help="location of the training dataset in the local filesystem (" + "will be downloaded if needed)") + parser.add_argument( + "--address", + required=False, + type=str, + default=None, + help="Address of Ray cluster.") + + args = parser.parse_args() + + if args.address: + ray.init(args.address) + else: + ray.init() + + use_cuda = args.use_gpu if args.use_gpu is not None else False + + kwargs = { + "data_dir": args.data_dir, + "seed": args.seed, + "use_cuda": use_cuda, + "batch_size": args.batch_size, + "use_adasum": args.use_adasum if args.use_adasum else False, + "lr": args.lr, + "momentum": args.momentum, + "log_interval": args.log_interval + } + + main( + num_workers=args.num_workers, + use_gpu=use_cuda, + num_epochs=args.num_epochs, + config=kwargs) diff --git a/python/ray/util/sgd/v2/tests/test_backend.py b/python/ray/util/sgd/v2/tests/test_backend.py index 0b7c96c27..010e0308b 100644 --- a/python/ray/util/sgd/v2/tests/test_backend.py +++ b/python/ray/util/sgd/v2/tests/test_backend.py @@ -54,9 +54,9 @@ def ray_2_node_4_gpu(): def gen_execute_special(special_f): def execute_async_special(self, f): """Runs f on worker 0, special_f on other workers.""" - futures = [self.workers[0].execute.remote(f)] + futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)] for worker in self.workers[1:]: - futures.append(worker.execute.remote(special_f)) + futures.append(worker._BaseWorkerMixin__execute.remote(special_f)) return futures return execute_async_special diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index d65596f77..bc05353ae 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -17,8 +17,11 @@ from ray.util.sgd.v2.backends.backend import BackendConfig, Backend, \ from ray.util.sgd.v2.callbacks.callback import SGDCallback from ray.util.sgd.v2.examples.tensorflow_mnist_example import train_func as \ tensorflow_mnist_train_func +from ray.util.sgd.v2.examples.horovod.horovod_stateful_example import \ + TrainClass as HorovodTrainClass from ray.util.sgd.v2.examples.train_fashion_mnist_example import train_func \ - as fashion_mnist_train_func + as \ + fashion_mnist_train_func from ray.util.sgd.v2.examples.train_linear_example import train_func as \ linear_train_func @@ -87,7 +90,8 @@ def gen_execute_single_async_special(special_f): assert len(self.workers) == 2 if i == 0 and hasattr(self, "should_fail") and self.should_fail: kwargs["train_func"] = special_f - return self.workers[i].execute.remote(f, *args, **kwargs) + return self.workers[i]._BaseWorkerMixin__execute.remote( + f, *args, **kwargs) return execute_single_async_special @@ -630,6 +634,25 @@ def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus): assert worker_result[num_epochs - 1] < worker_result[0] +def test_horovod_torch_mnist_stateful(ray_start_2_cpus): + num_workers = 2 + num_epochs = 2 + trainer = Trainer("horovod", num_workers) + workers = trainer.to_worker_group( + HorovodTrainClass, config={ + "num_epochs": num_epochs, + "lr": 1e-3 + }) + results = [] + for epoch in range(num_epochs): + results.append(ray.get([w.train.remote(epoch=epoch) for w in workers])) + trainer.shutdown() + + assert len(results) == num_epochs + for i in range(num_workers): + assert results[num_epochs - 1][i] < results[0][i] + + def test_init_failure(ray_start_2_cpus): with pytest.raises(TypeError): Trainer(5) @@ -1001,6 +1024,33 @@ def test_gpu_requests(ray_start_4_cpus_4_gpus_4_extra): trainer.shutdown() +def test_to_worker_group(ray_start_2_cpus): + config = TestConfig() + trainer = Trainer(config, num_workers=2) + + class Incrementer: + def __init__(self, starting=0): + self.count = starting + + def increment(self): + self.count += 1 + + def get_count(self): + return self.count + + workers = trainer.to_worker_group(Incrementer, starting=2) + assert ray.get([w.get_count.remote() for w in workers]) == [2, 2] + + ray.get([w.increment.remote() for w in workers]) + assert ray.get([w.get_count.remote() for w in workers]) == [3, 3] + + ray.get(workers[0].increment.remote()) + assert ray.get([w.get_count.remote() for w in workers]) == [4, 3] + + ray.get(workers[1].increment.remote()) + assert ray.get([w.get_count.remote() for w in workers]) == [4, 4] + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/util/sgd/v2/trainer.py b/python/ray/util/sgd/v2/trainer.py index f30a7ec31..5cdea37b6 100644 --- a/python/ray/util/sgd/v2/trainer.py +++ b/python/ray/util/sgd/v2/trainer.py @@ -4,8 +4,9 @@ import logging import os from pathlib import Path from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \ - Type, Iterator + Type +from ray.actor import ActorHandle from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \ InactiveWorkerGroupError, SGDBackendError, TrainingWorkerError from ray.util.sgd.v2.backends.horovod import HorovodConfig @@ -18,6 +19,7 @@ from ray.util.sgd.v2.constants import TUNE_INSTALLED, DEFAULT_RESULTS_DIR, \ # Ray SGD should be usable even if Tune is not installed. from ray.util.sgd.v2.utils import construct_path +from ray.util.sgd.v2.worker_group import WorkerGroup if TUNE_INSTALLED: from ray import tune @@ -174,19 +176,12 @@ class Trainer: else: raise TypeError(f"Invalid type for backend: {type(backend)}.") - def start(self, - initialization_hook: Optional[Callable[[], None]] = None, - train_cls: Optional[S] = None, - *args, - **kwargs): + def start(self, initialization_hook: Optional[Callable[[], None]] = None): """Starts the training execution service. Args: initialization_hook (Optional[Callable]): The function to call on each worker when it is instantiated. - train_cls (Optional[cls]): The training class that each worker - should be instantiated as. - args, kwargs: The arguments to pass into ``train_cls.__init__``. """ self._executor.start(initialization_hook) @@ -258,7 +253,7 @@ class Trainer: config: Optional[Dict[str, Any]] = None, checkpoint: Optional[Union[Dict, str, Path]] = None, checkpoint_strategy: Optional[CheckpointStrategy] = None - ) -> Iterator[List[Dict]]: + ) -> "SGDIterator": """Same as ``run`` except returns an iterator over the results. This is useful if you want to have more customization of what to do @@ -343,36 +338,6 @@ class Trainer: else: # num_params == 0 return train_func - def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]: - """Executes a function for all instances of ``self.train_cls``. - - Args: - func (Callable): The function that should be executed. - The first argument should be an instance of - ``self.train_cls``. - args, kwargs: The arguments to pass into ``func``. - - Returns: - A list of results from ``func``. Each value in the - list corresponds to the output of ``func`` from - each worker. - """ - raise NotImplementedError - - def execute_single(self, func: Callable[..., T], *args, **kwargs) -> T: - """Executes a function on a single instance of ``self.train_cls``. - - Args: - func (Callable): The function that should be executed. - The first argument should be an instance of - ``self.train_cls``. - args, kwargs: The arguments to pass into ``func``. - - Returns: - The output of ``func`` from a single worker. - """ - raise NotImplementedError - @property def latest_run_dir(self) -> Optional[Path]: """Path to the log directory for the latest call to ``run()``. @@ -444,6 +409,91 @@ class Trainer: self._num_workers, self._use_gpu, self._resources_per_worker) + def to_worker_group(self, train_cls: Type, *args, + **kwargs) -> "SGDWorkerGroup": + """Returns Ray actors with the provided class and the backend started. + + This is useful if you want to provide your own class for training + and have more control over execution, but still want to use Ray SGD + to setup the appropriate backend configurations (torch, tf, etc.). + + .. code-block:: python + + class Trainer: + def __init__(self, config): + self.config = config + + def train_epoch(self): + ... + return 1 + + config = {"lr": 0.1} + trainer = Trainer(num_workers=2, backend="torch") + workers = trainer.to_worker_group(train_cls=Trainer, config=config) + futures = [w.train_epoch.remote() for w in workers] + assert ray.get(futures) == [1, 1] + assert ray.get(workers[0].train_epoch.remote()) == 1 + workers.shutdown() + + Args: + train_cls (Type): The class definition to use for the Ray + actors/workers. + args, kwargs: Arguments to pass into the ``__init__`` of the + provided ``train_cls``. + """ + if self._executor.is_started: + raise RuntimeError("The Trainer must not be active to use " + "`to_worker_group`. Either shutdown the " + "Trainer or don't start it in the first place.") + self._executor.start( + train_cls=train_cls, train_cls_args=args, train_cls_kwargs=kwargs) + return SGDWorkerGroup(self._executor.worker_group) + + +class SGDWorkerGroup: + """A container for a group of Ray actors. + + You should not instantiate this directly and only use this as the output + of ``Trainer.to_worker_group``. You can index or iterate this object like + you would a List. + + .. code-block:: python + + class Trainer: + def __init__(self, config): + self.config = config + + def train_epoch(self): + ... + return 1 + + config = {"lr": 0.1} + trainer = Trainer(num_workers=2, backend="torch") + workers = trainer.to_worker_group(train_cls=Trainer, config=config) + futures = [w.train_epoch.remote() for w in workers] + assert ray.get(futures) == [1, 1] + assert ray.get(workers[0].train_epoch.remote()) == 1 + workers.shutdown()` + """ + + def __init__(self, worker_group: WorkerGroup): + self._worker_group = worker_group + + def __getitem__(self, item) -> ActorHandle: + return self._worker_group.workers[item] + + def shutdown(self, patience_s: float = 5): + """Shutdown all the workers. + + Args: + patience_s (float): Attempt a graceful shutdown + of the workers for this many seconds. Fallback to force kill + if graceful shutdown is not complete after this time. If + this is less than or equal to 0, immediately force kill all + workers. + """ + self._worker_group.shutdown(patience_s=patience_s) + class SGDIterator: """An iterator over SGD results. Returned by ``trainer.run_iterator``.""" diff --git a/python/ray/util/sgd/v2/worker_group.py b/python/ray/util/sgd/v2/worker_group.py index 9ea545f92..0f7188f81 100644 --- a/python/ray/util/sgd/v2/worker_group.py +++ b/python/ray/util/sgd/v2/worker_group.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, List, TypeVar, Optional, Dict +from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple import ray from ray.types import ObjectRef @@ -9,10 +9,10 @@ T = TypeVar("T") logger = logging.getLogger(__name__) -class BaseWorker: +class BaseWorkerMixin: """A class to execute arbitrary functions. Does not hold any state.""" - def execute(self, func: Callable[..., T], *args, **kwargs) -> T: + def __execute(self, func: Callable[..., T], *args, **kwargs) -> T: """Executes the input function and returns the output. Args: @@ -22,6 +22,21 @@ class BaseWorker: return func(*args, **kwargs) +def create_executable_class(executable_cls: Optional[Type] = None) -> Type: + """Create the executable class to use as the Ray actors.""" + if not executable_cls: + return BaseWorkerMixin + elif issubclass(executable_cls, BaseWorkerMixin): + return executable_cls + else: + + class _WrappedExecutable(executable_cls, BaseWorkerMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + return _WrappedExecutable + + class WorkerGroup: """Group of Ray Actors that can execute arbitrary functions. @@ -43,6 +58,10 @@ class WorkerGroup: Dictionary specifying the extra resources that will be requested for each worker in addition to ``num_cpus_per_worker`` and ``num_gpus_per_worker``. + actor_cls (Optional[Type]): If specified use this class as the + remote actors. + remote_cls_args, remote_cls_kwargs: If ``remote_cls`` is provided, + these args will be used for the worker initialization. Example: @@ -55,12 +74,15 @@ class WorkerGroup: assert all(o == 1 for o in output) """ - def __init__(self, - num_workers: int = 1, - num_cpus_per_worker: float = 1, - num_gpus_per_worker: float = 0, - additional_resources_per_worker: Optional[Dict[ - str, float]] = None): + def __init__( + self, + num_workers: int = 1, + num_cpus_per_worker: float = 1, + num_gpus_per_worker: float = 0, + additional_resources_per_worker: Optional[Dict[str, float]] = None, + actor_cls: Type = None, + actor_cls_args: Optional[Tuple] = None, + actor_cls_kwargs: Optional[Dict] = None): if num_workers <= 0: raise ValueError("The provided `num_workers` must be greater " @@ -72,21 +94,32 @@ class WorkerGroup: f"num_cpus_per_worker={num_cpus_per_worker} and " f"num_gpus_per_worker={num_gpus_per_worker}.") + if (actor_cls_args or actor_cls_kwargs) and not actor_cls: + raise ValueError("`actor_cls_args` or `actor_class_kwargs` are " + "passed in but no `actor_cls` is passed in.") + self.num_workers = num_workers self.num_cpus_per_worker = num_cpus_per_worker self.num_gpus_per_worker = num_gpus_per_worker self.additional_resources_per_worker = additional_resources_per_worker self.workers = [] + self._base_cls = create_executable_class(actor_cls) + assert issubclass(self._base_cls, BaseWorkerMixin) + + self._actor_cls_args = actor_cls_args or [] + self._actor_cls_kwargs = actor_cls_kwargs or {} + # TODO(matt): Validate resources. Fast-fail if it is impossible to # handle the request, rather than hang indefinitely. self._remote_cls = ray.remote( num_cpus=self.num_cpus_per_worker, num_gpus=self.num_gpus_per_worker, - resources=self.additional_resources_per_worker)(BaseWorker) + resources=self.additional_resources_per_worker)(self._base_cls) self.start() def _create_worker(self): - return self._remote_cls.remote() + return self._remote_cls.remote(*self._actor_cls_args, + **self._actor_cls_kwargs) def start(self): """Starts all the workers in this worker group.""" @@ -146,7 +179,10 @@ class WorkerGroup: "group has most likely been shut down. Please" "create a new WorkerGroup or restart this one.") - return [w.execute.remote(func, *args, **kwargs) for w in self.workers] + return [ + w._BaseWorkerMixin__execute.remote(func, *args, **kwargs) + for w in self.workers + ] def execute(self, func: Callable[..., T], *args, **kwargs) -> List[T]: """Execute ``func`` on each worker and return the outputs of ``func``. @@ -178,7 +214,8 @@ class WorkerGroup: if worker_index >= len(self.workers): raise ValueError(f"The provided worker_index {worker_index} is " f"not valid for {self.num_workers} workers.") - return self.workers[worker_index].execute.remote(func, *args, **kwargs) + return self.workers[worker_index]._BaseWorkerMixin__execute.remote( + func, *args, **kwargs) def execute_single(self, worker_index: int, func: Callable[..., T], *args, **kwargs) -> T: