Fix the resetting of reusable variables on the driver and cache functions to run on all workers. (#446)

* Properly reset reusable variables on the driver when remote functions are run locally on the driver.

* Cache functions to run on all workers that occur before ray.init is called.
This commit is contained in:
Robert Nishihara 2016-10-12 22:17:22 -07:00 committed by Philipp Moritz
parent 1c3aaf7189
commit 0a44145906
4 changed files with 245 additions and 32 deletions

View file

@ -57,6 +57,9 @@ class DistArray(object):
a = self.assemble()
return a[sliced]
# Register the DistArray class with Ray so that it knows how to serialize it.
ray.register_class(DistArray)
@ray.remote
def assemble(a):
return a.assemble()

View file

@ -163,8 +163,21 @@ class RayReusables(object):
Attributes:
_names (List[str]): A list of the names of all the reusable variables.
_reusables (Dict[str, Reusable]): A dictionary mapping the name of the
reusable variables to the corresponding Reusable object.
_reinitializers (Dict[str, Callable]): A dictionary mapping the name of the
reusable variables to the corresponding reinitializer.
_running_remote_function_locally (bool): A flag used to indicate if a remote
function is running locally on the driver so that we can simulate the same
behavior as running a remote function remotely.
_reusables: A dictionary mapping the name of a reusable variable to the
value of the reusable variable.
_local_mode_reusables: A copy of _reusables used on the driver when running
remote functions locally on the driver. This is needed because there are
two ways in which reusable variables can be used on the driver. The first
is that the driver's copy can be manipulated. This copy is never reset
(think of the driver as a single long-running task). The second way is
that a remote function can be run locally on the driver, and this remote
function needs access to a copy of the reusable variable, and that copy
must be reinitialized after use.
_cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first
element of each pair is the name of a reusable variable, and the second
element is the Reusable object. This list is used to store reusable
@ -178,18 +191,54 @@ class RayReusables(object):
def __init__(self):
"""Initialize a RayReusables object."""
self._names = set()
self._reinitializers = {}
self._running_remote_function_locally = False
self._reusables = {}
self._local_mode_reusables = {}
self._cached_reusables = []
self._used = set()
self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_reusables", "_local_mode_reusables", "_cached_reusables", "_used", "_slots", "_create_and_export", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
def _create_and_export(self, name, reusable):
"""Create a reusable variable and add export it to the workers.
If ray.init has not been called yet, then store the reusable variable and
export it later then connect is called.
Args:
name (str): The name of the reusable variable.
reusable (Reusable): The reusable object to use to create the reusable
variable.
"""
self._names.add(name)
self._reinitializers[name] = reusable.reinitializer
# Export the reusable variable to the workers if we are on the driver. If
# ray.init has not been called yet, then cache the reusable variable to
# export later.
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
_export_reusable_variable(name, reusable)
elif _mode() is None:
self._cached_reusables.append((name, reusable))
self._reusables[name] = reusable.initializer()
# We create a second copy of the reusable variable on the driver to use
# inside of remote functions that run locally. This occurs when we start Ray
# in PYTHON_MODE and when we call a remote function locally.
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
self._local_mode_reusables[name] = reusable.initializer()
def _reinitialize(self):
"""Reinitialize the reusable variables that the current task used."""
for name in self._used:
current_value = getattr(self, name)
new_value = self._reusables[name].reinitializer(current_value)
object.__setattr__(self, name, new_value)
current_value = self._reusables[name]
new_value = self._reinitializers[name](current_value)
# If we are on the driver, reset the copy of the reusable variable in the
# _local_mode_reusables dictionary.
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
assert self._running_remote_function_locally
self._local_mode_reusables[name] = new_value
else:
self._reusables[name] = new_value
self._used.clear() # Reset the _used list.
def __getattribute__(self, name):
@ -205,9 +254,16 @@ class RayReusables(object):
return object.__getattribute__(self, name)
if name in self._slots:
return object.__getattribute__(self, name)
# Handle various fields that are not reusable variables.
if name not in self._names:
return object.__getattribute__(self, name)
# Make a note of the fact that the reusable variable has been used.
if name in self._names and name not in self._used:
self._used.add(name)
return object.__getattribute__(self, name)
if self._running_remote_function_locally:
return self._local_mode_reusables[name]
else:
return self._reusables[name]
def __setattr__(self, name, value):
"""Set an attribute. This handles reusable variables as a special case.
@ -217,13 +273,14 @@ class RayReusables(object):
called on the driver, then the functions for initializing and reinitializing
the variable are shipped to the workers.
If this is called before ray.init has been run, then the reusable variable
will be cached and it will be created and exported when connect is called.
Args:
name (str): The name of the attribute to set. This is either a whitelisted
name or it is treated as the name of a reusable variable.
value: If name is a whitelisted name, then value can be any value. If name
is the name of a reusable variable, then this is either the serialized
initializer code or it is a tuple of the serialized initializer and
reinitializer code.
is the name of a reusable variable, then this is a Reusable object.
"""
try:
slots = self._slots
@ -236,13 +293,11 @@ class RayReusables(object):
reusable = value
if not issubclass(type(reusable), Reusable):
raise Exception("To set a reusable variable, you must pass in a Reusable object")
self._names.add(name)
self._reusables[name] = reusable
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
_export_reusable_variable(name, reusable)
elif _mode() is None:
self._cached_reusables.append((name, reusable))
object.__setattr__(self, name, reusable.initializer())
# Create the reusable variable locally, and export it if possible.
self._create_and_export(name, reusable)
# Create an empty attribute with the name of the reusable variable. This
# allows the Python interpreter to do tab complete properly.
return object.__setattr__(self, name, None)
def __delattr__(self, name):
"""We do not allow attributes of RayReusables to be deleted.
@ -299,6 +354,8 @@ class Worker(object):
eventually does call connect, if it is a driver, it will export these
functions to the scheduler. If cached_remote_functions is None, that means
that connect has been called already.
cached_functions_to_run (List): A list of functions to run on all of the
workers that should be exported as soon as connect is called.
"""
def __init__(self):
@ -307,6 +364,7 @@ class Worker(object):
self.handle = None
self.mode = None
self.cached_remote_functions = []
self.cached_functions_to_run = []
def set_mode(self, mode):
"""Set the mode of the worker.
@ -430,12 +488,8 @@ class Worker(object):
objectids = raylib.submit_task(self.handle, task_capsule)
return objectids
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be exported
to all of the workers to be run. It will also be run on any new workers that
register later.
def export_function_to_run_on_all_workers(self, function):
"""Export this function and run it on all workers.
Args:
function (Callable): The function to run on all of the workers. It should
@ -444,12 +498,35 @@ class Worker(object):
"""
if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a driver.")
# First run the function on the driver.
function(self)
# Then run the function on all of the workers.
# Run the function on all of the workers.
if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be exported
to all of the workers to be run. It will also be run on any new workers that
register later. If ray.init has not been called yet, then cache the function
and export it later.
Args:
function (Callable): The function to run on all of the workers. It should
not take any arguments. If it returns anything, its return values will
not be used.
"""
if self.mode not in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a driver.")
# First run the function on the driver.
function(self)
# If ray.init has not been called yet, then cache the function and export it
# when connect is called. Otherwise, run the function on all workers.
if self.mode is None:
self.cached_functions_to_run.append(function)
else:
self.export_function_to_run_on_all_workers(function)
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
@ -738,7 +815,7 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
_logger().addHandler(log_handler)
_logger().setLevel(logging.DEBUG)
_logger().propagate = False
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
# assumes that the directory structures on the machines in the clusters are
@ -747,14 +824,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
# Export cached functions_to_run.
for function in worker.cached_functions_to_run:
worker.export_function_to_run_on_all_workers(function)
# Export cached remote functions to the workers.
for function_name, function_to_export in worker.cached_remote_functions:
raylib.export_remote_function(worker.handle, function_name, function_to_export)
# Export cached reusable variables to the workers.
# Export the cached reusable variables.
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
# Initialize the serialization library.
initialize_numbuf()
worker.cached_functions_to_run = None
worker.cached_remote_functions = None
reusables._cached_reusables = None
@ -766,6 +847,7 @@ def disconnect(worker=global_worker):
# Reset the list of cached remote functions so that if more remote functions
# are defined and then connect is called again, the remote functions will be
# exported. This is mostly relevant for the tests.
worker.cached_functions_to_run = []
worker.cached_remote_functions = []
reusables._cached_reusables = []
@ -786,6 +868,11 @@ def register_class(cls, pickle=False, worker=global_worker):
Exception: An exception is raised if pickle=False and the class cannot be
efficiently serialized by Ray.
"""
# If the worker is not a driver, then return. We do this so that Python
# modules can register classes and these modules can be imported on workers
# without any trouble.
if worker.mode == raylib.WORKER_MODE:
return
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
@ -1100,6 +1187,14 @@ def _logger():
"""
return logger
def _reusables():
"""Return the reusables object.
We use this wrapper because so that functions which use the reusables variable
can be pickled.
"""
return reusables
def _export_reusable_variable(name, reusable, worker=global_worker):
"""Export a reusable variable to the workers. This is only called by a driver.
@ -1133,7 +1228,13 @@ def remote(*args, **kwargs):
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
try:
_reusables()._running_remote_function_locally = True
result = func(*copy.deepcopy(args))
finally:
_reusables()._reinitialize()
_reusables()._running_remote_function_locally = False
return result
objectids = _submit_task(func_name, args)
if len(objectids) == 1:
return objectids[0]

View file

@ -47,7 +47,6 @@ class DistributedArrayTest(unittest.TestCase):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
reload(module)
ray.init(start_ray_local=True, num_workers=1)
ray.register_class(da.DistArray)
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
@ -60,7 +59,6 @@ class DistributedArrayTest(unittest.TestCase):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
reload(module)
ray.init(start_ray_local=True, num_objstores=2, num_workers=10)
ray.register_class(da.DistArray)
x = da.zeros.remote([9, 25, 51], "float")
assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51]))

View file

@ -451,6 +451,43 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testCachingFunctionsToRun(self):
# Test that we export functions to run on all workers before the driver is connected.
def f(worker):
sys.path.append(1)
ray.worker.global_worker.run_function_on_all_workers(f)
def f(worker):
sys.path.append(2)
ray.worker.global_worker.run_function_on_all_workers(f)
def g(worker):
sys.path.append(3)
ray.worker.global_worker.run_function_on_all_workers(g)
def f(worker):
sys.path.append(4)
ray.worker.global_worker.run_function_on_all_workers(f)
ray.init(start_ray_local=True, num_workers=2)
@ray.remote
def get_state():
time.sleep(1)
return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1]
res1 = get_state.remote()
res2 = get_state.remote()
self.assertEqual(ray.get(res1), (1, 2, 3, 4))
self.assertEqual(ray.get(res2), (1, 2, 3, 4))
# Clean up the path on the workers.
def f(worker):
sys.path.pop()
sys.path.pop()
sys.path.pop()
sys.path.pop()
ray.worker.global_worker.run_function_on_all_workers(f)
ray.worker.cleanup()
def testRunningFunctionOnAllWorkers(self):
ray.init(start_ray_local=True, num_workers=1)
@ -493,7 +530,6 @@ class ReferenceCountingTest(unittest.TestCase):
for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]:
reload(module)
ray.init(start_ray_local=True, num_workers=1)
ray.register_class(da.DistArray)
def check_not_deallocated(object_ids):
reference_counts = ray.scheduler_info()["reference_counts"]
@ -638,6 +674,43 @@ class PythonModeTest(unittest.TestCase):
ray.worker.cleanup()
def testReusableVariablesInPythonMode(self):
reload(test_functions)
ray.init(start_ray_local=True, driver_mode=ray.PYTHON_MODE)
def l_init():
return []
def l_reinit(l):
return []
ray.reusables.l = ray.Reusable(l_init, l_reinit)
@ray.remote
def use_l():
l = ray.reusables.l
l.append(1)
return l
# Get the local copy of the reusable variable. This should be stateful.
l = ray.reusables.l
assert_equal(l, [])
# Make sure the remote function does what we expect.
assert_equal(ray.get(use_l.remote()), [1])
assert_equal(ray.get(use_l.remote()), [1])
# Make sure the local copy of the reusable variable has not been mutated.
assert_equal(l, [])
l = ray.reusables.l
assert_equal(l, [])
# Make sure that running a remote function does not reset the state of the
# local copy of the reusable variable.
l.append(2)
assert_equal(ray.get(use_l.remote()), [1])
assert_equal(l, [2])
ray.worker.cleanup()
class PythonCExtensionTest(unittest.TestCase):
def testReferenceCountNone(self):
@ -757,6 +830,44 @@ class ReusablesTest(unittest.TestCase):
ray.worker.cleanup()
def testUsingReusablesOnDriver(self):
ray.init(start_ray_local=True, num_workers=1)
# Test that we can add a variable to the key-value store.
def foo_initializer():
return []
def foo_reinitializer(foo):
return []
ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer)
@ray.remote
def use_foo():
foo = ray.reusables.foo
foo.append(1)
return foo
# Check that running a remote function does not reset the reusable variable
# on the driver.
foo = ray.reusables.foo
self.assertEqual(foo, [])
foo.append(2)
self.assertEqual(foo, [2])
foo.append(3)
self.assertEqual(foo, [2, 3])
self.assertEqual(ray.get(use_foo.remote()), [1])
self.assertEqual(ray.get(use_foo.remote()), [1])
self.assertEqual(ray.get(use_foo.remote()), [1])
# Check that the copy of foo on the driver has not changed.
self.assertEqual(foo, [2, 3])
foo = ray.reusables.foo
self.assertEqual(foo, [2, 3])
ray.worker.cleanup()
class ClusterAttachingTest(unittest.TestCase):
def testAttachingToCluster(self):