[jobs] Enable default port in http:// addresses (#22014)

Closes https://github.com/ray-project/ray/issues/22012
This commit is contained in:
Edward Oakes 2022-02-02 14:34:34 -06:00 committed by GitHub
parent 8bbc5b936a
commit e85bbfb338
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 30 deletions

View file

@ -26,6 +26,7 @@ from ray.dashboard.modules.job.common import (
JobLogsResponse,
uri_to_http_components,
)
from ray.ray_constants import DEFAULT_DASHBOARD_PORT
from ray.client_builder import _split_address
@ -51,9 +52,13 @@ def get_job_submission_client_cluster_info(
cookies: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
_use_tls: Optional[bool] = False,
) -> ClusterInfo:
"""Get address, cookies, and metadata used for JobSubmissionClient.
If no port is specified in `address`, the Ray dashboard default will be
inserted.
Args:
address (str): Address without the module prefix that is passed
to JobSubmissionClient.
@ -66,8 +71,23 @@ def get_job_submission_client_cluster_info(
ClusterInfo object consisting of address, cookies, and metadata
for JobSubmissionClient to use.
"""
scheme = "https" if _use_tls else "http"
split = address.split(":")
host = split[0]
if len(split) == 1:
port = DEFAULT_DASHBOARD_PORT
elif len(split) == 2:
port = int(split[1])
else:
raise ValueError(f"Invalid address: {address}.")
return ClusterInfo(
address="http://" + address, cookies=cookies, metadata=metadata, headers=headers
address=f"{scheme}://{host}:{port}",
cookies=cookies,
metadata=metadata,
headers=headers,
)
@ -80,19 +100,15 @@ def parse_cluster_info(
) -> ClusterInfo:
module_string, inner_address = _split_address(address.rstrip("/"))
# If user passes in a raw HTTP(S) address, just pass it through.
if module_string == "http" or module_string == "https":
return ClusterInfo(
address=address, cookies=cookies, metadata=metadata, headers=headers
)
# If user passes in a Ray address, convert it to HTTP.
elif module_string == "ray":
# If user passes http(s):// or ray://, go through normal parsing.
if module_string in {"http", "https", "ray"}:
return get_job_submission_client_cluster_info(
inner_address,
create_cluster_if_needed=create_cluster_if_needed,
cookies=cookies,
metadata=metadata,
headers=headers,
_use_tls=module_string == "https",
)
# Try to dynamically import the function to get cluster info.
else:

View file

@ -2,24 +2,26 @@ import logging
from pathlib import Path
import sys
import tempfile
from typing import Optional
import pytest
from unittest.mock import patch
import ray
from ray.dashboard.tests.conftest import * # noqa
from ray.tests.conftest import _ray_start
from ray._private.test_utils import (
format_web_url,
wait_for_condition,
wait_until_server_available,
)
from ray.dashboard.modules.job.common import CURRENT_VERSION, JobStatus
from ray.dashboard.modules.job.sdk import (
ClusterInfo,
JobSubmissionClient,
parse_cluster_info,
)
from unittest.mock import patch
from ray.dashboard.tests.conftest import * # noqa
from ray.ray_constants import DEFAULT_DASHBOARD_PORT
from ray.tests.conftest import _ray_start
from ray._private.test_utils import (
format_web_url,
wait_for_condition,
wait_until_server_available,
)
logger = logging.getLogger(__name__)
@ -319,26 +321,28 @@ def test_request_headers(job_sdk_client):
)
@pytest.mark.parametrize(
"address",
[
"http://127.0.0.1",
"https://127.0.0.1",
"ray://127.0.0.1",
"fake_module://127.0.0.1",
],
)
def test_parse_cluster_info(address: str):
if address.startswith("ray"):
@pytest.mark.parametrize("scheme", ["http", "https", "ray", "fake_module"])
@pytest.mark.parametrize("host", ["127.0.0.1", "localhost", "fake.dns.name"])
@pytest.mark.parametrize("port", [None, 8265, 10000])
def test_parse_cluster_info(scheme: str, host: str, port: Optional[int]):
address = f"{scheme}://{host}"
if port is not None:
address += f":{port}"
final_port = port if port is not None else DEFAULT_DASHBOARD_PORT
if scheme in {"http", "ray"}:
assert parse_cluster_info(address, False) == ClusterInfo(
address="http" + address[address.index("://") :],
address=f"http://{host}:{final_port}",
cookies=None,
metadata=None,
headers=None,
)
elif address.startswith("http") or address.startswith("https"):
elif scheme == "https":
assert parse_cluster_info(address, False) == ClusterInfo(
address=address, cookies=None, metadata=None, headers=None
address=f"https://{host}:{final_port}",
cookies=None,
metadata=None,
headers=None,
)
else:
with pytest.raises(RuntimeError):