[ClientBuilder] Verify Module has ClientBuilder Class (#16076)

This commit is contained in:
Ian Rodney 2021-06-02 09:19:44 -07:00 committed by GitHub
parent c53893cb13
commit 4116c8c3f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 5 deletions

View file

@ -2,8 +2,8 @@ import os
import importlib
import logging
from dataclasses import dataclass
from urllib.parse import urlparse
import sys
from typing import Any, Dict, Optional, Tuple
from ray.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
@ -128,9 +128,10 @@ def _split_address(address: str) -> Tuple[str, str]:
"""
if "://" not in address:
address = "ray://" + address
url_object = urlparse(address)
module_string = url_object.scheme
inner_address = address.replace(module_string + "://", "", 1)
# NOTE: We use a custom splitting function instead of urllib because
# PEP allows "underscores" in a module names, while URL schemes do not
# allow them.
module_string, inner_address = address.split("://", maxsplit=1)
return (module_string, inner_address)
@ -151,7 +152,14 @@ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
pass
return _LocalClientBuilder(address)
module_string, inner_address = _split_address(address)
module = importlib.import_module(module_string)
try:
module = importlib.import_module(module_string)
except Exception:
raise RuntimeError(
f"Module: {module_string} does not exist.\n"
f"This module was parsed from Address: {address}") from None
assert "ClientBuilder" in dir(module), (f"Module: {module_string} does "
"not have ClientBuilder.")
return module.ClientBuilder(inner_address)

View file

@ -2,6 +2,7 @@ import os
import pytest
import subprocess
import sys
from unittest.mock import patch, Mock
import ray
import ray.util.client.server.server as ray_client_server
@ -25,6 +26,9 @@ def test_split_address(address):
specified_other_module = f"module://{address}"
assert client_builder._split_address(specified_other_module) == ("module",
address)
non_url_compliant_module = f"module_test://{address}"
assert client_builder._split_address(non_url_compliant_module) == (
"module", address)
@pytest.mark.parametrize(
@ -182,6 +186,41 @@ assert len(ray._private.services.find_redis_address()) == 1
subprocess.check_output("ray stop --force", shell=True)
def test_non_existent_modules():
exception = None
try:
ray.client("badmodule://address")
except RuntimeError as e:
exception = e
assert exception is not None, "Bad Module did not raise RuntimeException"
assert "does not exist" in str(exception)
def test_module_lacks_client_builder():
mock_importlib = Mock()
def mock_import_module(module_string):
if module_string == "ray":
return ray
else:
# Mock() does not have a `ClientBuilder` in its scope
return Mock()
mock_importlib.import_module = mock_import_module
with patch("ray.client_builder.importlib", mock_importlib):
assert isinstance(ray.client(""), ray.ClientBuilder)
assert isinstance(ray.client("ray://"), ray.ClientBuilder)
exception = None
try:
ray.client("othermodule://")
except AssertionError as e:
exception = e
assert exception is not None, ("Module without ClientBuilder did not "
"raise AssertionError")
assert "does not have ClientBuilder" in str(exception)
def test_disconnect(call_ray_stop_only):
subprocess.check_output(
"ray start --head --ray-client-server-port=25555", shell=True)