diff --git a/lib/python/ray/array/distributed/core.py b/lib/python/ray/array/distributed/core.py index a17e9f03e..6481762a6 100644 --- a/lib/python/ray/array/distributed/core.py +++ b/lib/python/ray/array/distributed/core.py @@ -57,6 +57,9 @@ class DistArray(object): a = self.assemble() return a[sliced] +# Register the DistArray class with Ray so that it knows how to serialize it. +ray.register_class(DistArray) + @ray.remote def assemble(a): return a.assemble() diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index cd2a90304..4a6e4189e 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -163,8 +163,21 @@ class RayReusables(object): Attributes: _names (List[str]): A list of the names of all the reusable variables. - _reusables (Dict[str, Reusable]): A dictionary mapping the name of the - reusable variables to the corresponding Reusable object. + _reinitializers (Dict[str, Callable]): A dictionary mapping the name of the + reusable variables to the corresponding reinitializer. + _running_remote_function_locally (bool): A flag used to indicate if a remote + function is running locally on the driver so that we can simulate the same + behavior as running a remote function remotely. + _reusables: A dictionary mapping the name of a reusable variable to the + value of the reusable variable. + _local_mode_reusables: A copy of _reusables used on the driver when running + remote functions locally on the driver. This is needed because there are + two ways in which reusable variables can be used on the driver. The first + is that the driver's copy can be manipulated. This copy is never reset + (think of the driver as a single long-running task). The second way is + that a remote function can be run locally on the driver, and this remote + function needs access to a copy of the reusable variable, and that copy + must be reinitialized after use. _cached_reusables (List[Tuple[str, Reusable]]): A list of pairs. The first element of each pair is the name of a reusable variable, and the second element is the Reusable object. This list is used to store reusable @@ -178,18 +191,54 @@ class RayReusables(object): def __init__(self): """Initialize a RayReusables object.""" self._names = set() + self._reinitializers = {} + self._running_remote_function_locally = False self._reusables = {} + self._local_mode_reusables = {} self._cached_reusables = [] self._used = set() - self._slots = ("_names", "_reusables", "_cached_reusables", "_used", "_slots", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__") + self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_reusables", "_local_mode_reusables", "_cached_reusables", "_used", "_slots", "_create_and_export", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__") # CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion. + def _create_and_export(self, name, reusable): + """Create a reusable variable and add export it to the workers. + + If ray.init has not been called yet, then store the reusable variable and + export it later then connect is called. + + Args: + name (str): The name of the reusable variable. + reusable (Reusable): The reusable object to use to create the reusable + variable. + """ + self._names.add(name) + self._reinitializers[name] = reusable.reinitializer + # Export the reusable variable to the workers if we are on the driver. If + # ray.init has not been called yet, then cache the reusable variable to + # export later. + if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + _export_reusable_variable(name, reusable) + elif _mode() is None: + self._cached_reusables.append((name, reusable)) + self._reusables[name] = reusable.initializer() + # We create a second copy of the reusable variable on the driver to use + # inside of remote functions that run locally. This occurs when we start Ray + # in PYTHON_MODE and when we call a remote function locally. + if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: + self._local_mode_reusables[name] = reusable.initializer() + def _reinitialize(self): """Reinitialize the reusable variables that the current task used.""" for name in self._used: - current_value = getattr(self, name) - new_value = self._reusables[name].reinitializer(current_value) - object.__setattr__(self, name, new_value) + current_value = self._reusables[name] + new_value = self._reinitializers[name](current_value) + # If we are on the driver, reset the copy of the reusable variable in the + # _local_mode_reusables dictionary. + if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: + assert self._running_remote_function_locally + self._local_mode_reusables[name] = new_value + else: + self._reusables[name] = new_value self._used.clear() # Reset the _used list. def __getattribute__(self, name): @@ -205,9 +254,16 @@ class RayReusables(object): return object.__getattribute__(self, name) if name in self._slots: return object.__getattribute__(self, name) + # Handle various fields that are not reusable variables. + if name not in self._names: + return object.__getattribute__(self, name) + # Make a note of the fact that the reusable variable has been used. if name in self._names and name not in self._used: self._used.add(name) - return object.__getattribute__(self, name) + if self._running_remote_function_locally: + return self._local_mode_reusables[name] + else: + return self._reusables[name] def __setattr__(self, name, value): """Set an attribute. This handles reusable variables as a special case. @@ -217,13 +273,14 @@ class RayReusables(object): called on the driver, then the functions for initializing and reinitializing the variable are shipped to the workers. + If this is called before ray.init has been run, then the reusable variable + will be cached and it will be created and exported when connect is called. + Args: name (str): The name of the attribute to set. This is either a whitelisted name or it is treated as the name of a reusable variable. value: If name is a whitelisted name, then value can be any value. If name - is the name of a reusable variable, then this is either the serialized - initializer code or it is a tuple of the serialized initializer and - reinitializer code. + is the name of a reusable variable, then this is a Reusable object. """ try: slots = self._slots @@ -236,13 +293,11 @@ class RayReusables(object): reusable = value if not issubclass(type(reusable), Reusable): raise Exception("To set a reusable variable, you must pass in a Reusable object") - self._names.add(name) - self._reusables[name] = reusable - if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: - _export_reusable_variable(name, reusable) - elif _mode() is None: - self._cached_reusables.append((name, reusable)) - object.__setattr__(self, name, reusable.initializer()) + # Create the reusable variable locally, and export it if possible. + self._create_and_export(name, reusable) + # Create an empty attribute with the name of the reusable variable. This + # allows the Python interpreter to do tab complete properly. + return object.__setattr__(self, name, None) def __delattr__(self, name): """We do not allow attributes of RayReusables to be deleted. @@ -299,6 +354,8 @@ class Worker(object): eventually does call connect, if it is a driver, it will export these functions to the scheduler. If cached_remote_functions is None, that means that connect has been called already. + cached_functions_to_run (List): A list of functions to run on all of the + workers that should be exported as soon as connect is called. """ def __init__(self): @@ -307,6 +364,7 @@ class Worker(object): self.handle = None self.mode = None self.cached_remote_functions = [] + self.cached_functions_to_run = [] def set_mode(self, mode): """Set the mode of the worker. @@ -430,12 +488,8 @@ class Worker(object): objectids = raylib.submit_task(self.handle, task_capsule) return objectids - def run_function_on_all_workers(self, function): - """Run arbitrary code on all of the workers. - - This function will first be run on the driver, and then it will be exported - to all of the workers to be run. It will also be run on any new workers that - register later. + def export_function_to_run_on_all_workers(self, function): + """Export this function and run it on all workers. Args: function (Callable): The function to run on all of the workers. It should @@ -444,12 +498,35 @@ class Worker(object): """ if self.mode not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: raise Exception("run_function_on_all_workers can only be called on a driver.") - # First run the function on the driver. - function(self) - # Then run the function on all of the workers. + # Run the function on all of the workers. if self.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: raylib.run_function_on_all_workers(self.handle, pickling.dumps(function)) + + def run_function_on_all_workers(self, function): + """Run arbitrary code on all of the workers. + + This function will first be run on the driver, and then it will be exported + to all of the workers to be run. It will also be run on any new workers that + register later. If ray.init has not been called yet, then cache the function + and export it later. + + Args: + function (Callable): The function to run on all of the workers. It should + not take any arguments. If it returns anything, its return values will + not be used. + """ + if self.mode not in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: + raise Exception("run_function_on_all_workers can only be called on a driver.") + # First run the function on the driver. + function(self) + # If ray.init has not been called yet, then cache the function and export it + # when connect is called. Otherwise, run the function on all workers. + if self.mode is None: + self.cached_functions_to_run.append(function) + else: + self.export_function_to_run_on_all_workers(function) + global_worker = Worker() """Worker: The global Worker object for this worker process. @@ -738,7 +815,7 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl _logger().addHandler(log_handler) _logger().setLevel(logging.DEBUG) _logger().propagate = False - if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE, raylib.PYTHON_MODE]: # Add the directory containing the script that is running to the Python # paths of the workers. Also add the current directory. Note that this # assumes that the directory structures on the machines in the clusters are @@ -747,14 +824,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl current_directory = os.path.abspath(os.path.curdir) worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory)) worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory)) + # Export cached functions_to_run. + for function in worker.cached_functions_to_run: + worker.export_function_to_run_on_all_workers(function) # Export cached remote functions to the workers. for function_name, function_to_export in worker.cached_remote_functions: raylib.export_remote_function(worker.handle, function_name, function_to_export) - # Export cached reusable variables to the workers. + # Export the cached reusable variables. for name, reusable_variable in reusables._cached_reusables: _export_reusable_variable(name, reusable_variable) # Initialize the serialization library. initialize_numbuf() + worker.cached_functions_to_run = None worker.cached_remote_functions = None reusables._cached_reusables = None @@ -766,6 +847,7 @@ def disconnect(worker=global_worker): # Reset the list of cached remote functions so that if more remote functions # are defined and then connect is called again, the remote functions will be # exported. This is mostly relevant for the tests. + worker.cached_functions_to_run = [] worker.cached_remote_functions = [] reusables._cached_reusables = [] @@ -786,6 +868,11 @@ def register_class(cls, pickle=False, worker=global_worker): Exception: An exception is raised if pickle=False and the class cannot be efficiently serialized by Ray. """ + # If the worker is not a driver, then return. We do this so that Python + # modules can register classes and these modules can be imported on workers + # without any trouble. + if worker.mode == raylib.WORKER_MODE: + return # Raise an exception if cls cannot be serialized efficiently by Ray. if not pickle: serialization.check_serializable(cls) @@ -1100,6 +1187,14 @@ def _logger(): """ return logger +def _reusables(): + """Return the reusables object. + + We use this wrapper because so that functions which use the reusables variable + can be pickled. + """ + return reusables + def _export_reusable_variable(name, reusable, worker=global_worker): """Export a reusable variable to the workers. This is only called by a driver. @@ -1133,7 +1228,13 @@ def remote(*args, **kwargs): # In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the # arguments to prevent the function call from mutating them and to match # the usual behavior of immutable remote objects. - return func(*copy.deepcopy(args)) + try: + _reusables()._running_remote_function_locally = True + result = func(*copy.deepcopy(args)) + finally: + _reusables()._reinitialize() + _reusables()._running_remote_function_locally = False + return result objectids = _submit_task(func_name, args) if len(objectids) == 1: return objectids[0] diff --git a/test/array_test.py b/test/array_test.py index c33579d02..5d00cbe61 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -47,7 +47,6 @@ class DistributedArrayTest(unittest.TestCase): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_workers=1) - ray.register_class(da.DistArray) a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) @@ -60,7 +59,6 @@ class DistributedArrayTest(unittest.TestCase): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_objstores=2, num_workers=10) - ray.register_class(da.DistArray) x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) diff --git a/test/runtest.py b/test/runtest.py index a88df8cc1..03b502143 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -451,6 +451,43 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testCachingFunctionsToRun(self): + # Test that we export functions to run on all workers before the driver is connected. + def f(worker): + sys.path.append(1) + ray.worker.global_worker.run_function_on_all_workers(f) + def f(worker): + sys.path.append(2) + ray.worker.global_worker.run_function_on_all_workers(f) + def g(worker): + sys.path.append(3) + ray.worker.global_worker.run_function_on_all_workers(g) + def f(worker): + sys.path.append(4) + ray.worker.global_worker.run_function_on_all_workers(f) + + ray.init(start_ray_local=True, num_workers=2) + + @ray.remote + def get_state(): + time.sleep(1) + return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1] + + res1 = get_state.remote() + res2 = get_state.remote() + self.assertEqual(ray.get(res1), (1, 2, 3, 4)) + self.assertEqual(ray.get(res2), (1, 2, 3, 4)) + + # Clean up the path on the workers. + def f(worker): + sys.path.pop() + sys.path.pop() + sys.path.pop() + sys.path.pop() + ray.worker.global_worker.run_function_on_all_workers(f) + + ray.worker.cleanup() + def testRunningFunctionOnAllWorkers(self): ray.init(start_ray_local=True, num_workers=1) @@ -493,7 +530,6 @@ class ReferenceCountingTest(unittest.TestCase): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) ray.init(start_ray_local=True, num_workers=1) - ray.register_class(da.DistArray) def check_not_deallocated(object_ids): reference_counts = ray.scheduler_info()["reference_counts"] @@ -638,6 +674,43 @@ class PythonModeTest(unittest.TestCase): ray.worker.cleanup() + def testReusableVariablesInPythonMode(self): + reload(test_functions) + ray.init(start_ray_local=True, driver_mode=ray.PYTHON_MODE) + + def l_init(): + return [] + def l_reinit(l): + return [] + ray.reusables.l = ray.Reusable(l_init, l_reinit) + + @ray.remote + def use_l(): + l = ray.reusables.l + l.append(1) + return l + + # Get the local copy of the reusable variable. This should be stateful. + l = ray.reusables.l + assert_equal(l, []) + + # Make sure the remote function does what we expect. + assert_equal(ray.get(use_l.remote()), [1]) + assert_equal(ray.get(use_l.remote()), [1]) + + # Make sure the local copy of the reusable variable has not been mutated. + assert_equal(l, []) + l = ray.reusables.l + assert_equal(l, []) + + # Make sure that running a remote function does not reset the state of the + # local copy of the reusable variable. + l.append(2) + assert_equal(ray.get(use_l.remote()), [1]) + assert_equal(l, [2]) + + ray.worker.cleanup() + class PythonCExtensionTest(unittest.TestCase): def testReferenceCountNone(self): @@ -757,6 +830,44 @@ class ReusablesTest(unittest.TestCase): ray.worker.cleanup() + def testUsingReusablesOnDriver(self): + ray.init(start_ray_local=True, num_workers=1) + + # Test that we can add a variable to the key-value store. + + def foo_initializer(): + return [] + def foo_reinitializer(foo): + return [] + + ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer) + + @ray.remote + def use_foo(): + foo = ray.reusables.foo + foo.append(1) + return foo + + # Check that running a remote function does not reset the reusable variable + # on the driver. + foo = ray.reusables.foo + self.assertEqual(foo, []) + foo.append(2) + self.assertEqual(foo, [2]) + foo.append(3) + self.assertEqual(foo, [2, 3]) + + self.assertEqual(ray.get(use_foo.remote()), [1]) + self.assertEqual(ray.get(use_foo.remote()), [1]) + self.assertEqual(ray.get(use_foo.remote()), [1]) + + # Check that the copy of foo on the driver has not changed. + self.assertEqual(foo, [2, 3]) + foo = ray.reusables.foo + self.assertEqual(foo, [2, 3]) + + ray.worker.cleanup() + class ClusterAttachingTest(unittest.TestCase): def testAttachingToCluster(self):