[SGD] v2 Horovod backend (#18047)

* [SGD] add Horovod backend

* address comments: set CUDA_VISIBLE_DEVICES, refactor code

* fix gpu test

* fix lint/test import

* address comments, add example cluster config

* delay horovod imports
This commit is contained in:
matthewdeng 2021-08-31 12:54:59 -07:00 committed by GitHub
parent 6133a561e9
commit a3123b6860
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 639 additions and 19 deletions

View file

@ -434,7 +434,7 @@
conditions: ["RAY_CI_SGD_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- SGD_TESTING=1 ./ci/travis/install-dependencies.sh
- SGD_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=tf,-pytorch,-py37,-flaky,-client python/ray/util/sgd/...
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-tf,pytorch,-py37,-flaky,-client python/ray/util/sgd/...
- bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=client_unit_tests --test_env=RAY_CLIENT_MODE=1 python/ray/util/sgd/...

View file

@ -42,6 +42,8 @@ MOCK_MODULES = [
"gym.spaces",
"horovod",
"horovod.ray",
"horovod.ray.runner",
"horovod.ray.utils",
"hyperopt",
"hyperopt.hp"
"kubernetes",

View file

@ -118,7 +118,8 @@ def train_fn(data_dir=None,
def main(num_workers, use_gpu, **kwargs):
settings = RayExecutor.create_settings(timeout_s=30)
executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers)
executor.run(train_fn, kwargs=kwargs)
executor.start()
executor.run(train_fn, **kwargs)
if __name__ == "__main__":
@ -133,7 +134,7 @@ if __name__ == "__main__":
metavar="N",
help="input batch size for training (default: 64)")
parser.add_argument(
"--epochs",
"--num-epochs",
type=int,
default=5,
metavar="N",
@ -151,10 +152,10 @@ if __name__ == "__main__":
metavar="M",
help="SGD momentum (default: 0.5)")
parser.add_argument(
"--no-cuda",
"--use-cuda",
action="store_true",
default=False,
help="disables CUDA training")
help="enables CUDA training")
parser.add_argument(
"--seed",
type=int,
@ -183,17 +184,10 @@ if __name__ == "__main__":
"will be downloaded if needed)")
parser.add_argument(
"--address",
require=False,
types=str,
default=None,
help="Address of Ray cluster.")
parser.add_argument(
"--server-address",
required=False,
type=str,
default=None,
required=False,
help="The address of server to connect to if using "
"Ray Client.")
help="Address of Ray cluster.")
args = parser.parse_args()
@ -201,8 +195,6 @@ if __name__ == "__main__":
if args.address:
ray.init(args.address)
elif args.server_address:
ray.init(f"ray://{args.server_address}")
else:
ray.init()

View file

@ -1,4 +1,6 @@
import logging
import os
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Callable, TypeVar, List, Optional, Dict, Union
@ -12,7 +14,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
DEFAULT_RESULTS_DIR
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.utils import construct_path
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids
from ray.util.sgd.v2.worker_group import WorkerGroup
T = TypeVar("T")
@ -105,8 +107,64 @@ class BackendExecutor:
self._num_gpus_per_worker)
if initialization_hook:
self.worker_group.execute(initialization_hook)
if self._num_gpus_per_worker > 0:
self._setup_gpus()
self._backend.on_start(self.worker_group, self._backend_config)
def _setup_gpus(self):
"""Sets CUDA_VISIBLE_DEVICES on all workers.
For each worker, CUDA_VISIBLE_DEVICES will be set to the GPU IDs
visible to all workers on that worker's node.
This allows GPU workers on the same node to communicate with one
another.
Example:
Setup:
- Node1:
- Worker1: {0, 1}
- Worker2: {2, 3}
- Node2:
- Worker3: {0, 1}
CUDA_VISIBLE_DEVICES:
- Worker1: "0,1,2,3"
- Worker2: "0,1,2,3"
- Worker2: "0,1"
"""
def get_node_id_and_gpu():
node_id = get_node_id()
gpu_ids = get_gpu_ids()
return node_id, gpu_ids
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
node_id_to_worker_id = defaultdict(set)
node_id_to_gpu_ids = defaultdict(set)
for worker_id, (node_id, gpu_ids) in enumerate(node_ids_and_gpu_ids):
node_id_to_worker_id[node_id].add(worker_id)
node_id_to_gpu_ids[node_id].update(gpu_ids)
futures = []
for node_id, gpu_ids in node_id_to_gpu_ids.items():
all_gpu_ids = ",".join([str(gpu_id) for gpu_id in gpu_ids])
def set_gpu_ids():
os.environ["CUDA_VISIBLE_DEVICES"] = all_gpu_ids
for worker_id in node_id_to_worker_id[node_id]:
futures.append(
self.worker_group.execute_single_async(
worker_id, set_gpu_ids))
ray.get(futures)
def start_training(
self,
train_func: Callable[[], T],

View file

@ -0,0 +1,101 @@
import logging
import os
from dataclasses import dataclass
from typing import Optional, Set
import ray
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars
from ray.util.sgd.v2.worker_group import WorkerGroup
try:
from horovod.ray.runner import Coordinator
from horovod.ray.utils import detect_nics, nics_to_env_var
except ImportError:
Coordinator = None
detect_nics = None
nics_to_env_var = None
logger = logging.getLogger(__name__)
@dataclass
class HorovodConfig(BackendConfig):
"""Configurations for Horovod setup.
See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501
Args:
nics (Optional[Set[str]): Network interfaces that can be used for
communication.
verbose (int): Horovod logging verbosity.
"""
nics: Optional[Set[str]] = None
verbose: int = 1
def __post_init__(self):
if Coordinator is None:
raise ValueError(
"`horovod[ray]` is not installed. "
"Please install 'horovod[ray]' to use this backend.")
@property
def backend_cls(self):
return HorovodBackend
def init_env_vars(world_rank: int, world_size: int):
"""Initialize Horovod environment variables."""
os.environ["HOROVOD_HOSTNAME"] = get_node_id()
os.environ["HOROVOD_RANK"] = str(world_rank)
os.environ["HOROVOD_SIZE"] = str(world_size)
class HorovodBackend(BackendInterface):
def on_start(self, worker_group: WorkerGroup,
backend_config: HorovodConfig):
# TODO(matt): Implement placement group strategies in BackendExecutor.
# Initialize workers with Horovod environment variables
setup_futures = []
for rank in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(rank, init_env_vars, rank,
len(worker_group)))
ray.get(setup_futures)
# Use Horovod Ray Coordinator
# backend_config as settings
self.coordinator = Coordinator(backend_config)
# Get all the hostnames of all workers
node_ids = worker_group.execute(get_node_id)
hostnames = worker_group.execute(get_hostname)
# Register each hostname to the coordinator. assumes the hostname
# ordering is the same.
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
self.coordinator.register(hostname, node_id, rank)
all_info = self.coordinator.finalize_registration()
setup_futures = []
for rank, local_cross_env_var in all_info.items():
setup_futures.append(
worker_group.execute_single_async(rank, update_env_vars,
local_cross_env_var))
ray.get(setup_futures)
coordinator_envs = self.coordinator.establish_rendezvous()
nics = detect_nics(
backend_config,
all_host_names=list(self.coordinator.hostnames),
node_workers=worker_group.workers)
coordinator_envs.update(nics_to_env_var(nics))
worker_group.execute(update_env_vars, coordinator_envs)
def on_shutdown(self, worker_group: WorkerGroup,
backend_config: HorovodConfig):
# Currently no additional steps are needed
pass

View file

@ -0,0 +1,55 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: horovod-cluster
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 3
max_workers: 3
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
available_node_types:
ray.head.default:
min_workers: 0
max_workers: 0
resources: {}
node_config:
InstanceType: g3.8xlarge
ImageId: latest_dlami
InstanceMarketOptions:
MarketType: spot
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
VolumeSize: 300
ray.worker.default:
min_workers: 3
max_workers: 3
resources: {}
node_config:
InstanceType: g3.8xlarge
ImageId: latest_dlami
InstanceMarketOptions:
MarketType: spot
BlockDeviceMappings:
- DeviceName: /dev/sda1
Ebs:
VolumeSize: 300
setup_commands:
# This replaces the standard anaconda Ray installation
- pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp37-cp37m-manylinux2014_x86_64.whl
- pip install ray[tune]
# Install Horovod
- HOROVOD_WITH_GLOO=1 HOROVOD_GPU_OPERATIONS=NCCL HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_TENSORFLOW=1 HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITH_PYTORCH=1 pip install torch torchvision horovod

View file

@ -0,0 +1,221 @@
import argparse
import os
import horovod.torch as hvd
import ray
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.distributed
from filelock import FileLock
from ray.util.sgd.v2 import Trainer
from torchvision import datasets, transforms
def metric_average(val, name):
tensor = torch.tensor(val)
avg_tensor = hvd.allreduce(tensor, name=name)
return avg_tensor.item()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train_func(config):
data_dir = config.get("data_dir", None)
seed = config.get("seed", 42)
use_cuda = config.get("use_cuda", False)
batch_size = config.get("batch_size", 64)
use_adasum = config.get("use_adasum", False)
lr = config.get("lr", 0.01)
momentum = config.get("momentum", 0.5)
num_epochs = config.get("num_epochs", 10)
log_interval = config.get("log_interval", 10)
# Horovod: initialize library.
hvd.init()
torch.manual_seed(seed)
if use_cuda:
# Horovod: pin GPU to local rank.
torch.cuda.set_device(hvd.local_rank())
torch.cuda.manual_seed(seed)
# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(1)
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
data_dir = data_dir or "~/data"
with FileLock(os.path.expanduser("~/.horovod_lock")):
train_dataset = \
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
# Horovod: use DistributedSampler to partition the training data.
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs)
model = Net()
# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not use_adasum else 1
if use_cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()
# Horovod: scale learning rate by lr_scaler.
optimizer = optim.SGD(
model.parameters(), lr=lr * lr_scaler, momentum=momentum)
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
op=hvd.Adasum if use_adasum else hvd.Average)
results = []
for epoch in range(1, num_epochs + 1):
model.train()
# Horovod: set epoch to sampler for shuffling.
train_sampler.set_epoch(epoch)
num_batches = len(train_loader)
for batch_idx, (data, target) in enumerate(train_loader):
if use_cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
# Horovod: use train_sampler to determine the number of
# examples in this worker's partition.
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch, batch_idx * len(data), len(train_sampler),
100. * batch_idx / len(train_loader), loss.item()))
if batch_idx == num_batches - 1:
results.append(loss.item())
return results
def main(num_workers, use_gpu, kwargs):
trainer = Trainer("horovod", use_gpu=use_gpu, num_workers=num_workers)
trainer.start()
loss_per_epoch = trainer.run(train_func, config=kwargs)
trainer.shutdown()
print(loss_per_epoch)
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(
description="PyTorch MNIST Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)")
parser.add_argument(
"--num-epochs",
type=int,
default=5,
metavar="N",
help="number of epochs to train (default: 10)")
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)")
parser.add_argument(
"--momentum",
type=float,
default=0.5,
metavar="M",
help="SGD momentum (default: 0.5)")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="enables CUDA training")
parser.add_argument(
"--seed",
type=int,
default=42,
metavar="S",
help="random seed (default: 42)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status")
parser.add_argument(
"--use-adasum",
action="store_true",
default=False,
help="use adasum algorithm to do reduction")
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="Number of Ray workers to use for training.")
parser.add_argument(
"--data-dir",
help="location of the training dataset in the local filesystem ("
"will be downloaded if needed)")
parser.add_argument(
"--address",
required=False,
type=str,
default=None,
help="Address of Ray cluster.")
args = parser.parse_args()
if args.address:
ray.init(args.address)
else:
ray.init()
use_cuda = args.use_gpu if args.use_gpu is not None else False
kwargs = {
"data_dir": args.data_dir,
"seed": args.seed,
"use_cuda": use_cuda,
"batch_size": args.batch_size,
"use_adasum": args.use_adasum if args.use_adasum else False,
"lr": args.lr,
"momentum": args.momentum,
"num_epochs": args.num_epochs,
"log_interval": args.log_interval
}
main(num_workers=args.num_workers, use_gpu=use_cuda, kwargs=kwargs)

View file

@ -4,6 +4,7 @@ import pytest
from unittest.mock import patch
import ray
from ray.cluster_utils import Cluster
from ray.util.sgd import v2 as sgd
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
@ -22,6 +23,34 @@ def ray_start_2_cpus():
ray.shutdown()
@pytest.fixture
def ray_2_node_2_gpu():
cluster = Cluster()
for _ in range(2):
cluster.add_node(num_cpus=2, num_gpus=2)
ray.init(address=cluster.address)
yield
ray.shutdown()
cluster.shutdown()
@pytest.fixture
def ray_2_node_4_gpu():
cluster = Cluster()
for _ in range(2):
cluster.add_node(num_cpus=2, num_gpus=4)
ray.init(address=cluster.address)
yield
ray.shutdown()
cluster.shutdown()
def gen_execute_special(special_f):
def execute_async_special(self, f):
"""Runs f on worker 0, special_f on worker 1."""
@ -249,6 +278,80 @@ def test_torch_start_shutdown(ray_start_2_cpus, init_method):
assert not any(e.finish_training())
@pytest.mark.parametrize("worker_results", [(1, ["0"]), (2, ["0,1", "0,1"]),
(3, ["0", "0,1", "0,1"]),
(4, ["0,1", "0,1", "0,1", "0,1"])])
def test_cuda_visible_devices(ray_2_node_2_gpu, worker_results):
config = TestConfig()
def get_resources():
return os.environ["CUDA_VISIBLE_DEVICES"]
num_workers, expected_results = worker_results
e = BackendExecutor(
config,
num_workers=num_workers,
num_cpus_per_worker=0,
num_gpus_per_worker=1)
e.start()
e.start_training(get_resources)
results = e.finish_training()
results.sort()
assert results == expected_results
@pytest.mark.parametrize(
"worker_results",
[(1, ["0"]), (2, ["0", "0"]), (3, ["0,1", "0,1", "0,1"]),
(4, ["0,1", "0,1", "0,1", "0,1"]), (5, ["0", "0,1", "0,1", "0,1", "0,1"]),
(6, ["0", "0", "0,1", "0,1", "0,1", "0,1"]),
(7, ["0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1"]),
(8, ["0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1", "0,1"])])
def test_cuda_visible_devices_fractional(ray_2_node_2_gpu, worker_results):
config = TestConfig()
def get_resources():
return os.environ["CUDA_VISIBLE_DEVICES"]
num_workers, expected_results = worker_results
e = BackendExecutor(
config,
num_workers=num_workers,
num_cpus_per_worker=0,
num_gpus_per_worker=0.5)
e.start()
e.start_training(get_resources)
results = e.finish_training()
results.sort()
assert results == expected_results
@pytest.mark.parametrize("worker_results",
[(1, ["0,1"]), (2, ["0,1,2,3", "0,1,2,3"]),
(3, ["0,1", "0,1,2,3", "0,1,2,3"]),
(4, ["0,1,2,3", "0,1,2,3", "0,1,2,3", "0,1,2,3"])])
def test_cuda_visible_devices_multiple(ray_2_node_4_gpu, worker_results):
config = TestConfig()
def get_resources():
return os.environ["CUDA_VISIBLE_DEVICES"]
num_workers, expected_results = worker_results
e = BackendExecutor(
config,
num_workers=num_workers,
num_cpus_per_worker=0,
num_gpus_per_worker=2)
e.start()
e.start_training(get_resources)
results = e.finish_training()
results.sort()
assert results == expected_results
if __name__ == "__main__":
import pytest
import sys

View file

@ -2,6 +2,7 @@ import time
from pathlib import Path
from unittest.mock import patch
import horovod.torch as hvd_torch
import pytest
import ray
import ray.util.sgd.v2 as sgd
@ -17,6 +18,9 @@ from ray.util.sgd.v2.examples.train_fashion_mnist import train_func as \
fashion_mnist_train_func
from ray.util.sgd.v2.examples.train_linear import train_func as \
linear_train_func
from ray.util.sgd.v2.examples.horovod.horovod_example import train_func as \
horovod_torch_train_func
from ray.util.sgd.v2.worker_group import WorkerGroup
@ -537,6 +541,61 @@ def test_torch_fashion_mnist_gpu(ray_start_2_cpus_2_gpus):
assert result[-1] < result[0]
def test_horovod_simple(ray_start_2_cpus):
def simple_fn():
hvd_torch.init()
return hvd_torch.rank()
num_workers = 2
trainer = Trainer("horovod", num_workers)
trainer.start()
result = trainer.run(simple_fn)
trainer.shutdown()
assert result == list(range(num_workers))
def test_horovod_torch_mnist(ray_start_2_cpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers)
trainer.start()
results = trainer.run(
horovod_torch_train_func,
config={
"num_epochs": num_epochs,
"lr": 1e-3
})
trainer.shutdown()
assert len(results) == num_workers
for worker_result in results:
assert len(worker_result) == num_epochs
assert worker_result[num_epochs - 1] < worker_result[0]
@pytest.mark.skipif(
torch.cuda.device_count() < 2,
reason="Only run if multiple GPUs are available.")
def test_horovod_torch_mnist_gpu(ray_start_2_cpus_2_gpus):
num_workers = 2
num_epochs = 2
trainer = Trainer("horovod", num_workers, use_gpu=True)
trainer.start()
results = trainer.run(
horovod_torch_train_func,
config={
"num_epochs": num_epochs,
"lr": 1e-3
})
trainer.shutdown()
assert len(results) == num_workers
for worker_result in results:
assert len(worker_result) == num_epochs
assert worker_result[num_epochs - 1] < worker_result[0]
def test_init_failure(ray_start_2_cpus):
with pytest.raises(TypeError):
Trainer(5)

View file

@ -6,6 +6,7 @@ from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, \
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendExecutor, \
InactiveWorkerGroupError, SGDBackendError
from ray.util.sgd.v2.backends.horovod import HorovodConfig
from ray.util.sgd.v2.backends.tensorflow import TensorflowConfig
from ray.util.sgd.v2.backends.torch import TorchConfig
from ray.util.sgd.v2.callbacks.callback import SGDCallback
@ -33,6 +34,7 @@ S = TypeVar("S")
logger = logging.getLogger(__name__)
BACKEND_NAME_TO_CONFIG_CLS = {
"horovod": HorovodConfig,
"tensorflow": TensorflowConfig,
"torch": TorchConfig
}

View file

@ -1,8 +1,9 @@
import os
from contextlib import closing
import socket
from pathlib import Path
from threading import Thread
from typing import Tuple
from typing import Tuple, Dict, List, Any
import ray
@ -49,3 +50,28 @@ class PropagatingThread(Thread):
if self.exc:
raise self.exc
return self.ret
def get_node_id() -> str:
"""Returns the ID of the node that this worker is on."""
return ray.get_runtime_context().node_id.hex()
def get_hostname() -> str:
"""Returns the hostname that this worker is on."""
return socket.gethostname()
def get_gpu_ids() -> List[int]:
"""Return list of CUDA device IDs available to this worker."""
return ray.get_gpu_ids()
def update_env_vars(env_vars: Dict[str, Any]):
"""Updates the environment variables on this worker process.
Args:
env_vars (Dict): Environment variables to set.
"""
sanitized = {k: str(v) for k, v in env_vars.items()}
os.environ.update(sanitized)

View file

@ -14,8 +14,9 @@ class BaseWorker:
def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
"""Executes the input function and returns the output.
Args:
func(Callable): The function to execute.
func (Callable): The function to execute.
args, kwargs: The arguments to pass into func.
"""
return func(*args, **kwargs)