Revert Revert "[Observability] Fix --follow lost connection when it is used for > 30 seconds" #26162 (#26163)

* Revert "Revert "[Observability] Fix --follow lost connection when it is used for > 30 seconds (#26080)" (#26162)"

This reverts commit 3017128d5e.
This commit is contained in:
SangBin Cho 2022-06-29 08:07:32 +09:00 committed by GitHub
parent 68315b34b4
commit def02bd4c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 5 deletions

View file

@ -88,7 +88,10 @@ class LogsManager:
keep_alive=keep_alive,
lines=options.lines,
interval=options.interval,
timeout=options.timeout,
# If we keepalive logs connection, we shouldn't have timeout
# otherwise the stream will be terminated forcefully
# after the deadline is expired.
timeout=options.timeout if not keep_alive else None,
)
async for streamed_log in stream:

View file

@ -187,7 +187,7 @@ class StateHead(dashboard_utils.DashboardHeadModule):
glob_filter = req.query.get("glob", "*")
node_id = req.query.get("node_id", None)
node_ip = req.query.get("node_ip", None)
timeout = req.query.get("timeout", DEFAULT_RPC_TIMEOUT)
timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
# TODO(sang): Do input validation from the middleware instead.
if not node_id and not node_ip:

View file

@ -560,6 +560,7 @@ def get_log(
pid: Optional[int] = None,
follow: bool = False,
tail: int = 100,
timeout: int = DEFAULT_RPC_TIMEOUT,
_interval: Optional[float] = None,
) -> Generator[str, None, None]:
if api_server_url is None:
@ -569,6 +570,7 @@ def get_log(
)
media_type = "stream" if follow else "file"
options = GetLogOptions(
node_id=node_id,
node_ip=node_ip,
@ -579,7 +581,7 @@ def get_log(
lines=tail,
interval=_interval,
media_type=media_type,
timeout=DEFAULT_RPC_TIMEOUT,
timeout=timeout,
)
options_dict = {}
for field in fields(options):
@ -612,6 +614,7 @@ def list_logs(
node_id: str = None,
node_ip: str = None,
glob_filter: str = None,
timeout: int = DEFAULT_RPC_TIMEOUT,
) -> Dict[str, List[str]]:
if api_server_url is None:
assert ray.is_initialized()
@ -629,6 +632,7 @@ def list_logs(
options_dict["node_id"] = node_id
if glob_filter:
options_dict["glob"] = glob_filter
options_dict["timeout"] = timeout
r = requests.get(
f"{api_server_url}/api/v0/logs?{urllib.parse.urlencode(options_dict)}"

View file

@ -44,7 +44,7 @@ from ray.autoscaler._private.fake_multi_node.node_provider import FAKE_HEAD_NODE
from ray.autoscaler._private.kuberay.run_autoscaler import run_kuberay_autoscaler
from ray.dashboard.modules.job.cli import job_cli_group
from ray.experimental.state.api import get_log, list_logs
from ray.experimental.state.common import DEFAULT_LIMIT
from ray.experimental.state.common import DEFAULT_LIMIT, DEFAULT_RPC_TIMEOUT
from ray.util.annotations import PublicAPI
from ray.experimental.state.state_cli import (
@ -2040,6 +2040,15 @@ def local_dump(
help="The interval to print new logs when `--follow` is specified.",
hidden=True,
)
@click.option(
"--timeout",
default=DEFAULT_RPC_TIMEOUT,
help=(
"Timeout in seconds for the API requests. "
f"Default is {DEFAULT_RPC_TIMEOUT}. If --follow is specified, "
"this option will be ignored."
),
)
def logs(
glob_filter,
node_ip: str,
@ -2050,6 +2059,7 @@ def logs(
follow: bool,
tail: int,
interval: float,
timeout: int,
):
if task_id is not None:
raise NotImplementedError("--task-id is not yet supported")
@ -2071,6 +2081,7 @@ def logs(
node_id=node_id,
node_ip=node_ip,
glob_filter=glob_filter,
timeout=timeout,
)
log_files_found = []
for _, log_files in logs.items():
@ -2111,6 +2122,7 @@ def logs(
tail=tail,
follow=follow,
_interval=interval,
timeout=timeout,
):
print(chunk, end="", flush=True)

View file

@ -363,13 +363,52 @@ async def test_logs_manager_stream_log(logs_manager):
keep_alive=True,
lines=10,
interval=0.5,
timeout=30,
timeout=None,
)
# Currently cannot test actor_id with AsyncMock.
# It will be tested by the integration test.
@pytest.mark.skipif(
sys.version_info < ASYNCMOCK_MIN_PYTHON_VER,
reason=f"unittest.mock.AsyncMock requires python {ASYNCMOCK_MIN_PYTHON_VER}"
" or higher",
)
@pytest.mark.asyncio
async def test_logs_manager_keepalive_no_timeout(logs_manager):
"""Test when --follow is specified, there's no timeout.
Related: https://github.com/ray-project/ray/issues/25721
"""
NUM_LOG_CHUNKS = 10
logs_client = logs_manager.data_source_client
logs_client.get_all_registered_agent_ids = MagicMock()
logs_client.get_all_registered_agent_ids.return_value = ["1", "2"]
logs_client.ip_to_node_id = MagicMock()
logs_client.stream_log.return_value = generate_logs_stream(NUM_LOG_CHUNKS)
# Test file_name, media_type="file", node_id
options = GetLogOptions(
timeout=30, media_type="stream", lines=10, node_id="1", filename="raylet.out"
)
async for chunk in logs_manager.stream_logs(options):
pass
# Make sure timeout == None when media_type == stream. This is to avoid
# closing the connection due to DEADLINE_EXCEEDED when --follow is specified.
logs_client.stream_log.assert_awaited_with(
node_id="1",
log_file_name="raylet.out",
keep_alive=True,
lines=10,
interval=None,
timeout=None,
)
# Integration tests