From ca913ff6d66ad4ec76424892cffc2ad76f6cc3a7 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 30 Jun 2022 13:25:22 +0200 Subject: [PATCH] [RLlib] Eval WorkerSet crashes when trying to re-add a failed worker (eval set does not have local worker). (#26134) --- rllib/algorithms/algorithm.py | 7 +- .../apex_dqn/tests/test_apex_dqn.py | 16 +- .../algorithms/tests/test_worker_failures.py | 152 +++++++++++++----- rllib/evaluation/worker_set.py | 23 ++- 4 files changed, 147 insertions(+), 51 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 28863e341..97390bddc 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -2105,7 +2105,9 @@ class Algorithm(Trainable): removed_workers, new_workers = [], [] # Search for failed workers and try to recover (restart) them. if recreate: - removed_workers, new_workers = worker_set.recreate_failed_workers() + removed_workers, new_workers = worker_set.recreate_failed_workers( + local_worker_for_synching=self.workers.local_worker() + ) elif ignore: removed_workers = worker_set.remove_failed_workers() @@ -2396,6 +2398,9 @@ class Algorithm(Trainable): # Evaluation results. if "evaluation" in iteration_results: results["evaluation"] = iteration_results.pop("evaluation") + results["evaluation"]["num_healthy_workers"] = len( + self.evaluation_workers.remote_workers() + ) # Custom metrics and episode media. results["custom_metrics"] = iteration_results.pop("custom_metrics", {}) diff --git a/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py b/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py index a88e70441..7d97a4ddf 100644 --- a/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py +++ b/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py @@ -15,7 +15,7 @@ from ray.rllib.utils.test_utils import ( class TestApexDQN(unittest.TestCase): def setUp(self): - ray.init(num_cpus=4) + ray.init(num_cpus=6) def tearDown(self): ray.shutdown() @@ -130,30 +130,30 @@ class TestApexDQN(unittest.TestCase): ) ) - def _step_n_times(trainer, n: int): + def _step_n_times(algo, n: int): """Step trainer n times. Returns: learning rate at the end of the execution. """ for _ in range(n): - results = trainer.train() + results = algo.train() return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ "cur_lr" ] - for _ in framework_iterator(config): - trainer = config.build(env="CartPole-v0") + for _ in framework_iterator(config, frameworks=("torch", "tf")): + algo = config.build(env="CartPole-v0") - lr = _step_n_times(trainer, 3) # 50 timesteps + lr = _step_n_times(algo, 3) # 50 timesteps # Close to 0.2 self.assertGreaterEqual(lr, 0.1) - lr = _step_n_times(trainer, 20) # 200 timesteps + lr = _step_n_times(algo, 20) # 200 timesteps # LR Annealed to 0.001 self.assertLessEqual(lr, 0.0011) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/tests/test_worker_failures.py b/rllib/algorithms/tests/test_worker_failures.py index 2c12eadd4..dbef7d211 100644 --- a/rllib/algorithms/tests/test_worker_failures.py +++ b/rllib/algorithms/tests/test_worker_failures.py @@ -1,9 +1,7 @@ +import gym import unittest -import gym - import ray -from ray.rllib import _register_all from ray.rllib.algorithms.registry import get_algorithm_class from ray.rllib.utils.test_utils import framework_iterator from ray.tune.registry import register_env @@ -23,6 +21,13 @@ class FaultInjectEnv(gym.Env): >>> bad_env = FaultInjectEnv( ... EnvContext({"bad_indices": [1, 2]}, ... worker_index=1, num_workers=3)) + + >>> from ray.rllib.env.env_context import EnvContext + >>> # This env will fail only on the first evaluation worker, not on the first + >>> # regular rollout worker. + >>> bad_env = FaultInjectEnv( + ... EnvContext({"bad_indices": [1], "eval_only": True}, + ... worker_index=2, num_workers=5)) """ def __init__(self, config): @@ -39,12 +44,12 @@ class FaultInjectEnv(gym.Env): # Only fail on the original workers with the specified indices. # Once on a recreated worker, don't fail anymore. if ( - self.config.worker_index in self.config["bad_indices"] + self.config.worker_index in self.config.get("bad_indices", []) and not self.config.recreated_worker ): raise ValueError( "This is a simulated error from " - f"worker-idx={self.config.worker_index}." + f"worker-idx={self.config.worker_index}!" ) return self.env.step(action) @@ -58,53 +63,90 @@ class TestWorkerFailure(unittest.TestCase): def tearDownClass(cls) -> None: ray.shutdown() - def do_test(self, alg: str, config: dict, fn=None): + def do_test(self, alg: str, config: dict, fn=None, eval_only=False): fn = fn or self._do_test_fault_ignore - try: - fn(alg, config) - finally: - _register_all() # re-register the evicted objects + fn(alg, config, eval_only) - def _do_test_fault_ignore(self, algo: str, config: dict): + def _do_test_fault_ignore(self, algo: str, config: dict, eval_only: bool = False): register_env("fault_env", lambda c: FaultInjectEnv(c)) algo_cls = get_algorithm_class(algo) # Test fault handling - config["num_workers"] = 2 - config["ignore_worker_failures"] = True - # Make worker idx=1 fail. Other workers will be ok. - config["env_config"] = {"bad_indices": [1]} + if not eval_only: + config["num_workers"] = 2 + config["ignore_worker_failures"] = True + # Make worker idx=1 fail. Other workers will be ok. + config["env_config"] = {"bad_indices": [1]} + else: + config["num_workers"] = 1 + config["evaluation_num_workers"] = 2 + config["evaluation_interval"] = 1 + config["evaluation_config"] = { + "ignore_worker_failures": True, + "env_config": { + # Make worker idx=1 fail. Other workers will be ok. + "bad_indices": [1], + }, + } for _ in framework_iterator(config, frameworks=("tf2", "torch")): algo = algo_cls(config=config, env="fault_env") result = algo.train() - self.assertTrue(result["num_healthy_workers"], 1) + if not eval_only: + self.assertTrue(result["num_healthy_workers"] == 1) + else: + self.assertTrue(result["num_healthy_workers"] == 1) + self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1) algo.stop() - def _do_test_fault_fatal(self, alg, config): + def _do_test_fault_fatal(self, alg, config, eval_only=False): register_env("fault_env", lambda c: FaultInjectEnv(c)) agent_cls = get_algorithm_class(alg) - # Test raises real error when out of workers - config["num_workers"] = 2 - config["ignore_worker_failures"] = True - # Make both worker idx=1 and 2 fail. - config["env_config"] = {"bad_indices": [1, 2]} + # Test raises real error when out of workers. + if not eval_only: + config["num_workers"] = 2 + config["ignore_worker_failures"] = True + # Make both worker idx=1 and 2 fail. + config["env_config"] = {"bad_indices": [1, 2]} + else: + config["num_workers"] = 1 + config["evaluation_num_workers"] = 1 + config["evaluation_interval"] = 1 + config["evaluation_config"] = { + "ignore_worker_failures": True, + # Make eval worker (index 1) fail. + "env_config": { + "bad_indices": [1], + }, + } for _ in framework_iterator(config, frameworks=("torch", "tf")): a = agent_cls(config=config, env="fault_env") self.assertRaises(Exception, lambda: a.train()) a.stop() - def _do_test_fault_fatal_but_recreate(self, alg, config): + def _do_test_fault_fatal_but_recreate(self, alg, config, eval_only=False): register_env("fault_env", lambda c: FaultInjectEnv(c)) agent_cls = get_algorithm_class(alg) - # Test raises real error when out of workers - config["num_workers"] = 2 - config["recreate_failed_workers"] = True - # Make both worker idx=1 and 2 fail. - config["env_config"] = {"bad_indices": [1, 2]} + # Test raises real error when out of workers. + if not eval_only: + config["num_workers"] = 2 + config["recreate_failed_workers"] = True + # Make both worker idx=1 and 2 fail. + config["env_config"] = {"bad_indices": [1, 2]} + else: + config["num_workers"] = 1 + config["evaluation_num_workers"] = 1 + config["evaluation_interval"] = 1 + config["evaluation_config"] = { + "recreate_failed_workers": True, + # Make eval worker (index 1) fail. + "env_config": { + "bad_indices": [1], + }, + } for _ in framework_iterator(config, frameworks=("tf2", "torch")): a = agent_cls(config=config, env="fault_env") @@ -121,21 +163,29 @@ class TestWorkerFailure(unittest.TestCase): ) ) result = a.train() - self.assertTrue(result["num_healthy_workers"], 2) - self.assertTrue( - all( - ray.get( - worker.apply.remote( - lambda w: w.recreated_worker - and w.env_context.recreated_worker + if not eval_only: + self.assertTrue(result["num_healthy_workers"] == 2) + self.assertTrue( + all( + ray.get( + worker.apply.remote( + lambda w: w.recreated_worker + and w.env_context.recreated_worker + ) ) + for worker in a.workers.remote_workers() ) - for worker in a.workers.remote_workers() ) - ) + else: + self.assertTrue(result["num_healthy_workers"] == 1) + self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1) # This should also work several times. result = a.train() - self.assertTrue(result["num_healthy_workers"], 2) + if not eval_only: + self.assertTrue(result["num_healthy_workers"] == 2) + else: + self.assertTrue(result["num_healthy_workers"] == 1) + self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1) a.stop() def test_fatal(self): @@ -188,6 +238,32 @@ class TestWorkerFailure(unittest.TestCase): def test_async_sampling_option(self): self.do_test("PG", {"optimizer": {}, "sample_async": True}) + def test_eval_workers_failing_ignore(self): + # Test the case where one eval worker fails, but we chose to ignore. + self.do_test( + "PG", + config={"model": {"fcnet_hiddens": [4]}}, + eval_only=True, + ) + + def test_eval_workers_failing_recreate(self): + # Test the case where all eval workers fail, but we chose to recover. + self.do_test( + "PG", + config={"model": {"fcnet_hiddens": [4]}}, + fn=self._do_test_fault_fatal_but_recreate, + eval_only=True, + ) + + def test_eval_workers_failing_fatal(self): + # Test the case where all eval workers fail (w/o recovery). + self.do_test( + "PG", + config={"model": {"fcnet_hiddens": [4]}}, + fn=self._do_test_fault_fatal, + eval_only=True, + ) + if __name__ == "__main__": import sys diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 330ed90cf..811fe5f8f 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -306,7 +306,19 @@ class WorkerSet: ) return removed_workers - def recreate_failed_workers(self) -> Tuple[List[ActorHandle], List[ActorHandle]]: + def recreate_failed_workers( + self, local_worker_for_synching: RolloutWorker + ) -> Tuple[List[ActorHandle], List[ActorHandle]]: + """Recreates any failed workers (after health check). + + Args: + local_worker_for_synching: RolloutWorker to use to synchronize the weights + after recreation. + + Returns: + A tuple consisting of two items: The list of removed workers and the list of + newly added ones. + """ faulty_indices = self._worker_health_check() removed_workers = [] new_workers = [] @@ -329,14 +341,17 @@ class WorkerSet: recreated_worker=True, config=self._remote_config, ) - # Sync new worker from local one. + + # Sync new worker from provided one (or local one). new_worker.set_weights.remote( - weights=self.local_worker().get_weights(), - global_vars=self.local_worker().get_global_vars(), + weights=local_worker_for_synching.get_weights(), + global_vars=local_worker_for_synching.get_global_vars(), ) + # Add new worker to list of remote workers. self._remote_workers[worker_index - 1] = new_worker new_workers.append(new_worker) + return removed_workers, new_workers def stop(self) -> None: