Print warning when defining very large remote function or actor. (#2179)

* Print warning when defining very large remote function or actor.

* Add weak test.

* Check that warnings appear in test.

* Make wait_for_errors actually fail in failure_test.py.

* Use constants for error types.

* Fix
This commit is contained in:
Robert Nishihara 2018-06-09 19:59:15 -07:00 committed by Philipp Moritz
parent 1475600c81
commit 125fe1c09c
7 changed files with 161 additions and 56 deletions

View file

@ -10,6 +10,7 @@ import traceback
import ray.cloudpickle as pickle
import ray.local_scheduler
import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray.utils import _random_string, is_cython, push_error_to_driver
@ -164,7 +165,7 @@ def save_and_log_checkpoint(worker, actor):
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
data={
@ -188,7 +189,7 @@ def restore_and_log_checkpoint(worker, actor):
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
data={
@ -330,7 +331,7 @@ def fetch_and_register_actor(actor_class_key, worker):
# Log the error message.
push_error_to_driver(
worker.redis_client,
"register_actor_signatures",
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
@ -392,6 +393,20 @@ def export_actor_class(class_id, Class, actor_method_names,
"actor_method_names": json.dumps(list(actor_method_names))
}
if (len(actor_class_info["class"]) >
ray_constants.PICKLE_OBJECT_WARNING_SIZE):
warning_message = ("Warning: The actor {} has size {} when pickled. "
"It will be stored in Redis, which could cause "
"memory issues. This may mean that the actor "
"definition uses a large array or other object."
.format(actor_class_info["class_name"],
len(actor_class_info["class"])))
ray.utils.push_error_to_driver(
worker.redis_client,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())
if worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we must
# cache the actor class definition and export it when 'ray.init()' is

View file

@ -12,6 +12,27 @@ def env_integer(key, default):
return default
# Different types of Ray errors that can be pushed to the driver.
# TODO(rkn): These should be defined in flatbuffers and must be synced with
# the existing C++ definitions.
WAIT_FOR_CLASS_PUSH_ERROR = "wait_for_class"
PICKLING_LARGE_OBJECT_PUSH_ERROR = "pickling_large_object"
WAIT_FOR_FUNCTION_PUSH_ERROR = "wait_for_function"
TASK_PUSH_ERROR = "task"
REGISTER_REMOTE_FUNCTION_PUSH_ERROR = "register_remote_function"
FUNCTION_TO_RUN_PUSH_ERROR = "function_to_run"
VERSION_MISMATCH_PUSH_ERROR = "version_mismatch"
CHECKPOINT_PUSH_ERROR = "checkpoint"
REGISTER_ACTOR_PUSH_ERROR = "register_actor"
WORKER_CRASH_PUSH_ERROR = "worker_crash"
WORKER_DIED_PUSH_ERROR = "worker_died"
PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction"
HASH_MISMATCH_PUSH_ERROR = "object_hash_mismatch"
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
PICKLE_OBJECT_WARNING_SIZE = 10**7
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
AUTOSCALER_MAX_NUM_FAILURES = env_integer("AUTOSCALER_MAX_NUM_FAILURES", 5)

View file

@ -28,6 +28,7 @@ import ray.services as services
import ray.signature
import ray.local_scheduler
import ray.plasma
import ray.ray_constants as ray_constants
from ray.utils import random_string, binary_to_hex, is_cython
# Import flatbuffer bindings.
@ -415,7 +416,7 @@ class Worker(object):
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
"wait_for_class",
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
warning_sent = True
@ -637,6 +638,19 @@ class Worker(object):
else:
del function.__globals__[function.__name__]
if len(pickled_function) > ray_constants.PICKLE_OBJECT_WARNING_SIZE:
warning_message = ("Warning: The remote function {} has size {} "
"when pickled. It will be stored in Redis, "
"which could cause memory issues. This may "
"mean that the function definition uses a "
"large array or other object.".format(
function_name, len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
self.redis_client.hmset(
key, {
"driver_id": self.task_driver_id.id(),
@ -684,6 +698,22 @@ class Worker(object):
# In this case, the function has already been exported, so
# we don't need to export it again.
return
if (len(pickled_function) >
ray_constants.PICKLE_OBJECT_WARNING_SIZE):
warning_message = ("Warning: The function {} has size {} when "
"pickled. It will be stored in Redis, "
"which could cause memory issues. This may "
"mean that the remote function definition "
"uses a large array or other object."
.format(function.__name__,
len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
# Run the function on all workers.
self.redis_client.hmset(
key, {
@ -735,7 +765,7 @@ class Worker(object):
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
"wait_for_function",
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
warning_sent = True
@ -896,7 +926,7 @@ class Worker(object):
# Log the error message.
ray.utils.push_error_to_driver(
self.redis_client,
"task",
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id.id(),
data={
@ -1132,6 +1162,11 @@ def error_info(worker=global_worker):
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_contents = worker.redis_client.hgetall(error_key)
error_contents = {
"type": error_contents[b"type"].decode("ascii"),
"message": error_contents[b"message"].decode("ascii"),
"data": error_contents[b"data"].decode("ascii")
}
errors.append(error_contents)
return errors
@ -1823,7 +1858,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"register_remote_function",
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={
@ -1868,7 +1903,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
and hasattr(function, "__name__")) else ""
ray.utils.push_error_to_driver(
worker.redis_client,
"function_to_run",
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
data={"name": name})
@ -2028,7 +2063,7 @@ def connect(info,
traceback_str = traceback.format_exc()
ray.utils.push_error_to_driver(
worker.redis_client,
"version_mismatch",
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
driver_id=None)

View file

@ -11,6 +11,7 @@ import time
import unittest
import ray
import ray.ray_constants as ray_constants
import ray.test.test_utils
@ -1569,7 +1570,8 @@ class ActorReconstruction(unittest.TestCase):
errors = ray.error_info()
self.assertLess(0, len(errors))
for error in errors:
self.assertEqual(error[b"type"], b"checkpoint")
self.assertEqual(error["type"],
ray_constants.CHECKPOINT_PUSH_ERROR)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
@ -1599,7 +1601,8 @@ class ActorReconstruction(unittest.TestCase):
errors = ray.error_info()
self.assertEqual(len(errors), 1)
for error in errors:
self.assertEqual(error[b"type"], b"checkpoint")
self.assertEqual(error["type"],
ray_constants.CHECKPOINT_PUSH_ERROR)
@unittest.skip("Fork/join consistency not yet implemented.")
def testDistributedHandle(self):

View file

@ -10,6 +10,7 @@ import tempfile
import time
import unittest
import ray.ray_constants as ray_constants
import ray.test.test_functions as test_functions
if sys.version_info >= (3, 0):
@ -17,7 +18,7 @@ if sys.version_info >= (3, 0):
def relevant_errors(error_type):
return [info for info in ray.error_info() if info[b"type"] == error_type]
return [info for info in ray.error_info() if info["type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
@ -26,7 +27,7 @@ def wait_for_errors(error_type, num_errors, timeout=10):
if len(relevant_errors(error_type)) >= num_errors:
return
time.sleep(0.1)
print("Timing out of wait.")
raise Exception("Timing out of wait.")
class TaskStatusTest(unittest.TestCase):
@ -39,11 +40,12 @@ class TaskStatusTest(unittest.TestCase):
test_functions.throw_exception_fct1.remote()
test_functions.throw_exception_fct1.remote()
wait_for_errors(b"task", 2)
self.assertEqual(len(relevant_errors(b"task")), 2)
for task in relevant_errors(b"task"):
self.assertIn(b"Test function 1 intentionally failed.",
task.get(b"message"))
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2)
self.assertEqual(
len(relevant_errors(ray_constants.TASK_PUSH_ERROR)), 2)
for task in relevant_errors(ray_constants.TASK_PUSH_ERROR):
self.assertIn("Test function 1 intentionally failed.",
task.get("message"))
x = test_functions.throw_exception_fct2.remote()
try:
@ -100,9 +102,9 @@ def temporary_helper_function():
def g():
return module.temporary_python_file()
wait_for_errors(b"register_remote_function", 2)
self.assertIn(b"No module named", ray.error_info()[0][b"message"])
self.assertIn(b"No module named", ray.error_info()[1][b"message"])
wait_for_errors(ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, 2)
self.assertIn("No module named", ray.error_info()[0]["message"])
self.assertIn("No module named", ray.error_info()[1]["message"])
# Check that if we try to call the function it throws an exception and
# does not hang.
@ -122,13 +124,13 @@ def temporary_helper_function():
raise Exception("Function to run failed.")
ray.worker.global_worker.run_function_on_all_workers(f)
wait_for_errors(b"function_to_run", 2)
wait_for_errors(ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, 2)
# Check that the error message is in the task info.
self.assertEqual(len(ray.error_info()), 2)
self.assertIn(b"Function to run failed.",
ray.error_info()[0][b"message"])
self.assertIn(b"Function to run failed.",
ray.error_info()[1][b"message"])
self.assertIn("Function to run failed.",
ray.error_info()[0]["message"])
self.assertIn("Function to run failed.",
ray.error_info()[1]["message"])
def testFailImportingActor(self):
ray.init(num_workers=2, driver_mode=ray.SILENT_MODE)
@ -165,14 +167,14 @@ def temporary_helper_function():
foo = Foo.remote()
# Wait for the error to arrive.
wait_for_errors(b"register_actor", 1)
self.assertIn(b"No module named", ray.error_info()[0][b"message"])
wait_for_errors(ray_constants.REGISTER_ACTOR_PUSH_ERROR, 1)
self.assertIn("No module named", ray.error_info()[0]["message"])
# Wait for the error from when the __init__ tries to run.
wait_for_errors(b"task", 1)
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 1)
self.assertIn(
b"failed to be imported, and so cannot execute this method",
ray.error_info()[1][b"message"])
"failed to be imported, and so cannot execute this method",
ray.error_info()[1]["message"])
# Check that if we try to get the function it throws an exception and
# does not hang.
@ -180,10 +182,10 @@ def temporary_helper_function():
ray.get(foo.get_val.remote())
# Wait for the error from when the call to get_val.
wait_for_errors(b"task", 2)
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2)
self.assertIn(
b"failed to be imported, and so cannot execute this method",
ray.error_info()[2][b"message"])
"failed to be imported, and so cannot execute this method",
ray.error_info()[2]["message"])
f.close()
@ -215,17 +217,15 @@ class ActorTest(unittest.TestCase):
a = FailedActor.remote()
# Make sure that we get errors from a failed constructor.
wait_for_errors(b"task", 1)
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 1)
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_message1,
ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_message1, ray.error_info()[0]["message"])
# Make sure that we get errors from a failed method.
a.fail_method.remote()
wait_for_errors(b"task", 2)
wait_for_errors(ray_constants.TASK_PUSH_ERROR, 2)
self.assertEqual(len(ray.error_info()), 2)
self.assertIn(error_message2,
ray.error_info()[1][b"message"].decode("ascii"))
self.assertIn(error_message2, ray.error_info()[1]["message"])
def testIncorrectMethodCalls(self):
ray.init(num_workers=0, driver_mode=ray.SILENT_MODE)
@ -283,8 +283,8 @@ class WorkerDeath(unittest.TestCase):
# the task has successfully completed.
f.remote()
wait_for_errors(b"worker_crash", 1)
wait_for_errors(b"worker_died", 1)
wait_for_errors(ray_constants.WORKER_CRASH_PUSH_ERROR, 1)
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
self.assertEqual(len(ray.error_info()), 2)
@unittest.skipIf(
@ -300,11 +300,11 @@ class WorkerDeath(unittest.TestCase):
f.remote()
wait_for_errors(b"worker_died", 1)
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
self.assertEqual(len(ray.error_info()), 1)
self.assertIn("died or was killed while executing the task",
ray.error_info()[0][b"message"].decode("ascii"))
ray.error_info()[0]["message"])
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
@ -325,7 +325,7 @@ class WorkerDeath(unittest.TestCase):
[obj], _ = ray.wait([a.kill.remote()], timeout=5000)
self.assertRaises(Exception, lambda: ray.get(obj))
self.assertRaises(Exception, lambda: ray.get(consume.remote(obj)))
wait_for_errors(b"worker_died", 1)
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
@ -350,7 +350,7 @@ class WorkerDeath(unittest.TestCase):
for obj in tasks1 + tasks2:
self.assertRaises(Exception, lambda: ray.get(obj))
wait_for_errors(b"worker_died", 1)
wait_for_errors(ray_constants.WORKER_DIED_PUSH_ERROR, 1)
@unittest.skipIf(
os.environ.get("RAY_USE_XRAY") == "1",
@ -419,7 +419,7 @@ class PutErrorTest(unittest.TestCase):
put_arg_task.remote()
# Make sure we receive the correct error message.
wait_for_errors(b"put_reconstruction", 1)
wait_for_errors(ray_constants.PUT_RECONSTRUCTION_PUSH_ERROR, 1)
def testPutError2(self):
# This is the same as the previous test, but it calls ray.put directly.
@ -465,7 +465,7 @@ class PutErrorTest(unittest.TestCase):
put_task.remote()
# Make sure we receive the correct error message.
wait_for_errors(b"put_reconstruction", 1)
wait_for_errors(ray_constants.PUT_RECONSTRUCTION_PUSH_ERROR, 1)
class ConfigurationTest(unittest.TestCase):
@ -478,10 +478,39 @@ class ConfigurationTest(unittest.TestCase):
ray.init(num_workers=1, driver_mode=ray.SILENT_MODE)
wait_for_errors(b"version_mismatch", 1)
wait_for_errors(ray_constants.VERSION_MISMATCH_PUSH_ERROR, 1)
ray.__version__ = ray_version
class WarningTest(unittest.TestCase):
def tearDown(self):
ray.worker.cleanup()
def testExportLargeObjects(self):
import ray.ray_constants as ray_constants
ray.init(num_workers=1)
large_object = np.zeros(2 * ray_constants.PICKLE_OBJECT_WARNING_SIZE)
@ray.remote
def f():
large_object
# Make sure that a warning is generated.
wait_for_errors(ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, 1)
@ray.remote
class Foo(object):
def __init__(self):
large_object
Foo.remote()
# Make sure that a warning is generated.
wait_for_errors(ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, 2)
if __name__ == "__main__":
unittest.main(verbosity=2)

View file

@ -68,8 +68,7 @@ class MultiNodeTest(unittest.TestCase):
# Make sure we got the error.
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_string1,
ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_string1, ray.error_info()[0]["message"])
# Start another driver and make sure that it does not receive this
# error. Make the other driver throw an error, and make sure it
@ -97,7 +96,7 @@ while len(ray.error_info()) != 1:
time.sleep(0.1)
assert len(ray.error_info()) == 1
assert "{}" in ray.error_info()[0][b"message"].decode("ascii")
assert "{}" in ray.error_info()[0]["message"]
print("success")
""".format(self.redis_address, error_string2, error_string2)
@ -109,8 +108,7 @@ print("success")
# Make sure that the other error message doesn't show up for this
# driver.
self.assertEqual(len(ray.error_info()), 1)
self.assertIn(error_string1,
ray.error_info()[0][b"message"].decode("ascii"))
self.assertIn(error_string1, ray.error_info()[0]["message"])
def testRemoteFunctionIsolation(self):
# This test will run multiple remote functions with the same names in

View file

@ -8,6 +8,8 @@ import ray
import numpy as np
import time
import ray.ray_constants as ray_constants
class TaskTests(unittest.TestCase):
def testSubmittingTasks(self):
@ -451,7 +453,8 @@ class ReconstructionTests(unittest.TestCase):
errors = self.wait_for_errors(error_check)
# Make sure all the errors have the correct type.
self.assertTrue(
all(error[b"type"] == b"object_hash_mismatch" for error in errors))
all(error["type"] == ray_constants.HASH_MISMATCH_PUSH_ERROR
for error in errors))
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False), "Hanging with new GCS API.")
@ -497,7 +500,8 @@ class ReconstructionTests(unittest.TestCase):
errors = self.wait_for_errors(error_check)
self.assertTrue(
all(error[b"type"] == b"put_reconstruction" for error in errors))
all(error["type"] == ray_constants.PUT_RECONSTRUCTION_PUSH_ERROR
for error in errors))
class ReconstructionTestsMultinode(ReconstructionTests):