[SGD] Retry sgd.local_rank() (#18824)

* finish

* fix

* wip

* address comment

* update

* fix test

* fix failing test

* address comments

* fix test

* fix
This commit is contained in:
Amog Kamsetty 2021-09-22 15:48:38 -07:00 committed by GitHub
parent 73c3cff18b
commit 00dd190df9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 224 additions and 56 deletions

View file

@ -153,6 +153,10 @@ system. Let's take following simple examples:
results = trainer.run(train_func_distributed) results = trainer.run(train_func_distributed)
trainer.shutdown() trainer.shutdown()
See :ref:`sgd-porting-code` for a more comprehensive example.
.. group-tab:: TensorFlow .. group-tab:: TensorFlow
This example shows how you can use RaySGD to set up `Multi-worker training This example shows how you can use RaySGD to set up `Multi-worker training
@ -250,4 +254,7 @@ system. Let's take following simple examples:
trainer.shutdown() trainer.shutdown()
See :ref:`sgd-porting-code` for a more comprehensive example.
**Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`! **Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`!

View file

@ -83,6 +83,30 @@ training.
sampler=DistributedSampler(dataset)) sampler=DistributedSampler(dataset))
**Step 3:** Set the proper CUDA device if you are using GPUs.
If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function.
This involves 3 steps:
1. Use the local rank to set the default CUDA device for the worker.
2. Move the model to the default CUDA device (or a specific CUDA device).
3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``.
.. code-block:: python
def train_func():
device = torch.device(f"cuda:{sgd.local_rank()}" if
torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
# Create model.
model = NeuralNetwork()
model = model.to(device)
model = DistributedDataParallel(
model,
device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)
.. group-tab:: TensorFlow .. group-tab:: TensorFlow
.. note:: .. note::

View file

@ -73,7 +73,7 @@ py_test(
py_test( py_test(
name = "test_worker_group", name = "test_worker_group",
size = "small", size = "medium",
srcs = ["tests/test_worker_group.py"], srcs = ["tests/test_worker_group.py"],
tags = ["team:ml", "exclusive"], tags = ["team:ml", "exclusive"],
deps = [":sgd_v2_lib"] deps = [":sgd_v2_lib"]

View file

@ -3,11 +3,11 @@ from ray.util.sgd.v2.backends import (BackendConfig, HorovodConfig,
from ray.util.sgd.v2.callbacks import SGDCallback from ray.util.sgd.v2.callbacks import SGDCallback
from ray.util.sgd.v2.checkpoint import CheckpointStrategy from ray.util.sgd.v2.checkpoint import CheckpointStrategy
from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report, from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report,
world_rank) world_rank, local_rank)
from ray.util.sgd.v2.trainer import Trainer, SGDIterator from ray.util.sgd.v2.trainer import Trainer, SGDIterator
__all__ = [ __all__ = [
"BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint", "BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint",
"report", "save_checkpoint", "SGDCallback", "SGDIterator", "local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator",
"TensorflowConfig", "TorchConfig", "Trainer", "world_rank" "TensorflowConfig", "TorchConfig", "Trainer", "world_rank"
] ]

View file

@ -15,8 +15,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
TUNE_CHECKPOINT_ID TUNE_CHECKPOINT_ID
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids, \ from ray.util.sgd.v2.utils import construct_path, check_for_failure
check_for_failure
from ray.util.sgd.v2.worker_group import WorkerGroup from ray.util.sgd.v2.worker_group import WorkerGroup
if TUNE_INSTALLED: if TUNE_INSTALLED:
@ -309,12 +308,8 @@ class BackendExecutor:
""" """
def get_node_id_and_gpu(): node_ids_and_gpu_ids = [(w.metadata.node_id, w.metadata.gpu_ids)
node_id = get_node_id() for w in self.worker_group.workers]
gpu_ids = get_gpu_ids()
return node_id, gpu_ids
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
node_id_to_worker_id = defaultdict(set) node_id_to_worker_id = defaultdict(set)
node_id_to_gpu_ids = defaultdict(set) node_id_to_gpu_ids = defaultdict(set)
@ -336,6 +331,37 @@ class BackendExecutor:
worker_id, set_gpu_ids)) worker_id, set_gpu_ids))
ray.get(futures) ray.get(futures)
def _create_local_rank_map(self) -> Dict:
"""Create mapping from worker world_rank to local_rank.
Example:
Worker 0: 0.0.0.0
Worker 1: 0.0.0.0
Worker 2: 0.0.0.1
Worker 3: 0.0.0.0
Worker 4: 0.0.0.1
Workers 0, 1, 3 are on 0.0.0.0.
Workers 2, 4 are on 0.0.0.1.
Expected Output:
{
0 -> 0,
1 -> 1,
2 -> 0,
3 -> 2,
4 -> 1
}
"""
rank_mapping = {}
ip_dict = defaultdict(int)
for world_rank in range(len(self.worker_group)):
worker = self.worker_group.workers[world_rank]
node_ip = worker.metadata.node_ip
rank_mapping[world_rank] = ip_dict[node_ip]
ip_dict[node_ip] += 1
return rank_mapping
def start_training( def start_training(
self, self,
train_func: Callable[[], T], train_func: Callable[[], T],
@ -371,11 +397,12 @@ class BackendExecutor:
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0) ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
# First initialize the session. # First initialize the session.
def initialize_session(world_rank, train_func, checkpoint): def initialize_session(world_rank, local_rank, train_func, checkpoint):
try: try:
init_session( init_session(
training_func=train_func, training_func=train_func,
world_rank=world_rank, world_rank=world_rank,
local_rank=local_rank,
checkpoint=checkpoint, checkpoint=checkpoint,
detailed_autofilled_metrics=use_detailed_autofilled_metrics detailed_autofilled_metrics=use_detailed_autofilled_metrics
) )
@ -388,6 +415,8 @@ class BackendExecutor:
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint) checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
local_rank_map = self._create_local_rank_map()
futures = [] futures = []
for world_rank in range(len(self.worker_group)): for world_rank in range(len(self.worker_group)):
futures.append( futures.append(
@ -395,6 +424,7 @@ class BackendExecutor:
world_rank, world_rank,
initialize_session, initialize_session,
world_rank=world_rank, world_rank=world_rank,
local_rank=local_rank_map[world_rank],
train_func=train_func, train_func=train_func,
checkpoint=checkpoint_dict)) checkpoint=checkpoint_dict))

View file

@ -5,7 +5,7 @@ from typing import Optional, Set
import ray import ray
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend from ray.util.sgd.v2.backends.backend import BackendConfig, Backend
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars from ray.util.sgd.v2.utils import update_env_vars
from ray.util.sgd.v2.worker_group import WorkerGroup from ray.util.sgd.v2.worker_group import WorkerGroup
try: try:
@ -44,9 +44,9 @@ class HorovodConfig(BackendConfig):
return HorovodBackend return HorovodBackend
def init_env_vars(world_rank: int, world_size: int): def init_env_vars(world_rank: int, world_size: int, node_id: str):
"""Initialize Horovod environment variables.""" """Initialize Horovod environment variables."""
os.environ["HOROVOD_HOSTNAME"] = get_node_id() os.environ["HOROVOD_HOSTNAME"] = node_id
os.environ["HOROVOD_RANK"] = str(world_rank) os.environ["HOROVOD_RANK"] = str(world_rank)
os.environ["HOROVOD_SIZE"] = str(world_size) os.environ["HOROVOD_SIZE"] = str(world_size)
@ -60,9 +60,11 @@ class HorovodBackend(Backend):
# Initialize workers with Horovod environment variables # Initialize workers with Horovod environment variables
setup_futures = [] setup_futures = []
for rank in range(len(worker_group)): for rank in range(len(worker_group)):
worker_node_id = worker_group.workers[rank].metadata.node_id
setup_futures.append( setup_futures.append(
worker_group.execute_single_async(rank, init_env_vars, rank, worker_group.execute_single_async(rank, init_env_vars, rank,
len(worker_group))) len(worker_group),
worker_node_id))
ray.get(setup_futures) ray.get(setup_futures)
# Use Horovod Ray Coordinator # Use Horovod Ray Coordinator
@ -70,8 +72,8 @@ class HorovodBackend(Backend):
self.coordinator = Coordinator(backend_config) self.coordinator = Coordinator(backend_config)
# Get all the hostnames of all workers # Get all the hostnames of all workers
node_ids = worker_group.execute(get_node_id) node_ids = [w.metadata.node_id for w in worker_group.workers]
hostnames = worker_group.execute(get_hostname) hostnames = [w.metadata.hostname for w in worker_group.workers]
# Register each hostname to the coordinator. assumes the hostname # Register each hostname to the coordinator. assumes the hostname
# ordering is the same. # ordering is the same.
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)): for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):

View file

@ -86,6 +86,9 @@ def train_func(config: Dict):
lr = config["lr"] lr = config["lr"]
epochs = config["epochs"] epochs = config["epochs"]
device = torch.device(f"cuda:{sgd.local_rank()}"
if torch.cuda.is_available() else "cpu")
# Create data loaders. # Create data loaders.
train_dataloader = DataLoader( train_dataloader = DataLoader(
training_data, training_data,
@ -97,10 +100,11 @@ def train_func(config: Dict):
sampler=DistributedSampler(test_data)) sampler=DistributedSampler(test_data))
# Create model. # Create model.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNetwork() model = NeuralNetwork()
model = model.to(device) model = model.to(device)
model = DistributedDataParallel(model) model = DistributedDataParallel(
model,
device_ids=[device.index] if torch.cuda.is_available() else None)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr) optimizer = torch.optim.SGD(model.parameters(), lr=lr)

View file

@ -32,12 +32,14 @@ class Session:
def __init__(self, def __init__(self,
training_func: Callable, training_func: Callable,
world_rank: int, world_rank: int,
local_rank: int,
checkpoint: Optional[Dict] = None, checkpoint: Optional[Dict] = None,
detailed_autofilled_metrics: bool = False): detailed_autofilled_metrics: bool = False):
# The Thread object that is running the training function. # The Thread object that is running the training function.
self.training_thread = PropagatingThread( self.training_thread = PropagatingThread(
target=training_func, daemon=True) target=training_func, daemon=True)
self.world_rank = world_rank self.world_rank = world_rank
self.local_rank = local_rank
self.loaded_checkpoint = checkpoint self.loaded_checkpoint = checkpoint
# This lock is used to control the execution of the training thread. # This lock is used to control the execution of the training thread.
@ -263,6 +265,29 @@ def world_rank() -> int:
return session.world_rank return session.world_rank
def local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. code-block:: python
import time
from ray.util import sgd
def train_func():
if torch.cuda.is_available():
torch.cuda.set_device(sgd.local_rank())
...
trainer = Trainer(backend="torch", use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()
"""
session = get_session()
return session.local_rank
def load_checkpoint() -> Optional[Dict]: def load_checkpoint() -> Optional[Dict]:
"""Loads checkpoint data onto the worker. """Loads checkpoint data onto the worker.

View file

@ -54,9 +54,10 @@ def ray_2_node_4_gpu():
def gen_execute_special(special_f): def gen_execute_special(special_f):
def execute_async_special(self, f): def execute_async_special(self, f):
"""Runs f on worker 0, special_f on other workers.""" """Runs f on worker 0, special_f on other workers."""
futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)] futures = [self.workers[0].actor._BaseWorkerMixin__execute.remote(f)]
for worker in self.workers[1:]: for worker in self.workers[1:]:
futures.append(worker._BaseWorkerMixin__execute.remote(special_f)) futures.append(
worker.actor._BaseWorkerMixin__execute.remote(special_f))
return futures return futures
return execute_async_special return execute_async_special
@ -123,6 +124,18 @@ def test_train(ray_start_2_cpus, tmp_path):
assert e.finish_training() == [1, 1] assert e.finish_training() == [1, 1]
def test_local_ranks(ray_start_2_cpus, tmp_path):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()
def train():
return sgd.local_rank()
e.start_training(train, run_dir=tmp_path)
assert set(e.finish_training()) == {0, 1}
def test_train_failure(ray_start_2_cpus, tmp_path): def test_train_failure(ray_start_2_cpus, tmp_path):
config = TestConfig() config = TestConfig()
e = BackendExecutor(config, num_workers=2) e = BackendExecutor(config, num_workers=2)

View file

@ -2,7 +2,8 @@ import time
import pytest import pytest
from ray.util.sgd.v2.session import init_session, shutdown_session, \ from ray.util.sgd.v2.session import init_session, shutdown_session, \
get_session, world_rank, report, save_checkpoint, TrainingResultType, \ get_session, world_rank, local_rank, report, save_checkpoint, \
TrainingResultType, \
load_checkpoint load_checkpoint
@ -11,7 +12,7 @@ def session():
def f(): def f():
return 1 return 1
init_session(training_func=f, world_rank=0) init_session(training_func=f, world_rank=0, local_rank=0)
yield get_session() yield get_session()
shutdown_session() shutdown_session()
@ -34,6 +35,13 @@ def test_world_rank(session):
world_rank() world_rank()
def test_local_rank(session):
assert local_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
local_rank()
def test_train(session): def test_train(session):
session.start() session.start()
output = session.finish() output = session.finish()
@ -45,7 +53,7 @@ def test_report():
for i in range(2): for i in range(2):
report(loss=i) report(loss=i)
init_session(training_func=train, world_rank=0) init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session() session = get_session()
session.start() session.start()
assert session.get_next().data["loss"] == 0 assert session.get_next().data["loss"] == 0
@ -62,7 +70,7 @@ def test_report_fail():
report(i) report(i)
return 1 return 1
init_session(training_func=train, world_rank=0) init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session() session = get_session()
session.start() session.start()
assert session.get_next() is None assert session.get_next() is None
@ -96,7 +104,7 @@ def test_checkpoint():
assert next.type == TrainingResultType.CHECKPOINT assert next.type == TrainingResultType.CHECKPOINT
assert next.data["epoch"] == expected assert next.data["epoch"] == expected
init_session(training_func=train, world_rank=0) init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session() session = get_session()
session.start() session.start()
validate_zero(0) validate_zero(0)
@ -110,7 +118,7 @@ def test_checkpoint():
assert next.type == TrainingResultType.CHECKPOINT assert next.type == TrainingResultType.CHECKPOINT
assert next.data == {} assert next.data == {}
init_session(training_func=train, world_rank=1) init_session(training_func=train, world_rank=1, local_rank=1)
session = get_session() session = get_session()
session.start() session.start()
validate_nonzero() validate_nonzero()
@ -129,7 +137,7 @@ def test_load_checkpoint_after_save():
checkpoint = load_checkpoint() checkpoint = load_checkpoint()
assert checkpoint["epoch"] == i assert checkpoint["epoch"] == i
init_session(training_func=train, world_rank=0) init_session(training_func=train, world_rank=0, local_rank=0)
session = get_session() session = get_session()
session.start() session.start()
for i in range(2): for i in range(2):
@ -145,7 +153,7 @@ def test_locking():
import _thread import _thread
_thread.interrupt_main() _thread.interrupt_main()
init_session(training_func=train_1, world_rank=0) init_session(training_func=train_1, world_rank=0, local_rank=0)
session = get_session() session = get_session()
with pytest.raises(KeyboardInterrupt): with pytest.raises(KeyboardInterrupt):
session.start() session.start()
@ -156,7 +164,7 @@ def test_locking():
report(loss=i) report(loss=i)
train_1() train_1()
init_session(training_func=train_2, world_rank=0) init_session(training_func=train_2, world_rank=0, local_rank=0)
session = get_session() session = get_session()
session.start() session.start()
time.sleep(3) time.sleep(3)

View file

@ -88,7 +88,7 @@ def gen_execute_single_async_special(special_f):
assert len(self.workers) == 2 assert len(self.workers) == 2
if i == 0 and hasattr(self, "should_fail") and self.should_fail: if i == 0 and hasattr(self, "should_fail") and self.should_fail:
kwargs["train_func"] = special_f kwargs["train_func"] = special_f
return self.workers[i]._BaseWorkerMixin__execute.remote( return self.workers[i].actor._BaseWorkerMixin__execute.remote(
f, *args, **kwargs) f, *args, **kwargs)
return execute_single_async_special return execute_single_async_special
@ -126,7 +126,7 @@ class KillCallback(SGDCallback):
print(results) print(results)
assert all(r["loss"] == 1 for r in results) assert all(r["loss"] == 1 for r in results)
if self.counter == self.fail_on: if self.counter == self.fail_on:
ray.kill(self.worker_group.workers[0]) ray.kill(self.worker_group.workers[0].actor)
time.sleep(3) time.sleep(3)
self.counter += 1 self.counter += 1
@ -752,6 +752,27 @@ def test_worker_failure_2(ray_start_2_cpus):
assert results == [1, 1] assert results == [1, 1]
def test_worker_failure_local_rank(ray_start_2_cpus):
test_config = TestConfig()
def train():
return sgd.local_rank()
def train_actor_failure():
import sys
sys.exit(0)
return sgd.local_rank()
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
new_backend_executor_cls):
trainer = Trainer(test_config, num_workers=2)
trainer.start()
results = trainer.run(train)
assert set(results) == {0, 1}
def test_worker_start_failure(ray_start_2_cpus): def test_worker_start_failure(ray_start_2_cpus):
test_config = TestConfig() test_config = TestConfig()

View file

@ -480,7 +480,7 @@ class SGDWorkerGroup:
self._worker_group = worker_group self._worker_group = worker_group
def __getitem__(self, item) -> ActorHandle: def __getitem__(self, item) -> ActorHandle:
return self._worker_group.workers[item] return self._worker_group.workers[item].actor
def shutdown(self, patience_s: float = 5): def shutdown(self, patience_s: float = 5):
"""Shutdown all the workers. """Shutdown all the workers.

View file

@ -87,21 +87,6 @@ class PropagatingThread(Thread):
return self.ret return self.ret
def get_node_id() -> str:
"""Returns the ID of the node that this worker is on."""
return ray.get_runtime_context().node_id.hex()
def get_hostname() -> str:
"""Returns the hostname that this worker is on."""
return socket.gethostname()
def get_gpu_ids() -> List[int]:
"""Return list of CUDA device IDs available to this worker."""
return ray.get_gpu_ids()
def update_env_vars(env_vars: Dict[str, Any]): def update_env_vars(env_vars: Dict[str, Any]):
"""Updates the environment variables on this worker process. """Updates the environment variables on this worker process.

View file

@ -1,7 +1,10 @@
import socket
from dataclasses import dataclass
import logging import logging
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
import ray import ray
from ray.actor import ActorHandle
from ray.types import ObjectRef from ray.types import ObjectRef
T = TypeVar("T") T = TypeVar("T")
@ -22,6 +25,32 @@ class BaseWorkerMixin:
return func(*args, **kwargs) return func(*args, **kwargs)
@dataclass
class WorkerMetadata:
"""Metadata for each worker/actor.
This information is expected to stay the same throughout the lifetime of
actor.
Args:
node_id (str): ID of the node this worker is on.
node_ip (str): IP address of the node this worker is on.
hostname (str): Hostname that this worker is on.
gpu_ids (List[int]): List of CUDA IDs available to this worker.
"""
node_id: str
node_ip: str
hostname: str
gpu_ids: Optional[List[int]]
@dataclass
class Worker:
"""Class representing a Worker."""
actor: ActorHandle
metadata: WorkerMetadata
def create_executable_class(executable_cls: Optional[Type] = None) -> Type: def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
"""Create the executable class to use as the Ray actors.""" """Create the executable class to use as the Ray actors."""
if not executable_cls: if not executable_cls:
@ -37,6 +66,20 @@ def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
return _WrappedExecutable return _WrappedExecutable
def construct_metadata() -> WorkerMetadata:
"""Creates metadata for this worker.
This function is expected to be run on the actor.
"""
node_id = ray.get_runtime_context().node_id.hex()
node_ip = ray.util.get_node_ip_address()
hostname = socket.gethostname()
gpu_ids = ray.get_gpu_ids()
return WorkerMetadata(
node_id=node_id, node_ip=node_ip, hostname=hostname, gpu_ids=gpu_ids)
class WorkerGroup: class WorkerGroup:
"""Group of Ray Actors that can execute arbitrary functions. """Group of Ray Actors that can execute arbitrary functions.
@ -118,8 +161,11 @@ class WorkerGroup:
self.start() self.start()
def _create_worker(self): def _create_worker(self):
return self._remote_cls.remote(*self._actor_cls_args, actor = self._remote_cls.remote(*self._actor_cls_args,
**self._actor_cls_kwargs) **self._actor_cls_kwargs)
actor_metadata = ray.get(
actor._BaseWorkerMixin__execute.remote(construct_metadata))
return Worker(actor=actor, metadata=actor_metadata)
def start(self): def start(self):
"""Starts all the workers in this worker group.""" """Starts all the workers in this worker group."""
@ -145,9 +191,11 @@ class WorkerGroup:
logger.debug(f"Shutting down {len(self.workers)} workers.") logger.debug(f"Shutting down {len(self.workers)} workers.")
if patience_s <= 0: if patience_s <= 0:
for worker in self.workers: for worker in self.workers:
ray.kill(worker) ray.kill(worker.actor)
else: else:
done_refs = [w.__ray_terminate__.remote() for w in self.workers] done_refs = [
w.actor.__ray_terminate__.remote() for w in self.workers
]
# Wait for actors to die gracefully. # Wait for actors to die gracefully.
done, not_done = ray.wait(done_refs, timeout=patience_s) done, not_done = ray.wait(done_refs, timeout=patience_s)
if not_done: if not_done:
@ -155,7 +203,7 @@ class WorkerGroup:
"force kill.") "force kill.")
# If all actors are not able to die gracefully, then kill them. # If all actors are not able to die gracefully, then kill them.
for worker in self.workers: for worker in self.workers:
ray.kill(worker) ray.kill(worker.actor)
logger.debug("Shutdown successful.") logger.debug("Shutdown successful.")
self.workers = [] self.workers = []
@ -180,7 +228,7 @@ class WorkerGroup:
"create a new WorkerGroup or restart this one.") "create a new WorkerGroup or restart this one.")
return [ return [
w._BaseWorkerMixin__execute.remote(func, *args, **kwargs) w.actor._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
for w in self.workers for w in self.workers
] ]
@ -214,7 +262,8 @@ class WorkerGroup:
if worker_index >= len(self.workers): if worker_index >= len(self.workers):
raise ValueError(f"The provided worker_index {worker_index} is " raise ValueError(f"The provided worker_index {worker_index} is "
f"not valid for {self.num_workers} workers.") f"not valid for {self.num_workers} workers.")
return self.workers[worker_index]._BaseWorkerMixin__execute.remote( return self.workers[worker_index].actor._BaseWorkerMixin__execute\
.remote(
func, *args, **kwargs) func, *args, **kwargs)
def execute_single(self, worker_index: int, func: Callable[..., T], *args, def execute_single(self, worker_index: int, func: Callable[..., T], *args,