mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Client] Add ray.client().disconnect() (#16021)
This commit is contained in:
parent
3d37e3a315
commit
ec46794767
5 changed files with 125 additions and 63 deletions
|
@ -19,34 +19,33 @@ Named actors are only accessible within their namespaces.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import
|
||||
import ray
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
pass
|
||||
|
||||
# Job 1 creates two actors, "orange" and "purple" in the "colors" namespace.
|
||||
ray.client().namespace("colors").connect()
|
||||
Actor.options(name="orange", lifetime="detached")
|
||||
Actor.options(name="purple", lifetime="detached")
|
||||
ray.util.disconnect()
|
||||
with ray.client().namespace("colors").connect():
|
||||
Actor.options(name="orange", lifetime="detached")
|
||||
Actor.options(name="purple", lifetime="detached")
|
||||
|
||||
# Job 2 is now connecting to a different namespace.
|
||||
ray.client().namespace("fruits").connect()
|
||||
# This fails because "orange" was defined in the "colors" namespace.
|
||||
ray.get_actor("orange")
|
||||
# This succceeds because the name "orange" is unused in this namespace.
|
||||
Actor.options(name="orange", lifetime="detached")
|
||||
Actor.options(name="watermelon", lifetime="detached")
|
||||
ray.util.disconnect()
|
||||
with ray.client().namespace("fruits").connect():
|
||||
# This fails because "orange" was defined in the "colors" namespace.
|
||||
ray.get_actor("orange")
|
||||
# This succceeds because the name "orange" is unused in this namespace.
|
||||
Actor.options(name="orange", lifetime="detached")
|
||||
Actor.options(name="watermelon", lifetime="detached")
|
||||
|
||||
# Job 3 connects to the original "colors" namespace
|
||||
ray.client().namespace("colors").connect()
|
||||
context = ray.client().namespace("colors").connect()
|
||||
# This fails because "watermelon" was in the fruits namespace.
|
||||
ray.get_actor("watermelon")
|
||||
# This returns the "orange" actor we created in the first job, not the second.
|
||||
ray.get_actor("orange")
|
||||
ray.util.disconnect()
|
||||
context.disconnect()
|
||||
# We are manually managing the scope of the connection in this example.
|
||||
|
||||
|
||||
Anonymous namespaces
|
||||
|
@ -58,22 +57,22 @@ will not have access to actors in other namespaces.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import
|
||||
import ray
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
pass
|
||||
|
||||
# Job 1 connects to an anonymous namespace by default
|
||||
ray.client().connect()
|
||||
ctx = ray.client().connect()
|
||||
Actor.options(name="my_actor", lifetime="detached")
|
||||
ray.util.disconnect()
|
||||
ctx.disconnect()
|
||||
|
||||
# Job 2 connects to an _different_ anonymous namespace by default
|
||||
ray.client().connect()
|
||||
ctx = ray.client().connect()
|
||||
# This succeeds because the second job is in its own namespace.
|
||||
Actor.options(name="my_actor", lifetime="detached")
|
||||
ray.util.disconnect()
|
||||
ctx.disconnect()
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import importlib
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from ray.ray_constants import RAY_ADDRESS_ENVIRONMENT_VARIABLE
|
||||
|
@ -13,17 +14,44 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@dataclass
|
||||
class ClientInfo:
|
||||
class ClientContext:
|
||||
"""
|
||||
Basic information of the remote server for a given Ray Client connection.
|
||||
Basic context manager for a ClientBuilder connection.
|
||||
"""
|
||||
dashboard_url: Optional[str]
|
||||
python_version: str
|
||||
ray_version: str
|
||||
ray_commit: str
|
||||
protocol_version: str
|
||||
protocol_version: Optional[str]
|
||||
_num_clients: int
|
||||
|
||||
def __enter__(self) -> "ClientContext":
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self.disconnect()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect Ray. This either disconnects from the remote Client Server
|
||||
or shuts the current driver down.
|
||||
"""
|
||||
if ray.util.client.ray.is_connected():
|
||||
# This is only a client connected to a server.
|
||||
ray.util.client_connect.disconnect()
|
||||
ray._private.client_mode_hook._explicitly_disable_client_mode()
|
||||
elif ray.worker.global_worker.node is None:
|
||||
# Already disconnected.
|
||||
return
|
||||
elif ray.worker.global_worker.node.is_head():
|
||||
logger.debug(
|
||||
"The current Ray Cluster is scoped to this process. "
|
||||
"Disconnecting is not possible as it will shutdown the "
|
||||
"cluster.")
|
||||
else:
|
||||
# This is only a driver connected to an existing cluster.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
class ClientBuilder:
|
||||
"""
|
||||
|
@ -56,7 +84,7 @@ class ClientBuilder:
|
|||
self._job_config.set_ray_namespace(namespace)
|
||||
return self
|
||||
|
||||
def connect(self) -> ClientInfo:
|
||||
def connect(self) -> ClientContext:
|
||||
"""
|
||||
Begin a connection to the address passed in via ray.client(...).
|
||||
|
||||
|
@ -69,7 +97,7 @@ class ClientBuilder:
|
|||
self.address, job_config=self._job_config)
|
||||
dashboard_url = ray.get(
|
||||
ray.remote(ray.worker.get_dashboard_url).remote())
|
||||
return ClientInfo(
|
||||
return ClientContext(
|
||||
dashboard_url=dashboard_url,
|
||||
python_version=client_info_dict["python_version"],
|
||||
ray_version=client_info_dict["ray_version"],
|
||||
|
@ -79,11 +107,19 @@ class ClientBuilder:
|
|||
|
||||
|
||||
class _LocalClientBuilder(ClientBuilder):
|
||||
def connect(self) -> ClientInfo:
|
||||
def connect(self) -> ClientContext:
|
||||
"""
|
||||
Begin a connection to the address passed in via ray.client(...).
|
||||
"""
|
||||
return ray.init(address=self.address, job_config=self._job_config)
|
||||
connection_dict = ray.init(
|
||||
address=self.address, job_config=self._job_config)
|
||||
return ClientContext(
|
||||
dashboard_url=connection_dict["webui_url"],
|
||||
python_version="{}.{}.{}".format(
|
||||
sys.version_info[0], sys.version_info[1], sys.version_info[2]),
|
||||
ray_version=ray.__version__,
|
||||
ray_commit=ray.__commit__,
|
||||
protocol_version=None)
|
||||
|
||||
|
||||
def _split_address(address: str) -> Tuple[str, str]:
|
||||
|
@ -103,7 +139,11 @@ def _get_builder_from_address(address: Optional[str]) -> ClientBuilder:
|
|||
return _LocalClientBuilder(None)
|
||||
if address is None:
|
||||
try:
|
||||
with open("/tmp/ray/current_cluster", "r") as f:
|
||||
# NOTE: This is not placed in `Node::get_temp_dir_path`, because
|
||||
# this file is accessed before the `Node` object is created.
|
||||
cluster_file = os.path.join(ray._private.utils.get_user_temp_dir(),
|
||||
"ray_current_cluster")
|
||||
with open(cluster_file, "r") as f:
|
||||
address = f.read()
|
||||
print(address)
|
||||
except FileNotFoundError:
|
||||
|
|
|
@ -90,15 +90,14 @@ print(ray.get_runtime_context().namespace)
|
|||
|
||||
def test_connect_to_cluster(ray_start_regular_shared):
|
||||
server = ray_client_server.serve("localhost:50055")
|
||||
client_info = ray.client("localhost:50055").connect()
|
||||
|
||||
assert client_info.dashboard_url == ray.worker.get_dashboard_url()
|
||||
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
|
||||
assert client_info.python_version == python_version
|
||||
assert client_info.ray_version == ray.__version__
|
||||
assert client_info.ray_commit == ray.__commit__
|
||||
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
|
||||
assert client_info.protocol_version == protocol_version
|
||||
with ray.client("localhost:50055").connect() as client_context:
|
||||
assert client_context.dashboard_url == ray.worker.get_dashboard_url()
|
||||
python_version = ".".join([str(x) for x in list(sys.version_info)[:3]])
|
||||
assert client_context.python_version == python_version
|
||||
assert client_context.ray_version == ray.__version__
|
||||
assert client_context.ray_commit == ray.__commit__
|
||||
protocol_version = ray.util.client.CURRENT_PROTOCOL_VERSION
|
||||
assert client_context.protocol_version == protocol_version
|
||||
|
||||
server.stop(0)
|
||||
subprocess.check_output("ray stop --force", shell=True)
|
||||
|
@ -180,3 +179,37 @@ assert len(ray._private.services.find_redis_address()) == 1
|
|||
retry_interval_ms=1000)
|
||||
p1.kill()
|
||||
subprocess.check_output("ray stop --force", shell=True)
|
||||
|
||||
|
||||
def test_disconnect(call_ray_stop_only):
|
||||
subprocess.check_output(
|
||||
"ray start --head --ray-client-server-port=25555", shell=True)
|
||||
with ray.client("localhost:25555").namespace("n1").connect():
|
||||
# Connect via Ray Client
|
||||
namespace = ray.get_runtime_context().namespace
|
||||
assert namespace == "n1"
|
||||
assert ray.util.client.ray.is_connected()
|
||||
|
||||
with pytest.raises(ray.exceptions.RaySystemError):
|
||||
ray.put(300)
|
||||
|
||||
with ray.client(None).namespace("n1").connect():
|
||||
# Connect Directly via Driver
|
||||
namespace = ray.get_runtime_context().namespace
|
||||
assert namespace == "n1"
|
||||
assert not ray.util.client.ray.is_connected()
|
||||
|
||||
with pytest.raises(ray.exceptions.RaySystemError):
|
||||
ray.put(300)
|
||||
|
||||
ctx = ray.client("localhost:25555").namespace("n1").connect()
|
||||
# Connect via Ray Client
|
||||
namespace = ray.get_runtime_context().namespace
|
||||
assert namespace == "n1"
|
||||
assert ray.util.client.ray.is_connected()
|
||||
ctx.disconnect()
|
||||
# Check idempotency
|
||||
ctx.disconnect()
|
||||
|
||||
with pytest.raises(ray.exceptions.RaySystemError):
|
||||
ray.put(300)
|
||||
|
|
|
@ -85,14 +85,12 @@ def test_multiple_clients_use_different_drivers(call_ray_start):
|
|||
"""
|
||||
Test that each client uses a separate JobIDs and namespaces.
|
||||
"""
|
||||
ray.client("localhost:25001").connect()
|
||||
job_id_one = ray.get_runtime_context().job_id
|
||||
namespace_one = ray.get_runtime_context().namespace
|
||||
ray.util.disconnect()
|
||||
ray.client("localhost:25001").connect()
|
||||
job_id_two = ray.get_runtime_context().job_id
|
||||
namespace_two = ray.get_runtime_context().namespace
|
||||
ray.util.disconnect()
|
||||
with ray.client("localhost:25001").connect():
|
||||
job_id_one = ray.get_runtime_context().job_id
|
||||
namespace_one = ray.get_runtime_context().namespace
|
||||
with ray.client("localhost:25001").connect():
|
||||
job_id_two = ray.get_runtime_context().job_id
|
||||
namespace_two = ray.get_runtime_context().namespace
|
||||
|
||||
assert job_id_one != job_id_two
|
||||
assert namespace_one != namespace_two
|
||||
|
|
|
@ -62,14 +62,15 @@ def conda_envs():
|
|||
|
||||
check_remote_client_conda = """
|
||||
import ray
|
||||
ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).connect()
|
||||
context = ray.client("localhost:24001").env({{"conda" : "tf-{tf_version}"}}).\\
|
||||
connect()
|
||||
@ray.remote
|
||||
def get_tf_version():
|
||||
import tensorflow as tf
|
||||
return tf.__version__
|
||||
|
||||
assert ray.get(get_tf_version.remote()) == "{tf_version}"
|
||||
ray.util.disconnect()
|
||||
context.disconnect()
|
||||
"""
|
||||
|
||||
|
||||
|
@ -96,9 +97,8 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs,
|
|||
|
||||
tf_versions = ["2.2.0", "2.3.0"]
|
||||
for i, tf_version in enumerate(tf_versions):
|
||||
try:
|
||||
runtime_env = {"conda": f"tf-{tf_version}"}
|
||||
ray.client("localhost:24001").env(runtime_env).connect()
|
||||
runtime_env = {"conda": f"tf-{tf_version}"}
|
||||
with ray.client("localhost:24001").env(runtime_env).connect():
|
||||
assert ray.get(get_tf_version.remote()) == tf_version
|
||||
actor_handle = TfVersionActor.remote()
|
||||
assert ray.get(actor_handle.get_tf_version.remote()) == tf_version
|
||||
|
@ -108,9 +108,6 @@ def test_client_tasks_and_actors_inherit_from_driver(conda_envs,
|
|||
other_tf_version = tf_versions[(i + 1) % 2]
|
||||
run_string_as_driver(
|
||||
check_remote_client_conda.format(tf_version=other_tf_version))
|
||||
finally:
|
||||
ray.util.disconnect()
|
||||
ray._private.client_mode_hook._explicitly_disable_client_mode()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -387,28 +384,23 @@ def test_conda_create_ray_client(call_ray_start):
|
|||
]
|
||||
}
|
||||
}
|
||||
try:
|
||||
ray.client("localhost:24001").env(runtime_env).connect()
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
import pip_install_test # noqa
|
||||
return True
|
||||
@ray.remote
|
||||
def f():
|
||||
import pip_install_test # noqa
|
||||
return True
|
||||
|
||||
with ray.client("localhost:24001").env(runtime_env).connect():
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
# Ensure pip-install-test is not installed on the test machine
|
||||
import pip_install_test # noqa
|
||||
assert ray.get(f.remote())
|
||||
|
||||
ray.util.disconnect()
|
||||
ray.client("localhost:24001").connect()
|
||||
with ray.client("localhost:24001").connect():
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
# Ensure pip-install-test is not installed in a client that doesn't
|
||||
# use the runtime_env
|
||||
ray.get(f.remote())
|
||||
finally:
|
||||
ray.util.disconnect()
|
||||
ray._private.client_mode_hook._explicitly_disable_client_mode()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
Loading…
Add table
Reference in a new issue