mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Fix bug in which worker import counters were treated incorrectly. (#28)
* Fix bug in which worker import counters were treated incorrectly. * Fix bug in which cached functions-to-run were double counted as exports. This also runs the functions-to-run on the driver only after ray.init is called. * Only define reusable variables locally after ray.init has been called. * Remove flaky reference counting tests. It's not clear that these tests make sense. * Make numbuf pip install verbose. * Export cached reusable variables before cached remote functions. * Fix bug causing the worker to hang sometimes. This happens when the worker is trying to run a task, but it hasn't imported enough imports to run the task, so it continually acquires and releases a lock while checking if it has enough imports. However, for some reason, the import thread is waiting to acquire the same lock and never does so (or takes a very long time to do so). By dropping the lock before sleeping, this makes it easier for other threads to acquire the lock. * Acquire locks using 'with' statements. * Fix possible test failure. * Try to start Redis multiple times with different random ports if the original attempt failed. * Fix test in which we redefine a remote function.
This commit is contained in:
parent
1147c4d34b
commit
90f88af902
5 changed files with 97 additions and 123 deletions
|
@ -41,4 +41,4 @@ elif [[ $platform == "macosx" ]]; then
|
|||
sudo pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 # We use the latest version of cloudpickle because it can serialize named tuples.
|
||||
fi
|
||||
|
||||
sudo pip install --upgrade git+git://github.com/ray-project/numbuf.git@d1974afbab9f0f1bcf8af15a8c476d868ad31aff
|
||||
sudo pip install --upgrade --verbose git+git://github.com/ray-project/numbuf.git@d1974afbab9f0f1bcf8af15a8c476d868ad31aff
|
||||
|
|
|
@ -64,11 +64,37 @@ def cleanup():
|
|||
print("Ray did not shut down properly.")
|
||||
all_processes = []
|
||||
|
||||
def start_redis(port, cleanup=True):
|
||||
def start_redis(num_retries=20, cleanup=True):
|
||||
"""Start a Redis server.
|
||||
|
||||
Args:
|
||||
num_retries (int): The number of times to attempt to start Redis.
|
||||
cleanup (bool): True if using Ray in local mode. If cleanup is true, then
|
||||
this process will be killed by serices.cleanup() when the Python process
|
||||
that imported services exits.
|
||||
|
||||
Returns:
|
||||
The port used by Redis.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if Redis could not be started.
|
||||
"""
|
||||
redis_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../common/thirdparty/redis-3.2.3/src/redis-server")
|
||||
counter = 0
|
||||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Redis failed to start, retrying now.")
|
||||
port = new_port()
|
||||
p = subprocess.Popen([redis_filepath, "--port", str(port), "--loglevel", "warning"])
|
||||
time.sleep(0.1)
|
||||
# Check if Redis successfully started (or at least if it the executable did
|
||||
# not exit within 0.1 seconds).
|
||||
if p.poll() is None:
|
||||
if cleanup:
|
||||
all_processes.append(p)
|
||||
return port
|
||||
counter += 1
|
||||
raise Exception("Couldn't start Redis.")
|
||||
|
||||
def start_local_scheduler(redis_address, plasma_store_name, cleanup=True):
|
||||
local_scheduler_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../photon/build/photon_scheduler")
|
||||
|
@ -150,9 +176,8 @@ def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, worker_path=None
|
|||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "default_worker.py")
|
||||
# Start Redis.
|
||||
redis_port = new_port()
|
||||
redis_port = start_redis(cleanup=True)
|
||||
redis_address = address(node_ip_address, redis_port)
|
||||
start_redis(redis_port, cleanup=True)
|
||||
time.sleep(0.1)
|
||||
# Start Plasma.
|
||||
object_store_name, object_store_manager_name, object_store_manager_port = start_objstore(node_ip_address, redis_address, cleanup=True)
|
||||
|
|
|
@ -217,14 +217,11 @@ class RayReusables(object):
|
|||
self._local_mode_reusables = {}
|
||||
self._cached_reusables = []
|
||||
self._used = set()
|
||||
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_reusables", "_local_mode_reusables", "_cached_reusables", "_used", "_slots", "_create_and_export", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_reusables", "_local_mode_reusables", "_cached_reusables", "_used", "_slots", "_create_reusable_variable", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
||||
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
|
||||
|
||||
def _create_and_export(self, name, reusable):
|
||||
"""Create a reusable variable and add export it to the workers.
|
||||
|
||||
If ray.init has not been called yet, then store the reusable variable and
|
||||
export it later then connect is called.
|
||||
def _create_reusable_variable(self, name, reusable):
|
||||
"""Create a reusable variable locally.
|
||||
|
||||
Args:
|
||||
name (str): The name of the reusable variable.
|
||||
|
@ -233,13 +230,6 @@ class RayReusables(object):
|
|||
"""
|
||||
self._names.add(name)
|
||||
self._reinitializers[name] = reusable.reinitializer
|
||||
# Export the reusable variable to the workers if we are on the driver. If
|
||||
# ray.init has not been called yet, then cache the reusable variable to
|
||||
# export later.
|
||||
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
elif _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
self._reusables[name] = reusable.initializer()
|
||||
# We create a second copy of the reusable variable on the driver to use
|
||||
# inside of remote functions that run locally. This occurs when we start Ray
|
||||
|
@ -313,11 +303,21 @@ class RayReusables(object):
|
|||
reusable = value
|
||||
if not issubclass(type(reusable), Reusable):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
# Create the reusable variable locally, and export it if possible.
|
||||
self._create_and_export(name, reusable)
|
||||
# If ray.init has not been called, cache the reusable variable to export
|
||||
# later. Otherwise, export the reusable variable to the workers and define
|
||||
# it locally.
|
||||
if _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
else:
|
||||
# If we are on the driver, export the reusable variable to all the
|
||||
# workers.
|
||||
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
# Define the reusable variable locally.
|
||||
self._create_reusable_variable(name, reusable)
|
||||
# Create an empty attribute with the name of the reusable variable. This
|
||||
# allows the Python interpreter to do tab complete properly.
|
||||
return object.__setattr__(self, name, None)
|
||||
object.__setattr__(self, name, None)
|
||||
|
||||
def __delattr__(self, name):
|
||||
"""We do not allow attributes of RayReusables to be deleted.
|
||||
|
@ -482,21 +482,6 @@ class Worker(object):
|
|||
|
||||
return task.returns()
|
||||
|
||||
def export_function_to_run_on_all_workers(self, function):
|
||||
"""Export this function and run it on all workers.
|
||||
|
||||
Args:
|
||||
function (Callable): The function to run on all of the workers. It should
|
||||
not take any arguments. If it returns anything, its return values will
|
||||
not be used.
|
||||
"""
|
||||
if self.mode not in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# Run the function on all of the workers.
|
||||
if self.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
self.run_function_on_all_workers(function)
|
||||
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
|
@ -512,13 +497,14 @@ class Worker(object):
|
|||
"""
|
||||
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
||||
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# If ray.init has not been called yet, then cache the function and export it
|
||||
# when connect is called. Otherwise, run the function on all workers.
|
||||
if self.mode is None:
|
||||
self.cached_functions_to_run.append(function)
|
||||
else:
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# Run the function on all workers.
|
||||
function_to_run_id = random_string()
|
||||
key = "FunctionsToRun:{}".format(function_to_run_id)
|
||||
self.redis_client.hmset(key, {"function_id": function_to_run_id,
|
||||
|
@ -685,26 +671,20 @@ def print_error_messages(worker):
|
|||
num_errors_printed = 0
|
||||
|
||||
# Get the exports that occurred before the call to psubscribe.
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
with worker.lock:
|
||||
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
||||
for error_key in error_keys:
|
||||
error_message = worker.redis_client.hget(error_key, "message")
|
||||
print(error_message)
|
||||
num_errors_printed += 1
|
||||
finally:
|
||||
worker.lock.release()
|
||||
|
||||
try:
|
||||
for msg in worker.error_message_pubsub_client.listen():
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
with worker.lock:
|
||||
for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1):
|
||||
error_message = worker.redis_client.hget(error_key, "message")
|
||||
print(error_message)
|
||||
num_errors_printed += 1
|
||||
finally:
|
||||
worker.lock.release()
|
||||
except redis.ConnectionError:
|
||||
# When Redis terminates the listen call will throw a ConnectionError, which
|
||||
# we catch here.
|
||||
|
@ -712,9 +692,10 @@ def print_error_messages(worker):
|
|||
|
||||
def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
"""Import a remote function."""
|
||||
function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["function_id", "name", "function", "num_return_vals", "module", "driver_export_counter"])
|
||||
function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["function_id", "name", "function", "num_return_vals", "module", "function_export_counter"])
|
||||
function_id = photon.ObjectID(function_id_str)
|
||||
num_return_vals = int(num_return_vals)
|
||||
function_export_counter = int(function_export_counter)
|
||||
try:
|
||||
function = pickling.loads(serialized_function)
|
||||
except:
|
||||
|
@ -785,8 +766,7 @@ def import_thread(worker):
|
|||
worker.worker_import_counter = 0
|
||||
|
||||
# Get the exports that occurred before the call to psubscribe.
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
with worker.lock:
|
||||
export_keys = worker.redis_client.lrange("Exports", 0, -1)
|
||||
for key in export_keys:
|
||||
if key.startswith("RemoteFunction"):
|
||||
|
@ -799,12 +779,9 @@ def import_thread(worker):
|
|||
raise Exception("This code should be unreachable.")
|
||||
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
||||
worker.worker_import_counter += 1
|
||||
finally:
|
||||
worker.lock.release()
|
||||
|
||||
for msg in worker.import_pubsub_client.listen():
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
with worker.lock:
|
||||
if msg["type"] == "psubscribe":
|
||||
continue
|
||||
assert msg["data"] == "rpush"
|
||||
|
@ -822,8 +799,6 @@ def import_thread(worker):
|
|||
raise Exception("This code should be unreachable.")
|
||||
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
||||
worker.worker_import_counter += 1
|
||||
finally:
|
||||
worker.lock.release()
|
||||
|
||||
def connect(address_info, mode=WORKER_MODE, worker=global_worker):
|
||||
"""Connect this worker to the scheduler and an object store.
|
||||
|
@ -886,15 +861,25 @@ def connect(address_info, mode=WORKER_MODE, worker=global_worker):
|
|||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
|
||||
# TODO(rkn): Here we first export functions to run, then reusable variables,
|
||||
# then remote functions. The order matters. For example, one of the
|
||||
# functions to run may set the Python path, which is needed to import a
|
||||
# module used to define a reusable variable, which in turn is used inside a
|
||||
# remote function. We may want to change the order to simply be the order in
|
||||
# which the exports were defined on the driver. In addition, we will need to
|
||||
# retain the ability to decide what the first few exports are (mostly to set
|
||||
# the Python path). Additionally, note that the first exports to be defined
|
||||
# on the driver will be the ones defined in separate modules that are
|
||||
# imported by the driver.
|
||||
# Export cached functions_to_run.
|
||||
for function in worker.cached_functions_to_run:
|
||||
worker.export_function_to_run_on_all_workers(function)
|
||||
worker.run_function_on_all_workers(function)
|
||||
# Export cached reusable variables to the workers.
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
reusables.__setattr__(name, reusable_variable)
|
||||
# Export cached remote functions to the workers.
|
||||
for function_id, func_name, func, num_return_vals in worker.cached_remote_functions:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals, worker)
|
||||
# Export cached reusable variables to the workers.
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions = None
|
||||
reusables._cached_reusables = None
|
||||
|
@ -1098,19 +1083,13 @@ def main_loop(worker=global_worker):
|
|||
# Check that the number of imports we have is at least as great as the
|
||||
# export counter for the task. If not, wait until we have imported enough.
|
||||
while True:
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
if worker.functions.has_key(function_id.id()) and worker.function_export_counters[function_id.id()] <= worker.worker_import_counter:
|
||||
with worker.lock:
|
||||
if worker.functions.has_key(function_id.id()) and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter):
|
||||
break
|
||||
time.sleep(0.001)
|
||||
finally:
|
||||
worker.lock.release()
|
||||
# Execute the task.
|
||||
try:
|
||||
worker.lock.acquire()
|
||||
with worker.lock:
|
||||
process_task(task)
|
||||
finally:
|
||||
worker.lock.release()
|
||||
|
||||
def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
|
|
@ -301,7 +301,7 @@ class PlasmaClient(object):
|
|||
break
|
||||
return message_data
|
||||
|
||||
def start_plasma_manager(store_name, manager_name, redis_address, num_retries=5, use_valgrind=False, run_profiler=False):
|
||||
def start_plasma_manager(store_name, manager_name, redis_address, num_retries=20, use_valgrind=False, run_profiler=False):
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
|
||||
Args:
|
||||
|
@ -316,8 +316,7 @@ def start_plasma_manager(store_name, manager_name, redis_address, num_retries=5,
|
|||
listening on.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the manager could not be properly
|
||||
started.
|
||||
Exception: An exception is raised if the manager could not be started.
|
||||
"""
|
||||
plasma_manager_executable = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../build/plasma_manager")
|
||||
port = None
|
||||
|
|
|
@ -266,7 +266,13 @@ class APITest(unittest.TestCase):
|
|||
@ray.remote
|
||||
def f(x):
|
||||
return x + 10
|
||||
self.assertEqual(ray.get(f.remote(0)), 10)
|
||||
while True:
|
||||
val = ray.get(f.remote(0))
|
||||
self.assertTrue((val == 10) or (val == 1))
|
||||
if val == 10:
|
||||
break
|
||||
else:
|
||||
print("Still using old definition of f, trying again.")
|
||||
|
||||
# Test that we can close over plain old data.
|
||||
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
|
||||
|
@ -411,13 +417,19 @@ class APITest(unittest.TestCase):
|
|||
sys.path.append("fake_directory")
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
@ray.remote
|
||||
def get_path():
|
||||
def get_path1():
|
||||
return sys.path
|
||||
self.assertEqual("fake_directory", ray.get(get_path.remote())[-1])
|
||||
self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1])
|
||||
def f(worker):
|
||||
sys.path.pop(-1)
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
self.assertTrue("fake_directory" not in ray.get(get_path.remote()))
|
||||
# Create a second remote function to guarantee that when we call
|
||||
# get_path2.remote(), the second function to run will have been run on the
|
||||
# worker.
|
||||
@ray.remote
|
||||
def get_path2():
|
||||
return sys.path
|
||||
self.assertTrue("fake_directory" not in ray.get(get_path2.remote()))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
|
@ -483,47 +495,6 @@ class PythonModeTest(unittest.TestCase):
|
|||
|
||||
ray.worker.cleanup()
|
||||
|
||||
class PythonCExtensionTest(unittest.TestCase):
|
||||
|
||||
# def testReferenceCountNone(self):
|
||||
# ray.init(start_ray_local=True, num_workers=1)
|
||||
#
|
||||
# # Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
# @ray.remote
|
||||
# def f():
|
||||
# return sys.getrefcount(None)
|
||||
# first_count = ray.get(f.remote())
|
||||
# second_count = ray.get(f.remote())
|
||||
# self.assertEqual(first_count, second_count)
|
||||
#
|
||||
# ray.worker.cleanup()
|
||||
|
||||
def testReferenceCountTrue(self):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
@ray.remote
|
||||
def f():
|
||||
return sys.getrefcount(True)
|
||||
first_count = ray.get(f.remote())
|
||||
second_count = ray.get(f.remote())
|
||||
self.assertEqual(first_count, second_count)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testReferenceCountFalse(self):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
@ray.remote
|
||||
def f():
|
||||
return sys.getrefcount(False)
|
||||
first_count = ray.get(f.remote())
|
||||
second_count = ray.get(f.remote())
|
||||
self.assertEqual(first_count, second_count)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
class ReusablesTest(unittest.TestCase):
|
||||
|
||||
def testReusables(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue