mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
Register Class Only Creates Entry in Redis Once (#1038)
Don't export the same custom class definition multiple times.
This commit is contained in:
parent
16e82b43d1
commit
97b3355adc
1 changed files with 33 additions and 13 deletions
|
@ -291,8 +291,9 @@ class Worker(object):
|
|||
"be incorrect in some cases."
|
||||
.format(type(e.example_object)))
|
||||
print(warning_message)
|
||||
except serialization.RayNotDictionarySerializable:
|
||||
_register_class(type(e.example_object), pickle=True)
|
||||
except (serialization.RayNotDictionarySerializable,
|
||||
pickle.pickle.PicklingError):
|
||||
_register_class(type(e.example_object), use_pickle=True)
|
||||
warning_message = ("WARNING: Falling back to serializing "
|
||||
"objects of type {} by using pickle. "
|
||||
"This may be inefficient."
|
||||
|
@ -519,7 +520,7 @@ class Worker(object):
|
|||
# actually run the function locally.
|
||||
pickled_function = pickle.dumps(function)
|
||||
|
||||
function_to_run_id = random_string()
|
||||
function_to_run_id = hashlib.sha1(pickled_function).digest()
|
||||
key = b"FunctionsToRun:" + function_to_run_id
|
||||
# First run the function on the driver. Pass in the number of
|
||||
# workers on this node that have already started executing this
|
||||
|
@ -527,13 +528,25 @@ class Worker(object):
|
|||
# counter starts at 0.
|
||||
counter = self.redis_client.hincrby(self.node_ip_address,
|
||||
key, 1) - 1
|
||||
# We always run the task locally.
|
||||
function({"counter": counter, "worker": self})
|
||||
# Check if the function has already been put into redis.
|
||||
function_exported = self.redis_client.setnx(b"Lock:" + key, 1)
|
||||
if not function_exported:
|
||||
# In this case, the function has already been exported, so
|
||||
# we don't need to export it again.
|
||||
return
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(key,
|
||||
{"driver_id": self.task_driver_id.id(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickled_function})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
# TODO(rkn): If the worker fails after it calls setnx and before it
|
||||
# successfully completes the hmset and rpush, then the program will
|
||||
# most likely hang. This could be fixed by making these three
|
||||
# operations into a transaction (or by implementing a custom
|
||||
# command that does all three things).
|
||||
|
||||
def push_error_to_driver(self, driver_id, error_type, message, data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
@ -1047,9 +1060,9 @@ def _initialize_serialization(worker=global_worker):
|
|||
_register_class(RayGetError)
|
||||
_register_class(RayGetArgumentError)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
_register_class(type(lambda: 0), pickle=True)
|
||||
_register_class(type(lambda: 0), use_pickle=True)
|
||||
# Tell Ray to serialize types with pickle.
|
||||
_register_class(type(int), pickle=True)
|
||||
_register_class(type(int), use_pickle=True)
|
||||
|
||||
|
||||
def get_address_info_from_redis_helper(redis_address, node_ip_address):
|
||||
|
@ -1884,12 +1897,12 @@ def disconnect(worker=global_worker):
|
|||
worker.serialization_context = pyarrow.SerializationContext()
|
||||
|
||||
|
||||
def register_class(cls, pickle=False, worker=global_worker):
|
||||
def register_class(cls, use_pickle=False, worker=global_worker):
|
||||
raise Exception("The function ray.register_class is deprecated. It should "
|
||||
"be safe to remove any calls to this function.")
|
||||
|
||||
|
||||
def _register_class(cls, pickle=False, worker=global_worker):
|
||||
def _register_class(cls, use_pickle=False, worker=global_worker):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
This method runs the register_class function defined below on every worker,
|
||||
|
@ -1898,21 +1911,28 @@ def _register_class(cls, pickle=False, worker=global_worker):
|
|||
|
||||
Args:
|
||||
cls (type): The class that ray should serialize.
|
||||
pickle (bool): If False then objects of this class will be serialized
|
||||
by turning their __dict__ fields into a dictionary. If True, then
|
||||
objects of this class will be serialized using pickle.
|
||||
use_pickle (bool): If False then objects of this class will be
|
||||
serialized by turning their __dict__ fields into a dictionary. If
|
||||
True, then objects of this class will be serialized using pickle.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if pickle=False and the class cannot
|
||||
be efficiently serialized by Ray.
|
||||
"""
|
||||
class_id = random_string()
|
||||
if not use_pickle:
|
||||
# In this case, the class ID will be used to deduplicate the class
|
||||
# across workers.
|
||||
class_id = hashlib.sha1(pickle.dumps(cls)).digest()
|
||||
else:
|
||||
# In this case, the class ID only needs to be meaningful on this worker
|
||||
# and not across workers.
|
||||
class_id = random_string()
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
worker_info["worker"].serialization_context.register_type(
|
||||
cls, class_id, pickle=pickle)
|
||||
cls, class_id, pickle=use_pickle)
|
||||
|
||||
if not pickle:
|
||||
if not use_pickle:
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
serialization.check_serializable(cls)
|
||||
worker.run_function_on_all_workers(register_class_for_serialization)
|
||||
|
|
Loading…
Add table
Reference in a new issue