[Ray SGD] [Hotfix] Worker group hotfix (#11008)

This commit is contained in:
Amog Kamsetty 2020-09-24 12:21:30 -07:00 committed by GitHub
parent 8c241d5f1d
commit 07bdf062b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View file

@ -96,6 +96,22 @@ def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_apply_all_workers(ray_start_2_cpus, num_workers, use_local):
def fn():
return 1
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local,
use_gpu=False)
results = trainer.apply_all_workers(fn)
assert all(x == 1 for x in results)
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_multi_model(ray_start_2_cpus, num_workers, use_local):

View file

@ -464,7 +464,7 @@ class LocalWorkerGroup(WorkerGroupInterface):
return [local_call] + ray.get(remote_calls)
def apply_all_workers(self, fn):
remote_calls = self.remote_worker_group.apply_all_workers(fn)
remote_calls = self.remote_worker_group._apply_all_workers(fn)
local_call = self.local_worker.apply(fn)
return [local_call] + ray.get(remote_calls)