mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Test against failure of nodes, for example for practical use of spot instances. (#26676)
This commit is contained in:
parent
2ce80d8163
commit
51d16b8ff9
3 changed files with 236 additions and 7 deletions
13
rllib/BUILD
13
rllib/BUILD
|
@ -780,10 +780,17 @@ py_test(
|
|||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_worker_failures",
|
||||
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
|
||||
name = "test_worker_failures",
|
||||
tags = ["team:rllib", "tests_dir", "tests_dir_W"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/tests/test_worker_failures.py"]
|
||||
srcs = ["tests/test_worker_failures.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_node_failure",
|
||||
tags = ["team:rllib", "tests_dir", "tests_dir_N", "exclusive"],
|
||||
size = "large",
|
||||
srcs = ["tests/test_node_failure.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
221
rllib/tests/test_node_failure.py
Normal file
221
rllib/tests/test_node_failure.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
# This workload tests RLlib's ability to recover from failing workers nodes
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray._private.test_utils import get_other_nodes
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.rllib.algorithms.ppo import PPO, PPOConfig
|
||||
|
||||
num_redis_shards = 5
|
||||
redis_max_memory = 10 ** 8
|
||||
object_store_memory = 10 ** 8
|
||||
num_nodes = 3
|
||||
|
||||
assert (
|
||||
num_nodes * object_store_memory + num_redis_shards * redis_max_memory
|
||||
< ray._private.utils.get_system_memory() / 2
|
||||
), (
|
||||
"Make sure there is enough memory on this machine to run this "
|
||||
"workload. We divide the system memory by 2 to provide a buffer."
|
||||
)
|
||||
|
||||
|
||||
class NodeFailureTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Simulate a cluster on one machine.
|
||||
self.cluster = Cluster()
|
||||
|
||||
for i in range(num_nodes):
|
||||
self.cluster.add_node(
|
||||
redis_port=6379 if i == 0 else None,
|
||||
num_redis_shards=num_redis_shards if i == 0 else None,
|
||||
num_cpus=2,
|
||||
num_gpus=0,
|
||||
object_store_memory=object_store_memory,
|
||||
redis_max_memory=redis_max_memory,
|
||||
dashboard_host="0.0.0.0",
|
||||
)
|
||||
self.cluster.wait_for_nodes()
|
||||
ray.init(address=self.cluster.address)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
self.cluster.shutdown()
|
||||
|
||||
def test_fail_on_node_failure(self):
|
||||
# We do not tolerate failing workers and stop training
|
||||
config = (
|
||||
PPOConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=6,
|
||||
ignore_worker_failures=False,
|
||||
recreate_failed_workers=False,
|
||||
)
|
||||
.training()
|
||||
)
|
||||
ppo = PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# One step with all nodes up, enough to satisfy resource requirements
|
||||
ppo.step()
|
||||
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# Remove the first non-head node.
|
||||
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
|
||||
self.cluster.remove_node(node_to_kill)
|
||||
|
||||
# Check faulty worker indices
|
||||
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# Fail with a node down, resource requirements not satisfied anymore
|
||||
with self.assertRaises(RayActorError):
|
||||
ppo.step()
|
||||
|
||||
def test_continue_training_on_failure(self):
|
||||
# We tolerate failing workers and don't pause training
|
||||
config = (
|
||||
PPOConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=6,
|
||||
ignore_worker_failures=True,
|
||||
recreate_failed_workers=False,
|
||||
)
|
||||
.training()
|
||||
)
|
||||
ppo = PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# One step with all nodes up, enough to satisfy resource requirements
|
||||
ppo.step()
|
||||
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# Remove the first non-head node.
|
||||
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
|
||||
self.cluster.remove_node(node_to_kill)
|
||||
|
||||
# Check faulty worker indices
|
||||
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# One step with a node down, resource requirements not satisfied anymore
|
||||
ppo.step()
|
||||
|
||||
# Training should have proceeded without errors, but two workers missing
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 4)
|
||||
|
||||
def test_recreate_workers_on_next_iter(self):
|
||||
# We tolerate failing workers and pause training
|
||||
config = (
|
||||
PPOConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=6,
|
||||
recreate_failed_workers=True,
|
||||
validate_workers_after_construction=True,
|
||||
)
|
||||
.training()
|
||||
)
|
||||
ppo = PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# One step with all nodes up, enough to satisfy resource requirements
|
||||
ppo.step()
|
||||
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# Remove the first non-head node.
|
||||
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
|
||||
self.cluster.remove_node(node_to_kill)
|
||||
|
||||
assert len(ppo.workers._worker_health_check()) == 2
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
# node comes back immediately.
|
||||
self.cluster.add_node(
|
||||
redis_port=None,
|
||||
num_redis_shards=None,
|
||||
num_cpus=2,
|
||||
num_gpus=0,
|
||||
object_store_memory=object_store_memory,
|
||||
redis_max_memory=redis_max_memory,
|
||||
dashboard_host="0.0.0.0",
|
||||
)
|
||||
|
||||
# step() should restore the missing workers on the added node.
|
||||
ppo.step()
|
||||
|
||||
# Workers should be back up, everything back to normal.
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
def test_wait_for_nodes_on_failure(self):
|
||||
# We tolerate failing workers and pause training
|
||||
config = (
|
||||
PPOConfig()
|
||||
.rollouts(
|
||||
num_rollout_workers=6,
|
||||
ignore_worker_failures=False,
|
||||
recreate_failed_workers=True,
|
||||
)
|
||||
.training()
|
||||
)
|
||||
ppo = PPO(config=config, env="CartPole-v0")
|
||||
|
||||
# One step with all nodes up, enough to satisfy resource requirements
|
||||
ppo.step()
|
||||
|
||||
# Remove the first non-head node.
|
||||
node_to_kill = get_other_nodes(self.cluster, exclude_head=True)[0]
|
||||
self.cluster.remove_node(node_to_kill)
|
||||
|
||||
# Check faulty worker indices
|
||||
# All nodes have 2 CPUs and we require a driver thread + 5 rollout workers
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 2)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
def _step_target():
|
||||
# Resource requirements satisfied after approx 30s
|
||||
ppo.step()
|
||||
|
||||
time_before_step = time.time()
|
||||
|
||||
# kill one node after n seconds
|
||||
t = threading.Thread(target=_step_target)
|
||||
t.start()
|
||||
|
||||
# Wait 30 seconds until the missing node reappears
|
||||
time.sleep(30)
|
||||
self.cluster.add_node(
|
||||
redis_port=None,
|
||||
num_redis_shards=None,
|
||||
num_cpus=2,
|
||||
num_gpus=0,
|
||||
object_store_memory=object_store_memory,
|
||||
redis_max_memory=redis_max_memory,
|
||||
dashboard_host="0.0.0.0",
|
||||
)
|
||||
|
||||
t.join()
|
||||
|
||||
td = time.time() - time_before_step
|
||||
|
||||
# TODO: Find out what values make sense here
|
||||
self.assertGreaterEqual(td, 30, msg="Stepped before node was added.")
|
||||
self.assertLess(td, 60, msg="Took too long to step after node was added.")
|
||||
|
||||
# Workers should be back up
|
||||
self.assertEqual(len(ppo.workers._worker_health_check()), 0)
|
||||
self.assertEqual(len(ppo.workers._remote_workers), 6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,8 +1,9 @@
|
|||
from collections import defaultdict
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.pg import PG, PGConfig
|
||||
|
@ -50,7 +51,7 @@ class FaultInjectEnv(gym.Env):
|
|||
... {"bad_indices": [1, 2]},
|
||||
... worker_index=1,
|
||||
... num_workers=3,
|
||||
.. )
|
||||
... )
|
||||
... )
|
||||
|
||||
>>> from ray.rllib.env.env_context import EnvContext
|
Loading…
Add table
Reference in a new issue