mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] Node Fault Tolerance (#3238)
This PR introduces single-node fault tolerance for Tune. ## Previous behavior: - Actors will be restarted without checking if resources are available. This can lead to problems if we lose resources. ## New behavior: - RUNNING trials will be resumed on another node on a best effort basis (meaning they will run if resources available). - If the cluster is saturated, RUNNING trials on that failed node will become PENDING and queued. - During recovery, TrialSchedulers and SearchAlgorithms should receive notification of this (via `trial_runner.stop_trial`) so that they don’t wait/block for a trial that isn’t running. Remaining questions: - Should `last_result` be consistent during restore? Yes; but not for earlier trials (trials that are yet to be checkpointed). - Waiting for some PRs to merge first (#3239) Closes #2851.
This commit is contained in:
parent
3e33f6f71b
commit
784a6399b0
9 changed files with 300 additions and 19 deletions
|
@ -5,7 +5,7 @@ FROM ray-project/deploy
|
|||
# This updates numpy to 1.14 and mutes errors from other libraries
|
||||
RUN conda install -y numpy
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout
|
||||
RUN pip install -U h5py # Mutes FutureWarnings
|
||||
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
RUN conda install pytorch-cpu torchvision-cpu -c pytorch
|
||||
|
|
|
@ -42,9 +42,10 @@ class Cluster(object):
|
|||
self.add_node(**head_node_args)
|
||||
if connect:
|
||||
redis_password = head_node_args.get("redis_password")
|
||||
ray.init(
|
||||
output_info = ray.init(
|
||||
redis_address=self.redis_address,
|
||||
redis_password=redis_password)
|
||||
logger.info(output_info)
|
||||
if shutdown_at_exit:
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
|
@ -172,8 +173,10 @@ class Cluster(object):
|
|||
for node in all_nodes:
|
||||
self.remove_node(node)
|
||||
|
||||
if self.head_node is not None:
|
||||
if self.head_node:
|
||||
self.remove_node(self.head_node)
|
||||
else:
|
||||
logger.warning("No headnode exists!")
|
||||
|
||||
|
||||
class Node(object):
|
||||
|
|
|
@ -6,7 +6,7 @@ import json
|
|||
import pytest
|
||||
try:
|
||||
import pytest_timeout
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
pytest_timeout = None
|
||||
import time
|
||||
|
||||
|
|
|
@ -278,6 +278,7 @@ class RayTrialExecutor(TrialExecutor):
|
|||
def save(self, trial, storage=Checkpoint.DISK):
|
||||
"""Saves the trial's state to a checkpoint."""
|
||||
trial._checkpoint.storage = storage
|
||||
trial._checkpoint.last_result = trial.last_result
|
||||
if storage == Checkpoint.MEMORY:
|
||||
trial._checkpoint.value = trial.runner.save_to_object.remote()
|
||||
else:
|
||||
|
@ -301,6 +302,8 @@ class RayTrialExecutor(TrialExecutor):
|
|||
ray.get(trial.runner.restore_from_object.remote(value))
|
||||
else:
|
||||
ray.get(trial.runner.restore.remote(value))
|
||||
trial.last_result = checkpoint.last_result
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error restoring runner.")
|
||||
|
|
264
python/ray/tune/test/cluster_tests.py
Normal file
264
python/ray/tune/test/cluster_tests.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
try:
|
||||
import pytest_timeout
|
||||
except ImportError:
|
||||
pytest_timeout = None
|
||||
|
||||
from ray.test.cluster_utils import Cluster
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
|
||||
def register_test_trainable():
|
||||
class _Train(tune.Trainable):
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 1}
|
||||
|
||||
def _train(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
return {}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
tune.register_trainable("test", _Train)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
head_node_args={
|
||||
"resources": dict(CPU=1),
|
||||
"_internal_config": json.dumps({
|
||||
"num_heartbeats_timeout": 10
|
||||
})
|
||||
})
|
||||
register_test_trainable()
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_emptyhead_cluster():
|
||||
"""Starts head with no resources."""
|
||||
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
head_node_args={
|
||||
"resources": dict(CPU=0),
|
||||
"_internal_config": json.dumps({
|
||||
"num_heartbeats_timeout": 10
|
||||
})
|
||||
})
|
||||
register_test_trainable()
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(10, method="thread")
|
||||
def test_counting_resources(start_connected_cluster):
|
||||
"""Tests that Tune accounting is consistent with actual cluster."""
|
||||
|
||||
cluster = start_connected_cluster
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
nodes = []
|
||||
nodes += [cluster.add_node(resources=dict(CPU=1))]
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 2
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {"stopping_criterion": {"training_iteration": 10}}
|
||||
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # run 1
|
||||
cluster.remove_node(nodes.pop())
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
runner.step() # run 2
|
||||
|
||||
for i in range(5):
|
||||
nodes += [cluster.add_node(resources=dict(CPU=1))]
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 6
|
||||
|
||||
runner.step() # 1 result
|
||||
|
||||
for i in range(5):
|
||||
node = nodes.pop()
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
|
||||
|
||||
@pytest.mark.skip("Add this test once reconstruction is fixed")
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(10, method="thread")
|
||||
def test_remove_node_before_result(start_connected_cluster):
|
||||
"""Removing a node should cause a Trial to be requeued."""
|
||||
cluster = start_connected_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
# TODO(rliaw): Make blocking an option?
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {"stopping_criterion": {"training_iteration": 3}}
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # run 1
|
||||
runner.step() # run 2
|
||||
assert all(t.status == Trial.RUNNING for t in trials)
|
||||
|
||||
runner.step() # 1 result
|
||||
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources["CPU"] == 1
|
||||
|
||||
runner.step() # recover
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
assert all(t.status == Trial.TERMINATED for t in trials)
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(120, method="thread")
|
||||
def test_trial_migration(start_connected_emptyhead_cluster):
|
||||
"""Removing a node while cluster has space should migrate trial.
|
||||
|
||||
The trial state should also be consistent with the checkpoint.
|
||||
"""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2
|
||||
}
|
||||
|
||||
# Test recovery of trial that hasn't been checkpointed
|
||||
t = Trial("test", **kwargs)
|
||||
runner.add_trial(t)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
assert t.last_result is not None
|
||||
node2 = cluster.add_node(resources=dict(CPU=1))
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
runner.step() # Recovery step
|
||||
|
||||
# TODO(rliaw): This assertion is not critical but will not pass
|
||||
# because checkpoint handling is messy and should be refactored
|
||||
# rather than hotfixed.
|
||||
# assert t.last_result is None, "Trial result not restored correctly."
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
|
||||
assert t.status == Trial.TERMINATED
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t2 = Trial("test", **kwargs)
|
||||
runner.add_trial(t2)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
assert t2.has_checkpoint()
|
||||
node3 = cluster.add_node(resources=dict(CPU=1))
|
||||
cluster.remove_node(node2)
|
||||
assert cluster.wait_for_nodes()
|
||||
runner.step() # Recovery step
|
||||
assert t2.last_result["training_iteration"] == 2
|
||||
for i in range(1):
|
||||
runner.step()
|
||||
|
||||
assert t2.status == Trial.TERMINATED
|
||||
|
||||
# Test recovery of trial that won't be checkpointed
|
||||
t3 = Trial("test", **{"stopping_criterion": {"training_iteration": 3}})
|
||||
runner.add_trial(t3)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
cluster.add_node(resources=dict(CPU=1))
|
||||
cluster.remove_node(node3)
|
||||
assert cluster.wait_for_nodes()
|
||||
runner.step() # Error handling step
|
||||
assert t3.status == Trial.ERROR
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(120, method="thread")
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"checkpoint_freq": 1,
|
||||
"max_failures": 1
|
||||
}
|
||||
|
||||
trials = [Trial("test", **kwargs), Trial("test", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
runner.step()
|
||||
assert all(t.status == Trial.PENDING for t in trials)
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
|
@ -85,9 +85,10 @@ class Checkpoint(object):
|
|||
MEMORY = "memory"
|
||||
DISK = "disk"
|
||||
|
||||
def __init__(self, storage, value):
|
||||
def __init__(self, storage, value, last_result=None):
|
||||
self.storage = storage
|
||||
self.value = value
|
||||
self.last_result = last_result
|
||||
|
||||
@staticmethod
|
||||
def from_object(value=None):
|
||||
|
@ -277,6 +278,14 @@ class Trial(object):
|
|||
def has_checkpoint(self):
|
||||
return self._checkpoint.value is not None
|
||||
|
||||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for restoring.
|
||||
|
||||
This is if a checkpoint frequency is set, which includes settings
|
||||
where there may not yet be a checkpoint.
|
||||
"""
|
||||
return self.checkpoint_freq > 0
|
||||
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result.update(done=True)
|
||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
|
@ -61,24 +60,24 @@ class TrialExecutor(object):
|
|||
"stop_trial() method")
|
||||
|
||||
def restart_trial(self, trial, error_msg=None):
|
||||
"""Restarts the trial.
|
||||
"""Restarts or requeues the trial.
|
||||
|
||||
The state of the trial should restore from the last checkpoint.
|
||||
The state of the trial should restore from the last checkpoint. Trial
|
||||
is requeued if the cluster no longer has resources to accomodate it.
|
||||
|
||||
Args:
|
||||
error_msg (str): Optional error message.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Attempting to recover trial state from last checkpoint")
|
||||
self.stop_trial(
|
||||
trial, error=True, error_msg=error_msg, stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
self.stop_trial(
|
||||
trial,
|
||||
error=error_msg is not None,
|
||||
error_msg=error_msg,
|
||||
stop_logger=False)
|
||||
trial.result_logger.flush()
|
||||
if self.has_resources(trial.resources):
|
||||
self.start_trial(trial)
|
||||
except Exception:
|
||||
error_msg = traceback.format_exc()
|
||||
logger.exception("Error recovering trial from checkpoint, abort.")
|
||||
self.stop_trial(trial, error=True, error_msg=error_msg)
|
||||
else:
|
||||
trial.status = Trial.PENDING
|
||||
|
||||
def continue_training(self, trial):
|
||||
"""Continues the training of this trial."""
|
||||
|
|
|
@ -297,7 +297,7 @@ class TrialRunner(object):
|
|||
logger.exception("Error processing event.")
|
||||
error_msg = traceback.format_exc()
|
||||
if trial.status == Trial.RUNNING:
|
||||
if trial.has_checkpoint() and \
|
||||
if trial.should_recover() and \
|
||||
trial.num_failures < trial.max_failures:
|
||||
self._try_recover(trial, error_msg)
|
||||
else:
|
||||
|
|
|
@ -261,6 +261,9 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_supported_spaces.py
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
pytest /ray/python/ray/tune/test/cluster_tests.py
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_env_with_subprocess.py
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue