diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 4db5d7e25..3eadcec52 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -221,6 +221,7 @@ py_test_module_list( py_test_module_list( files = [ "test_failure_3.py", + "test_chaos.py", "test_reference_counting_2.py", ], size = "large", diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 62a2a81ed..7a5edb60a 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -7,13 +7,18 @@ import pytest import subprocess import json import time +import threading + +import grpc import ray -from ray.cluster_utils import Cluster +from ray.cluster_utils import Cluster, AutoscalingCluster from ray._private.services import REDIS_EXECUTABLE, _start_redis_instance from ray._private.test_utils import init_error_pubsub, setup_tls, teardown_tls import ray.util.client.server.server as ray_client_server import ray._private.gcs_utils as gcs_utils +from ray.core.generated import node_manager_pb2 +from ray.core.generated import node_manager_pb2_grpc @pytest.fixture @@ -397,3 +402,72 @@ def unstable_spilling_config(request, tmp_path): ]) def slow_spilling_config(request, tmp_path): yield create_object_spilling_config(request, tmp_path) + + +@pytest.fixture +def ray_start_chaos_cluster(request): + """Returns the cluster and chaos thread. + + Run chaos_thread.start() to start the chaos testing. + NOTE: `cluster` is not thread-safe. `cluster` + shouldn't be modified by other thread once + chaos_thread.start() is called. + """ + os.environ["RAY_num_heartbeats_timeout"] = "5" + os.environ["RAY_raylet_heartbeat_period_milliseconds"] = "100" + param = getattr(request, "param", {}) + kill_interval = param.get("kill_interval", 2) + # Config of workers that are re-started. + head_resources = param["head_resources"] + worker_node_types = param["worker_node_types"] + timeout = param["timeout"] + + # Use the shutdown RPC instead of signals because we can't + # raise a signal in a non-main thread. + def kill_raylet(ip, port, graceful=False): + raylet_address = f"{ip}:{port}" + channel = grpc.insecure_channel(raylet_address) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + print(f"Sending a shutdown request to {ip}:{port}") + stub.ShutdownRaylet( + node_manager_pb2.ShutdownRayletRequest(graceful=graceful)) + + cluster = AutoscalingCluster(head_resources, worker_node_types) + cluster.start() + ray.init("auto") + nodes = ray.nodes() + assert len(nodes) == 1 + head_node_port = nodes[0]["NodeManagerPort"] + killed_port = set() + + def run_chaos_cluster(): + start = time.time() + while True: + node_to_kill_ip = None + node_to_kill_port = None + for node in ray.nodes(): + addr = node["NodeManagerAddress"] + port = node["NodeManagerPort"] + if (node["Alive"] and port != head_node_port + and port not in killed_port): + node_to_kill_ip = addr + node_to_kill_port = port + break + + if node_to_kill_port is not None: + kill_raylet(node_to_kill_ip, node_to_kill_port, graceful=False) + killed_port.add(node_to_kill_port) + time.sleep(kill_interval) + print(len(ray.nodes())) + if time.time() - start > timeout: + break + assert len(killed_port) > 0, ( + "None of nodes are killed by the conftest. It is a bug.") + + chaos_thread = threading.Thread(target=run_chaos_cluster) + yield chaos_thread + chaos_thread.join() + ray.shutdown() + cluster.shutdown() + del os.environ["RAY_num_heartbeats_timeout"] + del os.environ["RAY_raylet_heartbeat_period_milliseconds"] diff --git a/python/ray/tests/test_chaos.py b/python/ray/tests/test_chaos.py new file mode 100644 index 000000000..2e7e7e00e --- /dev/null +++ b/python/ray/tests/test_chaos.py @@ -0,0 +1,138 @@ +import sys +import random +import string + +import ray + +import numpy as np +import pytest +import time + +from ray.data.impl.progress_bar import ProgressBar +from ray._private.test_utils import get_all_log_message + + +def assert_no_system_failure(p, total_lines, timeout): + # Get logs for 20 seconds. + logs = get_all_log_message(p, total_lines, timeout=timeout) + for log in logs: + assert "SIG" not in log, ("There's the segfault or SIGBART reported.") + assert "Check failed" not in log, ( + "There's the check failure reported.") + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.parametrize( + "ray_start_chaos_cluster", [{ + "kill_interval": 3, + "timeout": 45, + "head_resources": { + "CPU": 0 + }, + "worker_node_types": { + "cpu_node": { + "resources": { + "CPU": 8, + }, + "node_config": {}, + "min_workers": 0, + "max_workers": 4, + }, + }, + }], + indirect=True) +def test_chaos_task_retry(ray_start_chaos_cluster, log_pubsub): + chaos_test_thread = ray_start_chaos_cluster + p = log_pubsub + chaos_test_thread.start() + + # Chaos testing. + @ray.remote(max_retries=-1) + def task(): + def generate_data(size_in_kb=10): + return np.zeros(1024 * size_in_kb, dtype=np.uint8) + + a = "" + for _ in range(100000): + a = a + random.choice(string.ascii_letters) + return generate_data(size_in_kb=50) + + @ray.remote(max_retries=-1) + def invoke_nested_task(): + time.sleep(0.8) + return ray.get(task.remote()) + + # 50MB of return values. + TOTAL_TASKS = 300 + + pb = ProgressBar("Chaos test sanity check", TOTAL_TASKS) + results = [invoke_nested_task.remote() for _ in range(TOTAL_TASKS)] + start = time.time() + pb.block_until_complete(results) + runtime_with_failure = time.time() - start + print(f"Runtime when there are many failures: {runtime_with_failure}") + pb.close() + + chaos_test_thread.join() + assert_no_system_failure(p, 10000, 10) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") +@pytest.mark.parametrize( + "ray_start_chaos_cluster", [{ + "kill_interval": 30, + "timeout": 30, + "head_resources": { + "CPU": 0 + }, + "worker_node_types": { + "cpu_node": { + "resources": { + "CPU": 8, + }, + "node_config": {}, + "min_workers": 0, + "max_workers": 4, + }, + }, + }], + indirect=True) +def test_chaos_actor_retry(ray_start_chaos_cluster, log_pubsub): + chaos_test_thread = ray_start_chaos_cluster + # p = log_pubsub + chaos_test_thread.start() + + # Chaos testing. + @ray.remote(num_cpus=1, max_restarts=-1, max_task_retries=-1) + class Actor: + def __init__(self): + self.letter_dict = set() + + def add(self, letter): + self.letter_dict.add(letter) + + def get(self): + return self.letter_dict + + NUM_CPUS = 32 + TOTAL_TASKS = 300 + + pb = ProgressBar("Chaos test sanity check", TOTAL_TASKS * NUM_CPUS) + actors = [Actor.remote() for _ in range(NUM_CPUS)] + results = [] + for a in actors: + results.extend([a.add.remote(str(i)) for i in range(TOTAL_TASKS)]) + start = time.time() + pb.fetch_until_complete(results) + runtime_with_failure = time.time() - start + print(f"Runtime when there are many failures: {runtime_with_failure}") + pb.close() + chaos_test_thread.join() + # TODO(sang): Currently, there are lots of SIGBART with + # plasma client failures. Fix it. + # assert_no_system_failure(p, 10000, 10) + + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_failure_3.py b/python/ray/tests/test_failure_3.py index 4673e0c40..9c362cf1a 100644 --- a/python/ray/tests/test_failure_3.py +++ b/python/ray/tests/test_failure_3.py @@ -1,8 +1,5 @@ -import threading import os import sys -import random -import string import ray @@ -11,7 +8,6 @@ import pytest import time from ray._private.test_utils import SignalActor -from ray.data.impl.progress_bar import ProgressBar @pytest.mark.parametrize( @@ -110,70 +106,6 @@ def test_async_actor_task_retries(ray_start_regular): assert ray.get(ref_3) == 3 -@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.") -def test_task_retry_mini_integration(ray_start_cluster): - """Test nested tasks with infinite retry and - keep killing nodes while retrying is happening. - - It is the sanity check test for larger scale chaos testing. - """ - cluster = ray_start_cluster - NUM_NODES = 3 - NUM_CPUS = 8 - # head node. - cluster.add_node(num_cpus=0, resources={"head": 1}) - ray.init(address=cluster.address) - workers = [] - for _ in range(NUM_NODES): - workers.append( - cluster.add_node(num_cpus=NUM_CPUS, resources={"worker": 1})) - - @ray.remote(max_retries=-1, resources={"worker": 0.1}) - def task(): - def generate_data(size_in_kb=10): - return np.zeros(1024 * size_in_kb, dtype=np.uint8) - - a = "" - for _ in range(100000): - a = a + random.choice(string.ascii_letters) - return generate_data(size_in_kb=50) - - @ray.remote(max_retries=-1, resources={"worker": 0.1}) - def invoke_nested_task(): - time.sleep(0.8) - return ray.get(task.remote()) - - # 50MB of return values. - TOTAL_TASKS = 500 - - def run_chaos_test(): - # Chaos testing. - pb = ProgressBar("Chaos test sanity check", TOTAL_TASKS) - results = [invoke_nested_task.remote() for _ in range(TOTAL_TASKS)] - start = time.time() - pb.block_until_complete(results) - runtime_with_failure = time.time() - start - print(f"Runtime when there are many failures: {runtime_with_failure}") - pb.close() - - x = threading.Thread(target=run_chaos_test) - x.start() - - kill_interval = 2 - start = time.time() - while True: - worker_to_kill = workers.pop(0) - pid = worker_to_kill.all_processes["raylet"][0].process.pid - # SIGKILL - os.kill(pid, 9) - workers.append( - cluster.add_node(num_cpus=NUM_CPUS, resources={"worker": 1})) - time.sleep(kill_interval) - if time.time() - start > 30: - break - x.join() - - if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__]))