mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Eval WorkerSet crashes when trying to re-add a failed worker (eval set does not have local worker). (#26134)
This commit is contained in:
parent
d83bbda281
commit
ca913ff6d6
4 changed files with 147 additions and 51 deletions
|
@ -2105,7 +2105,9 @@ class Algorithm(Trainable):
|
||||||
removed_workers, new_workers = [], []
|
removed_workers, new_workers = [], []
|
||||||
# Search for failed workers and try to recover (restart) them.
|
# Search for failed workers and try to recover (restart) them.
|
||||||
if recreate:
|
if recreate:
|
||||||
removed_workers, new_workers = worker_set.recreate_failed_workers()
|
removed_workers, new_workers = worker_set.recreate_failed_workers(
|
||||||
|
local_worker_for_synching=self.workers.local_worker()
|
||||||
|
)
|
||||||
elif ignore:
|
elif ignore:
|
||||||
removed_workers = worker_set.remove_failed_workers()
|
removed_workers = worker_set.remove_failed_workers()
|
||||||
|
|
||||||
|
@ -2396,6 +2398,9 @@ class Algorithm(Trainable):
|
||||||
# Evaluation results.
|
# Evaluation results.
|
||||||
if "evaluation" in iteration_results:
|
if "evaluation" in iteration_results:
|
||||||
results["evaluation"] = iteration_results.pop("evaluation")
|
results["evaluation"] = iteration_results.pop("evaluation")
|
||||||
|
results["evaluation"]["num_healthy_workers"] = len(
|
||||||
|
self.evaluation_workers.remote_workers()
|
||||||
|
)
|
||||||
|
|
||||||
# Custom metrics and episode media.
|
# Custom metrics and episode media.
|
||||||
results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
|
results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ray.rllib.utils.test_utils import (
|
||||||
|
|
||||||
class TestApexDQN(unittest.TestCase):
|
class TestApexDQN(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
ray.init(num_cpus=4)
|
ray.init(num_cpus=6)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
@ -130,30 +130,30 @@ class TestApexDQN(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _step_n_times(trainer, n: int):
|
def _step_n_times(algo, n: int):
|
||||||
"""Step trainer n times.
|
"""Step trainer n times.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
learning rate at the end of the execution.
|
learning rate at the end of the execution.
|
||||||
"""
|
"""
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
||||||
"cur_lr"
|
"cur_lr"
|
||||||
]
|
]
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
|
|
||||||
lr = _step_n_times(trainer, 3) # 50 timesteps
|
lr = _step_n_times(algo, 3) # 50 timesteps
|
||||||
# Close to 0.2
|
# Close to 0.2
|
||||||
self.assertGreaterEqual(lr, 0.1)
|
self.assertGreaterEqual(lr, 0.1)
|
||||||
|
|
||||||
lr = _step_n_times(trainer, 20) # 200 timesteps
|
lr = _step_n_times(algo, 20) # 200 timesteps
|
||||||
# LR Annealed to 0.001
|
# LR Annealed to 0.001
|
||||||
self.assertLessEqual(lr, 0.0011)
|
self.assertLessEqual(lr, 0.0011)
|
||||||
|
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
|
import gym
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import gym
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib import _register_all
|
|
||||||
from ray.rllib.algorithms.registry import get_algorithm_class
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
||||||
from ray.rllib.utils.test_utils import framework_iterator
|
from ray.rllib.utils.test_utils import framework_iterator
|
||||||
from ray.tune.registry import register_env
|
from ray.tune.registry import register_env
|
||||||
|
@ -23,6 +21,13 @@ class FaultInjectEnv(gym.Env):
|
||||||
>>> bad_env = FaultInjectEnv(
|
>>> bad_env = FaultInjectEnv(
|
||||||
... EnvContext({"bad_indices": [1, 2]},
|
... EnvContext({"bad_indices": [1, 2]},
|
||||||
... worker_index=1, num_workers=3))
|
... worker_index=1, num_workers=3))
|
||||||
|
|
||||||
|
>>> from ray.rllib.env.env_context import EnvContext
|
||||||
|
>>> # This env will fail only on the first evaluation worker, not on the first
|
||||||
|
>>> # regular rollout worker.
|
||||||
|
>>> bad_env = FaultInjectEnv(
|
||||||
|
... EnvContext({"bad_indices": [1], "eval_only": True},
|
||||||
|
... worker_index=2, num_workers=5))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -39,12 +44,12 @@ class FaultInjectEnv(gym.Env):
|
||||||
# Only fail on the original workers with the specified indices.
|
# Only fail on the original workers with the specified indices.
|
||||||
# Once on a recreated worker, don't fail anymore.
|
# Once on a recreated worker, don't fail anymore.
|
||||||
if (
|
if (
|
||||||
self.config.worker_index in self.config["bad_indices"]
|
self.config.worker_index in self.config.get("bad_indices", [])
|
||||||
and not self.config.recreated_worker
|
and not self.config.recreated_worker
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This is a simulated error from "
|
"This is a simulated error from "
|
||||||
f"worker-idx={self.config.worker_index}."
|
f"worker-idx={self.config.worker_index}!"
|
||||||
)
|
)
|
||||||
return self.env.step(action)
|
return self.env.step(action)
|
||||||
|
|
||||||
|
@ -58,53 +63,90 @@ class TestWorkerFailure(unittest.TestCase):
|
||||||
def tearDownClass(cls) -> None:
|
def tearDownClass(cls) -> None:
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def do_test(self, alg: str, config: dict, fn=None):
|
def do_test(self, alg: str, config: dict, fn=None, eval_only=False):
|
||||||
fn = fn or self._do_test_fault_ignore
|
fn = fn or self._do_test_fault_ignore
|
||||||
try:
|
fn(alg, config, eval_only)
|
||||||
fn(alg, config)
|
|
||||||
finally:
|
|
||||||
_register_all() # re-register the evicted objects
|
|
||||||
|
|
||||||
def _do_test_fault_ignore(self, algo: str, config: dict):
|
def _do_test_fault_ignore(self, algo: str, config: dict, eval_only: bool = False):
|
||||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||||
algo_cls = get_algorithm_class(algo)
|
algo_cls = get_algorithm_class(algo)
|
||||||
|
|
||||||
# Test fault handling
|
# Test fault handling
|
||||||
config["num_workers"] = 2
|
if not eval_only:
|
||||||
config["ignore_worker_failures"] = True
|
config["num_workers"] = 2
|
||||||
# Make worker idx=1 fail. Other workers will be ok.
|
config["ignore_worker_failures"] = True
|
||||||
config["env_config"] = {"bad_indices": [1]}
|
# Make worker idx=1 fail. Other workers will be ok.
|
||||||
|
config["env_config"] = {"bad_indices": [1]}
|
||||||
|
else:
|
||||||
|
config["num_workers"] = 1
|
||||||
|
config["evaluation_num_workers"] = 2
|
||||||
|
config["evaluation_interval"] = 1
|
||||||
|
config["evaluation_config"] = {
|
||||||
|
"ignore_worker_failures": True,
|
||||||
|
"env_config": {
|
||||||
|
# Make worker idx=1 fail. Other workers will be ok.
|
||||||
|
"bad_indices": [1],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||||
algo = algo_cls(config=config, env="fault_env")
|
algo = algo_cls(config=config, env="fault_env")
|
||||||
result = algo.train()
|
result = algo.train()
|
||||||
self.assertTrue(result["num_healthy_workers"], 1)
|
if not eval_only:
|
||||||
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
||||||
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
||||||
algo.stop()
|
algo.stop()
|
||||||
|
|
||||||
def _do_test_fault_fatal(self, alg, config):
|
def _do_test_fault_fatal(self, alg, config, eval_only=False):
|
||||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||||
agent_cls = get_algorithm_class(alg)
|
agent_cls = get_algorithm_class(alg)
|
||||||
|
|
||||||
# Test raises real error when out of workers
|
# Test raises real error when out of workers.
|
||||||
config["num_workers"] = 2
|
if not eval_only:
|
||||||
config["ignore_worker_failures"] = True
|
config["num_workers"] = 2
|
||||||
# Make both worker idx=1 and 2 fail.
|
config["ignore_worker_failures"] = True
|
||||||
config["env_config"] = {"bad_indices": [1, 2]}
|
# Make both worker idx=1 and 2 fail.
|
||||||
|
config["env_config"] = {"bad_indices": [1, 2]}
|
||||||
|
else:
|
||||||
|
config["num_workers"] = 1
|
||||||
|
config["evaluation_num_workers"] = 1
|
||||||
|
config["evaluation_interval"] = 1
|
||||||
|
config["evaluation_config"] = {
|
||||||
|
"ignore_worker_failures": True,
|
||||||
|
# Make eval worker (index 1) fail.
|
||||||
|
"env_config": {
|
||||||
|
"bad_indices": [1],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||||
a = agent_cls(config=config, env="fault_env")
|
a = agent_cls(config=config, env="fault_env")
|
||||||
self.assertRaises(Exception, lambda: a.train())
|
self.assertRaises(Exception, lambda: a.train())
|
||||||
a.stop()
|
a.stop()
|
||||||
|
|
||||||
def _do_test_fault_fatal_but_recreate(self, alg, config):
|
def _do_test_fault_fatal_but_recreate(self, alg, config, eval_only=False):
|
||||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||||
agent_cls = get_algorithm_class(alg)
|
agent_cls = get_algorithm_class(alg)
|
||||||
|
|
||||||
# Test raises real error when out of workers
|
# Test raises real error when out of workers.
|
||||||
config["num_workers"] = 2
|
if not eval_only:
|
||||||
config["recreate_failed_workers"] = True
|
config["num_workers"] = 2
|
||||||
# Make both worker idx=1 and 2 fail.
|
config["recreate_failed_workers"] = True
|
||||||
config["env_config"] = {"bad_indices": [1, 2]}
|
# Make both worker idx=1 and 2 fail.
|
||||||
|
config["env_config"] = {"bad_indices": [1, 2]}
|
||||||
|
else:
|
||||||
|
config["num_workers"] = 1
|
||||||
|
config["evaluation_num_workers"] = 1
|
||||||
|
config["evaluation_interval"] = 1
|
||||||
|
config["evaluation_config"] = {
|
||||||
|
"recreate_failed_workers": True,
|
||||||
|
# Make eval worker (index 1) fail.
|
||||||
|
"env_config": {
|
||||||
|
"bad_indices": [1],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||||
a = agent_cls(config=config, env="fault_env")
|
a = agent_cls(config=config, env="fault_env")
|
||||||
|
@ -121,21 +163,29 @@ class TestWorkerFailure(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = a.train()
|
result = a.train()
|
||||||
self.assertTrue(result["num_healthy_workers"], 2)
|
if not eval_only:
|
||||||
self.assertTrue(
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
||||||
all(
|
self.assertTrue(
|
||||||
ray.get(
|
all(
|
||||||
worker.apply.remote(
|
ray.get(
|
||||||
lambda w: w.recreated_worker
|
worker.apply.remote(
|
||||||
and w.env_context.recreated_worker
|
lambda w: w.recreated_worker
|
||||||
|
and w.env_context.recreated_worker
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
for worker in a.workers.remote_workers()
|
||||||
)
|
)
|
||||||
for worker in a.workers.remote_workers()
|
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
||||||
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
||||||
# This should also work several times.
|
# This should also work several times.
|
||||||
result = a.train()
|
result = a.train()
|
||||||
self.assertTrue(result["num_healthy_workers"], 2)
|
if not eval_only:
|
||||||
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
||||||
|
else:
|
||||||
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
||||||
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
||||||
a.stop()
|
a.stop()
|
||||||
|
|
||||||
def test_fatal(self):
|
def test_fatal(self):
|
||||||
|
@ -188,6 +238,32 @@ class TestWorkerFailure(unittest.TestCase):
|
||||||
def test_async_sampling_option(self):
|
def test_async_sampling_option(self):
|
||||||
self.do_test("PG", {"optimizer": {}, "sample_async": True})
|
self.do_test("PG", {"optimizer": {}, "sample_async": True})
|
||||||
|
|
||||||
|
def test_eval_workers_failing_ignore(self):
|
||||||
|
# Test the case where one eval worker fails, but we chose to ignore.
|
||||||
|
self.do_test(
|
||||||
|
"PG",
|
||||||
|
config={"model": {"fcnet_hiddens": [4]}},
|
||||||
|
eval_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_eval_workers_failing_recreate(self):
|
||||||
|
# Test the case where all eval workers fail, but we chose to recover.
|
||||||
|
self.do_test(
|
||||||
|
"PG",
|
||||||
|
config={"model": {"fcnet_hiddens": [4]}},
|
||||||
|
fn=self._do_test_fault_fatal_but_recreate,
|
||||||
|
eval_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_eval_workers_failing_fatal(self):
|
||||||
|
# Test the case where all eval workers fail (w/o recovery).
|
||||||
|
self.do_test(
|
||||||
|
"PG",
|
||||||
|
config={"model": {"fcnet_hiddens": [4]}},
|
||||||
|
fn=self._do_test_fault_fatal,
|
||||||
|
eval_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -306,7 +306,19 @@ class WorkerSet:
|
||||||
)
|
)
|
||||||
return removed_workers
|
return removed_workers
|
||||||
|
|
||||||
def recreate_failed_workers(self) -> Tuple[List[ActorHandle], List[ActorHandle]]:
|
def recreate_failed_workers(
|
||||||
|
self, local_worker_for_synching: RolloutWorker
|
||||||
|
) -> Tuple[List[ActorHandle], List[ActorHandle]]:
|
||||||
|
"""Recreates any failed workers (after health check).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_worker_for_synching: RolloutWorker to use to synchronize the weights
|
||||||
|
after recreation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple consisting of two items: The list of removed workers and the list of
|
||||||
|
newly added ones.
|
||||||
|
"""
|
||||||
faulty_indices = self._worker_health_check()
|
faulty_indices = self._worker_health_check()
|
||||||
removed_workers = []
|
removed_workers = []
|
||||||
new_workers = []
|
new_workers = []
|
||||||
|
@ -329,14 +341,17 @@ class WorkerSet:
|
||||||
recreated_worker=True,
|
recreated_worker=True,
|
||||||
config=self._remote_config,
|
config=self._remote_config,
|
||||||
)
|
)
|
||||||
# Sync new worker from local one.
|
|
||||||
|
# Sync new worker from provided one (or local one).
|
||||||
new_worker.set_weights.remote(
|
new_worker.set_weights.remote(
|
||||||
weights=self.local_worker().get_weights(),
|
weights=local_worker_for_synching.get_weights(),
|
||||||
global_vars=self.local_worker().get_global_vars(),
|
global_vars=local_worker_for_synching.get_global_vars(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add new worker to list of remote workers.
|
# Add new worker to list of remote workers.
|
||||||
self._remote_workers[worker_index - 1] = new_worker
|
self._remote_workers[worker_index - 1] = new_worker
|
||||||
new_workers.append(new_worker)
|
new_workers.append(new_worker)
|
||||||
|
|
||||||
return removed_workers, new_workers
|
return removed_workers, new_workers
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
|
|
Loading…
Add table
Reference in a new issue