ray/rllib/utils/actors.py
2020-01-02 17:42:13 -08:00

113 lines
3.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import ray
logger = logging.getLogger(__name__)
class TaskPool:
"""Helper class for tracking the status of many in-flight actor tasks."""
def __init__(self):
self._tasks = {}
self._objects = {}
self._fetching = []
def add(self, worker, all_obj_ids):
if isinstance(all_obj_ids, list):
obj_id = all_obj_ids[0]
else:
obj_id = all_obj_ids
self._tasks[obj_id] = worker
self._objects[obj_id] = all_obj_ids
def completed(self, blocking_wait=False):
pending = list(self._tasks)
if pending:
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
if not ready and blocking_wait:
ready, _ = ray.wait(pending, num_returns=1, timeout=10.0)
for obj_id in ready:
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))
def completed_prefetch(self, blocking_wait=False, max_yield=999):
"""Similar to completed but only returns once the object is local.
Assumes obj_id only is one id."""
for worker, obj_id in self.completed(blocking_wait=blocking_wait):
self._fetching.append((worker, obj_id))
remaining = []
num_yielded = 0
for worker, obj_id in self._fetching:
if num_yielded < max_yield:
yield (worker, obj_id)
num_yielded += 1
else:
remaining.append((worker, obj_id))
self._fetching = remaining
def reset_workers(self, workers):
"""Notify that some workers may be removed."""
for obj_id, ev in self._tasks.copy().items():
if ev not in workers:
del self._tasks[obj_id]
del self._objects[obj_id]
ok = []
for ev, obj_id in self._fetching:
if ev in workers:
ok.append((ev, obj_id))
self._fetching = ok
@property
def count(self):
return len(self._tasks)
def drop_colocated(actors):
colocated, non_colocated = split_colocated(actors)
for a in colocated:
a.__ray_terminate__.remote()
return non_colocated
def split_colocated(actors):
localhost = os.uname()[1]
hosts = ray.get([a.get_host.remote() for a in actors])
local = []
non_local = []
for host, a in zip(hosts, actors):
if host == localhost:
local.append(a)
else:
non_local.append(a)
return local, non_local
def try_create_colocated(cls, args, count):
actors = [cls.remote(*args) for _ in range(count)]
local, rest = split_colocated(actors)
logger.info("Got {} colocated actors of {}".format(len(local), count))
for a in rest:
a.__ray_terminate__.remote()
return local
def create_colocated(cls, args, count):
logger.info("Trying to create {} colocated actors".format(count))
ok = []
i = 1
while len(ok) < count and i < 10:
attempt = try_create_colocated(cls, args, count * i)
ok.extend(attempt)
i += 1
if len(ok) < count:
raise Exception("Unable to create enough colocated actors, abort.")
for a in ok[count:]:
a.__ray_terminate__.remote()
return ok[:count]