From 856bea31fbadbd51775abfc1cda5b475c594e525 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Mon, 13 Jun 2022 21:52:57 +0900 Subject: [PATCH] [State Observability] Ray log CLI / API (#25481) This PR implements the basic log APIs. For the better APIs (like higher level APIs like ray logs actors), it will be implemented after the internal API review is done. # If there's only 1 match, print a file content. Otherwise, print all files that match glob. ray logs [glob_filter] --node-id=[head node by default] Args: --tail: Tail the last X lines --follow: Follow the new logs --actor-id: The actor id --pid --node-ip: For worker logs --node-id: The node id of the log --interval: When --follow is specified, logs are printed with this interval. (should we remove it?) --- dashboard/modules/log/log_agent.py | 2 +- dashboard/modules/log/log_manager.py | 13 +- dashboard/modules/state/state_head.py | 42 ++--- dashboard/state_aggregator.py | 2 + python/ray/experimental/state/api.py | 109 ++++++++++++- python/ray/experimental/state/common.py | 10 +- python/ray/experimental/state/state_cli.py | 41 ++--- python/ray/scripts/scripts.py | 155 +++++++++++++++++- python/ray/tests/test_state_api_log.py | 177 +++++++++++++++++++-- 9 files changed, 492 insertions(+), 59 deletions(-) diff --git a/dashboard/modules/log/log_agent.py b/dashboard/modules/log/log_agent.py index 99983e777..a3abeeb54 100644 --- a/dashboard/modules/log/log_agent.py +++ b/dashboard/modules/log/log_agent.py @@ -99,7 +99,7 @@ class LogAgentV1Grpc( bytes, end = tail(f, lines) yield reporter_pb2.StreamLogReply(data=bytes + b"\n") if request.keep_alive: - interval = request.interval if request.interval else 0.5 + interval = request.interval if request.interval else 1 f.seek(end) while not context.done(): await asyncio.sleep(interval) diff --git a/dashboard/modules/log/log_manager.py b/dashboard/modules/log/log_manager.py index 769e4d42f..42e253a29 100644 --- a/dashboard/modules/log/log_manager.py +++ b/dashboard/modules/log/log_manager.py @@ -70,9 +70,8 @@ class LogsManager: Async generator of streamed logs in bytes. """ node_id = options.node_id or self.ip_to_node_id(options.node_ip) - self._verify_node_registered(node_id) - log_file_name = await self.resolve_filename( + log_file_name, node_id = await self.resolve_filename( node_id=node_id, log_filename=options.filename, actor_id=options.actor_id, @@ -133,6 +132,13 @@ class LogsManager: f"Worker ID for Actor ID {actor_id} not found. " "Actor is not scheduled yet." ) + node_id = actor_data["address"].get("rayletId") + if not node_id: + raise ValueError( + f"Node ID for Actor ID {actor_id} not found. " + "Actor is not scheduled yet." + ) + self._verify_node_registered(node_id) # List all worker logs that match actor's worker id. log_files = await self.list_logs( @@ -149,6 +155,7 @@ class LogsManager: elif task_id: raise NotImplementedError("task_id is not supported yet.") elif pid: + self._verify_node_registered(node_id) log_files = await self.list_logs(node_id, timeout, glob_filter=f"*{pid}*") for filename in log_files["worker_out"]: # worker-[worker_id]-[job_id]-[pid].log @@ -170,7 +177,7 @@ class LogsManager: f"\tpid: {pid}\n" ) - return log_filename + return log_filename, node_id def _categorize_log_files(self, log_files: List[str]) -> Dict[str, List[str]]: """Categorize the given log files after filterieng them out using a given glob. diff --git a/dashboard/modules/state/state_head.py b/dashboard/modules/state/state_head.py index 3898c6ebd..536cded83 100644 --- a/dashboard/modules/state/state_head.py +++ b/dashboard/modules/state/state_head.py @@ -19,6 +19,7 @@ from ray.experimental.state.common import ( ListApiOptions, GetLogOptions, DEFAULT_RPC_TIMEOUT, + DEFAULT_LIMIT, ) from ray.experimental.state.exception import DataSourceUnavailable from ray.experimental.state.state_manager import StateDataSourceClient @@ -211,38 +212,43 @@ class StateHead(dashboard_utils.DashboardHeadModule): @routes.get("/api/v0/logs/{media_type}") async def get_logs(self, req: aiohttp.web.Request): - """ - If `media_type = stream`, creates HTTP stream which is either kept alive while - the HTTP connection is not closed. Else, if `media_type = file`, the stream - ends once all the lines in the file requested are transmitted. - """ + # TODO(sang): We need a better error handling for streaming + # when we refactor the server framework. options = GetLogOptions( - timeout=req.query.get("timeout", DEFAULT_RPC_TIMEOUT), - node_id=req.query.get("node_id"), - node_ip=req.query.get("node_ip"), + timeout=int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)), + node_id=req.query.get("node_id", None), + node_ip=req.query.get("node_ip", None), media_type=req.match_info.get("media_type", "file"), - filename=req.query.get("filename"), - actor_id=req.query.get("actor_id"), - task_id=req.query.get("task_id"), - pid=req.query.get("pid"), - lines=req.query.get("lines", 1000), - interval=req.query.get("interval"), + filename=req.query.get("filename", None), + actor_id=req.query.get("actor_id", None), + task_id=req.query.get("task_id", None), + pid=req.query.get("pid", None), + lines=req.query.get("lines", DEFAULT_LIMIT), + interval=req.query.get("interval", None), ) response = aiohttp.web.StreamResponse() response.content_type = "text/plain" await response.prepare(req) - # try-except here in order to properly handle ongoing HTTP stream + # NOTE: The first byte indicates the success / failure of individual + # stream. If the first byte is b"1", it means the stream was successful. + # If it is b"0", it means it is failed. try: async for logs_in_bytes in self._log_api.stream_logs(options): - await response.write(logs_in_bytes) + logs_to_stream = bytearray(b"1") + logs_to_stream.extend(logs_in_bytes) + await response.write(bytes(logs_to_stream)) await response.write_eof() return response except Exception as e: logger.exception(e) - await response.write(b"Closing HTTP stream due to internal server error:\n") - await response.write(str(e).encode()) + error_msg = bytearray(b"0") + error_msg.extend( + f"Closing HTTP stream due to internal server error.\n{e}".encode() + ) + + await response.write(bytes(error_msg)) await response.write_eof() return response diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index a88ef8437..11422c2a7 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -232,6 +232,8 @@ class StateAPIManager: result = [] for message in reply.node_info_list: data = self._message_to_dict(message=message, fields_to_decode=["node_id"]) + data["node_ip"] = data["node_manager_address"] + data = filter_fields(data, NodeState) result.append(data) result = self._filter(result, option.filters, NodeState) diff --git a/python/ray/experimental/state/api.py b/python/ray/experimental/state/api.py index 2c7d95433..2e8a98c78 100644 --- a/python/ray/experimental/state/api.py +++ b/python/ray/experimental/state/api.py @@ -1,17 +1,24 @@ import requests import warnings +import urllib -from typing import List, Tuple +from typing import List, Tuple, Optional, Dict, Generator +from dataclasses import fields import ray from ray.experimental.state.common import ( SupportedFilterType, ListApiOptions, + GetLogOptions, DEFAULT_RPC_TIMEOUT, DEFAULT_LIMIT, ) from ray.experimental.state.exception import RayStateApiException +""" +List APIs +""" + # TODO(sang): Replace it with auto-generated methods. def _list( @@ -193,3 +200,103 @@ def list_runtime_envs( api_server_url=api_server_url, _explain=_explain, ) + + +""" +Log APIs +""" + + +def get_log( + api_server_url: str = None, + node_id: Optional[str] = None, + node_ip: Optional[str] = None, + filename: Optional[str] = None, + actor_id: Optional[str] = None, + task_id: Optional[str] = None, + pid: Optional[int] = None, + follow: bool = False, + tail: int = 100, + _interval: Optional[float] = None, +) -> Generator[str, None, None]: + if api_server_url is None: + assert ray.is_initialized() + api_server_url = ( + f"http://{ray.worker.global_worker.node.address_info['webui_url']}" + ) + + media_type = "stream" if follow else "file" + options = GetLogOptions( + node_id=node_id, + node_ip=node_ip, + filename=filename, + actor_id=actor_id, + task_id=task_id, + pid=pid, + lines=tail, + interval=_interval, + media_type=media_type, + timeout=DEFAULT_RPC_TIMEOUT, + ) + options_dict = {} + for field in fields(options): + option_val = getattr(options, field.name) + if option_val: + options_dict[field.name] = option_val + + with requests.get( + f"{api_server_url}/api/v0/logs/{media_type}?" + f"{urllib.parse.urlencode(options_dict)}", + stream=True, + ) as r: + if r.status_code != 200: + raise RayStateApiException(r.text) + for bytes in r.iter_content(chunk_size=None): + bytes = bytearray(bytes) + # First byte 1 means success. + if bytes.startswith(b"1"): + bytes.pop(0) + logs = bytes.decode("utf-8") + else: + assert bytes.startswith(b"0") + error_msg = bytes.decode("utf-8") + raise RayStateApiException(error_msg) + yield logs + + +def list_logs( + api_server_url: str = None, + node_id: str = None, + node_ip: str = None, + glob_filter: str = None, +) -> Dict[str, List[str]]: + if api_server_url is None: + assert ray.is_initialized() + api_server_url = ( + f"http://{ray.worker.global_worker.node.address_info['webui_url']}" + ) + + if not glob_filter: + glob_filter = "*" + + options_dict = {} + if node_ip: + options_dict["node_ip"] = node_ip + if node_id: + options_dict["node_id"] = node_id + if glob_filter: + options_dict["glob"] = glob_filter + + r = requests.get( + f"{api_server_url}/api/v0/logs?{urllib.parse.urlencode(options_dict)}" + ) + r.raise_for_status() + + response = r.json() + if response["result"] is False: + raise RayStateApiException( + "API server internal error. See dashboard.log file for more details. " + f"Error: {response['msg']}" + ) + + return response["data"]["result"] diff --git a/python/ray/experimental/state/common.py b/python/ray/experimental/state/common.py index a888d88a6..3914363fc 100644 --- a/python/ray/experimental/state/common.py +++ b/python/ray/experimental/state/common.py @@ -89,14 +89,17 @@ class GetLogOptions: self.interval = float(self.interval) self.lines = int(self.lines) + if self.task_id: + raise NotImplementedError("task_id is not supported yet.") + if self.media_type == "file": assert self.interval is None if self.media_type not in ["file", "stream"]: raise ValueError(f"Invalid media type: {self.media_type}") - if not (self.node_id or self.node_ip): + if not (self.node_id or self.node_ip) and not (self.actor_id or self.task_id): raise ValueError( - "Both node_id and node_ip is not given. " - "At least one of the should be provided." + "node_id or node_ip should be provided." + "Please provide at least one of them." ) if self.node_id and self.node_ip: raise ValueError( @@ -145,6 +148,7 @@ class PlacementGroupState(StateSchema): @dataclass(init=True) class NodeState(StateSchema): node_id: str + node_ip: str state: str node_name: str resources_total: dict diff --git a/python/ray/experimental/state/state_cli.py b/python/ray/experimental/state/state_cli.py index b7813be95..5a0b93e66 100644 --- a/python/ray/experimental/state/state_cli.py +++ b/python/ray/experimental/state/state_cli.py @@ -37,6 +37,28 @@ def _get_available_formats() -> List[str]: return [format_enum.value for format_enum in AvailableFormat] +def get_api_server_url(): + address = services.canonicalize_bootstrap_address(None) + gcs_client = GcsClient(address=address, nums_reconnect_retry=0) + ray.experimental.internal_kv._initialize_internal_kv(gcs_client) + api_server_url = ray._private.utils.internal_kv_get_with_retry( + gcs_client, + ray_constants.DASHBOARD_ADDRESS, + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + num_retries=20, + ) + + if api_server_url is None: + raise ValueError( + ( + "Couldn't obtain the API server address from GCS. It is likely that " + "the GCS server is down. Check gcs_server.[out | err] to see if it is " + "still alive." + ) + ) + return api_server_url + + def get_state_api_output_to_print( state_data: Union[dict, list], *, format: AvailableFormat = AvailableFormat.DEFAULT ): @@ -73,24 +95,7 @@ def _should_explain(format: AvailableFormat): @click.group("list") @click.pass_context def list_state_cli_group(ctx): - address = services.canonicalize_bootstrap_address(None) - gcs_client = GcsClient(address=address, nums_reconnect_retry=0) - ray.experimental.internal_kv._initialize_internal_kv(gcs_client) - api_server_url = ray._private.utils.internal_kv_get_with_retry( - gcs_client, - ray_constants.DASHBOARD_ADDRESS, - namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - num_retries=20, - ) - - if api_server_url is None: - raise ValueError( - ( - "Couldn't obtain the API server address from GCS. It is likely that " - "the GCS server is down. Check gcs_server.[out | err] to see if it is " - "still alive." - ) - ) + api_server_url = get_api_server_url() assert use_gcs_for_bootstrap() ctx.ensure_object(dict) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 65f400e04..4c9474e61 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -42,7 +42,16 @@ from ray.internal.internal_api import memory_summary from ray.internal.storage import _load_class from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf from ray.dashboard.modules.job.cli import job_cli_group -from ray.experimental.state.state_cli import list_state_cli_group +from ray.experimental.state.api import ( + get_log, + list_logs, +) +from ray.experimental.state.state_cli import ( + list_state_cli_group, + get_api_server_url, + get_state_api_output_to_print, +) +from ray.experimental.state.common import DEFAULT_LIMIT from distutils.dir_util import copy_tree logger = logging.getLogger(__name__) @@ -1937,6 +1946,150 @@ def local_dump( ) +@cli.command(hidden=True) +@click.argument( + "glob_filter", + required=False, + default="*", +) +@click.option( + "-ip", + "--node-ip", + required=False, + type=str, + default=None, + help="Filters the logs by this ip address.", +) +@click.option( + "--node-id", + "-id", + required=False, + type=str, + default=None, + help="Filters the logs by this NodeID.", +) +@click.option( + "--pid", + "-pid", + required=False, + type=str, + default=None, + help="Retrieves the logs from the process with this pid.", +) +@click.option( + "--actor-id", + "-a", + required=False, + type=str, + default=None, + help="Retrieves the logs corresponding to this ActorID.", +) +@click.option( + "--task-id", + "-t", + required=False, + type=str, + default=None, + help="Retrieves the logs corresponding to this TaskID.", +) +@click.option( + "--follow", + "-f", + required=False, + type=bool, + is_flag=True, + help="Streams the log file as it is updated instead of just tailing.", +) +@click.option( + "--tail", + required=False, + type=int, + default=None, + help="Number of lines to tail from log. -1 indicates fetching the whole file.", +) +@click.option( + "--interval", + required=False, + type=float, + default=None, + help="The interval to print new logs when `--follow` is specified.", + hidden=True, +) +def logs( + glob_filter, + node_ip: str, + node_id: str, + pid: str, + actor_id: str, + task_id: str, + follow: bool, + tail: int, + interval: float, +): + if task_id is not None: + raise NotImplementedError("--task-id is not yet supported") + + api_server_url = f"http://{get_api_server_url().decode()}" + + # If both id & ip are not provided, choose a head node as a default. + if node_id is None and node_ip is None: + address = ray._private.services.canonicalize_bootstrap_address(None) + node_ip = address.split(":")[0] + + filename = None + match_unique = pid is not None or actor_id is not None # Worker log # Actor log + + # If there's no unique match, try listing logs based on the glob filter. + if not match_unique: + logs = list_logs( + api_server_url=api_server_url, + node_id=node_id, + node_ip=node_ip, + glob_filter=glob_filter, + ) + log_files_found = [] + for _, log_files in logs.items(): + for log_file in log_files: + log_files_found.append(log_file) + + # if there's only 1 file, that means there's a unique match. + if len(log_files_found) == 1: + filename = log_files_found[0] + match_unique = True + # Otherwise, print a list of log files. + else: + if node_id: + print(f"Node ID: {node_id}") + elif node_ip: + print(f"Node IP: {node_ip}") + print(get_state_api_output_to_print(logs)) + + # If there's an unique match, print the log file. + if match_unique: + if not tail: + tail = 0 if follow else DEFAULT_LIMIT + + if tail > 0: + print( + f"--- Log has been truncated to last {tail} lines." + " Use `--tail` flag to toggle. ---\n" + ) + + for chunk in get_log( + api_server_url=api_server_url, + node_id=node_id, + node_ip=node_ip, + filename=filename, + actor_id=actor_id, + task_id=task_id, + pid=pid, + tail=tail, + follow=follow, + _interval=interval, + ): + print(chunk, end="", flush=True) + + @cli.command() @click.argument("cluster_config_file", required=False, type=str) @click.option( diff --git a/python/ray/tests/test_state_api_log.py b/python/ray/tests/test_state_api_log.py index 9ffda4c01..b4e156eab 100644 --- a/python/ray/tests/test_state_api_log.py +++ b/python/ray/tests/test_state_api_log.py @@ -14,6 +14,7 @@ else: import ray +from click.testing import CliRunner from ray._private.test_utils import ( format_web_url, wait_until_server_available, @@ -26,10 +27,11 @@ from ray.core.generated.reporter_pb2 import StreamLogReply, ListLogsReply from ray.core.generated.gcs_pb2 import ActorTableData from ray.dashboard.modules.log.log_agent import tail as tail_file from ray.dashboard.modules.log.log_manager import LogsManager -from ray.experimental.state.api import list_nodes, list_workers +from ray.experimental.state.api import list_nodes, list_workers, list_logs, get_log from ray.experimental.state.common import GetLogOptions from ray.experimental.state.exception import DataSourceUnavailable from ray.experimental.state.state_manager import StateDataSourceClient +import ray.scripts.scripts as scripts from ray._private.test_utils import wait_for_condition ASYNCMOCK_MIN_PYTHON_VER = (3, 8) @@ -162,8 +164,11 @@ async def test_logs_manager_resolve_file(logs_manager): """ Test filename is given. """ + logs_client = logs_manager.data_source_client + logs_client.get_all_registered_agent_ids = MagicMock() + logs_client.get_all_registered_agent_ids.return_value = [node_id.hex()] expected_filename = "filename" - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id, log_filename=expected_filename, actor_id=None, @@ -173,6 +178,7 @@ async def test_logs_manager_resolve_file(logs_manager): timeout=10, ) assert log_file_name == expected_filename + assert n == node_id """ Test actor id is given. """ @@ -185,7 +191,7 @@ async def test_logs_manager_resolve_file(logs_manager): return None assert False, "Not reachable." - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id, log_filename=None, actor_id=actor_id, @@ -199,7 +205,7 @@ async def test_logs_manager_resolve_file(logs_manager): actor_id = ActorID(b"2" * 16) with pytest.raises(ValueError): - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id, log_filename=None, actor_id=actor_id, @@ -216,7 +222,7 @@ async def test_logs_manager_resolve_file(logs_manager): logs_manager.list_logs.return_value = { "worker_out": [f"worker-{worker_id.hex()}-123-123.out"] } - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id.hex(), log_filename=None, actor_id=actor_id, @@ -229,13 +235,14 @@ async def test_logs_manager_resolve_file(logs_manager): node_id.hex(), 10, glob_filter=f"*{worker_id.hex()}*" ) assert log_file_name == f"worker-{worker_id.hex()}-123-123.out" + assert n == node_id.hex() """ Test task id is given. """ with pytest.raises(NotImplementedError): task_id = TaskID(b"2" * 24) - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id.hex(), log_filename=None, actor_id=None, @@ -269,7 +276,7 @@ async def test_logs_manager_resolve_file(logs_manager): logs_manager.list_logs = AsyncMock() # Provide the wrong pid. logs_manager.list_logs.return_value = {"worker_out": [f"worker-123-123-{pid}.out"]} - log_file_name = await logs_manager.resolve_filename( + log_file_name, n = await logs_manager.resolve_filename( node_id=node_id.hex(), log_filename=None, actor_id=None, @@ -484,20 +491,20 @@ def test_logs_stream_and_tail(ray_start_with_dashboard): actor = Actor.remote() ray.get(actor.write_log.remote([test_log_text.format("XXXXXX")])) - # def verify_actor_stream_log(): # Test stream and fetching by actor id stream_response = requests.get( webui_url - + f"/api/v0/logs/stream?node_id={node_id}&lines=2" + + "/api/v0/logs/stream?&lines=2" + f"&actor_id={actor._ray_actor_id.hex()}", stream=True, ) if stream_response.status_code != 200: raise ValueError(stream_response.content.decode("utf-8")) stream_iterator = stream_response.iter_content(chunk_size=None) + # NOTE: Prefix 1 indicates the stream has succeeded. assert ( next(stream_iterator).decode("utf-8") - == ":actor_name:Actor\n" + test_log_text.format("XXXXXX") + "\n" + == "1:actor_name:Actor\n" + test_log_text.format("XXXXXX") + "\n" ) streamed_string = "" @@ -512,18 +519,20 @@ def test_logs_stream_and_tail(ray_start_with_dashboard): for s in strings: string += s + "\n" streamed_string += string - assert next(stream_iterator).decode("utf-8") == string + # NOTE: Prefix 1 indicates the stream has succeeded. + assert next(stream_iterator).decode("utf-8") == "1" + string del stream_response # Test tailing log by actor id LINES = 150 file_response = requests.get( webui_url - + f"/api/v0/logs/file?node_id={node_id}&lines={LINES}" + + f"/api/v0/logs/file?&lines={LINES}" + "&actor_id=" + actor._ray_actor_id.hex(), ).content.decode("utf-8") - assert file_response == "\n".join(streamed_string.split("\n")[-(LINES + 1) :]) + # NOTE: Prefix 1 indicates the stream has succeeded. + assert file_response == "1" + "\n".join(streamed_string.split("\n")[-(LINES + 1) :]) # Test query by pid & node_ip instead of actor id. node_ip = list(ray.nodes())[0]["NodeManagerAddress"] @@ -533,7 +542,147 @@ def test_logs_stream_and_tail(ray_start_with_dashboard): + f"/api/v0/logs/file?node_ip={node_ip}&lines={LINES}" + f"&pid={pid}", ).content.decode("utf-8") - assert file_response == "\n".join(streamed_string.split("\n")[-(LINES + 1) :]) + # NOTE: Prefix 1 indicates the stream has succeeded. + assert file_response == "1" + "\n".join(streamed_string.split("\n")[-(LINES + 1) :]) + + +def test_log_list(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + + def verify(): + head_node = list(list_nodes().values())[0] + # When glob filter is not provided, it should provide all logs + logs = list_logs(node_id=head_node["node_id"]) + assert "raylet" in logs + assert "gcs_server" in logs + assert "dashboard" in logs + assert "agent" in logs + assert "internal" in logs + assert "driver" in logs + assert "autoscaler" in logs + + # Test glob works. + logs = list_logs(node_id=head_node["node_id"], glob_filter="raylet*") + assert len(logs) == 1 + return True + + wait_for_condition(verify) + + +def test_log_get(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0) + ray.init(address=cluster.address) + head_node = list(list_nodes().values())[0] + cluster.add_node(num_cpus=1) + + @ray.remote(num_cpus=1) + class Actor: + def print(self, i): + for _ in range(i): + print("1") + + def getpid(self): + import os + + return os.getpid() + + """ + Test filename match + """ + + def verify(): + # By default, node id should be configured to the head node. + for log in get_log( + node_id=head_node["node_id"], filename="raylet.out", tail=10 + ): + # + 1 since the last line is just empty. + assert len(log.split("\n")) == 11 + return True + + wait_for_condition(verify) + + """ + Test worker pid / IP match + """ + a = Actor.remote() + pid = ray.get(a.getpid.remote()) + ray.get(a.print.remote(20)) + + def verify(): + # By default, node id should be configured to the head node. + for log in get_log(node_ip=head_node["node_ip"], pid=pid, tail=10): + # + 1 since the last line is just empty. + assert len(log.split("\n")) == 11 + return True + + wait_for_condition(verify) + + """ + Test actor logs. + """ + actor_id = a._actor_id.hex() + + def verify(): + # By default, node id should be configured to the head node. + for log in get_log(actor_id=actor_id, tail=10): + # + 1 since the last line is just empty. + assert len(log.split("\n")) == 11 + return True + + wait_for_condition(verify) + + with pytest.raises(NotImplementedError): + for _ in get_log(task_id=123, tail=10): + pass + + +def test_log_cli(shutdown_only): + ray.init(num_cpus=1) + runner = CliRunner() + + # Test the head node is chosen by default. + def verify(): + result = runner.invoke(scripts.logs) + print(result.output) + assert result.exit_code == 0 + assert "raylet.out" in result.output + assert "raylet.err" in result.output + assert "gcs_server.out" in result.output + assert "gcs_server.err" in result.output + return True + + wait_for_condition(verify) + + # Test when there's only 1 match, it prints logs. + def verify(): + result = runner.invoke(scripts.logs, ["raylet.out"]) + assert result.exit_code == 0 + print(result.output) + assert "raylet.out" not in result.output + assert "raylet.err" not in result.output + assert "gcs_server.out" not in result.output + assert "gcs_server.err" not in result.output + # Make sure it prints the log message. + assert "NodeManager server started" in result.output + return True + + wait_for_condition(verify) + + # Test when there's more than 1 match, it prints a list of logs. + def verify(): + result = runner.invoke(scripts.logs, ["raylet.*"]) + assert result.exit_code == 0 + print(result.output) + assert "raylet.out" in result.output + assert "raylet.err" in result.output + assert "gcs_server.out" not in result.output + assert "gcs_server.err" not in result.output + return True + + wait_for_condition(verify) if __name__ == "__main__":