mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[Ray SGD] [Hotfix] Worker group hotfix (#11008)
This commit is contained in:
parent
8c241d5f1d
commit
07bdf062b9
2 changed files with 17 additions and 1 deletions
|
@ -96,6 +96,22 @@ def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
|
|||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_apply_all_workers(ray_start_2_cpus, num_workers, use_local):
|
||||
def fn():
|
||||
return 1
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
|
||||
results = trainer.apply_all_workers(fn)
|
||||
assert all(x == 1 for x in results)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_multi_model(ray_start_2_cpus, num_workers, use_local):
|
||||
|
|
|
@ -464,7 +464,7 @@ class LocalWorkerGroup(WorkerGroupInterface):
|
|||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
remote_calls = self.remote_worker_group.apply_all_workers(fn)
|
||||
remote_calls = self.remote_worker_group._apply_all_workers(fn)
|
||||
local_call = self.local_worker.apply(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue