From 51d16b8ff9db326b9dbb16f3eb357a7ae30018a5 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 29 Aug 2022 14:37:56 +0200 Subject: [PATCH] [RLlib] Test against failure of nodes, for example for practical use of spot instances. (#26676) --- rllib/BUILD | 13 +- rllib/tests/test_node_failure.py | 221 ++++++++++++++++++ .../tests/test_worker_failures.py | 9 +- 3 files changed, 236 insertions(+), 7 deletions(-) create mode 100644 rllib/tests/test_node_failure.py rename rllib/{algorithms => }/tests/test_worker_failures.py (99%) diff --git a/rllib/BUILD b/rllib/BUILD index df788946a..53d922365 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -780,10 +780,17 @@ py_test( ) py_test( - name = "tests/test_worker_failures", - tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"], + name = "test_worker_failures", + tags = ["team:rllib", "tests_dir", "tests_dir_W"], size = "large", - srcs = ["algorithms/tests/test_worker_failures.py"] + srcs = ["tests/test_worker_failures.py"] +) + +py_test( + name = "test_node_failure", + tags = ["team:rllib", "tests_dir", "tests_dir_N", "exclusive"], + size = "large", + srcs = ["tests/test_node_failure.py"], ) py_test( diff --git a/rllib/tests/test_node_failure.py b/rllib/tests/test_node_failure.py new file mode 100644 index 000000000..d60792da0 --- /dev/null +++ b/rllib/tests/test_node_failure.py @@ -0,0 +1,221 @@ +# This workload tests RLlib's ability to recover from failing workers nodes +import threading +import time +import unittest + +import ray +from ray._private.test_utils import get_other_nodes +from ray.cluster_utils import Cluster +from ray.exceptions import RayActorError +from ray.rllib.algorithms.ppo import PPO, PPOConfig + +num_redis_shards = 5 +redis_max_memory = 10 ** 8 +object_store_memory = 10 ** 8 +num_nodes = 3 + +assert ( + num_nodes * object_store_memory + num_redis_shards * redis_max_memory + < ray._private.utils.get_system_memory() / 2 +), ( + "Make sure there is enough memory on this machine to run this " + "workload. We divide the system memory by 2 to provide a buffer." +) + + +class NodeFailureTests(unittest.TestCase): + def setUp(self): + # Simulate a cluster on one machine. + self.cluster = Cluster() + + for i in range(num_nodes): + self.cluster.add_node( + redis_port=6379 if i == 0 else None, + num_redis_shards=num_redis_shards if i == 0 else None, + num_cpus=2, + num_gpus=0, + object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + dashboard_host="0.0.0.0", + ) + self.cluster.wait_for_nodes() + ray.init(address=self.cluster.address) + + def tearDown(self): + ray.shutdown() + self.cluster.shutdown() + + def test_fail_on_node_failure(self): + # We do not tolerate failing workers and stop training + config = ( + PPOConfig() + .rollouts( + num_rollout_workers=6, + ignore_worker_failures=False, + recreate_failed_workers=False, + ) + .training() + ) + ppo = PPO(config=config, env="CartPole-v0") + + # One step with all nodes up, enough to satisfy resource requirements + ppo.step() + + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # Remove the first non-head node. + node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] + self.cluster.remove_node(node_to_kill) + + # Check faulty worker indices + # All nodes have 2 CPUs and we require a driver thread + 5 rollout workers + self.assertEqual(len(ppo.workers._worker_health_check()), 2) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # Fail with a node down, resource requirements not satisfied anymore + with self.assertRaises(RayActorError): + ppo.step() + + def test_continue_training_on_failure(self): + # We tolerate failing workers and don't pause training + config = ( + PPOConfig() + .rollouts( + num_rollout_workers=6, + ignore_worker_failures=True, + recreate_failed_workers=False, + ) + .training() + ) + ppo = PPO(config=config, env="CartPole-v0") + + # One step with all nodes up, enough to satisfy resource requirements + ppo.step() + + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # Remove the first non-head node. + node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] + self.cluster.remove_node(node_to_kill) + + # Check faulty worker indices + # All nodes have 2 CPUs and we require a driver thread + 5 rollout workers + self.assertEqual(len(ppo.workers._worker_health_check()), 2) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # One step with a node down, resource requirements not satisfied anymore + ppo.step() + + # Training should have proceeded without errors, but two workers missing + self.assertEqual(len(ppo.workers._worker_health_check()), 0) + self.assertEqual(len(ppo.workers._remote_workers), 4) + + def test_recreate_workers_on_next_iter(self): + # We tolerate failing workers and pause training + config = ( + PPOConfig() + .rollouts( + num_rollout_workers=6, + recreate_failed_workers=True, + validate_workers_after_construction=True, + ) + .training() + ) + ppo = PPO(config=config, env="CartPole-v0") + + # One step with all nodes up, enough to satisfy resource requirements + ppo.step() + + self.assertEqual(len(ppo.workers._worker_health_check()), 0) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # Remove the first non-head node. + node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] + self.cluster.remove_node(node_to_kill) + + assert len(ppo.workers._worker_health_check()) == 2 + self.assertEqual(len(ppo.workers._remote_workers), 6) + + # node comes back immediately. + self.cluster.add_node( + redis_port=None, + num_redis_shards=None, + num_cpus=2, + num_gpus=0, + object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + dashboard_host="0.0.0.0", + ) + + # step() should restore the missing workers on the added node. + ppo.step() + + # Workers should be back up, everything back to normal. + self.assertEqual(len(ppo.workers._worker_health_check()), 0) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + def test_wait_for_nodes_on_failure(self): + # We tolerate failing workers and pause training + config = ( + PPOConfig() + .rollouts( + num_rollout_workers=6, + ignore_worker_failures=False, + recreate_failed_workers=True, + ) + .training() + ) + ppo = PPO(config=config, env="CartPole-v0") + + # One step with all nodes up, enough to satisfy resource requirements + ppo.step() + + # Remove the first non-head node. + node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0] + self.cluster.remove_node(node_to_kill) + + # Check faulty worker indices + # All nodes have 2 CPUs and we require a driver thread + 5 rollout workers + self.assertEqual(len(ppo.workers._worker_health_check()), 2) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + def _step_target(): + # Resource requirements satisfied after approx 30s + ppo.step() + + time_before_step = time.time() + + # kill one node after n seconds + t = threading.Thread(target=_step_target) + t.start() + + # Wait 30 seconds until the missing node reappears + time.sleep(30) + self.cluster.add_node( + redis_port=None, + num_redis_shards=None, + num_cpus=2, + num_gpus=0, + object_store_memory=object_store_memory, + redis_max_memory=redis_max_memory, + dashboard_host="0.0.0.0", + ) + + t.join() + + td = time.time() - time_before_step + + # TODO: Find out what values make sense here + self.assertGreaterEqual(td, 30, msg="Stepped before node was added.") + self.assertLess(td, 60, msg="Took too long to step after node was added.") + + # Workers should be back up + self.assertEqual(len(ppo.workers._worker_health_check()), 0) + self.assertEqual(len(ppo.workers._remote_workers), 6) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_worker_failures.py b/rllib/tests/test_worker_failures.py similarity index 99% rename from rllib/algorithms/tests/test_worker_failures.py rename to rllib/tests/test_worker_failures.py index fabddd9fc..71077afd8 100644 --- a/rllib/algorithms/tests/test_worker_failures.py +++ b/rllib/tests/test_worker_failures.py @@ -1,8 +1,9 @@ -from collections import defaultdict -import gym -import numpy as np import time import unittest +from collections import defaultdict + +import gym +import numpy as np import ray from ray.rllib.algorithms.pg import PG, PGConfig @@ -50,7 +51,7 @@ class FaultInjectEnv(gym.Env): ... {"bad_indices": [1, 2]}, ... worker_index=1, ... num_workers=3, - .. ) + ... ) ... ) >>> from ray.rllib.env.env_context import EnvContext