[SGD] Make actor creation async (#19325)

* fix

* fix

* fix
This commit is contained in:
Amog Kamsetty 2021-10-12 16:15:59 -07:00 committed by GitHub
parent d99b095eac
commit 09d8049584
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -160,13 +160,6 @@ class WorkerGroup:
resources=self.additional_resources_per_worker)(self._base_cls)
self.start()
def _create_worker(self):
actor = self._remote_cls.remote(*self._actor_cls_args,
**self._actor_cls_kwargs)
actor_metadata = ray.get(
actor._BaseWorkerMixin__execute.remote(construct_metadata))
return Worker(actor=actor, metadata=actor_metadata)
def start(self):
"""Starts all the workers in this worker group."""
if self.workers and len(self.workers) > 0:
@ -301,8 +294,21 @@ class WorkerGroup:
Args:
num_workers (int): The number of workers to add.
"""
new_actors = []
new_actor_metadata = []
for _ in range(num_workers):
self.workers.append(self._create_worker())
actor = self._remote_cls.remote(*self._actor_cls_args,
**self._actor_cls_kwargs)
new_actors.append(actor)
new_actor_metadata.append(
actor._BaseWorkerMixin__execute.remote(construct_metadata))
# Get metadata from all actors.
metadata = ray.get(new_actor_metadata)
for i in range(len(new_actors)):
self.workers.append(
Worker(actor=new_actors[i], metadata=metadata[i]))
def __len__(self):
return len(self.workers)