From 97b3355adc3213ab86e4c42594f279dc10dedf1e Mon Sep 17 00:00:00 2001 From: Wapaul1 Date: Sat, 30 Sep 2017 15:30:27 -0700 Subject: [PATCH] Register Class Only Creates Entry in Redis Once (#1038) Don't export the same custom class definition multiple times. --- python/ray/worker.py | 46 +++++++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/python/ray/worker.py b/python/ray/worker.py index 61b8f1709..c60205899 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -291,8 +291,9 @@ class Worker(object): "be incorrect in some cases." .format(type(e.example_object))) print(warning_message) - except serialization.RayNotDictionarySerializable: - _register_class(type(e.example_object), pickle=True) + except (serialization.RayNotDictionarySerializable, + pickle.pickle.PicklingError): + _register_class(type(e.example_object), use_pickle=True) warning_message = ("WARNING: Falling back to serializing " "objects of type {} by using pickle. " "This may be inefficient." @@ -519,7 +520,7 @@ class Worker(object): # actually run the function locally. pickled_function = pickle.dumps(function) - function_to_run_id = random_string() + function_to_run_id = hashlib.sha1(pickled_function).digest() key = b"FunctionsToRun:" + function_to_run_id # First run the function on the driver. Pass in the number of # workers on this node that have already started executing this @@ -527,13 +528,25 @@ class Worker(object): # counter starts at 0. counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1 + # We always run the task locally. function({"counter": counter, "worker": self}) + # Check if the function has already been put into redis. + function_exported = self.redis_client.setnx(b"Lock:" + key, 1) + if not function_exported: + # In this case, the function has already been exported, so + # we don't need to export it again. + return # Run the function on all workers. self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(), "function_id": function_to_run_id, "function": pickled_function}) self.redis_client.rpush("Exports", key) + # TODO(rkn): If the worker fails after it calls setnx and before it + # successfully completes the hmset and rpush, then the program will + # most likely hang. This could be fixed by making these three + # operations into a transaction (or by implementing a custom + # command that does all three things). def push_error_to_driver(self, driver_id, error_type, message, data=None): """Push an error message to the driver to be printed in the background. @@ -1047,9 +1060,9 @@ def _initialize_serialization(worker=global_worker): _register_class(RayGetError) _register_class(RayGetArgumentError) # Tell Ray to serialize lambdas with pickle. - _register_class(type(lambda: 0), pickle=True) + _register_class(type(lambda: 0), use_pickle=True) # Tell Ray to serialize types with pickle. - _register_class(type(int), pickle=True) + _register_class(type(int), use_pickle=True) def get_address_info_from_redis_helper(redis_address, node_ip_address): @@ -1884,12 +1897,12 @@ def disconnect(worker=global_worker): worker.serialization_context = pyarrow.SerializationContext() -def register_class(cls, pickle=False, worker=global_worker): +def register_class(cls, use_pickle=False, worker=global_worker): raise Exception("The function ray.register_class is deprecated. It should " "be safe to remove any calls to this function.") -def _register_class(cls, pickle=False, worker=global_worker): +def _register_class(cls, use_pickle=False, worker=global_worker): """Enable serialization and deserialization for a particular class. This method runs the register_class function defined below on every worker, @@ -1898,21 +1911,28 @@ def _register_class(cls, pickle=False, worker=global_worker): Args: cls (type): The class that ray should serialize. - pickle (bool): If False then objects of this class will be serialized - by turning their __dict__ fields into a dictionary. If True, then - objects of this class will be serialized using pickle. + use_pickle (bool): If False then objects of this class will be + serialized by turning their __dict__ fields into a dictionary. If + True, then objects of this class will be serialized using pickle. Raises: Exception: An exception is raised if pickle=False and the class cannot be efficiently serialized by Ray. """ - class_id = random_string() + if not use_pickle: + # In this case, the class ID will be used to deduplicate the class + # across workers. + class_id = hashlib.sha1(pickle.dumps(cls)).digest() + else: + # In this case, the class ID only needs to be meaningful on this worker + # and not across workers. + class_id = random_string() def register_class_for_serialization(worker_info): worker_info["worker"].serialization_context.register_type( - cls, class_id, pickle=pickle) + cls, class_id, pickle=use_pickle) - if not pickle: + if not use_pickle: # Raise an exception if cls cannot be serialized efficiently by Ray. serialization.check_serializable(cls) worker.run_function_on_all_workers(register_class_for_serialization)