mirror of
https://github.com/vale981/ray
synced 2025-03-12 06:06:39 -04:00
Enable to specify driver id by user. (#3084)
This commit is contained in:
parent
5ce7ed7dad
commit
ca7d4c2cf5
2 changed files with 38 additions and 1 deletions
|
@ -1279,6 +1279,7 @@ def _init(address_info=None,
|
||||||
plasma_directory=None,
|
plasma_directory=None,
|
||||||
huge_pages=False,
|
huge_pages=False,
|
||||||
include_webui=True,
|
include_webui=True,
|
||||||
|
driver_id=None,
|
||||||
plasma_store_socket_name=None,
|
plasma_store_socket_name=None,
|
||||||
raylet_socket_name=None,
|
raylet_socket_name=None,
|
||||||
temp_dir=None):
|
temp_dir=None):
|
||||||
|
@ -1336,6 +1337,7 @@ def _init(address_info=None,
|
||||||
Store with hugetlbfs support. Requires plasma_directory.
|
Store with hugetlbfs support. Requires plasma_directory.
|
||||||
include_webui: Boolean flag indicating whether to start the web
|
include_webui: Boolean flag indicating whether to start the web
|
||||||
UI, which is a Jupyter notebook.
|
UI, which is a Jupyter notebook.
|
||||||
|
driver_id: The ID of driver.
|
||||||
plasma_store_socket_name (str): If provided, it will specify the socket
|
plasma_store_socket_name (str): If provided, it will specify the socket
|
||||||
name used by the plasma store.
|
name used by the plasma store.
|
||||||
raylet_socket_name (str): If provided, it will specify the socket path
|
raylet_socket_name (str): If provided, it will specify the socket path
|
||||||
|
@ -1455,6 +1457,7 @@ def _init(address_info=None,
|
||||||
if raylet_socket_name is not None:
|
if raylet_socket_name is not None:
|
||||||
raise Exception("When connecting to an existing cluster, "
|
raise Exception("When connecting to an existing cluster, "
|
||||||
"raylet_socket_name must not be provided.")
|
"raylet_socket_name must not be provided.")
|
||||||
|
|
||||||
# Get the node IP address if one is not provided.
|
# Get the node IP address if one is not provided.
|
||||||
if node_ip_address is None:
|
if node_ip_address is None:
|
||||||
node_ip_address = services.get_node_ip_address(redis_address)
|
node_ip_address = services.get_node_ip_address(redis_address)
|
||||||
|
@ -1485,6 +1488,7 @@ def _init(address_info=None,
|
||||||
object_id_seed=object_id_seed,
|
object_id_seed=object_id_seed,
|
||||||
mode=driver_mode,
|
mode=driver_mode,
|
||||||
worker=global_worker,
|
worker=global_worker,
|
||||||
|
driver_id=driver_id,
|
||||||
redis_password=redis_password)
|
redis_password=redis_password)
|
||||||
return address_info
|
return address_info
|
||||||
|
|
||||||
|
@ -1508,6 +1512,7 @@ def init(redis_address=None,
|
||||||
plasma_directory=None,
|
plasma_directory=None,
|
||||||
huge_pages=False,
|
huge_pages=False,
|
||||||
include_webui=True,
|
include_webui=True,
|
||||||
|
driver_id=None,
|
||||||
configure_logging=True,
|
configure_logging=True,
|
||||||
logging_level=logging.INFO,
|
logging_level=logging.INFO,
|
||||||
logging_format=ray_constants.LOGGER_FORMAT,
|
logging_format=ray_constants.LOGGER_FORMAT,
|
||||||
|
@ -1573,6 +1578,7 @@ def init(redis_address=None,
|
||||||
Store with hugetlbfs support. Requires plasma_directory.
|
Store with hugetlbfs support. Requires plasma_directory.
|
||||||
include_webui: Boolean flag indicating whether to start the web
|
include_webui: Boolean flag indicating whether to start the web
|
||||||
UI, which is a Jupyter notebook.
|
UI, which is a Jupyter notebook.
|
||||||
|
driver_id: The ID of driver.
|
||||||
configure_logging: True if allow the logging cofiguration here.
|
configure_logging: True if allow the logging cofiguration here.
|
||||||
Otherwise, the users may want to configure it by their own.
|
Otherwise, the users may want to configure it by their own.
|
||||||
logging_level: Logging level, default will be loging.INFO.
|
logging_level: Logging level, default will be loging.INFO.
|
||||||
|
@ -1638,6 +1644,7 @@ def init(redis_address=None,
|
||||||
huge_pages=huge_pages,
|
huge_pages=huge_pages,
|
||||||
include_webui=include_webui,
|
include_webui=include_webui,
|
||||||
object_store_memory=object_store_memory,
|
object_store_memory=object_store_memory,
|
||||||
|
driver_id=driver_id,
|
||||||
plasma_store_socket_name=plasma_store_socket_name,
|
plasma_store_socket_name=plasma_store_socket_name,
|
||||||
raylet_socket_name=raylet_socket_name,
|
raylet_socket_name=raylet_socket_name,
|
||||||
temp_dir=temp_dir)
|
temp_dir=temp_dir)
|
||||||
|
@ -1829,6 +1836,7 @@ def connect(info,
|
||||||
object_id_seed=None,
|
object_id_seed=None,
|
||||||
mode=WORKER_MODE,
|
mode=WORKER_MODE,
|
||||||
worker=global_worker,
|
worker=global_worker,
|
||||||
|
driver_id=None,
|
||||||
redis_password=None):
|
redis_password=None):
|
||||||
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
||||||
|
|
||||||
|
@ -1839,6 +1847,7 @@ def connect(info,
|
||||||
deterministic.
|
deterministic.
|
||||||
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
|
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
|
||||||
LOCAL_MODE.
|
LOCAL_MODE.
|
||||||
|
driver_id: The ID of driver. If it's None, then we will generate one.
|
||||||
redis_password (str): Prevents external clients without the password
|
redis_password (str): Prevents external clients without the password
|
||||||
from connecting to Redis if provided.
|
from connecting to Redis if provided.
|
||||||
"""
|
"""
|
||||||
|
@ -1846,8 +1855,20 @@ def connect(info,
|
||||||
error_message = "Perhaps you called ray.init twice by accident?"
|
error_message = "Perhaps you called ray.init twice by accident?"
|
||||||
assert not worker.connected, error_message
|
assert not worker.connected, error_message
|
||||||
assert worker.cached_functions_to_run is not None, error_message
|
assert worker.cached_functions_to_run is not None, error_message
|
||||||
|
|
||||||
# Initialize some fields.
|
# Initialize some fields.
|
||||||
worker.worker_id = random_string()
|
if mode is WORKER_MODE:
|
||||||
|
worker.worker_id = random_string()
|
||||||
|
else:
|
||||||
|
# This is the code path of driver mode.
|
||||||
|
if driver_id is None:
|
||||||
|
driver_id = ray.ObjectID(random_string())
|
||||||
|
|
||||||
|
if not isinstance(driver_id, ray.ObjectID):
|
||||||
|
raise Exception(
|
||||||
|
"The type of given driver id must be ray.ObjectID.")
|
||||||
|
|
||||||
|
worker.worker_id = driver_id.id()
|
||||||
|
|
||||||
# When tasks are executed on remote workers in the context of multiple
|
# When tasks are executed on remote workers in the context of multiple
|
||||||
# drivers, the task driver ID is used to keep track of which driver is
|
# drivers, the task driver ID is used to keep track of which driver is
|
||||||
|
|
|
@ -2246,6 +2246,22 @@ def test_workers(shutdown_only):
|
||||||
assert "stdout_file" in info
|
assert "stdout_file" in info
|
||||||
|
|
||||||
|
|
||||||
|
def test_specific_driver_id():
|
||||||
|
dummy_driver_id = ray.ObjectID(b"00112233445566778899")
|
||||||
|
ray.init(driver_id=dummy_driver_id)
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def f():
|
||||||
|
return ray.worker.global_worker.task_driver_id.id()
|
||||||
|
|
||||||
|
assert_equal(dummy_driver_id.id(), ray.worker.global_worker.worker_id)
|
||||||
|
|
||||||
|
task_driver_id = ray.get(f.remote())
|
||||||
|
assert_equal(dummy_driver_id.id(), task_driver_id)
|
||||||
|
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def shutdown_only_with_initialization_check():
|
def shutdown_only_with_initialization_check():
|
||||||
yield None
|
yield None
|
||||||
|
|
Loading…
Add table
Reference in a new issue