ray/release/ray_release/tests/test_glue.py

595 lines
20 KiB
Python
Raw Normal View History

import os
import shutil
import sys
import tempfile
import time
import unittest
from typing import Type, Callable
from unittest.mock import patch
from ray_release.alerts.handle import result_to_handle_map
from ray_release.cluster_manager.cluster_manager import ClusterManager
from ray_release.cluster_manager.full import FullClusterManager
from ray_release.command_runner.command_runner import CommandRunner
from ray_release.config import Test
from ray_release.exception import (
ReleaseTestConfigError,
LocalEnvSetupError,
ClusterComputeCreateError,
ClusterEnvBuildError,
ClusterEnvBuildTimeout,
ClusterEnvCreateError,
ClusterCreationError,
ClusterStartupError,
ClusterStartupTimeout,
RemoteEnvSetupError,
CommandError,
PrepareCommandError,
CommandTimeout,
PrepareCommandTimeout,
TestCommandError,
TestCommandTimeout,
ResultsError,
LogsError,
ResultsAlert,
ClusterNodesWaitTimeout,
)
from ray_release.file_manager.file_manager import FileManager
from ray_release.glue import (
run_release_test,
type_str_to_command_runner,
command_runner_to_cluster_manager,
command_runner_to_file_manager,
)
from ray_release.logger import logger
from ray_release.reporter.reporter import Reporter
from ray_release.result import Result, ExitCode
from ray_release.tests.utils import MockSDK, APIDict
def _fail_on_call(error_type: Type[Exception] = RuntimeError, message: str = "Fail"):
def _fail(*args, **kwargs):
raise error_type(message)
return _fail
class MockReturn:
return_dict = {}
def __getattribute__(self, item):
return_dict = object.__getattribute__(self, "return_dict")
if item in return_dict:
mocked = return_dict[item]
if isinstance(mocked, Callable):
return mocked()
else:
return lambda *a, **kw: mocked
return object.__getattribute__(self, item)
class GlueTest(unittest.TestCase):
def writeClusterEnv(self, content: str):
with open(os.path.join(self.tempdir, "cluster_env.yaml"), "wt") as fp:
fp.write(content)
def writeClusterCompute(self, content: str):
with open(os.path.join(self.tempdir, "cluster_compute.yaml"), "wt") as fp:
fp.write(content)
def setUp(self) -> None:
self.tempdir = tempfile.mkdtemp()
self.sdk = MockSDK()
self.sdk.returns["get_project"] = APIDict(
result=APIDict(name="unit_test_project")
)
self.writeClusterEnv("{'env': true}")
self.writeClusterCompute("{'compute': true}")
with open(os.path.join(self.tempdir, "driver_fail.sh"), "wt") as f:
f.write("exit 1\n")
with open(os.path.join(self.tempdir, "driver_succeed.sh"), "wt") as f:
f.write("exit 0\n")
this_sdk = self.sdk
this_tempdir = self.tempdir
self.cluster_manager_return = {}
self.command_runner_return = {}
self.file_manager_return = {}
this_cluster_manager_return = self.cluster_manager_return
this_command_runner_return = self.command_runner_return
this_file_manager_return = self.file_manager_return
class MockClusterManager(MockReturn, FullClusterManager):
def __init__(self, test_name: str, project_id: str, sdk=None):
super(MockClusterManager, self).__init__(
test_name, project_id, this_sdk
)
self.return_dict = this_cluster_manager_return
class MockCommandRunner(MockReturn, CommandRunner):
return_dict = self.cluster_manager_return
def __init__(
self,
cluster_manager: ClusterManager,
file_manager: FileManager,
working_dir: str,
):
super(MockCommandRunner, self).__init__(
cluster_manager, file_manager, this_tempdir
)
self.return_dict = this_command_runner_return
class MockFileManager(MockReturn, FileManager):
def __init__(self, cluster_manager: ClusterManager):
super(MockFileManager, self).__init__(cluster_manager)
self.return_dict = this_file_manager_return
self.mock_alert_return = None
def mock_alerter(test: Test, result: Result):
return self.mock_alert_return
result_to_handle_map["unit_test_alerter"] = mock_alerter
type_str_to_command_runner["unit_test"] = MockCommandRunner
command_runner_to_cluster_manager[MockCommandRunner] = MockClusterManager
command_runner_to_file_manager[MockCommandRunner] = MockFileManager
self.test = Test(
name="unit_test_end_to_end",
run=dict(
type="unit_test",
prepare="prepare_cmd",
script="test_cmd",
wait_for_nodes=dict(num_nodes=4, timeout=40),
),
working_dir=self.tempdir,
cluster=dict(
cluster_env="cluster_env.yaml", cluster_compute="cluster_compute.yaml"
),
alert="unit_test_alerter",
driver_setup="driver_fail.sh",
)
self.anyscale_project = "prj_unit12345678"
self.ray_wheels_url = "http://mock.wheels/"
def tearDown(self) -> None:
shutil.rmtree(self.tempdir)
def _succeed_until(self, until: str):
# These commands should succeed
self.command_runner_return["prepare_local_env"] = None
if until == "local_env":
return
self.test["driver_setup"] = "driver_succeed.sh"
if until == "driver_setup":
return
self.cluster_manager_return["cluster_compute_id"] = "valid"
self.cluster_manager_return["create_cluster_compute"] = None
if until == "cluster_compute":
return
self.cluster_manager_return["cluster_env_id"] = "valid"
self.cluster_manager_return["create_cluster_env"] = None
self.cluster_manager_return["cluster_env_build_id"] = "valid"
self.cluster_manager_return["build_cluster_env"] = None
if until == "cluster_env":
return
self.cluster_manager_return["cluster_id"] = "valid"
self.cluster_manager_return["start_cluster"] = None
if until == "cluster_start":
return
self.command_runner_return["prepare_remote_env"] = None
if until == "remote_env":
return
self.command_runner_return["wait_for_nodes"] = None
if until == "wait_for_nodes":
return
self.command_runner_return["run_prepare_command"] = None
if until == "prepare_command":
return
self.command_runner_return["run_command"] = None
if until == "test_command":
return
self.command_runner_return["fetch_results"] = {
"time_taken": 50,
"last_update": time.time() - 60,
}
if until == "fetch_results":
return
self.command_runner_return["get_last_logs"] = "Lorem ipsum"
if until == "get_last_logs":
return
self.mock_alert_return = None
def _run(self, result: Result, **kwargs):
run_release_test(
test=self.test,
anyscale_project=self.anyscale_project,
result=result,
ray_wheels_url=self.ray_wheels_url,
**kwargs
)
def testInvalidClusterEnv(self):
result = Result()
# Any ReleaseTestConfigError
with patch(
"ray_release.glue.load_test_cluster_env",
_fail_on_call(ReleaseTestConfigError),
), self.assertRaises(ReleaseTestConfigError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because file not found
os.unlink(os.path.join(self.tempdir, "cluster_env.yaml"))
with self.assertRaisesRegex(ReleaseTestConfigError, "Path not found"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because invalid jinja template
self.writeClusterEnv("{{ INVALID")
with self.assertRaisesRegex(ReleaseTestConfigError, "yaml template"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because invalid json
self.writeClusterEnv("{'test': true, 'fail}")
with self.assertRaisesRegex(ReleaseTestConfigError, "quoted scalar"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
def testInvalidClusterCompute(self):
result = Result()
with patch(
"ray_release.glue.load_test_cluster_compute",
_fail_on_call(ReleaseTestConfigError),
), self.assertRaises(ReleaseTestConfigError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because file not found
os.unlink(os.path.join(self.tempdir, "cluster_compute.yaml"))
with self.assertRaisesRegex(ReleaseTestConfigError, "Path not found"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because invalid jinja template
self.writeClusterCompute("{{ INVALID")
with self.assertRaisesRegex(ReleaseTestConfigError, "yaml template"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
# Fails because invalid json
self.writeClusterCompute("{'test': true, 'fail}")
with self.assertRaisesRegex(ReleaseTestConfigError, "quoted scalar"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CONFIG_ERROR.value)
def testInvalidPrepareLocalEnv(self):
result = Result()
self.command_runner_return["prepare_local_env"] = _fail_on_call(
LocalEnvSetupError
)
with self.assertRaises(LocalEnvSetupError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.LOCAL_ENV_SETUP_ERROR.value)
def testDriverSetupFails(self):
result = Result()
self._succeed_until("local_env")
with self.assertRaises(LocalEnvSetupError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.LOCAL_ENV_SETUP_ERROR.value)
def testInvalidClusterIdOverride(self):
result = Result()
self._succeed_until("driver_setup")
self.sdk.returns["get_cluster_environment"] = None
with self.assertRaises(ClusterEnvCreateError):
self._run(result, cluster_env_id="existing")
self.sdk.returns["get_cluster_environment"] = APIDict(
result=APIDict(config_json={"overridden": True})
)
with self.assertRaises(Exception) as cm: # Fail somewhere else
self._run(result, cluster_env_id="existing")
self.assertNotIsInstance(cm.exception, ClusterEnvCreateError)
def testBuildConfigFailsClusterCompute(self):
result = Result()
self._succeed_until("driver_setup")
# These commands should succeed
self.command_runner_return["prepare_local_env"] = None
# Fails because API response faulty
with self.assertRaisesRegex(ClusterComputeCreateError, "Unexpected"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
# Fails for random cluster compute reason
self.cluster_manager_return["create_cluster_compute"] = _fail_on_call(
ClusterComputeCreateError, "Known"
)
with self.assertRaisesRegex(ClusterComputeCreateError, "Known"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
def testBuildConfigFailsClusterEnv(self):
result = Result()
self._succeed_until("cluster_compute")
# Fails because API response faulty
with self.assertRaisesRegex(ClusterEnvCreateError, "Unexpected"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
# Fails for random cluster env create reason
self.cluster_manager_return["create_cluster_env"] = _fail_on_call(
ClusterEnvCreateError, "Known"
)
with self.assertRaisesRegex(ClusterEnvCreateError, "Known"):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
# Now, succeed creation but fail on cluster env build
self.cluster_manager_return["cluster_env_id"] = "valid"
self.cluster_manager_return["create_cluster_env"] = None
self.cluster_manager_return["build_cluster_env"] = _fail_on_call(
ClusterEnvBuildError
)
with self.assertRaises(ClusterEnvBuildError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_ENV_BUILD_ERROR.value)
# Now, fail on cluster env timeout
self.cluster_manager_return["build_cluster_env"] = _fail_on_call(
ClusterEnvBuildTimeout
)
with self.assertRaises(ClusterEnvBuildTimeout):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_ENV_BUILD_TIMEOUT.value)
def testStartClusterFails(self):
result = Result()
self._succeed_until("cluster_env")
# Fails because API response faulty
with self.assertRaises(ClusterCreationError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_RESOURCE_ERROR.value)
self.cluster_manager_return["cluster_id"] = "valid"
# Fail for random cluster startup reason
self.cluster_manager_return["start_cluster"] = _fail_on_call(
ClusterStartupError
)
with self.assertRaises(ClusterStartupError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_STARTUP_ERROR.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
# Fail for cluster startup timeout
self.cluster_manager_return["start_cluster"] = _fail_on_call(
ClusterStartupTimeout
)
with self.assertRaises(ClusterStartupTimeout):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_STARTUP_TIMEOUT.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testPrepareRemoteEnvFails(self):
result = Result()
self._succeed_until("cluster_start")
self.command_runner_return["prepare_remote_env"] = _fail_on_call(
RemoteEnvSetupError
)
with self.assertRaises(RemoteEnvSetupError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.REMOTE_ENV_SETUP_ERROR.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testWaitForNodesFails(self):
result = Result()
self._succeed_until("remote_env")
# Wait for nodes command fails
self.command_runner_return["wait_for_nodes"] = _fail_on_call(
ClusterNodesWaitTimeout
)
with self.assertRaises(ClusterNodesWaitTimeout):
self._run(result)
self.assertEqual(result.return_code, ExitCode.CLUSTER_WAIT_TIMEOUT.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testPrepareCommandFails(self):
result = Result()
self._succeed_until("wait_for_nodes")
# Prepare command fails
self.command_runner_return["run_prepare_command"] = _fail_on_call(CommandError)
with self.assertRaises(PrepareCommandError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.PREPARE_ERROR.value)
# Prepare command times out
self.command_runner_return["run_prepare_command"] = _fail_on_call(
CommandTimeout
)
with self.assertRaises(PrepareCommandTimeout):
self._run(result)
# Special case: Prepare commands are usually waiting for nodes
# (this may change in the future!)
self.assertEqual(result.return_code, ExitCode.CLUSTER_WAIT_TIMEOUT.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testTestCommandFails(self):
result = Result()
self._succeed_until("prepare_command")
# Test command fails
self.command_runner_return["run_command"] = _fail_on_call(CommandError)
with self.assertRaises(TestCommandError):
self._run(result)
self.assertEqual(result.return_code, ExitCode.COMMAND_ERROR.value)
# Test command times out
self.command_runner_return["run_command"] = _fail_on_call(CommandTimeout)
with self.assertRaises(TestCommandTimeout):
self._run(result)
self.assertEqual(result.return_code, ExitCode.COMMAND_TIMEOUT.value)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testTestCommandTimeoutLongRunning(self):
result = Result()
self._succeed_until("fetch_results")
# Test command times out
self.command_runner_return["run_command"] = _fail_on_call(CommandTimeout)
with self.assertRaises(TestCommandTimeout):
self._run(result)
self.assertEqual(result.return_code, ExitCode.COMMAND_TIMEOUT.value)
# But now set test to long running
self.test["run"]["long_running"] = True
self._run(result) # Will not fail this time
self.assertGreaterEqual(result.results["last_update_diff"], 60.0)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testFetchResultFails(self):
result = Result()
self._succeed_until("test_command")
self.command_runner_return["fetch_results"] = _fail_on_call(ResultsError)
with self.assertLogs(logger, "ERROR") as cm:
self._run(result)
self.assertTrue(any("Could not fetch results" in o for o in cm.output))
self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
self.assertEqual(result.status, "finished")
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testLastLogsFails(self):
result = Result()
self._succeed_until("fetch_results")
self.command_runner_return["get_last_logs"] = _fail_on_call(LogsError)
with self.assertLogs(logger, "ERROR") as cm:
self._run(result)
self.assertTrue(any("Error fetching logs" in o for o in cm.output))
self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
self.assertEqual(result.status, "finished")
self.assertIn("No logs", result.last_logs)
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testAlertFails(self):
result = Result()
self._succeed_until("get_last_logs")
self.mock_alert_return = "Alert raised"
with self.assertRaises(ResultsAlert):
self._run(result)
self.assertEqual(result.return_code, ExitCode.COMMAND_ALERT.value)
self.assertEqual(result.status, "error")
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
def testReportFails(self):
result = Result()
self._succeed_until("complete")
class FailReporter(Reporter):
def report_result(self, test: Test, result: Result):
raise RuntimeError
with self.assertLogs(logger, "ERROR") as cm:
self._run(result, reporters=[FailReporter()])
self.assertTrue(any("Error reporting results" in o for o in cm.output))
self.assertEqual(result.return_code, ExitCode.SUCCESS.value)
self.assertEqual(result.status, "finished")
# Ensure cluster was terminated
self.assertGreaterEqual(self.sdk.call_counter["terminate_cluster"], 1)
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))