Register Class Only Creates Entry in Redis Once (#1038)

Don't export the same custom class definition multiple times.
This commit is contained in:
Wapaul1 2017-09-30 15:30:27 -07:00 committed by Robert Nishihara
parent 16e82b43d1
commit 97b3355adc

View file

@ -291,8 +291,9 @@ class Worker(object):
"be incorrect in some cases."
.format(type(e.example_object)))
print(warning_message)
except serialization.RayNotDictionarySerializable:
_register_class(type(e.example_object), pickle=True)
except (serialization.RayNotDictionarySerializable,
pickle.pickle.PicklingError):
_register_class(type(e.example_object), use_pickle=True)
warning_message = ("WARNING: Falling back to serializing "
"objects of type {} by using pickle. "
"This may be inefficient."
@ -519,7 +520,7 @@ class Worker(object):
# actually run the function locally.
pickled_function = pickle.dumps(function)
function_to_run_id = random_string()
function_to_run_id = hashlib.sha1(pickled_function).digest()
key = b"FunctionsToRun:" + function_to_run_id
# First run the function on the driver. Pass in the number of
# workers on this node that have already started executing this
@ -527,13 +528,25 @@ class Worker(object):
# counter starts at 0.
counter = self.redis_client.hincrby(self.node_ip_address,
key, 1) - 1
# We always run the task locally.
function({"counter": counter, "worker": self})
# Check if the function has already been put into redis.
function_exported = self.redis_client.setnx(b"Lock:" + key, 1)
if not function_exported:
# In this case, the function has already been exported, so
# we don't need to export it again.
return
# Run the function on all workers.
self.redis_client.hmset(key,
{"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickled_function})
self.redis_client.rpush("Exports", key)
# TODO(rkn): If the worker fails after it calls setnx and before it
# successfully completes the hmset and rpush, then the program will
# most likely hang. This could be fixed by making these three
# operations into a transaction (or by implementing a custom
# command that does all three things).
def push_error_to_driver(self, driver_id, error_type, message, data=None):
"""Push an error message to the driver to be printed in the background.
@ -1047,9 +1060,9 @@ def _initialize_serialization(worker=global_worker):
_register_class(RayGetError)
_register_class(RayGetArgumentError)
# Tell Ray to serialize lambdas with pickle.
_register_class(type(lambda: 0), pickle=True)
_register_class(type(lambda: 0), use_pickle=True)
# Tell Ray to serialize types with pickle.
_register_class(type(int), pickle=True)
_register_class(type(int), use_pickle=True)
def get_address_info_from_redis_helper(redis_address, node_ip_address):
@ -1884,12 +1897,12 @@ def disconnect(worker=global_worker):
worker.serialization_context = pyarrow.SerializationContext()
def register_class(cls, pickle=False, worker=global_worker):
def register_class(cls, use_pickle=False, worker=global_worker):
raise Exception("The function ray.register_class is deprecated. It should "
"be safe to remove any calls to this function.")
def _register_class(cls, pickle=False, worker=global_worker):
def _register_class(cls, use_pickle=False, worker=global_worker):
"""Enable serialization and deserialization for a particular class.
This method runs the register_class function defined below on every worker,
@ -1898,21 +1911,28 @@ def _register_class(cls, pickle=False, worker=global_worker):
Args:
cls (type): The class that ray should serialize.
pickle (bool): If False then objects of this class will be serialized
by turning their __dict__ fields into a dictionary. If True, then
objects of this class will be serialized using pickle.
use_pickle (bool): If False then objects of this class will be
serialized by turning their __dict__ fields into a dictionary. If
True, then objects of this class will be serialized using pickle.
Raises:
Exception: An exception is raised if pickle=False and the class cannot
be efficiently serialized by Ray.
"""
class_id = random_string()
if not use_pickle:
# In this case, the class ID will be used to deduplicate the class
# across workers.
class_id = hashlib.sha1(pickle.dumps(cls)).digest()
else:
# In this case, the class ID only needs to be meaningful on this worker
# and not across workers.
class_id = random_string()
def register_class_for_serialization(worker_info):
worker_info["worker"].serialization_context.register_type(
cls, class_id, pickle=pickle)
cls, class_id, pickle=use_pickle)
if not pickle:
if not use_pickle:
# Raise an exception if cls cannot be serialized efficiently by Ray.
serialization.check_serializable(cls)
worker.run_function_on_all_workers(register_class_for_serialization)