ray/release/ray_release/command_runner/client_runner.py
Kai Fricke eca5bcfc87
[ci/release] Reload modules after installing matching Ray (#23227)
Apparently, ray gets imported somewhere before running the client runner (maybe from an anyscale package). This means that we need to reload the ray package after installing a matching local ray wheel.
Additionally, job submission should also install a matching local ray to match with the job submission server.
2022-03-16 15:44:43 +00:00

189 lines
6 KiB
Python

import json
import os
import subprocess
import sys
import tempfile
import threading
import time
from collections import deque
from typing import Optional, Dict, Any
import ray
from ray_release.anyscale_util import LAST_LOGS_LENGTH
from ray_release.cluster_manager.cluster_manager import ClusterManager
from ray_release.exception import (
ResultsError,
LocalEnvSetupError,
ClusterNodesWaitTimeout,
CommandTimeout,
ClusterStartupError,
CommandError,
)
from ray_release.file_manager.file_manager import FileManager
from ray_release.logger import logger
from ray_release.command_runner.command_runner import CommandRunner
from ray_release.wheels import install_matching_ray_locally
def install_cluster_env_packages(cluster_env: Dict[Any, Any]):
os.environ.update(cluster_env.get("env_vars", {}))
packages = cluster_env["python"]["pip_packages"]
logger.info(f"Installing cluster env packages locally: {packages}")
for package in packages:
subprocess.check_output(
f"pip install -U {package}", shell=True, env=os.environ, text=True
)
class ClientRunner(CommandRunner):
def __init__(
self,
cluster_manager: ClusterManager,
file_manager: FileManager,
working_dir: str,
):
super(ClientRunner, self).__init__(cluster_manager, file_manager, working_dir)
self.last_logs = None
self.result_output_json = tempfile.mktemp()
def prepare_remote_env(self):
pass
def prepare_local_env(self, ray_wheels_url: Optional[str] = None):
try:
install_matching_ray_locally(
ray_wheels_url or os.environ.get("RAY_WHEELS", None)
)
install_cluster_env_packages(self.cluster_manager.cluster_env)
except Exception as e:
raise LocalEnvSetupError(f"Error setting up local environment: {e}") from e
def wait_for_nodes(self, num_nodes: int, timeout: float = 900):
ray_address = self.cluster_manager.get_cluster_address()
try:
if ray.is_initialized:
ray.shutdown()
ray.init(address=ray_address)
start_time = time.monotonic()
timeout_at = start_time + timeout
next_status = start_time + 30
nodes_up = len(ray.nodes())
while nodes_up < num_nodes:
now = time.monotonic()
if now >= timeout_at:
raise ClusterNodesWaitTimeout(
f"Only {len(ray.nodes())}/{num_nodes} are up after "
f"{timeout} seconds."
)
if now >= next_status:
logger.info(
f"Waiting for nodes to come up: "
f"{len(ray.nodes())}/{num_nodes} "
f"({now - start_time:.2f} seconds, "
f"timeout: {timeout} seconds)."
)
next_status += 30
time.sleep(1)
nodes_up = len(ray.nodes())
ray.shutdown()
except Exception as e:
raise ClusterStartupError(f"Exception when waiting for nodes: {e}") from e
logger.info(f"All {num_nodes} nodes are up.")
def get_last_logs(self) -> Optional[str]:
return self.last_logs
def fetch_results(self) -> Dict[str, Any]:
try:
with open(self.result_output_json, "rt") as fp:
return json.load(fp)
except Exception as e:
raise ResultsError(
f"Could not load local results from " f"client command: {e}"
) from e
def run_command(
self, command: str, env: Optional[Dict] = None, timeout: float = 3600.0
) -> float:
logger.info(
f"Running command using Ray client on cluster "
f"{self.cluster_manager.cluster_name}: {command}"
)
env = env or {}
full_env = self.get_full_command_env(
{
**os.environ,
**env,
"RAY_ADDRESS": self.cluster_manager.get_cluster_address(),
"RAY_JOB_NAME": "test_job",
"PYTHONUNBUFFERED": "1",
}
)
kill_event = threading.Event()
def _kill_after(
proc: subprocess.Popen,
timeout: int = 30,
kill_event: Optional[threading.Event] = None,
):
timeout_at = time.monotonic() + timeout
while time.monotonic() < timeout_at:
if proc.poll() is not None:
return
time.sleep(1)
logger.info(
f"Client command timed out after {timeout} seconds, "
f"killing subprocess."
)
if kill_event:
kill_event.set()
proc.terminate()
start_time = time.monotonic()
proc = subprocess.Popen(
command,
env=full_env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
shell=True,
text=True,
)
kill_thread = threading.Thread(
target=_kill_after, args=(proc, timeout, kill_event)
)
kill_thread.start()
proc.stdout.reconfigure(line_buffering=True)
sys.stdout.reconfigure(line_buffering=True)
logs = deque(maxlen=LAST_LOGS_LENGTH)
for line in proc.stdout:
logs.append(line)
sys.stdout.write(line)
proc.wait()
sys.stdout.reconfigure(line_buffering=False)
time_taken = time.monotonic() - start_time
self.last_logs = "\n".join(logs)
return_code = proc.poll()
if return_code == -15 or return_code == 15 or kill_event.is_set():
# Process has been terminated
raise CommandTimeout(f"Cluster command timed out after {timeout} seconds.")
if return_code != 0:
raise CommandError(f"Command returned non-success status: {return_code}")
logger.warning(f"WE GOT RETURN CODE {return_code} AFTER {time_taken}")
return time_taken