Only export actor classes once. (#510)

* Only export actor classes once.

* Fix linting.

* Fixes after rebase.
This commit is contained in:
Robert Nishihara 2017-05-09 19:49:23 -07:00 committed by Philipp Moritz
parent 118fac5619
commit b4788ae518
2 changed files with 67 additions and 36 deletions

View file

@ -20,6 +20,10 @@ def random_actor_id():
return ray.local_scheduler.ObjectID(random_string())
def random_actor_class_id():
return random_string()
def get_actor_method_function_id(attr):
"""Get the function ID corresponding to an actor method.
@ -36,15 +40,19 @@ def get_actor_method_function_id(attr):
return ray.local_scheduler.ObjectID(function_id)
def fetch_and_register_actor(key, worker):
"""Import an actor."""
(driver_id, actor_id_str, actor_name,
module, pickled_class, assigned_gpu_ids,
actor_method_names) = worker.redis_client.hmget(
key, ["driver_id", "actor_id", "name", "module", "class", "gpu_ids",
"actor_method_names"])
actor_id = ray.local_scheduler.ObjectID(actor_id_str)
actor_name = actor_name.decode("ascii")
def fetch_and_register_actor(actor_class_key, worker):
"""Import an actor.
This will be called by the worker's import thread when the worker receives
the actor_class export, assuming that the worker is an actor for that class.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name,
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "actor_method_names"])
actor_name = class_name.decode("ascii")
module = module.decode("ascii")
actor_method_names = json.loads(actor_method_names.decode("ascii"))
@ -71,7 +79,7 @@ def fetch_and_register_actor(key, worker):
traceback_str = ray.worker.format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_actor", traceback_str,
data={"actor_id": actor_id.id()})
data={"actor_id": actor_id_str})
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
@ -192,13 +200,25 @@ def select_local_scheduler(local_schedulers, num_gpus, worker):
return local_scheduler_id
def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
def export_actor_class(class_id, Class, actor_method_names, worker):
if worker.mode is None:
raise NotImplemented("TODO(pcm): Cache actors")
key = b"ActorClass:" + class_id
d = {"driver_id": worker.task_driver_id.id(),
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickling.dumps(Class),
"actor_method_names": json.dumps(list(actor_method_names))}
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
worker):
"""Export an actor to redis.
Args:
actor_id: The ID of the actor.
Class: Name of the class to be exported as an actor.
actor_method_names (list): A list of the names of this actor's methods.
num_cpus (int): The number of CPUs that this actor requires.
num_gpus (int): The number of GPUs that this actor requires.
@ -208,11 +228,12 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
raise Exception("Actors cannot be created before Ray has been started. "
"You can start Ray with 'ray.init()'.")
key = "Actor:{}".format(actor_id.id())
pickled_class = pickling.dumps(Class)
# For now, all actor methods have 1 return value.
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
# TODO(rkn): When we create a second actor, we are probably overwriting
# the values from the first actor here. This may or may not be a problem.
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = (1, num_cpus, 0)
@ -228,19 +249,11 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
worker)
assert local_scheduler_id is not None
d = {"driver_id": driver_id,
"actor_id": actor_id.id(),
"name": Class.__name__,
"module": Class.__module__,
"class": pickled_class,
"num_gpus": num_gpus,
"actor_method_names": json.dumps(list(actor_method_names))}
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
# We publish the actor notification after the call to hmset so that when the
# newly created actor queries Redis to find the number of GPUs assigned to
# it, that value is present.
# We must put the actor information in Redis before publishing the actor
# notification so that when the newly created actor attempts to fetch the
# information from Redis, it is already there.
worker.redis_client.hmset(key, {"class_id": class_id,
"num_gpus": num_gpus})
# Really we should encode this message as a flatbuffer object. However, we're
# having trouble getting that to work. It almost works, but in Python 2.7,
@ -258,6 +271,13 @@ def export_actor(actor_id, Class, actor_method_names, num_cpus, num_gpus,
def actor(*args, **kwargs):
def make_actor_decorator(num_cpus=1, num_gpus=0):
def make_actor(Class):
class_id = random_actor_class_id()
# The list exported will have length 0 if the class has not been exported
# yet, and length one if it has. This is just implementing a bool, but we
# don't use a bool because we need to modify it inside of the NewClass
# constructor.
exported = []
# The function actor_method_call gets called if somebody tries to call a
# method on their local actor stub object.
def actor_method_call(actor_id, attr, function_signature, *args,
@ -267,7 +287,6 @@ def actor(*args, **kwargs):
args = signature.extend_args(function_signature, args, kwargs)
function_id = get_actor_method_function_id(attr)
# TODO(pcm): Extend args with keyword args.
object_ids = ray.worker.global_worker.submit_task(function_id, "",
args,
actor_id=actor_id)
@ -296,7 +315,13 @@ def actor(*args, **kwargs):
self._ray_method_signatures[k] = signature.extract_signature(
v, ignore_first=True)
export_actor(self._ray_actor_id, Class,
# Export the actor class if it has not been exported yet.
if len(exported) == 0:
export_actor_class(class_id, Class, self._ray_actor_methods.keys(),
ray.worker.global_worker)
exported.append(0)
# Export the actor.
export_actor(self._ray_actor_id, class_id,
self._ray_actor_methods.keys(), num_cpus, num_gpus,
ray.worker.global_worker)
# Call __init__ as a remote function.
@ -350,4 +375,4 @@ def actor(*args, **kwargs):
"'ray.actor(num_gpus=1)'.")
ray.worker.global_worker.fetch_and_register["Actor"] = fetch_and_register_actor
ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor

View file

@ -447,7 +447,7 @@ class Worker(object):
self.mode = None
self.cached_remote_functions = []
self.cached_functions_to_run = []
self.fetch_and_register = {}
self.fetch_and_register_actor = None
self.actors = {}
# Use a defaultdict for the actor counts. If this is accessed with a
# missing key, the default value of 0 is returned, and that key value pair
@ -1297,12 +1297,12 @@ def import_thread(worker):
fetch_and_register_environment_variable(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
elif key.startswith(b"Actor"):
# Only get the actor if the actor ID matches the actor ID of this
# worker.
actor_id, = worker.redis_client.hmget(key, "actor_id")
if worker.actor_id == actor_id:
worker.fetch_and_register["Actor"](key, worker)
elif key.startswith(b"ActorClass"):
# If this worker is an actor that is supposed to construct this class,
# fetch the actor and class information and construct the class.
class_id = key.split(b":", 1)[1]
if worker.actor_id != NIL_ACTOR_ID and worker.class_id == class_id:
worker.fetch_and_register_actor(key, worker)
else:
raise Exception("This code should be unreachable.")
num_imported += 1
@ -1479,6 +1479,12 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
# task.
worker.current_task_id = driver_task.task_id()
# If this is an actor, get the ID of the corresponding class for the actor.
if worker.actor_id != NIL_ACTOR_ID:
actor_key = "Actor:{}".format(worker.actor_id)
class_id = worker.redis_client.hget(actor_key, "class_id")
worker.class_id = class_id
# If this is a worker, then start a thread to import exports from the driver.
if mode == WORKER_MODE:
t = threading.Thread(target=import_thread, args=(worker,))