diff --git a/python/ray/worker.py b/python/ray/worker.py index b503aabeb..8eee516ab 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1279,6 +1279,7 @@ def _init(address_info=None, plasma_directory=None, huge_pages=False, include_webui=True, + driver_id=None, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1336,6 +1337,7 @@ def _init(address_info=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. + driver_id: The ID of driver. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1455,6 +1457,7 @@ def _init(address_info=None, if raylet_socket_name is not None: raise Exception("When connecting to an existing cluster, " "raylet_socket_name must not be provided.") + # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) @@ -1485,6 +1488,7 @@ def _init(address_info=None, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, + driver_id=driver_id, redis_password=redis_password) return address_info @@ -1508,6 +1512,7 @@ def init(redis_address=None, plasma_directory=None, huge_pages=False, include_webui=True, + driver_id=None, configure_logging=True, logging_level=logging.INFO, logging_format=ray_constants.LOGGER_FORMAT, @@ -1573,6 +1578,7 @@ def init(redis_address=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. + driver_id: The ID of driver. configure_logging: True if allow the logging cofiguration here. Otherwise, the users may want to configure it by their own. logging_level: Logging level, default will be loging.INFO. @@ -1638,6 +1644,7 @@ def init(redis_address=None, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, + driver_id=driver_id, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -1829,6 +1836,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, + driver_id=None, redis_password=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. @@ -1839,6 +1847,7 @@ def connect(info, deterministic. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. + driver_id: The ID of driver. If it's None, then we will generate one. redis_password (str): Prevents external clients without the password from connecting to Redis if provided. """ @@ -1846,8 +1855,20 @@ def connect(info, error_message = "Perhaps you called ray.init twice by accident?" assert not worker.connected, error_message assert worker.cached_functions_to_run is not None, error_message + # Initialize some fields. - worker.worker_id = random_string() + if mode is WORKER_MODE: + worker.worker_id = random_string() + else: + # This is the code path of driver mode. + if driver_id is None: + driver_id = ray.ObjectID(random_string()) + + if not isinstance(driver_id, ray.ObjectID): + raise Exception( + "The type of given driver id must be ray.ObjectID.") + + worker.worker_id = driver_id.id() # When tasks are executed on remote workers in the context of multiple # drivers, the task driver ID is used to keep track of which driver is diff --git a/test/runtest.py b/test/runtest.py index a7c4b3e16..ab006ebac 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2246,6 +2246,22 @@ def test_workers(shutdown_only): assert "stdout_file" in info +def test_specific_driver_id(): + dummy_driver_id = ray.ObjectID(b"00112233445566778899") + ray.init(driver_id=dummy_driver_id) + + @ray.remote + def f(): + return ray.worker.global_worker.task_driver_id.id() + + assert_equal(dummy_driver_id.id(), ray.worker.global_worker.worker_id) + + task_driver_id = ray.get(f.remote()) + assert_equal(dummy_driver_id.id(), task_driver_id) + + ray.shutdown() + + @pytest.fixture def shutdown_only_with_initialization_check(): yield None