diff --git a/doc/source/namespaces.rst b/doc/source/namespaces.rst index 132f153f9..ded7efcf2 100644 --- a/doc/source/namespaces.rst +++ b/doc/source/namespaces.rst @@ -19,34 +19,33 @@ Named actors are only accessible within their namespaces. .. code-block:: python - import + import ray @ray.remote class Actor: pass # Job 1 creates two actors, "orange" and "purple" in the "colors" namespace. - ray.client().namespace("colors").connect() - Actor.options(name="orange", lifetime="detached") - Actor.options(name="purple", lifetime="detached") - ray.util.disconnect() + with ray.client().namespace("colors").connect(): + Actor.options(name="orange", lifetime="detached") + Actor.options(name="purple", lifetime="detached") # Job 2 is now connecting to a different namespace. - ray.client().namespace("fruits").connect() - # This fails because "orange" was defined in the "colors" namespace. - ray.get_actor("orange") - # This succceeds because the name "orange" is unused in this namespace. - Actor.options(name="orange", lifetime="detached") - Actor.options(name="watermelon", lifetime="detached") - ray.util.disconnect() + with ray.client().namespace("fruits").connect(): + # This fails because "orange" was defined in the "colors" namespace. + ray.get_actor("orange") + # This succceeds because the name "orange" is unused in this namespace. + Actor.options(name="orange", lifetime="detached") + Actor.options(name="watermelon", lifetime="detached") # Job 3 connects to the original "colors" namespace - ray.client().namespace("colors").connect() + context = ray.client().namespace("colors").connect() # This fails because "watermelon" was in the fruits namespace. ray.get_actor("watermelon") # This returns the "orange" actor we created in the first job, not the second. ray.get_actor("orange") - ray.util.disconnect() + context.disconnect() + # We are manually managing the scope of the connection in this example. Anonymous namespaces @@ -58,22 +57,22 @@ will not have access to actors in other namespaces. .. code-block:: python - import + import ray @ray.remote class Actor: pass # Job 1 connects to an anonymous namespace by default - ray.client().connect() + ctx = ray.client().connect() Actor.options(name="my_actor", lifetime="detached") - ray.util.disconnect() + ctx.disconnect() # Job 2 connects to an _different_ anonymous namespace by default - ray.client().connect() + ctx = ray.client().connect() # This succeeds because the second job is in its own namespace. Actor.options(name="my_actor", lifetime="detached") - ray.util.disconnect() + ctx.disconnect() .. note:: diff --git a/python/ray/client_builder.py b/python/ray/client_builder.py index 869366d9a..507223da5 100644 --- a/python/ray/client_builder.py +++ b/python/ray/client_builder.py @@ -3,6 +3,7 @@ import importlib import logging from dataclasses import dataclass from urllib.parse import urlparse +import sys from typing import Any, Dict, Optional, Tuple from ray.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE @@ -13,17 +14,44 @@ logger = logging.getLogger(__name__) @dataclass -class ClientInfo: +class ClientContext: """ - Basic information of the remote server for a given Ray Client connection. + Basic context manager for a ClientBuilder connection. """ dashboard_url: Optional[str] python_version: str ray_version: str ray_commit: str - protocol_version: str + protocol_version: Optional[str] _num_clients: int + def __enter__(self) -> "ClientContext": + return self + + def __exit__(self, *exc) -> None: + self.disconnect() + + def disconnect(self) -> None: + """ + Disconnect Ray. This either disconnects from the remote Client Server + or shuts the current driver down. + """ + if ray.util.client.ray.is_connected(): + # This is only a client connected to a server. + ray.util.client_connect.disconnect() + ray._private.client_mode_hook._explicitly_disable_client_mode() + elif ray.worker.global_worker.node is None: + # Already disconnected. + return + elif ray.worker.global_worker.node.is_head(): + logger.debug( + "The current Ray Cluster is scoped to this process. " + "Disconnecting is not possible as it will shutdown the " + "cluster.") + else: + # This is only a driver connected to an existing cluster. + ray.shutdown() + class ClientBuilder: """ @@ -56,7 +84,7 @@ class ClientBuilder: self._job_config.set_ray_namespace(namespace) return self - def connect(self) -> ClientInfo: + def connect(self) -> ClientContext: """ Begin a connection to the address passed in via ray.client(...). @@ -69,7 +97,7 @@ class ClientBuilder: self.address, job_config=self._job_config) dashboard_url = ray.get( ray.remote(ray.worker.get_dashboard_url).remote()) - return ClientInfo( + return ClientContext( dashboard_url=dashboard_url, python_version=client_info_dict["python_version"], ray_version=client_info_dict["ray_version"], @@ -79,11 +107,19 @@ class ClientBuilder: class _LocalClientBuilder(ClientBuilder): - def connect(self) -> ClientInfo: + def connect(self) -> ClientContext: """ Begin a connection to the address passed in via ray.client(...). """ - return ray.init(address=self.address, job_config=self._job_config) + connection_dict = ray.init( + address=self.address, job_config=self._job_config) + return ClientContext( + dashboard_url=connection_dict["webui_url"], + python_version="{}.{}.{}".format( + sys.version_info[0], sys.version_info[1], sys.version_info[2]), + ray_version=ray.__version__, + ray_commit=ray.__commit__, + protocol_version=None) def _split_address(address: str) -> Tuple[str, str]: @@ -103,7 +139,11 @@ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder: return _LocalClientBuilder(None) if address is None: try: - with open("/tmp/ray/current_cluster", "r") as f: + # NOTE: This is not placed in `Node::get_temp_dir_path`, because + # this file is accessed before the `Node` object is created. + cluster_file = os.path.join(ray._private.utils.get_user_temp_dir(), + "ray_current_cluster") + with open(cluster_file, "r") as f: address = f.read() print(address) except FileNotFoundError: diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index 0594c3305..4bf29e7a1 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -90,15 +90,14 @@ print(ray.get_runtime_context().namespace) def test_connect_to_cluster(ray_start_regular_shared): server = ray_client_server.serve("localhost:50055") - client_info = ray.client("localhost:50055").connect() - - assert client_info.dashboard_url == ray.worker.get_dashboard_url() - python_version = ".".join([str(x) for x in list(sys.version_info)[:3]]) - assert client_info.python_version == python_version - assert client_info.ray_version == ray.__version__ - assert client_info.ray_commit == ray.__commit__ - protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION - assert client_info.protocol_version == protocol_version + with ray.client("localhost:50055").connect() as client_context: + assert client_context.dashboard_url == ray.worker.get_dashboard_url() + python_version = ".".join([str(x) for x in list(sys.version_info)[:3]]) + assert client_context.python_version == python_version + assert client_context.ray_version == ray.__version__ + assert client_context.ray_commit == ray.__commit__ + protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION + assert client_context.protocol_version == protocol_version server.stop(0) subprocess.check_output("ray stop --force", shell=True) @@ -180,3 +179,37 @@ assert len(ray._private.services.find_redis_address()) == 1 retry_interval_ms=1000) p1.kill() subprocess.check_output("ray stop --force", shell=True) + + +def test_disconnect(call_ray_stop_only): + subprocess.check_output( + "ray start --head --ray-client-server-port=25555", shell=True) + with ray.client("localhost:25555").namespace("n1").connect(): + # Connect via Ray Client + namespace = ray.get_runtime_context().namespace + assert namespace == "n1" + assert ray.util.client.ray.is_connected() + + with pytest.raises(ray.exceptions.RaySystemError): + ray.put(300) + + with ray.client(None).namespace("n1").connect(): + # Connect Directly via Driver + namespace = ray.get_runtime_context().namespace + assert namespace == "n1" + assert not ray.util.client.ray.is_connected() + + with pytest.raises(ray.exceptions.RaySystemError): + ray.put(300) + + ctx = ray.client("localhost:25555").namespace("n1").connect() + # Connect via Ray Client + namespace = ray.get_runtime_context().namespace + assert namespace == "n1" + assert ray.util.client.ray.is_connected() + ctx.disconnect() + # Check idempotency + ctx.disconnect() + + with pytest.raises(ray.exceptions.RaySystemError): + ray.put(300) diff --git a/python/ray/tests/test_client_proxy.py b/python/ray/tests/test_client_proxy.py index 9f9fd512e..2a1daf323 100644 --- a/python/ray/tests/test_client_proxy.py +++ b/python/ray/tests/test_client_proxy.py @@ -85,14 +85,12 @@ def test_multiple_clients_use_different_drivers(call_ray_start): """ Test that each client uses a separate JobIDs and namespaces. """ - ray.client("localhost:25001").connect() - job_id_one = ray.get_runtime_context().job_id - namespace_one = ray.get_runtime_context().namespace - ray.util.disconnect() - ray.client("localhost:25001").connect() - job_id_two = ray.get_runtime_context().job_id - namespace_two = ray.get_runtime_context().namespace - ray.util.disconnect() + with ray.client("localhost:25001").connect(): + job_id_one = ray.get_runtime_context().job_id + namespace_one = ray.get_runtime_context().namespace + with ray.client("localhost:25001").connect(): + job_id_two = ray.get_runtime_context().job_id + namespace_two = ray.get_runtime_context().namespace assert job_id_one != job_id_two assert namespace_one != namespace_two diff --git a/python/ray/tests/test_runtime_env_complicated.py b/python/ray/tests/test_runtime_env_complicated.py index 435cd0f10..f16b9c7c4 100644 --- a/python/ray/tests/test_runtime_env_complicated.py +++ b/python/ray/tests/test_runtime_env_complicated.py @@ -62,14 +62,15 @@ def conda_envs(): check_remote_client_conda = """ import ray -ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).connect() +context = ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).\\ +connect() @ray.remote def get_tf_version(): import tensorflow as tf return tf.__version__ assert ray.get(get_tf_version.remote()) == "{tf_version}" -ray.util.disconnect() +context.disconnect() """ @@ -96,9 +97,8 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs, tf_versions = ["2.2.0", "2.3.0"] for i, tf_version in enumerate(tf_versions): - try: - runtime_env = {"conda": f"tf-{tf_version}"} - ray.client("localhost:24001").env(runtime_env).connect() + runtime_env = {"conda": f"tf-{tf_version}"} + with ray.client("localhost:24001").env(runtime_env).connect(): assert ray.get(get_tf_version.remote()) == tf_version actor_handle = TfVersionActor.remote() assert ray.get(actor_handle.get_tf_version.remote()) == tf_version @@ -108,9 +108,6 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs, other_tf_version = tf_versions[(i + 1) % 2] run_string_as_driver( check_remote_client_conda.format(tf_version=other_tf_version)) - finally: - ray.util.disconnect() - ray._private.client_mode_hook._explicitly_disable_client_mode() @pytest.mark.skipif( @@ -387,28 +384,23 @@ def test_conda_create_ray_client(call_ray_start): ] } } - try: - ray.client("localhost:24001").env(runtime_env).connect() - @ray.remote - def f(): - import pip_install_test # noqa - return True + @ray.remote + def f(): + import pip_install_test # noqa + return True + with ray.client("localhost:24001").env(runtime_env).connect(): with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed on the test machine import pip_install_test # noqa assert ray.get(f.remote()) - ray.util.disconnect() - ray.client("localhost:24001").connect() + with ray.client("localhost:24001").connect(): with pytest.raises(ModuleNotFoundError): # Ensure pip-install-test is not installed in a client that doesn't # use the runtime_env ray.get(f.remote()) - finally: - ray.util.disconnect() - ray._private.client_mode_hook._explicitly_disable_client_mode() @pytest.mark.skipif(