Enable to specify driver id by user. (#3084)

This commit is contained in:
Wang Qing 2018-11-03 10:01:50 +08:00 committed by Robert Nishihara
parent 5ce7ed7dad
commit ca7d4c2cf5
2 changed files with 38 additions and 1 deletions

View file

@ -1279,6 +1279,7 @@ def _init(address_info=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
driver_id=None,
plasma_store_socket_name=None,
raylet_socket_name=None,
temp_dir=None):
@ -1336,6 +1337,7 @@ def _init(address_info=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which is a Jupyter notebook.
driver_id: The ID of driver.
plasma_store_socket_name (str): If provided, it will specify the socket
name used by the plasma store.
raylet_socket_name (str): If provided, it will specify the socket path
@ -1455,6 +1457,7 @@ def _init(address_info=None,
if raylet_socket_name is not None:
raise Exception("When connecting to an existing cluster, "
"raylet_socket_name must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
@ -1485,6 +1488,7 @@ def _init(address_info=None,
object_id_seed=object_id_seed,
mode=driver_mode,
worker=global_worker,
driver_id=driver_id,
redis_password=redis_password)
return address_info
@ -1508,6 +1512,7 @@ def init(redis_address=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
driver_id=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
@ -1573,6 +1578,7 @@ def init(redis_address=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which is a Jupyter notebook.
driver_id: The ID of driver.
configure_logging: True if allow the logging cofiguration here.
Otherwise, the users may want to configure it by their own.
logging_level: Logging level, default will be loging.INFO.
@ -1638,6 +1644,7 @@ def init(redis_address=None,
huge_pages=huge_pages,
include_webui=include_webui,
object_store_memory=object_store_memory,
driver_id=driver_id,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@ -1829,6 +1836,7 @@ def connect(info,
object_id_seed=None,
mode=WORKER_MODE,
worker=global_worker,
driver_id=None,
redis_password=None):
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
@ -1839,6 +1847,7 @@ def connect(info,
deterministic.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
LOCAL_MODE.
driver_id: The ID of driver. If it's None, then we will generate one.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
"""
@ -1846,8 +1855,20 @@ def connect(info,
error_message = "Perhaps you called ray.init twice by accident?"
assert not worker.connected, error_message
assert worker.cached_functions_to_run is not None, error_message
# Initialize some fields.
worker.worker_id = random_string()
if mode is WORKER_MODE:
worker.worker_id = random_string()
else:
# This is the code path of driver mode.
if driver_id is None:
driver_id = ray.ObjectID(random_string())
if not isinstance(driver_id, ray.ObjectID):
raise Exception(
"The type of given driver id must be ray.ObjectID.")
worker.worker_id = driver_id.id()
# When tasks are executed on remote workers in the context of multiple
# drivers, the task driver ID is used to keep track of which driver is

View file

@ -2246,6 +2246,22 @@ def test_workers(shutdown_only):
assert "stdout_file" in info
def test_specific_driver_id():
dummy_driver_id = ray.ObjectID(b"00112233445566778899")
ray.init(driver_id=dummy_driver_id)
@ray.remote
def f():
return ray.worker.global_worker.task_driver_id.id()
assert_equal(dummy_driver_id.id(), ray.worker.global_worker.worker_id)
task_driver_id = ray.get(f.remote())
assert_equal(dummy_driver_id.id(), task_driver_id)
ray.shutdown()
@pytest.fixture
def shutdown_only_with_initialization_check():
yield None