[Client] Add ray.client().disconnect() (#16021)

This commit is contained in:
Ian Rodney 2021-05-28 10:15:44 -07:00 committed by GitHub
parent 3d37e3a315
commit ec46794767
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 125 additions and 63 deletions

View file

@ -19,34 +19,33 @@ Named actors are only accessible within their namespaces.
.. code-block:: python
import
import ray
@ray.remote
class Actor:
pass
# Job 1 creates two actors, "orange" and "purple" in the "colors" namespace.
ray.client().namespace("colors").connect()
Actor.options(name="orange", lifetime="detached")
Actor.options(name="purple", lifetime="detached")
ray.util.disconnect()
with ray.client().namespace("colors").connect():
Actor.options(name="orange", lifetime="detached")
Actor.options(name="purple", lifetime="detached")
# Job 2 is now connecting to a different namespace.
ray.client().namespace("fruits").connect()
# This fails because "orange" was defined in the "colors" namespace.
ray.get_actor("orange")
# This succceeds because the name "orange" is unused in this namespace.
Actor.options(name="orange", lifetime="detached")
Actor.options(name="watermelon", lifetime="detached")
ray.util.disconnect()
with ray.client().namespace("fruits").connect():
# This fails because "orange" was defined in the "colors" namespace.
ray.get_actor("orange")
# This succceeds because the name "orange" is unused in this namespace.
Actor.options(name="orange", lifetime="detached")
Actor.options(name="watermelon", lifetime="detached")
# Job 3 connects to the original "colors" namespace
ray.client().namespace("colors").connect()
context = ray.client().namespace("colors").connect()
# This fails because "watermelon" was in the fruits namespace.
ray.get_actor("watermelon")
# This returns the "orange" actor we created in the first job, not the second.
ray.get_actor("orange")
ray.util.disconnect()
context.disconnect()
# We are manually managing the scope of the connection in this example.
Anonymous namespaces
@ -58,22 +57,22 @@ will not have access to actors in other namespaces.
.. code-block:: python
import
import ray
@ray.remote
class Actor:
pass
# Job 1 connects to an anonymous namespace by default
ray.client().connect()
ctx = ray.client().connect()
Actor.options(name="my_actor", lifetime="detached")
ray.util.disconnect()
ctx.disconnect()
# Job 2 connects to an _different_ anonymous namespace by default
ray.client().connect()
ctx = ray.client().connect()
# This succeeds because the second job is in its own namespace.
Actor.options(name="my_actor", lifetime="detached")
ray.util.disconnect()
ctx.disconnect()
.. note::

View file

@ -3,6 +3,7 @@ import importlib
import logging
from dataclasses import dataclass
from urllib.parse import urlparse
import sys
from typing import Any, Dict, Optional, Tuple
from ray.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
@ -13,17 +14,44 @@ logger = logging.getLogger(__name__)
@dataclass
class ClientInfo:
class ClientContext:
"""
Basic information of the remote server for a given Ray Client connection.
Basic context manager for a ClientBuilder connection.
"""
dashboard_url: Optional[str]
python_version: str
ray_version: str
ray_commit: str
protocol_version: str
protocol_version: Optional[str]
_num_clients: int
def __enter__(self) -> "ClientContext":
return self
def __exit__(self, *exc) -> None:
self.disconnect()
def disconnect(self) -> None:
"""
Disconnect Ray. This either disconnects from the remote Client Server
or shuts the current driver down.
"""
if ray.util.client.ray.is_connected():
# This is only a client connected to a server.
ray.util.client_connect.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()
elif ray.worker.global_worker.node is None:
# Already disconnected.
return
elif ray.worker.global_worker.node.is_head():
logger.debug(
"The current Ray Cluster is scoped to this process. "
"Disconnecting is not possible as it will shutdown the "
"cluster.")
else:
# This is only a driver connected to an existing cluster.
ray.shutdown()
class ClientBuilder:
"""
@ -56,7 +84,7 @@ class ClientBuilder:
self._job_config.set_ray_namespace(namespace)
return self
def connect(self) -> ClientInfo:
def connect(self) -> ClientContext:
"""
Begin a connection to the address passed in via ray.client(...).
@ -69,7 +97,7 @@ class ClientBuilder:
self.address, job_config=self._job_config)
dashboard_url = ray.get(
ray.remote(ray.worker.get_dashboard_url).remote())
return ClientInfo(
return ClientContext(
dashboard_url=dashboard_url,
python_version=client_info_dict["python_version"],
ray_version=client_info_dict["ray_version"],
@ -79,11 +107,19 @@ class ClientBuilder:
class _LocalClientBuilder(ClientBuilder):
def connect(self) -> ClientInfo:
def connect(self) -> ClientContext:
"""
Begin a connection to the address passed in via ray.client(...).
"""
return ray.init(address=self.address, job_config=self._job_config)
connection_dict = ray.init(
address=self.address, job_config=self._job_config)
return ClientContext(
dashboard_url=connection_dict["webui_url"],
python_version="{}.{}.{}".format(
sys.version_info[0], sys.version_info[1], sys.version_info[2]),
ray_version=ray.__version__,
ray_commit=ray.__commit__,
protocol_version=None)
def _split_address(address: str) -> Tuple[str, str]:
@ -103,7 +139,11 @@ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
return _LocalClientBuilder(None)
if address is None:
try:
with open("/tmp/ray/current_cluster", "r") as f:
# NOTE: This is not placed in `Node::get_temp_dir_path`, because
# this file is accessed before the `Node` object is created.
cluster_file = os.path.join(ray._private.utils.get_user_temp_dir(),
"ray_current_cluster")
with open(cluster_file, "r") as f:
address = f.read()
print(address)
except FileNotFoundError:

View file

@ -90,15 +90,14 @@ print(ray.get_runtime_context().namespace)
def test_connect_to_cluster(ray_start_regular_shared):
server = ray_client_server.serve("localhost:50055")
client_info = ray.client("localhost:50055").connect()
assert client_info.dashboard_url == ray.worker.get_dashboard_url()
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert client_info.python_version == python_version
assert client_info.ray_version == ray.__version__
assert client_info.ray_commit == ray.__commit__
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
assert client_info.protocol_version == protocol_version
with ray.client("localhost:50055").connect() as client_context:
assert client_context.dashboard_url == ray.worker.get_dashboard_url()
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
assert client_context.python_version == python_version
assert client_context.ray_version == ray.__version__
assert client_context.ray_commit == ray.__commit__
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
assert client_context.protocol_version == protocol_version
server.stop(0)
subprocess.check_output("ray stop --force", shell=True)
@ -180,3 +179,37 @@ assert len(ray._private.services.find_redis_address()) == 1
retry_interval_ms=1000)
p1.kill()
subprocess.check_output("ray stop --force", shell=True)
def test_disconnect(call_ray_stop_only):
subprocess.check_output(
"ray start --head --ray-client-server-port=25555", shell=True)
with ray.client("localhost:25555").namespace("n1").connect():
# Connect via Ray Client
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert ray.util.client.ray.is_connected()
with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)
with ray.client(None).namespace("n1").connect():
# Connect Directly via Driver
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert not ray.util.client.ray.is_connected()
with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)
ctx = ray.client("localhost:25555").namespace("n1").connect()
# Connect via Ray Client
namespace = ray.get_runtime_context().namespace
assert namespace == "n1"
assert ray.util.client.ray.is_connected()
ctx.disconnect()
# Check idempotency
ctx.disconnect()
with pytest.raises(ray.exceptions.RaySystemError):
ray.put(300)

View file

@ -85,14 +85,12 @@ def test_multiple_clients_use_different_drivers(call_ray_start):
"""
Test that each client uses a separate JobIDs and namespaces.
"""
ray.client("localhost:25001").connect()
job_id_one = ray.get_runtime_context().job_id
namespace_one = ray.get_runtime_context().namespace
ray.util.disconnect()
ray.client("localhost:25001").connect()
job_id_two = ray.get_runtime_context().job_id
namespace_two = ray.get_runtime_context().namespace
ray.util.disconnect()
with ray.client("localhost:25001").connect():
job_id_one = ray.get_runtime_context().job_id
namespace_one = ray.get_runtime_context().namespace
with ray.client("localhost:25001").connect():
job_id_two = ray.get_runtime_context().job_id
namespace_two = ray.get_runtime_context().namespace
assert job_id_one != job_id_two
assert namespace_one != namespace_two

View file

@ -62,14 +62,15 @@ def conda_envs():
check_remote_client_conda = """
import ray
ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).connect()
context = ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).\\
connect()
@ray.remote
def get_tf_version():
import tensorflow as tf
return tf.__version__
assert ray.get(get_tf_version.remote()) == "{tf_version}"
ray.util.disconnect()
context.disconnect()
"""
@ -96,9 +97,8 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs,
tf_versions = ["2.2.0", "2.3.0"]
for i, tf_version in enumerate(tf_versions):
try:
runtime_env = {"conda": f"tf-{tf_version}"}
ray.client("localhost:24001").env(runtime_env).connect()
runtime_env = {"conda": f"tf-{tf_version}"}
with ray.client("localhost:24001").env(runtime_env).connect():
assert ray.get(get_tf_version.remote()) == tf_version
actor_handle = TfVersionActor.remote()
assert ray.get(actor_handle.get_tf_version.remote()) == tf_version
@ -108,9 +108,6 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs,
other_tf_version = tf_versions[(i + 1) % 2]
run_string_as_driver(
check_remote_client_conda.format(tf_version=other_tf_version))
finally:
ray.util.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()
@pytest.mark.skipif(
@ -387,28 +384,23 @@ def test_conda_create_ray_client(call_ray_start):
]
}
}
try:
ray.client("localhost:24001").env(runtime_env).connect()
@ray.remote
def f():
import pip_install_test # noqa
return True
@ray.remote
def f():
import pip_install_test # noqa
return True
with ray.client("localhost:24001").env(runtime_env).connect():
with pytest.raises(ModuleNotFoundError):
# Ensure pip-install-test is not installed on the test machine
import pip_install_test # noqa
assert ray.get(f.remote())
ray.util.disconnect()
ray.client("localhost:24001").connect()
with ray.client("localhost:24001").connect():
with pytest.raises(ModuleNotFoundError):
# Ensure pip-install-test is not installed in a client that doesn't
# use the runtime_env
ray.get(f.remote())
finally:
ray.util.disconnect()
ray._private.client_mode_hook._explicitly_disable_client_mode()
@pytest.mark.skipif(