mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
Enable to specify driver id by user. (#3084)
This commit is contained in:
parent
5ce7ed7dad
commit
ca7d4c2cf5
2 changed files with 38 additions and 1 deletions
|
@ -1279,6 +1279,7 @@ def _init(address_info=None,
|
|||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=True,
|
||||
driver_id=None,
|
||||
plasma_store_socket_name=None,
|
||||
raylet_socket_name=None,
|
||||
temp_dir=None):
|
||||
|
@ -1336,6 +1337,7 @@ def _init(address_info=None,
|
|||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which is a Jupyter notebook.
|
||||
driver_id: The ID of driver.
|
||||
plasma_store_socket_name (str): If provided, it will specify the socket
|
||||
name used by the plasma store.
|
||||
raylet_socket_name (str): If provided, it will specify the socket path
|
||||
|
@ -1455,6 +1457,7 @@ def _init(address_info=None,
|
|||
if raylet_socket_name is not None:
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"raylet_socket_name must not be provided.")
|
||||
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
|
@ -1485,6 +1488,7 @@ def _init(address_info=None,
|
|||
object_id_seed=object_id_seed,
|
||||
mode=driver_mode,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id,
|
||||
redis_password=redis_password)
|
||||
return address_info
|
||||
|
||||
|
@ -1508,6 +1512,7 @@ def init(redis_address=None,
|
|||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=True,
|
||||
driver_id=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
|
@ -1573,6 +1578,7 @@ def init(redis_address=None,
|
|||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which is a Jupyter notebook.
|
||||
driver_id: The ID of driver.
|
||||
configure_logging: True if allow the logging cofiguration here.
|
||||
Otherwise, the users may want to configure it by their own.
|
||||
logging_level: Logging level, default will be loging.INFO.
|
||||
|
@ -1638,6 +1644,7 @@ def init(redis_address=None,
|
|||
huge_pages=huge_pages,
|
||||
include_webui=include_webui,
|
||||
object_store_memory=object_store_memory,
|
||||
driver_id=driver_id,
|
||||
plasma_store_socket_name=plasma_store_socket_name,
|
||||
raylet_socket_name=raylet_socket_name,
|
||||
temp_dir=temp_dir)
|
||||
|
@ -1829,6 +1836,7 @@ def connect(info,
|
|||
object_id_seed=None,
|
||||
mode=WORKER_MODE,
|
||||
worker=global_worker,
|
||||
driver_id=None,
|
||||
redis_password=None):
|
||||
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
||||
|
||||
|
@ -1839,6 +1847,7 @@ def connect(info,
|
|||
deterministic.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
|
||||
LOCAL_MODE.
|
||||
driver_id: The ID of driver. If it's None, then we will generate one.
|
||||
redis_password (str): Prevents external clients without the password
|
||||
from connecting to Redis if provided.
|
||||
"""
|
||||
|
@ -1846,8 +1855,20 @@ def connect(info,
|
|||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
assert not worker.connected, error_message
|
||||
assert worker.cached_functions_to_run is not None, error_message
|
||||
|
||||
# Initialize some fields.
|
||||
worker.worker_id = random_string()
|
||||
if mode is WORKER_MODE:
|
||||
worker.worker_id = random_string()
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = ray.ObjectID(random_string())
|
||||
|
||||
if not isinstance(driver_id, ray.ObjectID):
|
||||
raise Exception(
|
||||
"The type of given driver id must be ray.ObjectID.")
|
||||
|
||||
worker.worker_id = driver_id.id()
|
||||
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
|
|
|
@ -2246,6 +2246,22 @@ def test_workers(shutdown_only):
|
|||
assert "stdout_file" in info
|
||||
|
||||
|
||||
def test_specific_driver_id():
|
||||
dummy_driver_id = ray.ObjectID(b"00112233445566778899")
|
||||
ray.init(driver_id=dummy_driver_id)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return ray.worker.global_worker.task_driver_id.id()
|
||||
|
||||
assert_equal(dummy_driver_id.id(), ray.worker.global_worker.worker_id)
|
||||
|
||||
task_driver_id = ray.get(f.remote())
|
||||
assert_equal(dummy_driver_id.id(), task_driver_id)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shutdown_only_with_initialization_check():
|
||||
yield None
|
||||
|
|
Loading…
Add table
Reference in a new issue