[RLlib] Test against failure of nodes, for example for practical use of spot instances. (#26676)

This commit is contained in:
Artur Niederfahrenhorst 2022-08-29 14:37:56 +02:00 committed by GitHub
parent 2ce80d8163
commit 51d16b8ff9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 236 additions and 7 deletions

View file

@ -780,10 +780,17 @@ py_test(
)
py_test(
name = "tests/test_worker_failures",
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
name = "test_worker_failures",
tags = ["team:rllib", "tests_dir", "tests_dir_W"],
size = "large",
srcs = ["algorithms/tests/test_worker_failures.py"]
srcs = ["tests/test_worker_failures.py"]
)
py_test(
name = "test_node_failure",
tags = ["team:rllib", "tests_dir", "tests_dir_N", "exclusive"],
size = "large",
srcs = ["tests/test_node_failure.py"],
)
py_test(

View file

@ -0,0 +1,221 @@
# This workload tests RLlib's ability to recover from failing workers nodes
import threading
import time
import unittest
import ray
from ray._private.test_utils import get_other_nodes
from ray.cluster_utils import Cluster
from ray.exceptions import RayActorError
from ray.rllib.algorithms.ppo import PPO, PPOConfig
num_redis_shards = 5
redis_max_memory = 10 ** 8
object_store_memory = 10 ** 8
num_nodes = 3
assert (
num_nodes * object_store_memory + num_redis_shards * redis_max_memory
< ray._private.utils.get_system_memory() / 2
), (
"Make sure there is enough memory on this machine to run this "
"workload. We divide the system memory by 2 to provide a buffer."
)
class NodeFailureTests(unittest.TestCase):
def setUp(self):
# Simulate a cluster on one machine.
self.cluster = Cluster()
for i in range(num_nodes):
self.cluster.add_node(
redis_port=6379 if i == 0 else None,
num_redis_shards=num_redis_shards if i == 0 else None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)
self.cluster.wait_for_nodes()
ray.init(address=self.cluster.address)
def tearDown(self):
ray.shutdown()
self.cluster.shutdown()
def test_fail_on_node_failure(self):
# We do not tolerate failing workers and stop training
config = (
PPOConfig()
.rollouts(
num_rollout_workers=6,
ignore_worker_failures=False,
recreate_failed_workers=False,
)
.training()
)
ppo = PPO(config=config, env="CartPole-v0")
# One step with all nodes up, enough to satisfy resource requirements
ppo.step()
self.assertEqual(len(ppo.workers._remote_workers), 6)
# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
# Check faulty worker indices
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
self.assertEqual(len(ppo.workers._remote_workers), 6)
# Fail with a node down, resource requirements not satisfied anymore
with self.assertRaises(RayActorError):
ppo.step()
def test_continue_training_on_failure(self):
# We tolerate failing workers and don't pause training
config = (
PPOConfig()
.rollouts(
num_rollout_workers=6,
ignore_worker_failures=True,
recreate_failed_workers=False,
)
.training()
)
ppo = PPO(config=config, env="CartPole-v0")
# One step with all nodes up, enough to satisfy resource requirements
ppo.step()
self.assertEqual(len(ppo.workers._remote_workers), 6)
# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
# Check faulty worker indices
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
self.assertEqual(len(ppo.workers._remote_workers), 6)
# One step with a node down, resource requirements not satisfied anymore
ppo.step()
# Training should have proceeded without errors, but two workers missing
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
self.assertEqual(len(ppo.workers._remote_workers), 4)
def test_recreate_workers_on_next_iter(self):
# We tolerate failing workers and pause training
config = (
PPOConfig()
.rollouts(
num_rollout_workers=6,
recreate_failed_workers=True,
validate_workers_after_construction=True,
)
.training()
)
ppo = PPO(config=config, env="CartPole-v0")
# One step with all nodes up, enough to satisfy resource requirements
ppo.step()
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
self.assertEqual(len(ppo.workers._remote_workers), 6)
# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
assert len(ppo.workers._worker_health_check()) == 2
self.assertEqual(len(ppo.workers._remote_workers), 6)
# node comes back immediately.
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)
# step() should restore the missing workers on the added node.
ppo.step()
# Workers should be back up, everything back to normal.
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
self.assertEqual(len(ppo.workers._remote_workers), 6)
def test_wait_for_nodes_on_failure(self):
# We tolerate failing workers and pause training
config = (
PPOConfig()
.rollouts(
num_rollout_workers=6,
ignore_worker_failures=False,
recreate_failed_workers=True,
)
.training()
)
ppo = PPO(config=config, env="CartPole-v0")
# One step with all nodes up, enough to satisfy resource requirements
ppo.step()
# Remove the first non-head node.
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
self.cluster.remove_node(node_to_kill)
# Check faulty worker indices
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
self.assertEqual(len(ppo.workers._remote_workers), 6)
def _step_target():
# Resource requirements satisfied after approx 30s
ppo.step()
time_before_step = time.time()
# kill one node after n seconds
t = threading.Thread(target=_step_target)
t.start()
# Wait 30 seconds until the missing node reappears
time.sleep(30)
self.cluster.add_node(
redis_port=None,
num_redis_shards=None,
num_cpus=2,
num_gpus=0,
object_store_memory=object_store_memory,
redis_max_memory=redis_max_memory,
dashboard_host="0.0.0.0",
)
t.join()
td = time.time() - time_before_step
# TODO: Find out what values make sense here
self.assertGreaterEqual(td, 30, msg="Stepped before node was added.")
self.assertLess(td, 60, msg="Took too long to step after node was added.")
# Workers should be back up
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
self.assertEqual(len(ppo.workers._remote_workers), 6)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,8 +1,9 @@
from collections import defaultdict
import gym
import numpy as np
import time
import unittest
from collections import defaultdict
import gym
import numpy as np
import ray
from ray.rllib.algorithms.pg import PG, PGConfig
@ -50,7 +51,7 @@ class FaultInjectEnv(gym.Env):
... {"bad_indices": [1, 2]},
... worker_index=1,
... num_workers=3,
.. )
... )
... )
>>> from ray.rllib.env.env_context import EnvContext