diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index a2f76ee59..41a5d4ef7 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -1078,6 +1078,26 @@ def chdir(d: str): os.chdir(old_dir) +def check_local_files_gced(cluster): + for node in cluster.list_all_nodes(): + for subdir in [ + "conda", "pip", "working_dir_files", "py_modules_files" + ]: + all_files = os.listdir( + os.path.join(node.get_runtime_env_dir_path(), subdir)) + # Check that there are no files remaining except for .lock files + # and generated requirements.txt files. + # TODO(architkulkarni): these files should get cleaned up too! + if len( + list( + filter(lambda f: not f.endswith((".lock", ".txt")), + all_files))) > 0: + print(str(all_files)) + return False + + return True + + def generate_runtime_env_dict(field, spec_format, tmp_path, pip_list=None): if pip_list is None: pip_list = ["pip-install-test==0.5"] diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 95af0c42a..1ff2427fd 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -211,6 +211,7 @@ py_test_module_list( files = [ "test_runtime_env_conda_and_pip.py", "test_runtime_env_conda_and_pip_2.py", + "test_runtime_env_conda_and_pip_3.py", "test_runtime_env_complicated.py" ], size = "large", diff --git a/python/ray/tests/test_runtime_env_conda_and_pip.py b/python/ray/tests/test_runtime_env_conda_and_pip.py index 42200109a..1295f2d1f 100644 --- a/python/ray/tests/test_runtime_env_conda_and_pip.py +++ b/python/ray/tests/test_runtime_env_conda_and_pip.py @@ -2,6 +2,7 @@ import os import pytest import sys from ray._private.test_utils import (wait_for_condition, chdir, + check_local_files_gced, generate_runtime_env_dict) import yaml @@ -16,24 +17,6 @@ if not os.environ.get("CI"): os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1" -def check_local_files_gced(cluster): - for node in cluster.list_all_nodes(): - for subdir in ["conda", "pip"]: - all_files = os.listdir( - os.path.join(node.get_runtime_env_dir_path(), subdir)) - # Check that there are no files remaining except for .lock files - # and generated requirements.txt files. - # TODO(architkulkarni): these files should get cleaned up too! - if len( - list( - filter(lambda f: not f.endswith((".lock", ".txt")), - all_files))) > 0: - print(str(all_files)) - return False - - return True - - @pytest.mark.skipif( os.environ.get("CI") and sys.platform != "linux", reason="Requires PR wheels built in CI, so only run on linux CI machines.") @@ -159,37 +142,5 @@ def test_detached_actor_gc(start_cluster, field, spec_format, tmp_path): wait_for_condition(lambda: check_local_files_gced(cluster), timeout=30) -# TODO(architkulkarni): fix bug #19602 and enable test. -@pytest.mark.skip("Currently failing") -@pytest.mark.skipif( - os.environ.get("CI") and sys.platform != "linux", - reason="Requires PR wheels built in CI, so only run on linux CI machines.") -@pytest.mark.parametrize("field", ["conda", "pip"]) -@pytest.mark.parametrize("spec_format", ["file", "python_object"]) -def test_actor_level_gc(start_cluster, field, spec_format, tmp_path): - """Tests that actor-level working_dir is GC'd when the actor exits.""" - cluster, address = start_cluster - - ray.init(address) - - runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path) - - @ray.remote - class A: - def test_import(self): - import pip_install_test # noqa: F401 - return True - - NUM_ACTORS = 5 - actors = [ - A.options(runtime_env=runtime_env).remote() for _ in range(NUM_ACTORS) - ] - ray.get([a.test_import.remote() for a in actors]) - for i in range(5): - assert not check_local_files_gced(cluster) - ray.kill(actors[i]) - wait_for_condition(lambda: check_local_files_gced(cluster)) - - if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_conda_and_pip_3.py b/python/ray/tests/test_runtime_env_conda_and_pip_3.py new file mode 100644 index 000000000..145d3eb1b --- /dev/null +++ b/python/ray/tests/test_runtime_env_conda_and_pip_3.py @@ -0,0 +1,124 @@ +import os +import pytest +import sys +from ray._private.test_utils import ( + wait_for_condition, check_local_files_gced, generate_runtime_env_dict) +import ray + +if not os.environ.get("CI"): + # This flags turns on the local development that link against current ray + # packages and fall back all the dependencies to current python's site. + os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1" + + +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="Requires PR wheels built in CI, so only run on linux CI machines.") +@pytest.mark.parametrize("field", ["conda", "pip"]) +@pytest.mark.parametrize("spec_format", ["file", "python_object"]) +def test_actor_level_gc(start_cluster, field, spec_format, tmp_path): + """Tests that actor-level working_dir is GC'd when the actor exits.""" + cluster, address = start_cluster + + ray.init(address) + + runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path) + + @ray.remote + class A: + def test_import(self): + import pip_install_test # noqa: F401 + return True + + NUM_ACTORS = 5 + actors = [ + A.options(runtime_env=runtime_env).remote() for _ in range(NUM_ACTORS) + ] + ray.get([a.test_import.remote() for a in actors]) + for i in range(5): + assert not check_local_files_gced(cluster) + ray.kill(actors[i]) + wait_for_condition(lambda: check_local_files_gced(cluster)) + + +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="Requires PR wheels built in CI, so only run on linux CI machines.") +@pytest.mark.parametrize( + "ray_start_cluster", [ + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 0, + }, + }, + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 5, + }, + }, + ], + indirect=True) +@pytest.mark.parametrize("field", ["conda", "pip"]) +@pytest.mark.parametrize("spec_format", ["file", "python_object"]) +def test_task_level_gc(ray_start_cluster, field, spec_format, tmp_path): + """Tests that task-level working_dir is GC'd when the actor exits.""" + + cluster = ray_start_cluster + + soft_limit_zero = False + system_config = cluster.list_all_nodes()[0]._ray_params._system_config + if "num_workers_soft_limit" in system_config and \ + system_config["num_workers_soft_limit"] == 0: + soft_limit_zero = True + + runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path) + + @ray.remote + def f(): + import pip_install_test # noqa: F401 + return True + + @ray.remote + class A: + def test_import(self): + import pip_install_test # noqa: F401 + return True + + # Start a task with runtime env + ray.get(f.options(runtime_env=runtime_env).remote()) + if soft_limit_zero: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + # Start a actor with runtime env + actor = A.options(runtime_env=runtime_env).remote() + ray.get(actor.test_import.remote()) + # Local files should not be gced + assert not check_local_files_gced(cluster) + + # Kill actor + ray.kill(actor) + if soft_limit_zero: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + # Start a task with runtime env + ray.get(f.options(runtime_env=runtime_env).remote()) + if soft_limit_zero: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/tests/test_runtime_env_working_dir.py b/python/ray/tests/test_runtime_env_working_dir.py index 058cb7eb8..6064622c0 100644 --- a/python/ray/tests/test_runtime_env_working_dir.py +++ b/python/ray/tests/test_runtime_env_working_dir.py @@ -7,6 +7,7 @@ import tempfile import pytest import ray +import time # This test requires you have AWS credentials set up (any AWS credentials will # do, this test only accesses a public bucket). @@ -59,6 +60,12 @@ def test_lazy_reads(start_cluster, tmp_working_dir, option: str): def reinit(): ray.shutdown() + # TODO(SongGuyang): Currently, reinit the driver will generate the same + # job id. And if we reinit immediately after shutdown, raylet may + # process new job started before old job finished in some cases. This + # inconsistency could disorder the URI reference and delete a valid + # runtime env. We sleep here to walk around this issue. + time.sleep(5) call_ray_init() @ray.remote @@ -135,6 +142,12 @@ def test_captured_import(start_cluster, tmp_working_dir, option: str): def reinit(): ray.shutdown() + # TODO(SongGuyang): Currently, reinit the driver will generate the same + # job id. And if we reinit immediately after shutdown, raylet may + # process new job started before old job finished in some cases. This + # inconsistency could disorder the URI reference and delete a valid + # runtime env. We sleep here to walk around this issue. + time.sleep(5) call_ray_init() # Import in the driver. diff --git a/python/ray/tests/test_runtime_env_working_dir_2.py b/python/ray/tests/test_runtime_env_working_dir_2.py index c8d32c0af..8973eb781 100644 --- a/python/ray/tests/test_runtime_env_working_dir_2.py +++ b/python/ray/tests/test_runtime_env_working_dir_2.py @@ -9,7 +9,8 @@ from ray._private.test_utils import run_string_as_driver import ray import ray.experimental.internal_kv as kv -from ray._private.test_utils import wait_for_condition, chdir +from ray._private.test_utils import (wait_for_condition, chdir, + check_local_files_gced) from ray._private.runtime_env import RAY_WORKER_DEV_EXCLUDES from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE # This test requires you have AWS credentials set up (any AWS credentials will @@ -239,20 +240,6 @@ def check_internal_kv_gced(): return len(kv._internal_kv_list("gcs://")) == 0 -def check_local_files_gced(cluster): - for node in cluster.list_all_nodes(): - for subdir in ["working_dir_files", "py_modules_files"]: - all_files = os.listdir( - os.path.join(node.get_runtime_env_dir_path(), subdir)) - # Check that there are no files remaining except for .lock files. - # TODO(edoakes): the lock files should get cleaned up too! - if len(list(filter(lambda f: not f.endswith(".lock"), - all_files))) > 0: - return False - - return True - - @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") @pytest.mark.parametrize("option", ["working_dir", "py_modules"]) @pytest.mark.parametrize( @@ -304,8 +291,6 @@ def test_job_level_gc(start_cluster, option: str, source: str): wait_for_condition(lambda: check_local_files_gced(cluster)) -# TODO(architkulkarni): fix bug #19602 and enable test. -@pytest.mark.skip("Currently failing.") @pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") @pytest.mark.parametrize("option", ["working_dir", "py_modules"]) def test_actor_level_gc(start_cluster, option: str): diff --git a/python/ray/tests/test_runtime_env_working_dir_3.py b/python/ray/tests/test_runtime_env_working_dir_3.py new file mode 100644 index 000000000..85003c4fd --- /dev/null +++ b/python/ray/tests/test_runtime_env_working_dir_3.py @@ -0,0 +1,147 @@ +import os +import sys + +import pytest + +import ray +from ray.exceptions import GetTimeoutError +from ray._private.test_utils import (wait_for_condition, + check_local_files_gced) +# This test requires you have AWS credentials set up (any AWS credentials will +# do, this test only accesses a public bucket). + +# This package contains a subdirectory called `test_module`. +# Calling `test_module.one()` should return `2`. +# If you find that confusing, take it up with @jiaodong... +S3_PACKAGE_URI = "s3://runtime-env-test/test_runtime_env.zip" + + +@pytest.mark.skipif( + os.environ.get("CI") and sys.platform != "linux", + reason="Requires PR wheels built in CI, so only run on linux CI machines.") +@pytest.mark.parametrize( + "ray_start_cluster", + [ + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 0, + }, + }, + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 5, + }, + }, + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 0, + # this delay will make worker start slow and time out + "testing_asio_delay_us": "InternalKVGcsService.grpc_server" + ".InternalKVGet=2000000:2000000", + "worker_register_timeout_seconds": 1, + }, + }, + { + "num_nodes": 1, + "_system_config": { + "num_workers_soft_limit": 5, + # this delay will make worker start slow and time out + "testing_asio_delay_us": "InternalKVGcsService.grpc_server" + ".InternalKVGet=2000000:2000000", + "worker_register_timeout_seconds": 1, + }, + }, + ], + indirect=True) +@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.") +@pytest.mark.parametrize("option", ["working_dir", "py_modules"]) +def test_task_level_gc(ray_start_cluster, option): + """Tests that task-level working_dir is GC'd when the worker exits.""" + + cluster = ray_start_cluster + + soft_limit_zero = False + worker_register_timeout = False + system_config = cluster.list_all_nodes()[0]._ray_params._system_config + if "num_workers_soft_limit" in system_config and \ + system_config["num_workers_soft_limit"] == 0: + soft_limit_zero = True + if "worker_register_timeout_seconds" in system_config and \ + system_config["worker_register_timeout_seconds"] != 0: + worker_register_timeout = True + + @ray.remote + def f(): + import test_module + test_module.one() + + @ray.remote(num_cpus=1) + class A: + def check(self): + import test_module + test_module.one() + + if option == "working_dir": + runtime_env = {"working_dir": S3_PACKAGE_URI} + else: + runtime_env = {"py_modules": [S3_PACKAGE_URI]} + + # Note: We should set a bigger timeout if downloads the s3 package slowly. + get_timeout = 10 + + # Start a task with runtime env + if worker_register_timeout: + with pytest.raises(GetTimeoutError): + ray.get( + f.options(runtime_env=runtime_env).remote(), + timeout=get_timeout) + else: + ray.get(f.options(runtime_env=runtime_env).remote()) + if soft_limit_zero or worker_register_timeout: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + # Start a actor with runtime env + actor = A.options(runtime_env=runtime_env).remote() + if worker_register_timeout: + with pytest.raises(GetTimeoutError): + ray.get(actor.check.remote(), timeout=get_timeout) + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + ray.get(actor.check.remote()) + assert not check_local_files_gced(cluster) + + # Kill actor + ray.kill(actor) + if soft_limit_zero or worker_register_timeout: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + # Start a task with runtime env + if worker_register_timeout: + with pytest.raises(GetTimeoutError): + ray.get( + f.options(runtime_env=runtime_env).remote(), + timeout=get_timeout) + else: + ray.get(f.options(runtime_env=runtime_env).remote()) + if soft_limit_zero or worker_register_timeout: + # Wait for worker exited and local files gced + wait_for_condition(lambda: check_local_files_gced(cluster)) + else: + # Local files should not be gced because of an enough soft limit. + assert not check_local_files_gced(cluster) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index cf3919d51..20e421737 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -367,6 +367,10 @@ RAY_CONFIG(uint64_t, kill_idle_workers_interval_ms, 200) /// The idle time threshold for an idle worker to be killed. RAY_CONFIG(int64_t, idle_worker_killing_time_threshold_ms, 1000) +/// The soft limit of the number of workers. +/// -1 means using num_cpus instead. +RAY_CONFIG(int64_t, num_workers_soft_limit, -1) + // The interval where metrics are exported in milliseconds. RAY_CONFIG(uint64_t, metrics_report_interval_ms, 10000) diff --git a/src/ray/common/runtime_env_manager.cc b/src/ray/common/runtime_env_manager.cc index a6861b0bd..6b96394db 100644 --- a/src/ray/common/runtime_env_manager.cc +++ b/src/ray/common/runtime_env_manager.cc @@ -23,7 +23,7 @@ void RuntimeEnvManager::AddURIReference(const std::string &hex_id, for (const auto &uri : uris) { uri_reference_[uri]++; id_to_uris_[hex_id].push_back(uri); - RAY_LOG(DEBUG) << "Added URI Reference " << uri; + RAY_LOG(DEBUG) << "Added URI Reference " << uri << " for id " << hex_id; } } @@ -35,6 +35,7 @@ const std::vector &RuntimeEnvManager::GetReferences( } void RuntimeEnvManager::RemoveURIReference(const std::string &hex_id) { + RAY_LOG(DEBUG) << "Subtracting 1 from URI Reference for id " << hex_id; if (!id_to_uris_.count(hex_id)) { return; } diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 41829612b..274f0dfad 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -171,7 +171,9 @@ int main(int argc, char *argv[]) { << node_manager_config.resource_config.ToString(); node_manager_config.node_manager_address = node_ip_address; node_manager_config.node_manager_port = node_manager_port; - node_manager_config.num_workers_soft_limit = num_cpus; + auto soft_limit_config = RayConfig::instance().num_workers_soft_limit(); + node_manager_config.num_workers_soft_limit = + soft_limit_config >= 0 ? soft_limit_config : num_cpus; node_manager_config.num_initial_python_workers_for_first_job = num_initial_python_workers_for_first_job; node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 954ace379..337e0b2e9 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -299,13 +299,6 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self global_gc_throttler_(RayConfig::instance().global_gc_min_interval_s() * 1e9), local_gc_interval_ns_(RayConfig::instance().local_gc_interval_s() * 1e9), record_metrics_period_ms_(config.record_metrics_period_ms), - runtime_env_manager_( - /*deleter=*/[this](const std::string &uri, std::function cb) { - if (RayConfig::instance().runtime_env_skip_local_gc()) { - return cb(true); - } - return agent_manager_->DeleteURIs({uri}, cb); - }), next_resource_seq_no_(0) { RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_; RAY_CHECK(RayConfig::instance().raylet_heartbeat_period_milliseconds() > 0); @@ -554,11 +547,6 @@ void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_ << job_data.driver_pid() << " is dead: " << job_data.is_dead() << " driver address: " << job_data.driver_ip_address(); worker_pool_.HandleJobStarted(job_id, job_data.config()); - // NOTE: Technically `HandleJobStarted` isn't idempotent because we'll - // increment the ref count multiple times. This is fine because - // `HandleJobFinisehd` will also decrement the ref count multiple times. - runtime_env_manager_.AddURIReference(job_id.Hex(), - job_data.config().runtime_env_info()); // Tasks of this job may already arrived but failed to pop a worker because the job // config is not local yet. So we trigger dispatching again here to try to // reschedule these tasks. @@ -569,7 +557,6 @@ void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job RAY_LOG(DEBUG) << "HandleJobFinished " << job_id; RAY_CHECK(job_data.is_dead()); worker_pool_.HandleJobFinished(job_id); - runtime_env_manager_.RemoveURIReference(job_id.Hex()); } void NodeManager::FillNormalTaskResourceUsage(rpc::ResourcesData &resources_data) { @@ -1271,10 +1258,6 @@ void NodeManager::DisconnectClient(const std::shared_ptr &clie cluster_task_manager_->TaskFinished(worker, &task); } - if (worker->IsDetachedActor()) { - runtime_env_manager_.RemoveURIReference(actor_id.Hex()); - } - if (disconnect_type == rpc::WorkerExitType::SYSTEM_ERROR_EXIT) { // Push the error to driver. const JobID &job_id = worker->GetAssignedJobId(); @@ -1955,8 +1938,6 @@ void NodeManager::FinishAssignedActorCreationTask(WorkerInterface &worker, auto job_id = task.GetTaskSpecification().JobId(); auto job_config = worker_pool_.GetJobConfig(job_id); RAY_CHECK(job_config); - runtime_env_manager_.AddURIReference(actor_id.Hex(), - task.GetTaskSpecification().RuntimeEnvInfo()); } } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 8965bba38..a508ed7e5 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -28,7 +28,6 @@ #include "ray/object_manager/object_manager.h" #include "ray/raylet/agent_manager.h" #include "ray/raylet_client/raylet_client.h" -#include "ray/common/runtime_env_manager.h" #include "ray/raylet/local_object_manager.h" #include "ray/raylet/scheduling/scheduling_ids.h" #include "ray/raylet/scheduling/cluster_resource_scheduler.h" @@ -747,9 +746,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Managers all bundle-related operations. std::shared_ptr placement_group_resource_manager_; - /// Manage all runtime env locally - RuntimeEnvManager runtime_env_manager_; - /// Next resource broadcast seq no. Non-incrementing sequence numbers /// indicate network issues (dropped/duplicated/ooo packets, etc). int64_t next_resource_seq_no_; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index c118d77ba..aa34231f5 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -34,6 +34,10 @@ DEFINE_stats(worker_register_time_ms, "end to end latency of register a worker p namespace { +// Add this prefix because the worker setup token is just a counter which is easy to +// duplicate with other ids. +static const std::string kWorkerSetupTokenPrefix = "worker_startup_token:"; + // A helper function to get a worker from a list. std::shared_ptr GetWorker( const std::unordered_set> &worker_pool, @@ -85,7 +89,14 @@ WorkerPool::WorkerPool(instrumented_io_context &io_service, const NodeID node_id num_initial_python_workers_for_first_job, maximum_startup_concurrency)), num_initial_python_workers_for_first_job_(num_initial_python_workers_for_first_job), periodical_runner_(io_service), - get_time_(get_time) { + get_time_(get_time), + runtime_env_manager_( + /*deleter=*/[this](const std::string &uri, std::function cb) { + if (RayConfig::instance().runtime_env_skip_local_gc()) { + return cb(true); + } + return agent_manager_->DeleteURIs({uri}, cb); + }) { RAY_CHECK(maximum_startup_concurrency > 0); // We need to record so that the metric exists. This way, we report that 0 // processes have started before a task runs on the node (as opposed to the @@ -182,11 +193,32 @@ void WorkerPool::update_worker_startup_token_counter() { worker_startup_token_counter_ += 1; } +void WorkerPool::AddStartingWorkerProcess( + State &state, const int workers_to_start, const rpc::WorkerType worker_type, + const Process &proc, const std::chrono::high_resolution_clock::time_point &start, + const rpc::RuntimeEnvInfo &runtime_env_info) { + state.starting_worker_processes.emplace( + worker_startup_token_counter_, + StartingWorkerProcessInfo{workers_to_start, workers_to_start, worker_type, proc, + start, runtime_env_info}); + runtime_env_manager_.AddURIReference( + kWorkerSetupTokenPrefix + std::to_string(worker_startup_token_counter_), + runtime_env_info); +} + +void WorkerPool::RemoveStartingWorkerProcess(State &state, + const StartupToken &proc_startup_token) { + state.starting_worker_processes.erase(proc_startup_token); + runtime_env_manager_.RemoveURIReference(kWorkerSetupTokenPrefix + + std::to_string(proc_startup_token)); +} + std::tuple WorkerPool::StartWorkerProcess( const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, PopWorkerStatus *status, const std::vector &dynamic_options, const int runtime_env_hash, const std::string &serialized_runtime_env_context, - const std::string &allocated_instances_serialized_json) { + const std::string &allocated_instances_serialized_json, + const rpc::RuntimeEnvInfo &runtime_env_info) { rpc::JobConfig *job_config = nullptr; if (!IsIOWorkerType(worker_type)) { RAY_CHECK(!job_id.IsNil()); @@ -415,13 +447,12 @@ std::tuple WorkerPool::StartWorkerProcess( stats::ProcessStartupTimeMs.Record(duration.count()); stats::NumWorkersStarted.Record(1); RAY_LOG(INFO) << "Started worker process of " << workers_to_start - << " worker(s) with pid " << proc.GetId(); + << " worker(s) with pid " << proc.GetId() << ", the token " + << worker_startup_token_counter_; MonitorStartingWorkerProcess(proc, worker_startup_token_counter_, language, worker_type); - state.starting_worker_processes.emplace( - worker_startup_token_counter_, - StartingWorkerProcessInfo{workers_to_start, workers_to_start, worker_type, proc, - start}); + AddStartingWorkerProcess(state, workers_to_start, worker_type, proc, start, + runtime_env_info); StartupToken worker_startup_token = worker_startup_token_counter_; update_worker_startup_token_counter(); if (IsIOWorkerType(worker_type)) { @@ -471,7 +502,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, proc_startup_token, nullptr, status, &found, &used, &task_id); } - state.starting_worker_processes.erase(it); + RemoveStartingWorkerProcess(state, proc_startup_token); if (IsIOWorkerType(worker_type)) { // Mark the I/O worker as failed. auto &io_worker_state = GetIOWorkerStateFromWorkerType(worker_type, state); @@ -560,34 +591,53 @@ void WorkerPool::MarkPortAsFree(int port) { } } -void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { - all_jobs_[job_id] = job_config; +static bool NeedToEagerInstallRuntimeEnv(const rpc::JobConfig &job_config) { if (job_config.has_runtime_env_info() && job_config.runtime_env_info().runtime_env_eager_install()) { auto const &runtime_env = job_config.runtime_env_info().serialized_runtime_env(); if (runtime_env != "{}" && runtime_env != "") { - RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id - << ". The runtime environment was " << runtime_env << "."; - CreateRuntimeEnv( - runtime_env, job_id, - [job_id](bool successful, const std::string &serialized_runtime_env_context) { - if (successful) { - RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " - << job_id << ". The result context was " - << serialized_runtime_env_context << "."; - } else { - RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " - << job_id << "."; - } - }); + return true; } } + return false; +} + +void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { + all_jobs_[job_id] = job_config; + if (NeedToEagerInstallRuntimeEnv(job_config)) { + auto const &runtime_env = job_config.runtime_env_info().serialized_runtime_env(); + // NOTE: Technically `HandleJobStarted` isn't idempotent because we'll + // increment the ref count multiple times. This is fine because + // `HandleJobFinished` will also decrement the ref count multiple times. + runtime_env_manager_.AddURIReference(job_id.Hex(), job_config.runtime_env_info()); + RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id + << ". The runtime environment was " << runtime_env << "."; + CreateRuntimeEnv( + runtime_env, job_id, + [job_id](bool successful, const std::string &serialized_runtime_env_context) { + if (successful) { + RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id + << ". The result context was " << serialized_runtime_env_context + << "."; + } else { + RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job " + << job_id << "."; + } + }); + } } void WorkerPool::HandleJobFinished(const JobID &job_id) { // Currently we don't erase the job from `all_jobs_` , as a workaround for // https://github.com/ray-project/ray/issues/11437. // unfinished_jobs_.erase(job_id); + auto job_config = GetJobConfig(job_id); + RAY_CHECK(job_config); + // Check eager install here because we only add URI reference when runtime + // env install really happens. + if (NeedToEagerInstallRuntimeEnv(*job_config)) { + runtime_env_manager_.RemoveURIReference(job_id.Hex()); + } finished_jobs_.insert(job_id); } @@ -608,7 +658,7 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr &worker if (it == state.starting_worker_processes.end()) { RAY_LOG(WARNING) << "Received a register request from an unknown worker shim process: " - << worker_shim_pid; + << worker_shim_pid << ", token: " << worker_startup_token; Status status = Status::Invalid("Unknown worker"); send_reply_callback(status, /*port=*/0); return status; @@ -651,9 +701,11 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr &worker) auto it = state.starting_worker_processes.find(worker_startup_token); if (it != state.starting_worker_processes.end()) { + runtime_env_manager_.AddURIReference(worker->WorkerId().Hex(), + it->second.runtime_env_info); it->second.num_starting_workers--; if (it->second.num_starting_workers == 0) { - state.starting_worker_processes.erase(it); + RemoveStartingWorkerProcess(state, worker_startup_token); // We may have slots to start more workers now. TryStartIOWorkers(worker->GetLanguage()); } @@ -1047,7 +1099,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, auto [proc, startup_token] = StartWorkerProcess( task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status, dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env_context, - allocated_instances_serialized_json); + allocated_instances_serialized_json, task_spec.RuntimeEnvInfo()); if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); @@ -1210,6 +1262,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker, rpc::WorkerExitType disconnect_type) { + runtime_env_manager_.RemoveURIReference(worker->WorkerId().Hex()); auto &state = GetStateForLanguage(worker->GetLanguage()); RAY_CHECK(RemoveWorker(state.registered_workers, worker)); RAY_UNUSED(RemoveWorker(state.pending_disconnection_workers, worker)); diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 447e602b2..5a02120bf 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -28,6 +28,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/periodical_runner.h" #include "ray/common/client_connection.h" +#include "ray/common/runtime_env_manager.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" #include "ray/gcs/gcs_client/gcs_client.h" @@ -405,6 +406,7 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param serialized_runtime_env_context The context of runtime env. /// \param allocated_instances_serialized_json The allocated resource instances // json string. + /// \param runtime_env_info The raw runtime env info. /// \return The process that we started and a token. If the token is less than 0, /// we didn't start a process. std::tuple StartWorkerProcess( @@ -413,7 +415,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { const std::vector &dynamic_options = {}, const int runtime_env_hash = 0, const std::string &serialized_runtime_env_context = "{}", - const std::string &allocated_instances_serialized_json = "{}"); + const std::string &allocated_instances_serialized_json = "{}", + const rpc::RuntimeEnvInfo &runtime_env_info = rpc::RuntimeEnvInfo()); /// The implementation of how to start a new worker process with command arguments. /// The lifetime of the process is tied to that of the returned object, @@ -463,6 +466,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { Process proc; /// The worker process start time. std::chrono::high_resolution_clock::time_point start_time; + /// The runtime env Info. + rpc::RuntimeEnvInfo runtime_env_info; }; struct TaskWaitingForWorkerInfo { @@ -620,6 +625,13 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { const std::function &callback, const std::string &serialized_allocated_resource_instances = "{}"); + void AddStartingWorkerProcess( + State &state, const int workers_to_start, const rpc::WorkerType worker_type, + const Process &proc, const std::chrono::high_resolution_clock::time_point &start, + const rpc::RuntimeEnvInfo &runtime_env_info); + + void RemoveStartingWorkerProcess(State &state, const StartupToken &proc_startup_token); + /// For Process class for managing subprocesses (e.g. reaping zombies). instrumented_io_context *io_service_; /// Node ID of the current node. @@ -684,6 +696,146 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// Agent manager. std::shared_ptr agent_manager_; + /// Manage all runtime env resources locally by URI reference. We add or remove URI + /// reference in the cases below: + /// For the job with an eager installed runtime env: + /// - Add URI reference when job started. + /// - Remove URI reference when job finished. + /// For the worker process with a valid runtime env: + /// - Add URI reference when worker process started. + /// - Remove URI reference when all the worker instance registration completed or any + /// worker instance registration times out. + /// - Add URI reference when a worker instance registered. + /// - Remove URI reference when a worker instance disconnected. + /// + /// A normal state change flow is: + /// job level: + /// HandleJobStarted(ref + 1 = 1) -> HandleJobFinshed(ref - 1 = 0) + /// worker level: + /// StartWorkerProcess(ref + 1 = 1) + /// -> OnWorkerStarted * 3 (ref + 3 = 4) + /// -> All worker instances registered (ref - 1 = 3) + /// -> DisconnectWorker * 3 (ref - 3 = 0) + /// + /// A state change flow for worker timeout case is: + /// StartWorkerProcess(ref + 1 = 1) + /// -> OnWorkerStarted * 2 (ref + 2 = 3) + /// -> One worker registration times out, kill worker process (ref - 1 = 2) + /// -> DisconnectWorker * 2 (ref - 2 = 0) + /// + /// Note: "OnWorkerStarted * 3" means that three workers are started. And we assume + /// that the worker process has tree worker instances totally. + /// + /// An example to show reference table changes: + /// + /// Start a job with eager installed runtime env: + /// ray.init(runtime_env= + /// { + /// "conda": { + /// "dependencies": ["requests"] + /// }, + /// "eager_install": True + /// }) + /// Add URI reference and get the reference tables: + /// +---------------------------------------------+ + /// | id | URIs | + /// +--------------------+------------------------+ + /// | job-id-hex-A | conda://{hash-A} | + /// +---------------------------------------------+ + /// +---------------------------------------------+ + /// | URI | count | + /// +--------------------+------------------------+ + /// | conda://{hash-A} | 1 | + /// +---------------------------------------------+ + /// + /// Create actor with runtime env: + /// @ray.remote + /// class actor: + /// def Foo(): + /// import my_module + /// pass + /// actor.options(runtime_env= + /// { + /// "conda": { + /// "dependencies": ["requests"] + /// }, + /// "py_modules": [ + /// "s3://bucket/my_module.zip" + /// ] + /// }).remote() + /// First step when worker process started, add URI reference and get the reference + /// tables: + /// +-------------------------------------------------------------------+ + /// | id | URIs | + /// +----------------------+--------------------------------------------+ + /// | job-id-hex-A | conda://{hash-A} | + /// | worker-setup-token-A | conda://{hash-A},s3://bucket/my_module.zip | + /// +-------------------------------------------------------------------+ + /// +------------------------------------------------------+ + /// | URI | count | + /// +-----------------------------+------------------------+ + /// | conda://{hash-A} | 2 | + /// | s3://bucket/my_module.zip | 1 | + /// +------------------------------------------------------+ + /// Second step when worker instance registers, add URI reference for worker + /// instance and get the reference tables: + /// +-------------------------------------------------------------------+ + /// | id | URIs | + /// +----------------------+--------------------------------------------+ + /// | job-id-hex-A | conda://{hash-A} | + /// | worker-setup-token-A | conda://{hash-A},s3://bucket/my_module.zip | + /// | worker-id-hex-A | conda://{hash-A},s3://bucket/my_module.zip | + /// +-------------------------------------------------------------------+ + /// +------------------------------------------------------+ + /// | URI | count | + /// +-----------------------------+------------------------+ + /// | conda://{hash-A} | 3 | + /// | s3://bucket/my_module.zip | 2 | + /// +------------------------------------------------------+ + /// At the same time, we should remove URI reference for worker process because python + /// worker only contains one worker instance: + /// +-------------------------------------------------------------------+ + /// | id | URIs | + /// +----------------------+--------------------------------------------+ + /// | job-id-hex-A | conda://{hash-A} | + /// | worker-id-hex-A | conda://{hash-A},s3://bucket/my_module.zip | + /// +-------------------------------------------------------------------+ + /// +------------------------------------------------------+ + /// | URI | count | + /// +-----------------------------+------------------------+ + /// | conda://{hash-A} | 2 | + /// | s3://bucket/my_module.zip | 1 | + /// +------------------------------------------------------+ + /// + /// Next, when the actor is killed, remove URI reference for worker instance: + /// +-------------------------------------------------------------------+ + /// | id | URIs | + /// +----------------------+--------------------------------------------+ + /// | job-id-hex-A | conda://{hash-A} | + /// +-------------------------------------------------------------------+ + /// +------------------------------------------------------+ + /// | URI | count | + /// +-----------------------------+------------------------+ + /// | conda://{hash-A} | 1 | + /// +------------------------------------------------------+ + + /// Finally, when the job is finished, remove URI reference and got an empty reference + /// table: + /// +-------------------------------------------------------------------+ + /// | id | URIs | + /// +----------------------+--------------------------------------------+ + /// | | | + /// +-------------------------------------------------------------------+ + /// +------------------------------------------------------+ + /// | URI | count | + /// +-----------------------------+------------------------+ + /// | | | + /// +------------------------------------------------------+ + /// + /// Now, we can delete the runtime env resources safely. + + RuntimeEnvManager runtime_env_manager_; + /// Stats int64_t process_failed_job_config_missing_ = 0; int64_t process_failed_rate_limited_ = 0; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index c34a9d818..83613279a 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -30,6 +30,7 @@ int NUM_WORKERS_PER_PROCESS_JAVA = 3; int MAXIMUM_STARTUP_CONCURRENCY = 5; int MAX_IO_WORKER_SIZE = 2; int POOL_SIZE_SOFT_LIMIT = 5; +int WORKER_REGISTER_TIMEOUT_SECONDS = 3; JobID JOB_ID = JobID::FromInt(1); std::string BAD_RUNTIME_ENV = "bad runtime env"; @@ -76,6 +77,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface { instrumented_io_context &io_service_; }; +static std::unordered_set valid_uris; + class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface { public: void CreateRuntimeEnv(const rpc::CreateRuntimeEnvRequest &request, @@ -84,6 +87,14 @@ class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface { if (request.serialized_runtime_env() == BAD_RUNTIME_ENV) { reply.set_status(rpc::AGENT_RPC_STATUS_FAILED); } else { + rpc::RuntimeEnv runtime_env; + if (google::protobuf::util::JsonStringToMessage(request.serialized_runtime_env(), + &runtime_env) + .ok()) { + for (auto uri : runtime_env.uris().py_modules_uris()) { + valid_uris.emplace(uri); + } + } reply.set_status(rpc::AGENT_RPC_STATUS_OK); reply.set_serialized_runtime_env_context("{\"dummy\":\"dummy\"}"); } @@ -92,6 +103,9 @@ class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface { void DeleteURIs(const rpc::DeleteURIsRequest &request, const rpc::ClientCallback &callback) { + for (auto uri : request.uris()) { + valid_uris.erase(uri); + } rpc::DeleteURIsReply reply; reply.set_status(rpc::AGENT_RPC_STATUS_OK); callback(Status::OK(), reply); @@ -203,7 +217,8 @@ class WorkerPoolMock : public WorkerPool { Process proc, const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID, const rpc::WorkerType worker_type = rpc::WorkerType::WORKER, - int runtime_env_hash = 0, StartupToken worker_startup_token = 0) { + int runtime_env_hash = 0, StartupToken worker_startup_token = 0, + bool set_process = true) { std::function client_handler = [this](ClientConnection &client) { HandleNewClient(client); }; std::function, int64_t, @@ -225,7 +240,7 @@ class WorkerPoolMock : public WorkerPool { auto rpc_client = std::make_shared(instrumented_io_service_); worker->Connect(rpc_client); mock_worker_rpc_clients_.emplace(worker->WorkerId(), rpc_client); - if (!proc.IsNull()) { + if (set_process && !proc.IsNull()) { worker->SetProcess(proc); worker->SetShimProcess(proc); } @@ -245,13 +260,16 @@ class WorkerPoolMock : public WorkerPool { } // Create workers for processes and push them to worker pool. - void PushWorkers() { + // \param[in] timeout_worker_number Don't register some workers to simulate worker + // registration timeout. + void PushWorkers(int timeout_worker_number = 0) { auto processes = GetProcesses(); for (auto it = processes.begin(); it != processes.end(); ++it) { auto pushed_it = pushedProcesses_.find(it->first); if (pushed_it == pushedProcesses_.end()) { int runtime_env_hash = 0; bool is_java = false; + bool has_dynamic_options = false; // Parses runtime env hash to make sure the pushed workers can be popped out. for (auto command_args : it->second) { std::string runtime_env_key = "--runtime-env-hash="; @@ -264,14 +282,27 @@ class WorkerPoolMock : public WorkerPool { if (pos != std::string::npos) { is_java = true; } + pos = command_args.find("-X"); + if (pos != std::string::npos) { + has_dynamic_options = true; + } } // TODO(SongGuyang): support C++ language workers. - int num_workers = is_java ? NUM_WORKERS_PER_PROCESS_JAVA : 1; - for (int i = 0; i < num_workers; i++) { - auto worker = - CreateWorker(it->first, is_java ? Language::JAVA : Language::PYTHON, JOB_ID, - rpc::WorkerType::WORKER, runtime_env_hash, - startup_tokens_by_proc_[it->first]); + int num_workers = + (is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1; + RAY_CHECK(timeout_worker_number <= num_workers) + << "The timeout worker number cannot exceed the total number of workers"; + auto register_workers = num_workers - timeout_worker_number; + for (int i = 0; i < register_workers; i++) { + auto worker = CreateWorker( + it->first, is_java ? Language::JAVA : Language::PYTHON, JOB_ID, + rpc::WorkerType::WORKER, runtime_env_hash, + startup_tokens_by_proc_[it->first], + // Don't set process to ensure the `RegisterWorker` succeeds below. + false); + RAY_CHECK_OK(RegisterWorker(worker, it->first.GetId(), it->first.GetId(), + startup_tokens_by_proc_[it->first], + [](Status, int) {})); OnWorkerStarted(worker); PushAvailableWorker(worker); } @@ -284,9 +315,10 @@ class WorkerPoolMock : public WorkerPool { // worker synchronously. // \param[in] push_workers If true, tries to push the workers from the started // processes. - std::shared_ptr PopWorkerSync( - const TaskSpecification &task_spec, bool push_workers = true, - PopWorkerStatus *worker_status = nullptr) { + std::shared_ptr PopWorkerSync(const TaskSpecification &task_spec, + bool push_workers = true, + PopWorkerStatus *worker_status = nullptr, + int timeout_worker_number = 0) { std::shared_ptr popped_worker = nullptr; std::promise promise; this->PopWorker(task_spec, @@ -301,7 +333,7 @@ class WorkerPoolMock : public WorkerPool { return true; }); if (push_workers) { - PushWorkers(); + PushWorkers(timeout_worker_number); } promise.get_future().get(); return popped_worker; @@ -328,7 +360,9 @@ class WorkerPoolTest : public ::testing::Test { public: WorkerPoolTest() { RayConfig::instance().initialize( - R"({"worker_register_timeout_seconds": 3, "object_spilling_config": "dummy", "max_io_workers": )" + + R"({"worker_register_timeout_seconds": )" + + std::to_string(WORKER_REGISTER_TIMEOUT_SECONDS) + + R"(, "object_spilling_config": "dummy", "max_io_workers": )" + std::to_string(MAX_IO_WORKER_SIZE) + "}"); SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, @@ -344,7 +378,10 @@ class WorkerPoolTest : public ::testing::Test { StartMockAgent(); } - virtual void TearDown() { AssertNoLeaks(); } + virtual void TearDown() { + AssertNoLeaks(); + valid_uris.clear(); + } void AssertNoLeaks() { ASSERT_EQ(worker_pool_->pending_exit_idle_workers_.size(), 0); } @@ -448,12 +485,36 @@ class WorkerPoolTest : public ::testing::Test { std::unique_ptr worker_pool_; }; +static inline rpc::RuntimeEnvInfo ExampleRuntimeEnvInfo( + const std::vector uris, bool eager_install = false) { + rpc::RuntimeEnv runtime_env; + for (auto &uri : uris) { + runtime_env.mutable_uris()->mutable_py_modules_uris()->Add(std::string(uri)); + } + std::string runtime_env_string; + google::protobuf::util::MessageToJsonString(runtime_env, &runtime_env_string); + rpc::RuntimeEnvInfo runtime_env_info; + runtime_env_info.set_serialized_runtime_env(runtime_env_string); + for (auto &uri : uris) { + runtime_env_info.mutable_uris()->Add(std::string(uri)); + } + runtime_env_info.set_runtime_env_eager_install(eager_install); + return runtime_env_info; +} + +static inline rpc::RuntimeEnvInfo ExampleRuntimeEnvInfoFromString( + std::string serialized_runtime_env) { + rpc::RuntimeEnvInfo runtime_env_info; + runtime_env_info.set_serialized_runtime_env(serialized_runtime_env); + return runtime_env_info; +} + static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID, const ActorID actor_creation_id = ActorID::Nil(), const std::vector &dynamic_worker_options = {}, const TaskID &task_id = TaskID::FromRandom(JobID::Nil()), - const std::string serialized_runtime_env = "") { + const rpc::RuntimeEnvInfo runtime_env_info = rpc::RuntimeEnvInfo()) { rpc::TaskSpec message; message.set_job_id(job_id.Binary()); message.set_language(language); @@ -472,7 +533,7 @@ static inline TaskSpecification ExampleTaskSpec( } else { message.set_type(TaskType::NORMAL_TASK); } - message.mutable_runtime_env_info()->set_serialized_runtime_env(serialized_runtime_env); + message.mutable_runtime_env_info()->CopyFrom(runtime_env_info); return TaskSpecification(std::move(message)); } @@ -618,7 +679,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { rpc::JobConfig job_config = rpc::JobConfig(); job_config.add_code_search_path("/test/code_search_path"); - job_config.set_num_java_workers_per_process(1); + job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); job_config.add_jvm_options("-Xmx1g"); job_config.add_jvm_options("-Xms500m"); job_config.add_jvm_options("-Dmy-job.hello=world"); @@ -1320,10 +1381,10 @@ TEST_F(WorkerPoolTest, PopWorkerWithRuntimeEnv) { auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1); const auto actor_creation_task_spec = ExampleTaskSpec( ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})"); + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); const auto normal_task_spec = ExampleTaskSpec( ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})"); + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); const auto normal_task_spec_without_runtime_env = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {}); // Pop worker for actor creation task again. @@ -1348,6 +1409,167 @@ TEST_F(WorkerPoolTest, PopWorkerWithRuntimeEnv) { ASSERT_EQ(worker_pool_->GetProcessSize(), 3); } +TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceJobLevel) { + // First part, test start job with eager installed runtime env. + { + auto job_id = JobID::FromInt(12345); + std::string uri = "s3://123"; + auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true); + rpc::JobConfig job_config; + job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); + // Start job. + worker_pool_->HandleJobStarted(job_id, job_config); + ASSERT_EQ(valid_uris.size(), 1); + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + ASSERT_EQ(valid_uris.size(), 0); + } + + // Second part, test start job without eager installed runtime env. + { + auto job_id = JobID::FromInt(12345); + std::string uri = "s3://123"; + auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false); + rpc::JobConfig job_config; + job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); + // Start job. + worker_pool_->HandleJobStarted(job_id, job_config); + ASSERT_EQ(valid_uris.size(), 0); + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + ASSERT_EQ(valid_uris.size(), 0); + } +} + +TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) { + auto job_id = JobID::FromInt(12345); + std::string uri = "s3://123"; + + // First part, test URI reference with eager install. + { + auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true); + rpc::JobConfig job_config; + job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); + // Start job with eager installed runtime env. + worker_pool_->HandleJobStarted(job_id, job_config); + ASSERT_EQ(valid_uris.size(), 1); + // Start actor with runtime env. + auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); + const auto actor_creation_task_spec = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id, + {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); + ASSERT_EQ(valid_uris.size(), 1); + // Start task with runtime env. + const auto normal_task_spec = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), + {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + auto popped_normal_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); + ASSERT_EQ(valid_uris.size(), 1); + // Disconnect actor worker. + worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT); + ASSERT_EQ(valid_uris.size(), 1); + // Disconnect task worker. + worker_pool_->DisconnectWorker(popped_normal_worker, rpc::WorkerExitType::IDLE_EXIT); + ASSERT_EQ(valid_uris.size(), 1); + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + ASSERT_EQ(valid_uris.size(), 0); + } + + // Second part, test URI reference without eager install. + { + auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true); + auto runtime_env_info_without_eager_install = ExampleRuntimeEnvInfo({uri}, false); + rpc::JobConfig job_config; + job_config.mutable_runtime_env_info()->CopyFrom( + runtime_env_info_without_eager_install); + // Start job without eager installed runtime env. + worker_pool_->HandleJobStarted(job_id, job_config); + ASSERT_EQ(valid_uris.size(), 0); + // Start actor with runtime env. + auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 2); + const auto actor_creation_task_spec = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id, + {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info); + auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); + ASSERT_EQ(valid_uris.size(), 1); + // Start task with runtime env. + auto popped_normal_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); + ASSERT_EQ(valid_uris.size(), 1); + // Disconnect actor worker. + worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT); + ASSERT_EQ(valid_uris.size(), 1); + // Disconnect task worker. + worker_pool_->DisconnectWorker(popped_normal_worker, rpc::WorkerExitType::IDLE_EXIT); + ASSERT_EQ(valid_uris.size(), 0); + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + ASSERT_EQ(valid_uris.size(), 0); + } +} + +TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) { + auto job_id = JOB_ID; + std::string uri = "s3://567"; + auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false); + rpc::JobConfig job_config; + job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); + job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info); + // Start job without eager installed runtime env. + worker_pool_->HandleJobStarted(job_id, job_config); + ASSERT_EQ(valid_uris.size(), 0); + + // First part, test normal case with all worker registered. + { + // Start actors with runtime env. The Java actors will trigger a multi-worker process. + std::vector> workers; + for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { + auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1); + const auto actor_creation_task_spec = + ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {}, + TaskID::FromRandom(JobID::Nil()), runtime_env_info); + auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec); + ASSERT_NE(popped_actor_worker, nullptr); + workers.push_back(popped_actor_worker); + ASSERT_EQ(valid_uris.size(), 1); + } + // Make sure only one worker process has been started. + ASSERT_EQ(worker_pool_->GetProcessSize(), 1); + // Disconnect all actor workers. + for (auto &worker : workers) { + worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT); + } + ASSERT_EQ(valid_uris.size(), 0); + } + + // Second part, test corner case with some worker registration timeout. + { + // Start one actor with runtime env. The Java actor will trigger a multi-worker + // process. + auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); + const auto actor_creation_task_spec = + ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {}, + TaskID::FromRandom(JobID::Nil()), runtime_env_info); + PopWorkerStatus status; + // Only one worker registration. All the other worker registration times out. + auto popped_actor_worker = worker_pool_->PopWorkerSync( + actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1); + ASSERT_EQ(valid_uris.size(), 1); + // Disconnect actor worker. + worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT); + ASSERT_EQ(valid_uris.size(), 1); + // Sleep for a while to wait worker registration timeout. + std::this_thread::sleep_for( + std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1)); + ASSERT_EQ(valid_uris.size(), 0); + } + + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + ASSERT_EQ(valid_uris.size(), 0); +} + TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { /// /// Check that a worker can be popped only if there is a @@ -1356,15 +1578,18 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) { /// ASSERT_EQ(worker_pool_->GetProcessSize(), 0); auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1); - const auto actor_creation_task_spec_1 = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_1"); - const auto task_spec_1 = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_1"); - const auto task_spec_2 = ExampleTaskSpec( - ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), - /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_2"); + const auto actor_creation_task_spec_1 = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, + /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfoFromString("mock_runtime_env_1")); + const auto task_spec_1 = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), + /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfoFromString("mock_runtime_env_1")); + const auto task_spec_2 = + ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), + /*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), + ExampleRuntimeEnvInfoFromString("mock_runtime_env_2")); const WorkerCacheKey env1 = {"mock_runtime_env_1", {}}; const int runtime_env_hash_1 = env1.IntHash(); @@ -1510,9 +1735,9 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) { /* Test PopWorkerStatus RuntimeEnvCreationFailed */ // Create a task with bad runtime env. - const auto task_spec_with_bad_runtime_env = - ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), - {"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), BAD_RUNTIME_ENV); + const auto task_spec_with_bad_runtime_env = ExampleTaskSpec( + ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"}, + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString(BAD_RUNTIME_ENV)); popped_worker = worker_pool_->PopWorkerSync(task_spec_with_bad_runtime_env, true, &status); // PopWorker failed and the status is `RuntimeEnvCreationFailed`. @@ -1522,7 +1747,7 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) { // Create a task with available runtime env. const auto task_spec_with_runtime_env = ExampleTaskSpec( ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"}, - TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})"); + TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"})); popped_worker = worker_pool_->PopWorkerSync(task_spec_with_runtime_env, true, &status); // PopWorker success. ASSERT_NE(popped_worker, nullptr);