mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[SGD] Retry sgd.local_rank()
(#18824)
* finish * fix * wip * address comment * update * fix test * fix failing test * address comments * fix test * fix
This commit is contained in:
parent
5e9cb232c7
commit
f71cfca439
14 changed files with 224 additions and 56 deletions
|
@ -153,6 +153,10 @@ system. Let's take following simple examples:
|
|||
results = trainer.run(train_func_distributed)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
See :ref:`sgd-porting-code` for a more comprehensive example.
|
||||
|
||||
|
||||
.. group-tab:: TensorFlow
|
||||
|
||||
This example shows how you can use RaySGD to set up `Multi-worker training
|
||||
|
@ -250,4 +254,7 @@ system. Let's take following simple examples:
|
|||
trainer.shutdown()
|
||||
|
||||
|
||||
See :ref:`sgd-porting-code` for a more comprehensive example.
|
||||
|
||||
|
||||
**Next steps:** Check out the :ref:`User Guide <sgd-user-guide>`!
|
||||
|
|
|
@ -83,6 +83,30 @@ training.
|
|||
sampler=DistributedSampler(dataset))
|
||||
|
||||
|
||||
**Step 3:** Set the proper CUDA device if you are using GPUs.
|
||||
|
||||
If you are using GPUs, you need to make sure to the CUDA devices are properly setup inside your training function.
|
||||
|
||||
This involves 3 steps:
|
||||
1. Use the local rank to set the default CUDA device for the worker.
|
||||
2. Move the model to the default CUDA device (or a specific CUDA device).
|
||||
3. Specify ``device_ids`` when wrapping in ``DistributedDataParallel``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def train_func():
|
||||
device = torch.device(f"cuda:{sgd.local_rank()}" if
|
||||
torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Create model.
|
||||
model = NeuralNetwork()
|
||||
model = model.to(device)
|
||||
model = DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[sgd.local_rank()] if torch.cuda.is_available() else None)
|
||||
|
||||
|
||||
.. group-tab:: TensorFlow
|
||||
|
||||
.. note::
|
||||
|
|
|
@ -73,7 +73,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_worker_group",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_worker_group.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":sgd_v2_lib"]
|
||||
|
|
|
@ -3,11 +3,11 @@ from ray.util.sgd.v2.backends import (BackendConfig, HorovodConfig,
|
|||
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||
from ray.util.sgd.v2.checkpoint import CheckpointStrategy
|
||||
from ray.util.sgd.v2.session import (load_checkpoint, save_checkpoint, report,
|
||||
world_rank)
|
||||
world_rank, local_rank)
|
||||
from ray.util.sgd.v2.trainer import Trainer, SGDIterator
|
||||
|
||||
__all__ = [
|
||||
"BackendConfig", "CheckpointStrategy", "HorovodConfig", "load_checkpoint",
|
||||
"report", "save_checkpoint", "SGDCallback", "SGDIterator",
|
||||
"local_rank", "report", "save_checkpoint", "SGDCallback", "SGDIterator",
|
||||
"TensorflowConfig", "TorchConfig", "Trainer", "world_rank"
|
||||
]
|
||||
|
|
|
@ -15,8 +15,7 @@ from ray.util.sgd.v2.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
|
|||
TUNE_CHECKPOINT_ID
|
||||
from ray.util.sgd.v2.session import TrainingResultType, TrainingResult
|
||||
from ray.util.sgd.v2.session import init_session, get_session, shutdown_session
|
||||
from ray.util.sgd.v2.utils import construct_path, get_node_id, get_gpu_ids, \
|
||||
check_for_failure
|
||||
from ray.util.sgd.v2.utils import construct_path, check_for_failure
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
|
@ -309,12 +308,8 @@ class BackendExecutor:
|
|||
|
||||
"""
|
||||
|
||||
def get_node_id_and_gpu():
|
||||
node_id = get_node_id()
|
||||
gpu_ids = get_gpu_ids()
|
||||
return node_id, gpu_ids
|
||||
|
||||
node_ids_and_gpu_ids = self.worker_group.execute(get_node_id_and_gpu)
|
||||
node_ids_and_gpu_ids = [(w.metadata.node_id, w.metadata.gpu_ids)
|
||||
for w in self.worker_group.workers]
|
||||
|
||||
node_id_to_worker_id = defaultdict(set)
|
||||
node_id_to_gpu_ids = defaultdict(set)
|
||||
|
@ -336,6 +331,37 @@ class BackendExecutor:
|
|||
worker_id, set_gpu_ids))
|
||||
ray.get(futures)
|
||||
|
||||
def _create_local_rank_map(self) -> Dict:
|
||||
"""Create mapping from worker world_rank to local_rank.
|
||||
|
||||
Example:
|
||||
Worker 0: 0.0.0.0
|
||||
Worker 1: 0.0.0.0
|
||||
Worker 2: 0.0.0.1
|
||||
Worker 3: 0.0.0.0
|
||||
Worker 4: 0.0.0.1
|
||||
|
||||
Workers 0, 1, 3 are on 0.0.0.0.
|
||||
Workers 2, 4 are on 0.0.0.1.
|
||||
|
||||
Expected Output:
|
||||
{
|
||||
0 -> 0,
|
||||
1 -> 1,
|
||||
2 -> 0,
|
||||
3 -> 2,
|
||||
4 -> 1
|
||||
}
|
||||
"""
|
||||
rank_mapping = {}
|
||||
ip_dict = defaultdict(int)
|
||||
for world_rank in range(len(self.worker_group)):
|
||||
worker = self.worker_group.workers[world_rank]
|
||||
node_ip = worker.metadata.node_ip
|
||||
rank_mapping[world_rank] = ip_dict[node_ip]
|
||||
ip_dict[node_ip] += 1
|
||||
return rank_mapping
|
||||
|
||||
def start_training(
|
||||
self,
|
||||
train_func: Callable[[], T],
|
||||
|
@ -371,11 +397,12 @@ class BackendExecutor:
|
|||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)
|
||||
|
||||
# First initialize the session.
|
||||
def initialize_session(world_rank, train_func, checkpoint):
|
||||
def initialize_session(world_rank, local_rank, train_func, checkpoint):
|
||||
try:
|
||||
init_session(
|
||||
training_func=train_func,
|
||||
world_rank=world_rank,
|
||||
local_rank=local_rank,
|
||||
checkpoint=checkpoint,
|
||||
detailed_autofilled_metrics=use_detailed_autofilled_metrics
|
||||
)
|
||||
|
@ -388,6 +415,8 @@ class BackendExecutor:
|
|||
|
||||
checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)
|
||||
|
||||
local_rank_map = self._create_local_rank_map()
|
||||
|
||||
futures = []
|
||||
for world_rank in range(len(self.worker_group)):
|
||||
futures.append(
|
||||
|
@ -395,6 +424,7 @@ class BackendExecutor:
|
|||
world_rank,
|
||||
initialize_session,
|
||||
world_rank=world_rank,
|
||||
local_rank=local_rank_map[world_rank],
|
||||
train_func=train_func,
|
||||
checkpoint=checkpoint_dict))
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional, Set
|
|||
|
||||
import ray
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, Backend
|
||||
from ray.util.sgd.v2.utils import get_node_id, get_hostname, update_env_vars
|
||||
from ray.util.sgd.v2.utils import update_env_vars
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
try:
|
||||
|
@ -44,9 +44,9 @@ class HorovodConfig(BackendConfig):
|
|||
return HorovodBackend
|
||||
|
||||
|
||||
def init_env_vars(world_rank: int, world_size: int):
|
||||
def init_env_vars(world_rank: int, world_size: int, node_id: str):
|
||||
"""Initialize Horovod environment variables."""
|
||||
os.environ["HOROVOD_HOSTNAME"] = get_node_id()
|
||||
os.environ["HOROVOD_HOSTNAME"] = node_id
|
||||
os.environ["HOROVOD_RANK"] = str(world_rank)
|
||||
os.environ["HOROVOD_SIZE"] = str(world_size)
|
||||
|
||||
|
@ -60,9 +60,11 @@ class HorovodBackend(Backend):
|
|||
# Initialize workers with Horovod environment variables
|
||||
setup_futures = []
|
||||
for rank in range(len(worker_group)):
|
||||
worker_node_id = worker_group.workers[rank].metadata.node_id
|
||||
setup_futures.append(
|
||||
worker_group.execute_single_async(rank, init_env_vars, rank,
|
||||
len(worker_group)))
|
||||
len(worker_group),
|
||||
worker_node_id))
|
||||
ray.get(setup_futures)
|
||||
|
||||
# Use Horovod Ray Coordinator
|
||||
|
@ -70,8 +72,8 @@ class HorovodBackend(Backend):
|
|||
self.coordinator = Coordinator(backend_config)
|
||||
|
||||
# Get all the hostnames of all workers
|
||||
node_ids = worker_group.execute(get_node_id)
|
||||
hostnames = worker_group.execute(get_hostname)
|
||||
node_ids = [w.metadata.node_id for w in worker_group.workers]
|
||||
hostnames = [w.metadata.hostname for w in worker_group.workers]
|
||||
# Register each hostname to the coordinator. assumes the hostname
|
||||
# ordering is the same.
|
||||
for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)):
|
||||
|
|
|
@ -86,6 +86,9 @@ def train_func(config: Dict):
|
|||
lr = config["lr"]
|
||||
epochs = config["epochs"]
|
||||
|
||||
device = torch.device(f"cuda:{sgd.local_rank()}"
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create data loaders.
|
||||
train_dataloader = DataLoader(
|
||||
training_data,
|
||||
|
@ -97,10 +100,11 @@ def train_func(config: Dict):
|
|||
sampler=DistributedSampler(test_data))
|
||||
|
||||
# Create model.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = NeuralNetwork()
|
||||
model = model.to(device)
|
||||
model = DistributedDataParallel(model)
|
||||
model = DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[device.index] if torch.cuda.is_available() else None)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
|
|
|
@ -32,12 +32,14 @@ class Session:
|
|||
def __init__(self,
|
||||
training_func: Callable,
|
||||
world_rank: int,
|
||||
local_rank: int,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
detailed_autofilled_metrics: bool = False):
|
||||
# The Thread object that is running the training function.
|
||||
self.training_thread = PropagatingThread(
|
||||
target=training_func, daemon=True)
|
||||
self.world_rank = world_rank
|
||||
self.local_rank = local_rank
|
||||
self.loaded_checkpoint = checkpoint
|
||||
|
||||
# This lock is used to control the execution of the training thread.
|
||||
|
@ -263,6 +265,29 @@ def world_rank() -> int:
|
|||
return session.world_rank
|
||||
|
||||
|
||||
def local_rank() -> int:
|
||||
"""Get the local rank of this worker (rank of the worker on its node).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray.util import sgd
|
||||
|
||||
def train_func():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(sgd.local_rank())
|
||||
...
|
||||
|
||||
trainer = Trainer(backend="torch", use_gpu=True)
|
||||
trainer.start()
|
||||
trainer.run(train_func)
|
||||
trainer.shutdown()
|
||||
|
||||
"""
|
||||
session = get_session()
|
||||
return session.local_rank
|
||||
|
||||
|
||||
def load_checkpoint() -> Optional[Dict]:
|
||||
"""Loads checkpoint data onto the worker.
|
||||
|
||||
|
|
|
@ -54,9 +54,10 @@ def ray_2_node_4_gpu():
|
|||
def gen_execute_special(special_f):
|
||||
def execute_async_special(self, f):
|
||||
"""Runs f on worker 0, special_f on other workers."""
|
||||
futures = [self.workers[0]._BaseWorkerMixin__execute.remote(f)]
|
||||
futures = [self.workers[0].actor._BaseWorkerMixin__execute.remote(f)]
|
||||
for worker in self.workers[1:]:
|
||||
futures.append(worker._BaseWorkerMixin__execute.remote(special_f))
|
||||
futures.append(
|
||||
worker.actor._BaseWorkerMixin__execute.remote(special_f))
|
||||
return futures
|
||||
|
||||
return execute_async_special
|
||||
|
@ -123,6 +124,18 @@ def test_train(ray_start_2_cpus, tmp_path):
|
|||
assert e.finish_training() == [1, 1]
|
||||
|
||||
|
||||
def test_local_ranks(ray_start_2_cpus, tmp_path):
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
e.start()
|
||||
|
||||
def train():
|
||||
return sgd.local_rank()
|
||||
|
||||
e.start_training(train, run_dir=tmp_path)
|
||||
assert set(e.finish_training()) == {0, 1}
|
||||
|
||||
|
||||
def test_train_failure(ray_start_2_cpus, tmp_path):
|
||||
config = TestConfig()
|
||||
e = BackendExecutor(config, num_workers=2)
|
||||
|
|
|
@ -2,7 +2,8 @@ import time
|
|||
|
||||
import pytest
|
||||
from ray.util.sgd.v2.session import init_session, shutdown_session, \
|
||||
get_session, world_rank, report, save_checkpoint, TrainingResultType, \
|
||||
get_session, world_rank, local_rank, report, save_checkpoint, \
|
||||
TrainingResultType, \
|
||||
load_checkpoint
|
||||
|
||||
|
||||
|
@ -11,7 +12,7 @@ def session():
|
|||
def f():
|
||||
return 1
|
||||
|
||||
init_session(training_func=f, world_rank=0)
|
||||
init_session(training_func=f, world_rank=0, local_rank=0)
|
||||
yield get_session()
|
||||
shutdown_session()
|
||||
|
||||
|
@ -34,6 +35,13 @@ def test_world_rank(session):
|
|||
world_rank()
|
||||
|
||||
|
||||
def test_local_rank(session):
|
||||
assert local_rank() == 0
|
||||
shutdown_session()
|
||||
with pytest.raises(ValueError):
|
||||
local_rank()
|
||||
|
||||
|
||||
def test_train(session):
|
||||
session.start()
|
||||
output = session.finish()
|
||||
|
@ -45,7 +53,7 @@ def test_report():
|
|||
for i in range(2):
|
||||
report(loss=i)
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
assert session.get_next().data["loss"] == 0
|
||||
|
@ -62,7 +70,7 @@ def test_report_fail():
|
|||
report(i)
|
||||
return 1
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
assert session.get_next() is None
|
||||
|
@ -96,7 +104,7 @@ def test_checkpoint():
|
|||
assert next.type == TrainingResultType.CHECKPOINT
|
||||
assert next.data["epoch"] == expected
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
validate_zero(0)
|
||||
|
@ -110,7 +118,7 @@ def test_checkpoint():
|
|||
assert next.type == TrainingResultType.CHECKPOINT
|
||||
assert next.data == {}
|
||||
|
||||
init_session(training_func=train, world_rank=1)
|
||||
init_session(training_func=train, world_rank=1, local_rank=1)
|
||||
session = get_session()
|
||||
session.start()
|
||||
validate_nonzero()
|
||||
|
@ -129,7 +137,7 @@ def test_load_checkpoint_after_save():
|
|||
checkpoint = load_checkpoint()
|
||||
assert checkpoint["epoch"] == i
|
||||
|
||||
init_session(training_func=train, world_rank=0)
|
||||
init_session(training_func=train, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
for i in range(2):
|
||||
|
@ -145,7 +153,7 @@ def test_locking():
|
|||
import _thread
|
||||
_thread.interrupt_main()
|
||||
|
||||
init_session(training_func=train_1, world_rank=0)
|
||||
init_session(training_func=train_1, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
session.start()
|
||||
|
@ -156,7 +164,7 @@ def test_locking():
|
|||
report(loss=i)
|
||||
train_1()
|
||||
|
||||
init_session(training_func=train_2, world_rank=0)
|
||||
init_session(training_func=train_2, world_rank=0, local_rank=0)
|
||||
session = get_session()
|
||||
session.start()
|
||||
time.sleep(3)
|
||||
|
|
|
@ -88,7 +88,7 @@ def gen_execute_single_async_special(special_f):
|
|||
assert len(self.workers) == 2
|
||||
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
|
||||
kwargs["train_func"] = special_f
|
||||
return self.workers[i]._BaseWorkerMixin__execute.remote(
|
||||
return self.workers[i].actor._BaseWorkerMixin__execute.remote(
|
||||
f, *args, **kwargs)
|
||||
|
||||
return execute_single_async_special
|
||||
|
@ -126,7 +126,7 @@ class KillCallback(SGDCallback):
|
|||
print(results)
|
||||
assert all(r["loss"] == 1 for r in results)
|
||||
if self.counter == self.fail_on:
|
||||
ray.kill(self.worker_group.workers[0])
|
||||
ray.kill(self.worker_group.workers[0].actor)
|
||||
time.sleep(3)
|
||||
self.counter += 1
|
||||
|
||||
|
@ -752,6 +752,27 @@ def test_worker_failure_2(ray_start_2_cpus):
|
|||
assert results == [1, 1]
|
||||
|
||||
|
||||
def test_worker_failure_local_rank(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
def train():
|
||||
return sgd.local_rank()
|
||||
|
||||
def train_actor_failure():
|
||||
import sys
|
||||
sys.exit(0)
|
||||
return sgd.local_rank()
|
||||
|
||||
new_backend_executor_cls = gen_new_backend_executor(train_actor_failure)
|
||||
|
||||
with patch.object(ray.util.sgd.v2.trainer, "BackendExecutor",
|
||||
new_backend_executor_cls):
|
||||
trainer = Trainer(test_config, num_workers=2)
|
||||
trainer.start()
|
||||
results = trainer.run(train)
|
||||
assert set(results) == {0, 1}
|
||||
|
||||
|
||||
def test_worker_start_failure(ray_start_2_cpus):
|
||||
test_config = TestConfig()
|
||||
|
||||
|
|
|
@ -480,7 +480,7 @@ class SGDWorkerGroup:
|
|||
self._worker_group = worker_group
|
||||
|
||||
def __getitem__(self, item) -> ActorHandle:
|
||||
return self._worker_group.workers[item]
|
||||
return self._worker_group.workers[item].actor
|
||||
|
||||
def shutdown(self, patience_s: float = 5):
|
||||
"""Shutdown all the workers.
|
||||
|
|
|
@ -87,21 +87,6 @@ class PropagatingThread(Thread):
|
|||
return self.ret
|
||||
|
||||
|
||||
def get_node_id() -> str:
|
||||
"""Returns the ID of the node that this worker is on."""
|
||||
return ray.get_runtime_context().node_id.hex()
|
||||
|
||||
|
||||
def get_hostname() -> str:
|
||||
"""Returns the hostname that this worker is on."""
|
||||
return socket.gethostname()
|
||||
|
||||
|
||||
def get_gpu_ids() -> List[int]:
|
||||
"""Return list of CUDA device IDs available to this worker."""
|
||||
return ray.get_gpu_ids()
|
||||
|
||||
|
||||
def update_env_vars(env_vars: Dict[str, Any]):
|
||||
"""Updates the environment variables on this worker process.
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import socket
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.types import ObjectRef
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -22,6 +25,32 @@ class BaseWorkerMixin:
|
|||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerMetadata:
|
||||
"""Metadata for each worker/actor.
|
||||
|
||||
This information is expected to stay the same throughout the lifetime of
|
||||
actor.
|
||||
|
||||
Args:
|
||||
node_id (str): ID of the node this worker is on.
|
||||
node_ip (str): IP address of the node this worker is on.
|
||||
hostname (str): Hostname that this worker is on.
|
||||
gpu_ids (List[int]): List of CUDA IDs available to this worker.
|
||||
"""
|
||||
node_id: str
|
||||
node_ip: str
|
||||
hostname: str
|
||||
gpu_ids: Optional[List[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
"""Class representing a Worker."""
|
||||
actor: ActorHandle
|
||||
metadata: WorkerMetadata
|
||||
|
||||
|
||||
def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
||||
"""Create the executable class to use as the Ray actors."""
|
||||
if not executable_cls:
|
||||
|
@ -37,6 +66,20 @@ def create_executable_class(executable_cls: Optional[Type] = None) -> Type:
|
|||
return _WrappedExecutable
|
||||
|
||||
|
||||
def construct_metadata() -> WorkerMetadata:
|
||||
"""Creates metadata for this worker.
|
||||
|
||||
This function is expected to be run on the actor.
|
||||
"""
|
||||
node_id = ray.get_runtime_context().node_id.hex()
|
||||
node_ip = ray.util.get_node_ip_address()
|
||||
hostname = socket.gethostname()
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
|
||||
return WorkerMetadata(
|
||||
node_id=node_id, node_ip=node_ip, hostname=hostname, gpu_ids=gpu_ids)
|
||||
|
||||
|
||||
class WorkerGroup:
|
||||
"""Group of Ray Actors that can execute arbitrary functions.
|
||||
|
||||
|
@ -118,8 +161,11 @@ class WorkerGroup:
|
|||
self.start()
|
||||
|
||||
def _create_worker(self):
|
||||
return self._remote_cls.remote(*self._actor_cls_args,
|
||||
**self._actor_cls_kwargs)
|
||||
actor = self._remote_cls.remote(*self._actor_cls_args,
|
||||
**self._actor_cls_kwargs)
|
||||
actor_metadata = ray.get(
|
||||
actor._BaseWorkerMixin__execute.remote(construct_metadata))
|
||||
return Worker(actor=actor, metadata=actor_metadata)
|
||||
|
||||
def start(self):
|
||||
"""Starts all the workers in this worker group."""
|
||||
|
@ -145,9 +191,11 @@ class WorkerGroup:
|
|||
logger.debug(f"Shutting down {len(self.workers)} workers.")
|
||||
if patience_s <= 0:
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
ray.kill(worker.actor)
|
||||
else:
|
||||
done_refs = [w.__ray_terminate__.remote() for w in self.workers]
|
||||
done_refs = [
|
||||
w.actor.__ray_terminate__.remote() for w in self.workers
|
||||
]
|
||||
# Wait for actors to die gracefully.
|
||||
done, not_done = ray.wait(done_refs, timeout=patience_s)
|
||||
if not_done:
|
||||
|
@ -155,7 +203,7 @@ class WorkerGroup:
|
|||
"force kill.")
|
||||
# If all actors are not able to die gracefully, then kill them.
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
ray.kill(worker.actor)
|
||||
|
||||
logger.debug("Shutdown successful.")
|
||||
self.workers = []
|
||||
|
@ -180,7 +228,7 @@ class WorkerGroup:
|
|||
"create a new WorkerGroup or restart this one.")
|
||||
|
||||
return [
|
||||
w._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
|
||||
w.actor._BaseWorkerMixin__execute.remote(func, *args, **kwargs)
|
||||
for w in self.workers
|
||||
]
|
||||
|
||||
|
@ -214,7 +262,8 @@ class WorkerGroup:
|
|||
if worker_index >= len(self.workers):
|
||||
raise ValueError(f"The provided worker_index {worker_index} is "
|
||||
f"not valid for {self.num_workers} workers.")
|
||||
return self.workers[worker_index]._BaseWorkerMixin__execute.remote(
|
||||
return self.workers[worker_index].actor._BaseWorkerMixin__execute\
|
||||
.remote(
|
||||
func, *args, **kwargs)
|
||||
|
||||
def execute_single(self, worker_index: int, func: Callable[..., T], *args,
|
||||
|
|
Loading…
Add table
Reference in a new issue