mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Fix the resetting of reusable variables on the driver and cache functions to run on all workers. (#446)
* Properly reset reusable variables on the driver when remote functions are run locally on the driver. * Cache functions to run on all workers that occur before ray.init is called.
This commit is contained in:
parent
1c3aaf7189
commit
0a44145906
4 changed files with 245 additions and 32 deletions
|
@ -57,6 +57,9 @@ class DistArray(object):
|
|||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
# Register the DistArray class with Ray so that it knows how to serialize it.
|
||||
ray.register_class(DistArray)
|
||||
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
|
|
@ -163,8 +163,21 @@ class RayReusables(object):
|
|||
|
||||
Attributes:
|
||||
_names (List[str]): A list of the names of all the reusable variables.
|
||||
_reusables (Dict[str, Reusable]): A dictionary mapping the name of the
|
||||
reusable variables to the corresponding Reusable object.
|
||||
_reinitializers (Dict[str, Callable]): A dictionary mapping the name of the
|
||||
reusable variables to the corresponding reinitializer.
|
||||
_running_remote_function_locally (bool): A flag used to indicate if a remote
|
||||
function is running locally on the driver so that we can simulate the same
|
||||
behavior as running a remote function remotely.
|
||||
_reusables: A dictionary mapping the name of a reusable variable to the
|
||||
value of the reusable variable.
|
||||
_local_mode_reusables: A copy of _reusables used on the driver when running
|
||||
remote functions locally on the driver. This is needed because there are
|
||||
two ways in which reusable variables can be used on the driver. The first
|
||||
is that the driver's copy can be manipulated. This copy is never reset
|
||||
(think of the driver as a single long-running task). The second way is
|
||||
that a remote function can be run locally on the driver, and this remote
|
||||
function needs access to a copy of the reusable variable, and that copy
|
||||
must be reinitialized after use.
|
||||
_cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first
|
||||
element of each pair is the name of a reusable variable, and the second
|
||||
element is the Reusable object. This list is used to store reusable
|
||||
|
@ -178,18 +191,54 @@ class RayReusables(object):
|
|||
def __init__(self):
|
||||
"""Initialize a RayReusables object."""
|
||||
self._names = set()
|
||||
self._reinitializers = {}
|
||||
self._running_remote_function_locally = False
|
||||
self._reusables = {}
|
||||
self._local_mode_reusables = {}
|
||||
self._cached_reusables = []
|
||||
self._used = set()
|
||||
self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_reusables", "_local_mode_reusables", "_cached_reusables", "_used", "_slots", "_create_and_export", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
|
||||
|
||||
def _create_and_export(self, name, reusable):
|
||||
"""Create a reusable variable and add export it to the workers.
|
||||
|
||||
If ray.init has not been called yet, then store the reusable variable and
|
||||
export it later then connect is called.
|
||||
|
||||
Args:
|
||||
name (str): The name of the reusable variable.
|
||||
reusable (Reusable): The reusable object to use to create the reusable
|
||||
variable.
|
||||
"""
|
||||
self._names.add(name)
|
||||
self._reinitializers[name] = reusable.reinitializer
|
||||
# Export the reusable variable to the workers if we are on the driver. If
|
||||
# ray.init has not been called yet, then cache the reusable variable to
|
||||
# export later.
|
||||
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
elif _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
self._reusables[name] = reusable.initializer()
|
||||
# We create a second copy of the reusable variable on the driver to use
|
||||
# inside of remote functions that run locally. This occurs when we start Ray
|
||||
# in PYTHON_MODE and when we call a remote function locally.
|
||||
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
self._local_mode_reusables[name] = reusable.initializer()
|
||||
|
||||
def _reinitialize(self):
|
||||
"""Reinitialize the reusable variables that the current task used."""
|
||||
for name in self._used:
|
||||
current_value = getattr(self, name)
|
||||
new_value = self._reusables[name].reinitializer(current_value)
|
||||
object.__setattr__(self, name, new_value)
|
||||
current_value = self._reusables[name]
|
||||
new_value = self._reinitializers[name](current_value)
|
||||
# If we are on the driver, reset the copy of the reusable variable in the
|
||||
# _local_mode_reusables dictionary.
|
||||
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
assert self._running_remote_function_locally
|
||||
self._local_mode_reusables[name] = new_value
|
||||
else:
|
||||
self._reusables[name] = new_value
|
||||
self._used.clear() # Reset the _used list.
|
||||
|
||||
def __getattribute__(self, name):
|
||||
|
@ -205,9 +254,16 @@ class RayReusables(object):
|
|||
return object.__getattribute__(self, name)
|
||||
if name in self._slots:
|
||||
return object.__getattribute__(self, name)
|
||||
# Handle various fields that are not reusable variables.
|
||||
if name not in self._names:
|
||||
return object.__getattribute__(self, name)
|
||||
# Make a note of the fact that the reusable variable has been used.
|
||||
if name in self._names and name not in self._used:
|
||||
self._used.add(name)
|
||||
return object.__getattribute__(self, name)
|
||||
if self._running_remote_function_locally:
|
||||
return self._local_mode_reusables[name]
|
||||
else:
|
||||
return self._reusables[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Set an attribute. This handles reusable variables as a special case.
|
||||
|
@ -217,13 +273,14 @@ class RayReusables(object):
|
|||
called on the driver, then the functions for initializing and reinitializing
|
||||
the variable are shipped to the workers.
|
||||
|
||||
If this is called before ray.init has been run, then the reusable variable
|
||||
will be cached and it will be created and exported when connect is called.
|
||||
|
||||
Args:
|
||||
name (str): The name of the attribute to set. This is either a whitelisted
|
||||
name or it is treated as the name of a reusable variable.
|
||||
value: If name is a whitelisted name, then value can be any value. If name
|
||||
is the name of a reusable variable, then this is either the serialized
|
||||
initializer code or it is a tuple of the serialized initializer and
|
||||
reinitializer code.
|
||||
is the name of a reusable variable, then this is a Reusable object.
|
||||
"""
|
||||
try:
|
||||
slots = self._slots
|
||||
|
@ -236,13 +293,11 @@ class RayReusables(object):
|
|||
reusable = value
|
||||
if not issubclass(type(reusable), Reusable):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
self._names.add(name)
|
||||
self._reusables[name] = reusable
|
||||
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
elif _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
object.__setattr__(self, name, reusable.initializer())
|
||||
# Create the reusable variable locally, and export it if possible.
|
||||
self._create_and_export(name, reusable)
|
||||
# Create an empty attribute with the name of the reusable variable. This
|
||||
# allows the Python interpreter to do tab complete properly.
|
||||
return object.__setattr__(self, name, None)
|
||||
|
||||
def __delattr__(self, name):
|
||||
"""We do not allow attributes of RayReusables to be deleted.
|
||||
|
@ -299,6 +354,8 @@ class Worker(object):
|
|||
eventually does call connect, if it is a driver, it will export these
|
||||
functions to the scheduler. If cached_remote_functions is None, that means
|
||||
that connect has been called already.
|
||||
cached_functions_to_run (List): A list of functions to run on all of the
|
||||
workers that should be exported as soon as connect is called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -307,6 +364,7 @@ class Worker(object):
|
|||
self.handle = None
|
||||
self.mode = None
|
||||
self.cached_remote_functions = []
|
||||
self.cached_functions_to_run = []
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""Set the mode of the worker.
|
||||
|
@ -430,12 +488,8 @@ class Worker(object):
|
|||
objectids = raylib.submit_task(self.handle, task_capsule)
|
||||
return objectids
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
This function will first be run on the driver, and then it will be exported
|
||||
to all of the workers to be run. It will also be run on any new workers that
|
||||
register later.
|
||||
def export_function_to_run_on_all_workers(self, function):
|
||||
"""Export this function and run it on all workers.
|
||||
|
||||
Args:
|
||||
function (Callable): The function to run on all of the workers. It should
|
||||
|
@ -444,12 +498,35 @@ class Worker(object):
|
|||
"""
|
||||
if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# Then run the function on all of the workers.
|
||||
# Run the function on all of the workers.
|
||||
if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
|
||||
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
This function will first be run on the driver, and then it will be exported
|
||||
to all of the workers to be run. It will also be run on any new workers that
|
||||
register later. If ray.init has not been called yet, then cache the function
|
||||
and export it later.
|
||||
|
||||
Args:
|
||||
function (Callable): The function to run on all of the workers. It should
|
||||
not take any arguments. If it returns anything, its return values will
|
||||
not be used.
|
||||
"""
|
||||
if self.mode not in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# If ray.init has not been called yet, then cache the function and export it
|
||||
# when connect is called. Otherwise, run the function on all workers.
|
||||
if self.mode is None:
|
||||
self.cached_functions_to_run.append(function)
|
||||
else:
|
||||
self.export_function_to_run_on_all_workers(function)
|
||||
|
||||
global_worker = Worker()
|
||||
"""Worker: The global Worker object for this worker process.
|
||||
|
||||
|
@ -738,7 +815,7 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
|
|||
_logger().addHandler(log_handler)
|
||||
_logger().setLevel(logging.DEBUG)
|
||||
_logger().propagate = False
|
||||
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
|
||||
# Add the directory containing the script that is running to the Python
|
||||
# paths of the workers. Also add the current directory. Note that this
|
||||
# assumes that the directory structures on the machines in the clusters are
|
||||
|
@ -747,14 +824,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
|
|||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
|
||||
# Export cached functions_to_run.
|
||||
for function in worker.cached_functions_to_run:
|
||||
worker.export_function_to_run_on_all_workers(function)
|
||||
# Export cached remote functions to the workers.
|
||||
for function_name, function_to_export in worker.cached_remote_functions:
|
||||
raylib.export_remote_function(worker.handle, function_name, function_to_export)
|
||||
# Export cached reusable variables to the workers.
|
||||
# Export the cached reusable variables.
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
# Initialize the serialization library.
|
||||
initialize_numbuf()
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions = None
|
||||
reusables._cached_reusables = None
|
||||
|
||||
|
@ -766,6 +847,7 @@ def disconnect(worker=global_worker):
|
|||
# Reset the list of cached remote functions so that if more remote functions
|
||||
# are defined and then connect is called again, the remote functions will be
|
||||
# exported. This is mostly relevant for the tests.
|
||||
worker.cached_functions_to_run = []
|
||||
worker.cached_remote_functions = []
|
||||
reusables._cached_reusables = []
|
||||
|
||||
|
@ -786,6 +868,11 @@ def register_class(cls, pickle=False, worker=global_worker):
|
|||
Exception: An exception is raised if pickle=False and the class cannot be
|
||||
efficiently serialized by Ray.
|
||||
"""
|
||||
# If the worker is not a driver, then return. We do this so that Python
|
||||
# modules can register classes and these modules can be imported on workers
|
||||
# without any trouble.
|
||||
if worker.mode == raylib.WORKER_MODE:
|
||||
return
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
if not pickle:
|
||||
serialization.check_serializable(cls)
|
||||
|
@ -1100,6 +1187,14 @@ def _logger():
|
|||
"""
|
||||
return logger
|
||||
|
||||
def _reusables():
|
||||
"""Return the reusables object.
|
||||
|
||||
We use this wrapper because so that functions which use the reusables variable
|
||||
can be pickled.
|
||||
"""
|
||||
return reusables
|
||||
|
||||
def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
"""Export a reusable variable to the workers. This is only called by a driver.
|
||||
|
||||
|
@ -1133,7 +1228,13 @@ def remote(*args, **kwargs):
|
|||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
try:
|
||||
_reusables()._running_remote_function_locally = True
|
||||
result = func(*copy.deepcopy(args))
|
||||
finally:
|
||||
_reusables()._reinitialize()
|
||||
_reusables()._running_remote_function_locally = False
|
||||
return result
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
|
|
|
@ -47,7 +47,6 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
|
||||
reload(module)
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
ray.register_class(da.DistArray)
|
||||
|
||||
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
|
@ -60,7 +59,6 @@ class DistributedArrayTest(unittest.TestCase):
|
|||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
|
||||
reload(module)
|
||||
ray.init(start_ray_local=True, num_objstores=2, num_workers=10)
|
||||
ray.register_class(da.DistArray)
|
||||
|
||||
x = da.zeros.remote([9, 25, 51], "float")
|
||||
assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51]))
|
||||
|
|
113
test/runtest.py
113
test/runtest.py
|
@ -451,6 +451,43 @@ class APITest(unittest.TestCase):
|
|||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testCachingFunctionsToRun(self):
|
||||
# Test that we export functions to run on all workers before the driver is connected.
|
||||
def f(worker):
|
||||
sys.path.append(1)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
def f(worker):
|
||||
sys.path.append(2)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
def g(worker):
|
||||
sys.path.append(3)
|
||||
ray.worker.global_worker.run_function_on_all_workers(g)
|
||||
def f(worker):
|
||||
sys.path.append(4)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
ray.init(start_ray_local=True, num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
def get_state():
|
||||
time.sleep(1)
|
||||
return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1]
|
||||
|
||||
res1 = get_state.remote()
|
||||
res2 = get_state.remote()
|
||||
self.assertEqual(ray.get(res1), (1, 2, 3, 4))
|
||||
self.assertEqual(ray.get(res2), (1, 2, 3, 4))
|
||||
|
||||
# Clean up the path on the workers.
|
||||
def f(worker):
|
||||
sys.path.pop()
|
||||
sys.path.pop()
|
||||
sys.path.pop()
|
||||
sys.path.pop()
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testRunningFunctionOnAllWorkers(self):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
|
@ -493,7 +530,6 @@ class ReferenceCountingTest(unittest.TestCase):
|
|||
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
|
||||
reload(module)
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
ray.register_class(da.DistArray)
|
||||
|
||||
def check_not_deallocated(object_ids):
|
||||
reference_counts = ray.scheduler_info()["reference_counts"]
|
||||
|
@ -638,6 +674,43 @@ class PythonModeTest(unittest.TestCase):
|
|||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testReusableVariablesInPythonMode(self):
|
||||
reload(test_functions)
|
||||
ray.init(start_ray_local=True, driver_mode=ray.PYTHON_MODE)
|
||||
|
||||
def l_init():
|
||||
return []
|
||||
def l_reinit(l):
|
||||
return []
|
||||
ray.reusables.l = ray.Reusable(l_init, l_reinit)
|
||||
|
||||
@ray.remote
|
||||
def use_l():
|
||||
l = ray.reusables.l
|
||||
l.append(1)
|
||||
return l
|
||||
|
||||
# Get the local copy of the reusable variable. This should be stateful.
|
||||
l = ray.reusables.l
|
||||
assert_equal(l, [])
|
||||
|
||||
# Make sure the remote function does what we expect.
|
||||
assert_equal(ray.get(use_l.remote()), [1])
|
||||
assert_equal(ray.get(use_l.remote()), [1])
|
||||
|
||||
# Make sure the local copy of the reusable variable has not been mutated.
|
||||
assert_equal(l, [])
|
||||
l = ray.reusables.l
|
||||
assert_equal(l, [])
|
||||
|
||||
# Make sure that running a remote function does not reset the state of the
|
||||
# local copy of the reusable variable.
|
||||
l.append(2)
|
||||
assert_equal(ray.get(use_l.remote()), [1])
|
||||
assert_equal(l, [2])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
class PythonCExtensionTest(unittest.TestCase):
|
||||
|
||||
def testReferenceCountNone(self):
|
||||
|
@ -757,6 +830,44 @@ class ReusablesTest(unittest.TestCase):
|
|||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testUsingReusablesOnDriver(self):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Test that we can add a variable to the key-value store.
|
||||
|
||||
def foo_initializer():
|
||||
return []
|
||||
def foo_reinitializer(foo):
|
||||
return []
|
||||
|
||||
ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer)
|
||||
|
||||
@ray.remote
|
||||
def use_foo():
|
||||
foo = ray.reusables.foo
|
||||
foo.append(1)
|
||||
return foo
|
||||
|
||||
# Check that running a remote function does not reset the reusable variable
|
||||
# on the driver.
|
||||
foo = ray.reusables.foo
|
||||
self.assertEqual(foo, [])
|
||||
foo.append(2)
|
||||
self.assertEqual(foo, [2])
|
||||
foo.append(3)
|
||||
self.assertEqual(foo, [2, 3])
|
||||
|
||||
self.assertEqual(ray.get(use_foo.remote()), [1])
|
||||
self.assertEqual(ray.get(use_foo.remote()), [1])
|
||||
self.assertEqual(ray.get(use_foo.remote()), [1])
|
||||
|
||||
# Check that the copy of foo on the driver has not changed.
|
||||
self.assertEqual(foo, [2, 3])
|
||||
foo = ray.reusables.foo
|
||||
self.assertEqual(foo, [2, 3])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
class ClusterAttachingTest(unittest.TestCase):
|
||||
|
||||
def testAttachingToCluster(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue