[State Observability] Ray log CLI / API (#25481)

This PR implements the basic log APIs. For the better APIs (like higher level APIs like ray logs actors), it will be implemented after the internal API review is done.

# If there's only 1 match, print a file content. Otherwise, print all files that match glob.
ray logs [glob_filter] --node-id=[head node by default]

Args:
    --tail: Tail the last X lines
    --follow: Follow the new logs
    --actor-id: The actor id
    --pid --node-ip: For worker logs
    --node-id: The node id of the log
    --interval: When --follow is specified, logs are printed with this interval. (should we remove it?)
This commit is contained in:
SangBin Cho 2022-06-13 21:52:57 +09:00 committed by GitHub
parent ca10530a1a
commit 856bea31fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 492 additions and 59 deletions

View file

@ -99,7 +99,7 @@ class LogAgentV1Grpc(
bytes, end = tail(f, lines)
yield reporter_pb2.StreamLogReply(data=bytes + b"\n")
if request.keep_alive:
interval = request.interval if request.interval else 0.5
interval = request.interval if request.interval else 1
f.seek(end)
while not context.done():
await asyncio.sleep(interval)

View file

@ -70,9 +70,8 @@ class LogsManager:
Async generator of streamed logs in bytes.
"""
node_id = options.node_id or self.ip_to_node_id(options.node_ip)
self._verify_node_registered(node_id)
log_file_name = await self.resolve_filename(
log_file_name, node_id = await self.resolve_filename(
node_id=node_id,
log_filename=options.filename,
actor_id=options.actor_id,
@ -133,6 +132,13 @@ class LogsManager:
f"Worker ID for Actor ID {actor_id} not found. "
"Actor is not scheduled yet."
)
node_id = actor_data["address"].get("rayletId")
if not node_id:
raise ValueError(
f"Node ID for Actor ID {actor_id} not found. "
"Actor is not scheduled yet."
)
self._verify_node_registered(node_id)
# List all worker logs that match actor's worker id.
log_files = await self.list_logs(
@ -149,6 +155,7 @@ class LogsManager:
elif task_id:
raise NotImplementedError("task_id is not supported yet.")
elif pid:
self._verify_node_registered(node_id)
log_files = await self.list_logs(node_id, timeout, glob_filter=f"*{pid}*")
for filename in log_files["worker_out"]:
# worker-[worker_id]-[job_id]-[pid].log
@ -170,7 +177,7 @@ class LogsManager:
f"\tpid: {pid}\n"
)
return log_filename
return log_filename, node_id
def _categorize_log_files(self, log_files: List[str]) -> Dict[str, List[str]]:
"""Categorize the given log files after filterieng them out using a given glob.

View file

@ -19,6 +19,7 @@ from ray.experimental.state.common import (
ListApiOptions,
GetLogOptions,
DEFAULT_RPC_TIMEOUT,
DEFAULT_LIMIT,
)
from ray.experimental.state.exception import DataSourceUnavailable
from ray.experimental.state.state_manager import StateDataSourceClient
@ -211,38 +212,43 @@ class StateHead(dashboard_utils.DashboardHeadModule):
@routes.get("/api/v0/logs/{media_type}")
async def get_logs(self, req: aiohttp.web.Request):
"""
If `media_type = stream`, creates HTTP stream which is either kept alive while
the HTTP connection is not closed. Else, if `media_type = file`, the stream
ends once all the lines in the file requested are transmitted.
"""
# TODO(sang): We need a better error handling for streaming
# when we refactor the server framework.
options = GetLogOptions(
timeout=req.query.get("timeout", DEFAULT_RPC_TIMEOUT),
node_id=req.query.get("node_id"),
node_ip=req.query.get("node_ip"),
timeout=int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT)),
node_id=req.query.get("node_id", None),
node_ip=req.query.get("node_ip", None),
media_type=req.match_info.get("media_type", "file"),
filename=req.query.get("filename"),
actor_id=req.query.get("actor_id"),
task_id=req.query.get("task_id"),
pid=req.query.get("pid"),
lines=req.query.get("lines", 1000),
interval=req.query.get("interval"),
filename=req.query.get("filename", None),
actor_id=req.query.get("actor_id", None),
task_id=req.query.get("task_id", None),
pid=req.query.get("pid", None),
lines=req.query.get("lines", DEFAULT_LIMIT),
interval=req.query.get("interval", None),
)
response = aiohttp.web.StreamResponse()
response.content_type = "text/plain"
await response.prepare(req)
# try-except here in order to properly handle ongoing HTTP stream
# NOTE: The first byte indicates the success / failure of individual
# stream. If the first byte is b"1", it means the stream was successful.
# If it is b"0", it means it is failed.
try:
async for logs_in_bytes in self._log_api.stream_logs(options):
await response.write(logs_in_bytes)
logs_to_stream = bytearray(b"1")
logs_to_stream.extend(logs_in_bytes)
await response.write(bytes(logs_to_stream))
await response.write_eof()
return response
except Exception as e:
logger.exception(e)
await response.write(b"Closing HTTP stream due to internal server error:\n")
await response.write(str(e).encode())
error_msg = bytearray(b"0")
error_msg.extend(
f"Closing HTTP stream due to internal server error.\n{e}".encode()
)
await response.write(bytes(error_msg))
await response.write_eof()
return response

View file

@ -232,6 +232,8 @@ class StateAPIManager:
result = []
for message in reply.node_info_list:
data = self._message_to_dict(message=message, fields_to_decode=["node_id"])
data["node_ip"] = data["node_manager_address"]
data = filter_fields(data, NodeState)
result.append(data)
result = self._filter(result, option.filters, NodeState)

View file

@ -1,17 +1,24 @@
import requests
import warnings
import urllib
from typing import List, Tuple
from typing import List, Tuple, Optional, Dict, Generator
from dataclasses import fields
import ray
from ray.experimental.state.common import (
SupportedFilterType,
ListApiOptions,
GetLogOptions,
DEFAULT_RPC_TIMEOUT,
DEFAULT_LIMIT,
)
from ray.experimental.state.exception import RayStateApiException
"""
List APIs
"""
# TODO(sang): Replace it with auto-generated methods.
def _list(
@ -193,3 +200,103 @@ def list_runtime_envs(
api_server_url=api_server_url,
_explain=_explain,
)
"""
Log APIs
"""
def get_log(
api_server_url: str = None,
node_id: Optional[str] = None,
node_ip: Optional[str] = None,
filename: Optional[str] = None,
actor_id: Optional[str] = None,
task_id: Optional[str] = None,
pid: Optional[int] = None,
follow: bool = False,
tail: int = 100,
_interval: Optional[float] = None,
) -> Generator[str, None, None]:
if api_server_url is None:
assert ray.is_initialized()
api_server_url = (
f"http://{ray.worker.global_worker.node.address_info['webui_url']}"
)
media_type = "stream" if follow else "file"
options = GetLogOptions(
node_id=node_id,
node_ip=node_ip,
filename=filename,
actor_id=actor_id,
task_id=task_id,
pid=pid,
lines=tail,
interval=_interval,
media_type=media_type,
timeout=DEFAULT_RPC_TIMEOUT,
)
options_dict = {}
for field in fields(options):
option_val = getattr(options, field.name)
if option_val:
options_dict[field.name] = option_val
with requests.get(
f"{api_server_url}/api/v0/logs/{media_type}?"
f"{urllib.parse.urlencode(options_dict)}",
stream=True,
) as r:
if r.status_code != 200:
raise RayStateApiException(r.text)
for bytes in r.iter_content(chunk_size=None):
bytes = bytearray(bytes)
# First byte 1 means success.
if bytes.startswith(b"1"):
bytes.pop(0)
logs = bytes.decode("utf-8")
else:
assert bytes.startswith(b"0")
error_msg = bytes.decode("utf-8")
raise RayStateApiException(error_msg)
yield logs
def list_logs(
api_server_url: str = None,
node_id: str = None,
node_ip: str = None,
glob_filter: str = None,
) -> Dict[str, List[str]]:
if api_server_url is None:
assert ray.is_initialized()
api_server_url = (
f"http://{ray.worker.global_worker.node.address_info['webui_url']}"
)
if not glob_filter:
glob_filter = "*"
options_dict = {}
if node_ip:
options_dict["node_ip"] = node_ip
if node_id:
options_dict["node_id"] = node_id
if glob_filter:
options_dict["glob"] = glob_filter
r = requests.get(
f"{api_server_url}/api/v0/logs?{urllib.parse.urlencode(options_dict)}"
)
r.raise_for_status()
response = r.json()
if response["result"] is False:
raise RayStateApiException(
"API server internal error. See dashboard.log file for more details. "
f"Error: {response['msg']}"
)
return response["data"]["result"]

View file

@ -89,14 +89,17 @@ class GetLogOptions:
self.interval = float(self.interval)
self.lines = int(self.lines)
if self.task_id:
raise NotImplementedError("task_id is not supported yet.")
if self.media_type == "file":
assert self.interval is None
if self.media_type not in ["file", "stream"]:
raise ValueError(f"Invalid media type: {self.media_type}")
if not (self.node_id or self.node_ip):
if not (self.node_id or self.node_ip) and not (self.actor_id or self.task_id):
raise ValueError(
"Both node_id and node_ip is not given. "
"At least one of the should be provided."
"node_id or node_ip should be provided."
"Please provide at least one of them."
)
if self.node_id and self.node_ip:
raise ValueError(
@ -145,6 +148,7 @@ class PlacementGroupState(StateSchema):
@dataclass(init=True)
class NodeState(StateSchema):
node_id: str
node_ip: str
state: str
node_name: str
resources_total: dict

View file

@ -37,6 +37,28 @@ def _get_available_formats() -> List[str]:
return [format_enum.value for format_enum in AvailableFormat]
def get_api_server_url():
address = services.canonicalize_bootstrap_address(None)
gcs_client = GcsClient(address=address, nums_reconnect_retry=0)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
api_server_url = ray._private.utils.internal_kv_get_with_retry(
gcs_client,
ray_constants.DASHBOARD_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
num_retries=20,
)
if api_server_url is None:
raise ValueError(
(
"Couldn't obtain the API server address from GCS. It is likely that "
"the GCS server is down. Check gcs_server.[out | err] to see if it is "
"still alive."
)
)
return api_server_url
def get_state_api_output_to_print(
state_data: Union[dict, list], *, format: AvailableFormat = AvailableFormat.DEFAULT
):
@ -73,24 +95,7 @@ def _should_explain(format: AvailableFormat):
@click.group("list")
@click.pass_context
def list_state_cli_group(ctx):
address = services.canonicalize_bootstrap_address(None)
gcs_client = GcsClient(address=address, nums_reconnect_retry=0)
ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
api_server_url = ray._private.utils.internal_kv_get_with_retry(
gcs_client,
ray_constants.DASHBOARD_ADDRESS,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
num_retries=20,
)
if api_server_url is None:
raise ValueError(
(
"Couldn't obtain the API server address from GCS. It is likely that "
"the GCS server is down. Check gcs_server.[out | err] to see if it is "
"still alive."
)
)
api_server_url = get_api_server_url()
assert use_gcs_for_bootstrap()
ctx.ensure_object(dict)

View file

@ -42,7 +42,16 @@ from ray.internal.internal_api import memory_summary
from ray.internal.storage import _load_class
from ray.autoscaler._private.cli_logger import add_click_logging_options, cli_logger, cf
from ray.dashboard.modules.job.cli import job_cli_group
from ray.experimental.state.state_cli import list_state_cli_group
from ray.experimental.state.api import (
get_log,
list_logs,
)
from ray.experimental.state.state_cli import (
list_state_cli_group,
get_api_server_url,
get_state_api_output_to_print,
)
from ray.experimental.state.common import DEFAULT_LIMIT
from distutils.dir_util import copy_tree
logger = logging.getLogger(__name__)
@ -1937,6 +1946,150 @@ def local_dump(
)
@cli.command(hidden=True)
@click.argument(
"glob_filter",
required=False,
default="*",
)
@click.option(
"-ip",
"--node-ip",
required=False,
type=str,
default=None,
help="Filters the logs by this ip address.",
)
@click.option(
"--node-id",
"-id",
required=False,
type=str,
default=None,
help="Filters the logs by this NodeID.",
)
@click.option(
"--pid",
"-pid",
required=False,
type=str,
default=None,
help="Retrieves the logs from the process with this pid.",
)
@click.option(
"--actor-id",
"-a",
required=False,
type=str,
default=None,
help="Retrieves the logs corresponding to this ActorID.",
)
@click.option(
"--task-id",
"-t",
required=False,
type=str,
default=None,
help="Retrieves the logs corresponding to this TaskID.",
)
@click.option(
"--follow",
"-f",
required=False,
type=bool,
is_flag=True,
help="Streams the log file as it is updated instead of just tailing.",
)
@click.option(
"--tail",
required=False,
type=int,
default=None,
help="Number of lines to tail from log. -1 indicates fetching the whole file.",
)
@click.option(
"--interval",
required=False,
type=float,
default=None,
help="The interval to print new logs when `--follow` is specified.",
hidden=True,
)
def logs(
glob_filter,
node_ip: str,
node_id: str,
pid: str,
actor_id: str,
task_id: str,
follow: bool,
tail: int,
interval: float,
):
if task_id is not None:
raise NotImplementedError("--task-id is not yet supported")
api_server_url = f"http://{get_api_server_url().decode()}"
# If both id & ip are not provided, choose a head node as a default.
if node_id is None and node_ip is None:
address = ray._private.services.canonicalize_bootstrap_address(None)
node_ip = address.split(":")[0]
filename = None
match_unique = pid is not None or actor_id is not None # Worker log # Actor log
# If there's no unique match, try listing logs based on the glob filter.
if not match_unique:
logs = list_logs(
api_server_url=api_server_url,
node_id=node_id,
node_ip=node_ip,
glob_filter=glob_filter,
)
log_files_found = []
for _, log_files in logs.items():
for log_file in log_files:
log_files_found.append(log_file)
# if there's only 1 file, that means there's a unique match.
if len(log_files_found) == 1:
filename = log_files_found[0]
match_unique = True
# Otherwise, print a list of log files.
else:
if node_id:
print(f"Node ID: {node_id}")
elif node_ip:
print(f"Node IP: {node_ip}")
print(get_state_api_output_to_print(logs))
# If there's an unique match, print the log file.
if match_unique:
if not tail:
tail = 0 if follow else DEFAULT_LIMIT
if tail > 0:
print(
f"--- Log has been truncated to last {tail} lines."
" Use `--tail` flag to toggle. ---\n"
)
for chunk in get_log(
api_server_url=api_server_url,
node_id=node_id,
node_ip=node_ip,
filename=filename,
actor_id=actor_id,
task_id=task_id,
pid=pid,
tail=tail,
follow=follow,
_interval=interval,
):
print(chunk, end="", flush=True)
@cli.command()
@click.argument("cluster_config_file", required=False, type=str)
@click.option(

View file

@ -14,6 +14,7 @@ else:
import ray
from click.testing import CliRunner
from ray._private.test_utils import (
format_web_url,
wait_until_server_available,
@ -26,10 +27,11 @@ from ray.core.generated.reporter_pb2 import StreamLogReply, ListLogsReply
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.dashboard.modules.log.log_agent import tail as tail_file
from ray.dashboard.modules.log.log_manager import LogsManager
from ray.experimental.state.api import list_nodes, list_workers
from ray.experimental.state.api import list_nodes, list_workers, list_logs, get_log
from ray.experimental.state.common import GetLogOptions
from ray.experimental.state.exception import DataSourceUnavailable
from ray.experimental.state.state_manager import StateDataSourceClient
import ray.scripts.scripts as scripts
from ray._private.test_utils import wait_for_condition
ASYNCMOCK_MIN_PYTHON_VER = (3, 8)
@ -162,8 +164,11 @@ async def test_logs_manager_resolve_file(logs_manager):
"""
Test filename is given.
"""
logs_client = logs_manager.data_source_client
logs_client.get_all_registered_agent_ids = MagicMock()
logs_client.get_all_registered_agent_ids.return_value = [node_id.hex()]
expected_filename = "filename"
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id,
log_filename=expected_filename,
actor_id=None,
@ -173,6 +178,7 @@ async def test_logs_manager_resolve_file(logs_manager):
timeout=10,
)
assert log_file_name == expected_filename
assert n == node_id
"""
Test actor id is given.
"""
@ -185,7 +191,7 @@ async def test_logs_manager_resolve_file(logs_manager):
return None
assert False, "Not reachable."
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id,
log_filename=None,
actor_id=actor_id,
@ -199,7 +205,7 @@ async def test_logs_manager_resolve_file(logs_manager):
actor_id = ActorID(b"2" * 16)
with pytest.raises(ValueError):
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id,
log_filename=None,
actor_id=actor_id,
@ -216,7 +222,7 @@ async def test_logs_manager_resolve_file(logs_manager):
logs_manager.list_logs.return_value = {
"worker_out": [f"worker-{worker_id.hex()}-123-123.out"]
}
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id.hex(),
log_filename=None,
actor_id=actor_id,
@ -229,13 +235,14 @@ async def test_logs_manager_resolve_file(logs_manager):
node_id.hex(), 10, glob_filter=f"*{worker_id.hex()}*"
)
assert log_file_name == f"worker-{worker_id.hex()}-123-123.out"
assert n == node_id.hex()
"""
Test task id is given.
"""
with pytest.raises(NotImplementedError):
task_id = TaskID(b"2" * 24)
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id.hex(),
log_filename=None,
actor_id=None,
@ -269,7 +276,7 @@ async def test_logs_manager_resolve_file(logs_manager):
logs_manager.list_logs = AsyncMock()
# Provide the wrong pid.
logs_manager.list_logs.return_value = {"worker_out": [f"worker-123-123-{pid}.out"]}
log_file_name = await logs_manager.resolve_filename(
log_file_name, n = await logs_manager.resolve_filename(
node_id=node_id.hex(),
log_filename=None,
actor_id=None,
@ -484,20 +491,20 @@ def test_logs_stream_and_tail(ray_start_with_dashboard):
actor = Actor.remote()
ray.get(actor.write_log.remote([test_log_text.format("XXXXXX")]))
# def verify_actor_stream_log():
# Test stream and fetching by actor id
stream_response = requests.get(
webui_url
+ f"/api/v0/logs/stream?node_id={node_id}&lines=2"
+ "/api/v0/logs/stream?&lines=2"
+ f"&actor_id={actor._ray_actor_id.hex()}",
stream=True,
)
if stream_response.status_code != 200:
raise ValueError(stream_response.content.decode("utf-8"))
stream_iterator = stream_response.iter_content(chunk_size=None)
# NOTE: Prefix 1 indicates the stream has succeeded.
assert (
next(stream_iterator).decode("utf-8")
== ":actor_name:Actor\n" + test_log_text.format("XXXXXX") + "\n"
== "1:actor_name:Actor\n" + test_log_text.format("XXXXXX") + "\n"
)
streamed_string = ""
@ -512,18 +519,20 @@ def test_logs_stream_and_tail(ray_start_with_dashboard):
for s in strings:
string += s + "\n"
streamed_string += string
assert next(stream_iterator).decode("utf-8") == string
# NOTE: Prefix 1 indicates the stream has succeeded.
assert next(stream_iterator).decode("utf-8") == "1" + string
del stream_response
# Test tailing log by actor id
LINES = 150
file_response = requests.get(
webui_url
+ f"/api/v0/logs/file?node_id={node_id}&lines={LINES}"
+ f"/api/v0/logs/file?&lines={LINES}"
+ "&actor_id="
+ actor._ray_actor_id.hex(),
).content.decode("utf-8")
assert file_response == "\n".join(streamed_string.split("\n")[-(LINES + 1) :])
# NOTE: Prefix 1 indicates the stream has succeeded.
assert file_response == "1" + "\n".join(streamed_string.split("\n")[-(LINES + 1) :])
# Test query by pid & node_ip instead of actor id.
node_ip = list(ray.nodes())[0]["NodeManagerAddress"]
@ -533,7 +542,147 @@ def test_logs_stream_and_tail(ray_start_with_dashboard):
+ f"/api/v0/logs/file?node_ip={node_ip}&lines={LINES}"
+ f"&pid={pid}",
).content.decode("utf-8")
assert file_response == "\n".join(streamed_string.split("\n")[-(LINES + 1) :])
# NOTE: Prefix 1 indicates the stream has succeeded.
assert file_response == "1" + "\n".join(streamed_string.split("\n")[-(LINES + 1) :])
def test_log_list(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init(address=cluster.address)
def verify():
head_node = list(list_nodes().values())[0]
# When glob filter is not provided, it should provide all logs
logs = list_logs(node_id=head_node["node_id"])
assert "raylet" in logs
assert "gcs_server" in logs
assert "dashboard" in logs
assert "agent" in logs
assert "internal" in logs
assert "driver" in logs
assert "autoscaler" in logs
# Test glob works.
logs = list_logs(node_id=head_node["node_id"], glob_filter="raylet*")
assert len(logs) == 1
return True
wait_for_condition(verify)
def test_log_get(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init(address=cluster.address)
head_node = list(list_nodes().values())[0]
cluster.add_node(num_cpus=1)
@ray.remote(num_cpus=1)
class Actor:
def print(self, i):
for _ in range(i):
print("1")
def getpid(self):
import os
return os.getpid()
"""
Test filename match
"""
def verify():
# By default, node id should be configured to the head node.
for log in get_log(
node_id=head_node["node_id"], filename="raylet.out", tail=10
):
# + 1 since the last line is just empty.
assert len(log.split("\n")) == 11
return True
wait_for_condition(verify)
"""
Test worker pid / IP match
"""
a = Actor.remote()
pid = ray.get(a.getpid.remote())
ray.get(a.print.remote(20))
def verify():
# By default, node id should be configured to the head node.
for log in get_log(node_ip=head_node["node_ip"], pid=pid, tail=10):
# + 1 since the last line is just empty.
assert len(log.split("\n")) == 11
return True
wait_for_condition(verify)
"""
Test actor logs.
"""
actor_id = a._actor_id.hex()
def verify():
# By default, node id should be configured to the head node.
for log in get_log(actor_id=actor_id, tail=10):
# + 1 since the last line is just empty.
assert len(log.split("\n")) == 11
return True
wait_for_condition(verify)
with pytest.raises(NotImplementedError):
for _ in get_log(task_id=123, tail=10):
pass
def test_log_cli(shutdown_only):
ray.init(num_cpus=1)
runner = CliRunner()
# Test the head node is chosen by default.
def verify():
result = runner.invoke(scripts.logs)
print(result.output)
assert result.exit_code == 0
assert "raylet.out" in result.output
assert "raylet.err" in result.output
assert "gcs_server.out" in result.output
assert "gcs_server.err" in result.output
return True
wait_for_condition(verify)
# Test when there's only 1 match, it prints logs.
def verify():
result = runner.invoke(scripts.logs, ["raylet.out"])
assert result.exit_code == 0
print(result.output)
assert "raylet.out" not in result.output
assert "raylet.err" not in result.output
assert "gcs_server.out" not in result.output
assert "gcs_server.err" not in result.output
# Make sure it prints the log message.
assert "NodeManager server started" in result.output
return True
wait_for_condition(verify)
# Test when there's more than 1 match, it prints a list of logs.
def verify():
result = runner.invoke(scripts.logs, ["raylet.*"])
assert result.exit_code == 0
print(result.output)
assert "raylet.out" in result.output
assert "raylet.err" in result.output
assert "gcs_server.out" not in result.output
assert "gcs_server.err" not in result.output
return True
wait_for_condition(verify)
if __name__ == "__main__":