[runtime env] add and remove uri reference in worker pool (#20789)

Currently, the logic of uri reference in raylet is:
- For job level, add uri reference when job started and remove uri reference when job finished.
- For actor level, add and remove uri reference for detached actor only.

In this PR, the logic is optimized to:
- For job level, check if runtime env should be installed eagerly first. If true, add or remove uri reference. 
- For actor level
    * First, add uri reference for starting worker process to avoid that runtime env is gcd before worker registered.
    * Second, add uri reference for echo worker thread of worker process. We will remove reference when worker disconnected.

- Besides, we move the instance of `RuntimeEnvManager` from `node_manager` to `worker_pool`.
- Enable the test `test_actor_level_gc` and add some tests in python and worker pool test.
This commit is contained in:
Guyang Song 2021-12-16 17:00:05 +08:00 committed by GitHub
parent a778741db6
commit 32cf19a881
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 808 additions and 153 deletions

View file

@ -1078,6 +1078,26 @@ def chdir(d: str):
os.chdir(old_dir)
def check_local_files_gced(cluster):
for node in cluster.list_all_nodes():
for subdir in [
"conda", "pip", "working_dir_files", "py_modules_files"
]:
all_files = os.listdir(
os.path.join(node.get_runtime_env_dir_path(), subdir))
# Check that there are no files remaining except for .lock files
# and generated requirements.txt files.
# TODO(architkulkarni): these files should get cleaned up too!
if len(
list(
filter(lambda f: not f.endswith((".lock", ".txt")),
all_files))) > 0:
print(str(all_files))
return False
return True
def generate_runtime_env_dict(field, spec_format, tmp_path, pip_list=None):
if pip_list is None:
pip_list = ["pip-install-test==0.5"]

View file

@ -211,6 +211,7 @@ py_test_module_list(
files = [
"test_runtime_env_conda_and_pip.py",
"test_runtime_env_conda_and_pip_2.py",
"test_runtime_env_conda_and_pip_3.py",
"test_runtime_env_complicated.py"
],
size = "large",

View file

@ -2,6 +2,7 @@ import os
import pytest
import sys
from ray._private.test_utils import (wait_for_condition, chdir,
check_local_files_gced,
generate_runtime_env_dict)
import yaml
@ -16,24 +17,6 @@ if not os.environ.get("CI"):
os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1"
def check_local_files_gced(cluster):
for node in cluster.list_all_nodes():
for subdir in ["conda", "pip"]:
all_files = os.listdir(
os.path.join(node.get_runtime_env_dir_path(), subdir))
# Check that there are no files remaining except for .lock files
# and generated requirements.txt files.
# TODO(architkulkarni): these files should get cleaned up too!
if len(
list(
filter(lambda f: not f.endswith((".lock", ".txt")),
all_files))) > 0:
print(str(all_files))
return False
return True
@pytest.mark.skipif(
os.environ.get("CI") and sys.platform != "linux",
reason="Requires PR wheels built in CI, so only run on linux CI machines.")
@ -159,37 +142,5 @@ def test_detached_actor_gc(start_cluster, field, spec_format, tmp_path):
wait_for_condition(lambda: check_local_files_gced(cluster), timeout=30)
# TODO(architkulkarni): fix bug #19602 and enable test.
@pytest.mark.skip("Currently failing")
@pytest.mark.skipif(
os.environ.get("CI") and sys.platform != "linux",
reason="Requires PR wheels built in CI, so only run on linux CI machines.")
@pytest.mark.parametrize("field", ["conda", "pip"])
@pytest.mark.parametrize("spec_format", ["file", "python_object"])
def test_actor_level_gc(start_cluster, field, spec_format, tmp_path):
"""Tests that actor-level working_dir is GC'd when the actor exits."""
cluster, address = start_cluster
ray.init(address)
runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path)
@ray.remote
class A:
def test_import(self):
import pip_install_test # noqa: F401
return True
NUM_ACTORS = 5
actors = [
A.options(runtime_env=runtime_env).remote() for _ in range(NUM_ACTORS)
]
ray.get([a.test_import.remote() for a in actors])
for i in range(5):
assert not check_local_files_gced(cluster)
ray.kill(actors[i])
wait_for_condition(lambda: check_local_files_gced(cluster))
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -0,0 +1,124 @@
import os
import pytest
import sys
from ray._private.test_utils import (
wait_for_condition, check_local_files_gced, generate_runtime_env_dict)
import ray
if not os.environ.get("CI"):
# This flags turns on the local development that link against current ray
# packages and fall back all the dependencies to current python's site.
os.environ["RAY_RUNTIME_ENV_LOCAL_DEV_MODE"] = "1"
@pytest.mark.skipif(
os.environ.get("CI") and sys.platform != "linux",
reason="Requires PR wheels built in CI, so only run on linux CI machines.")
@pytest.mark.parametrize("field", ["conda", "pip"])
@pytest.mark.parametrize("spec_format", ["file", "python_object"])
def test_actor_level_gc(start_cluster, field, spec_format, tmp_path):
"""Tests that actor-level working_dir is GC'd when the actor exits."""
cluster, address = start_cluster
ray.init(address)
runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path)
@ray.remote
class A:
def test_import(self):
import pip_install_test # noqa: F401
return True
NUM_ACTORS = 5
actors = [
A.options(runtime_env=runtime_env).remote() for _ in range(NUM_ACTORS)
]
ray.get([a.test_import.remote() for a in actors])
for i in range(5):
assert not check_local_files_gced(cluster)
ray.kill(actors[i])
wait_for_condition(lambda: check_local_files_gced(cluster))
@pytest.mark.skipif(
os.environ.get("CI") and sys.platform != "linux",
reason="Requires PR wheels built in CI, so only run on linux CI machines.")
@pytest.mark.parametrize(
"ray_start_cluster", [
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 0,
},
},
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 5,
},
},
],
indirect=True)
@pytest.mark.parametrize("field", ["conda", "pip"])
@pytest.mark.parametrize("spec_format", ["file", "python_object"])
def test_task_level_gc(ray_start_cluster, field, spec_format, tmp_path):
"""Tests that task-level working_dir is GC'd when the actor exits."""
cluster = ray_start_cluster
soft_limit_zero = False
system_config = cluster.list_all_nodes()[0]._ray_params._system_config
if "num_workers_soft_limit" in system_config and \
system_config["num_workers_soft_limit"] == 0:
soft_limit_zero = True
runtime_env = generate_runtime_env_dict(field, spec_format, tmp_path)
@ray.remote
def f():
import pip_install_test # noqa: F401
return True
@ray.remote
class A:
def test_import(self):
import pip_install_test # noqa: F401
return True
# Start a task with runtime env
ray.get(f.options(runtime_env=runtime_env).remote())
if soft_limit_zero:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
# Start a actor with runtime env
actor = A.options(runtime_env=runtime_env).remote()
ray.get(actor.test_import.remote())
# Local files should not be gced
assert not check_local_files_gced(cluster)
# Kill actor
ray.kill(actor)
if soft_limit_zero:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
# Start a task with runtime env
ray.get(f.options(runtime_env=runtime_env).remote())
if soft_limit_zero:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -7,6 +7,7 @@ import tempfile
import pytest
import ray
import time
# This test requires you have AWS credentials set up (any AWS credentials will
# do, this test only accesses a public bucket).
@ -59,6 +60,12 @@ def test_lazy_reads(start_cluster, tmp_working_dir, option: str):
def reinit():
ray.shutdown()
# TODO(SongGuyang): Currently, reinit the driver will generate the same
# job id. And if we reinit immediately after shutdown, raylet may
# process new job started before old job finished in some cases. This
# inconsistency could disorder the URI reference and delete a valid
# runtime env. We sleep here to walk around this issue.
time.sleep(5)
call_ray_init()
@ray.remote
@ -135,6 +142,12 @@ def test_captured_import(start_cluster, tmp_working_dir, option: str):
def reinit():
ray.shutdown()
# TODO(SongGuyang): Currently, reinit the driver will generate the same
# job id. And if we reinit immediately after shutdown, raylet may
# process new job started before old job finished in some cases. This
# inconsistency could disorder the URI reference and delete a valid
# runtime env. We sleep here to walk around this issue.
time.sleep(5)
call_ray_init()
# Import in the driver.

View file

@ -9,7 +9,8 @@ from ray._private.test_utils import run_string_as_driver
import ray
import ray.experimental.internal_kv as kv
from ray._private.test_utils import wait_for_condition, chdir
from ray._private.test_utils import (wait_for_condition, chdir,
check_local_files_gced)
from ray._private.runtime_env import RAY_WORKER_DEV_EXCLUDES
from ray._private.runtime_env.packaging import GCS_STORAGE_MAX_SIZE
# This test requires you have AWS credentials set up (any AWS credentials will
@ -239,20 +240,6 @@ def check_internal_kv_gced():
return len(kv._internal_kv_list("gcs://")) == 0
def check_local_files_gced(cluster):
for node in cluster.list_all_nodes():
for subdir in ["working_dir_files", "py_modules_files"]:
all_files = os.listdir(
os.path.join(node.get_runtime_env_dir_path(), subdir))
# Check that there are no files remaining except for .lock files.
# TODO(edoakes): the lock files should get cleaned up too!
if len(list(filter(lambda f: not f.endswith(".lock"),
all_files))) > 0:
return False
return True
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
@pytest.mark.parametrize("option", ["working_dir", "py_modules"])
@pytest.mark.parametrize(
@ -304,8 +291,6 @@ def test_job_level_gc(start_cluster, option: str, source: str):
wait_for_condition(lambda: check_local_files_gced(cluster))
# TODO(architkulkarni): fix bug #19602 and enable test.
@pytest.mark.skip("Currently failing.")
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
@pytest.mark.parametrize("option", ["working_dir", "py_modules"])
def test_actor_level_gc(start_cluster, option: str):

View file

@ -0,0 +1,147 @@
import os
import sys
import pytest
import ray
from ray.exceptions import GetTimeoutError
from ray._private.test_utils import (wait_for_condition,
check_local_files_gced)
# This test requires you have AWS credentials set up (any AWS credentials will
# do, this test only accesses a public bucket).
# This package contains a subdirectory called `test_module`.
# Calling `test_module.one()` should return `2`.
# If you find that confusing, take it up with @jiaodong...
S3_PACKAGE_URI = "s3://runtime-env-test/test_runtime_env.zip"
@pytest.mark.skipif(
os.environ.get("CI") and sys.platform != "linux",
reason="Requires PR wheels built in CI, so only run on linux CI machines.")
@pytest.mark.parametrize(
"ray_start_cluster",
[
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 0,
},
},
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 5,
},
},
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 0,
# this delay will make worker start slow and time out
"testing_asio_delay_us": "InternalKVGcsService.grpc_server"
".InternalKVGet=2000000:2000000",
"worker_register_timeout_seconds": 1,
},
},
{
"num_nodes": 1,
"_system_config": {
"num_workers_soft_limit": 5,
# this delay will make worker start slow and time out
"testing_asio_delay_us": "InternalKVGcsService.grpc_server"
".InternalKVGet=2000000:2000000",
"worker_register_timeout_seconds": 1,
},
},
],
indirect=True)
@pytest.mark.skipif(sys.platform == "win32", reason="Fail to create temp dir.")
@pytest.mark.parametrize("option", ["working_dir", "py_modules"])
def test_task_level_gc(ray_start_cluster, option):
"""Tests that task-level working_dir is GC'd when the worker exits."""
cluster = ray_start_cluster
soft_limit_zero = False
worker_register_timeout = False
system_config = cluster.list_all_nodes()[0]._ray_params._system_config
if "num_workers_soft_limit" in system_config and \
system_config["num_workers_soft_limit"] == 0:
soft_limit_zero = True
if "worker_register_timeout_seconds" in system_config and \
system_config["worker_register_timeout_seconds"] != 0:
worker_register_timeout = True
@ray.remote
def f():
import test_module
test_module.one()
@ray.remote(num_cpus=1)
class A:
def check(self):
import test_module
test_module.one()
if option == "working_dir":
runtime_env = {"working_dir": S3_PACKAGE_URI}
else:
runtime_env = {"py_modules": [S3_PACKAGE_URI]}
# Note: We should set a bigger timeout if downloads the s3 package slowly.
get_timeout = 10
# Start a task with runtime env
if worker_register_timeout:
with pytest.raises(GetTimeoutError):
ray.get(
f.options(runtime_env=runtime_env).remote(),
timeout=get_timeout)
else:
ray.get(f.options(runtime_env=runtime_env).remote())
if soft_limit_zero or worker_register_timeout:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
# Start a actor with runtime env
actor = A.options(runtime_env=runtime_env).remote()
if worker_register_timeout:
with pytest.raises(GetTimeoutError):
ray.get(actor.check.remote(), timeout=get_timeout)
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
ray.get(actor.check.remote())
assert not check_local_files_gced(cluster)
# Kill actor
ray.kill(actor)
if soft_limit_zero or worker_register_timeout:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
# Start a task with runtime env
if worker_register_timeout:
with pytest.raises(GetTimeoutError):
ray.get(
f.options(runtime_env=runtime_env).remote(),
timeout=get_timeout)
else:
ray.get(f.options(runtime_env=runtime_env).remote())
if soft_limit_zero or worker_register_timeout:
# Wait for worker exited and local files gced
wait_for_condition(lambda: check_local_files_gced(cluster))
else:
# Local files should not be gced because of an enough soft limit.
assert not check_local_files_gced(cluster)
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", __file__]))

View file

@ -367,6 +367,10 @@ RAY_CONFIG(uint64_t, kill_idle_workers_interval_ms, 200)
/// The idle time threshold for an idle worker to be killed.
RAY_CONFIG(int64_t, idle_worker_killing_time_threshold_ms, 1000)
/// The soft limit of the number of workers.
/// -1 means using num_cpus instead.
RAY_CONFIG(int64_t, num_workers_soft_limit, -1)
// The interval where metrics are exported in milliseconds.
RAY_CONFIG(uint64_t, metrics_report_interval_ms, 10000)

View file

@ -23,7 +23,7 @@ void RuntimeEnvManager::AddURIReference(const std::string &hex_id,
for (const auto &uri : uris) {
uri_reference_[uri]++;
id_to_uris_[hex_id].push_back(uri);
RAY_LOG(DEBUG) << "Added URI Reference " << uri;
RAY_LOG(DEBUG) << "Added URI Reference " << uri << " for id " << hex_id;
}
}
@ -35,6 +35,7 @@ const std::vector<std::string> &RuntimeEnvManager::GetReferences(
}
void RuntimeEnvManager::RemoveURIReference(const std::string &hex_id) {
RAY_LOG(DEBUG) << "Subtracting 1 from URI Reference for id " << hex_id;
if (!id_to_uris_.count(hex_id)) {
return;
}

View file

@ -171,7 +171,9 @@ int main(int argc, char *argv[]) {
<< node_manager_config.resource_config.ToString();
node_manager_config.node_manager_address = node_ip_address;
node_manager_config.node_manager_port = node_manager_port;
node_manager_config.num_workers_soft_limit = num_cpus;
auto soft_limit_config = RayConfig::instance().num_workers_soft_limit();
node_manager_config.num_workers_soft_limit =
soft_limit_config >= 0 ? soft_limit_config : num_cpus;
node_manager_config.num_initial_python_workers_for_first_job =
num_initial_python_workers_for_first_job;
node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency;

View file

@ -299,13 +299,6 @@ NodeManager::NodeManager(instrumented_io_context &io_service, const NodeID &self
global_gc_throttler_(RayConfig::instance().global_gc_min_interval_s() * 1e9),
local_gc_interval_ns_(RayConfig::instance().local_gc_interval_s() * 1e9),
record_metrics_period_ms_(config.record_metrics_period_ms),
runtime_env_manager_(
/*deleter=*/[this](const std::string &uri, std::function<void(bool)> cb) {
if (RayConfig::instance().runtime_env_skip_local_gc()) {
return cb(true);
}
return agent_manager_->DeleteURIs({uri}, cb);
}),
next_resource_seq_no_(0) {
RAY_LOG(INFO) << "Initializing NodeManager with ID " << self_node_id_;
RAY_CHECK(RayConfig::instance().raylet_heartbeat_period_milliseconds() > 0);
@ -554,11 +547,6 @@ void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_
<< job_data.driver_pid() << " is dead: " << job_data.is_dead()
<< " driver address: " << job_data.driver_ip_address();
worker_pool_.HandleJobStarted(job_id, job_data.config());
// NOTE: Technically `HandleJobStarted` isn't idempotent because we'll
// increment the ref count multiple times. This is fine because
// `HandleJobFinisehd` will also decrement the ref count multiple times.
runtime_env_manager_.AddURIReference(job_id.Hex(),
job_data.config().runtime_env_info());
// Tasks of this job may already arrived but failed to pop a worker because the job
// config is not local yet. So we trigger dispatching again here to try to
// reschedule these tasks.
@ -569,7 +557,6 @@ void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job
RAY_LOG(DEBUG) << "HandleJobFinished " << job_id;
RAY_CHECK(job_data.is_dead());
worker_pool_.HandleJobFinished(job_id);
runtime_env_manager_.RemoveURIReference(job_id.Hex());
}
void NodeManager::FillNormalTaskResourceUsage(rpc::ResourcesData &resources_data) {
@ -1271,10 +1258,6 @@ void NodeManager::DisconnectClient(const std::shared_ptr<ClientConnection> &clie
cluster_task_manager_->TaskFinished(worker, &task);
}
if (worker->IsDetachedActor()) {
runtime_env_manager_.RemoveURIReference(actor_id.Hex());
}
if (disconnect_type == rpc::WorkerExitType::SYSTEM_ERROR_EXIT) {
// Push the error to driver.
const JobID &job_id = worker->GetAssignedJobId();
@ -1955,8 +1938,6 @@ void NodeManager::FinishAssignedActorCreationTask(WorkerInterface &worker,
auto job_id = task.GetTaskSpecification().JobId();
auto job_config = worker_pool_.GetJobConfig(job_id);
RAY_CHECK(job_config);
runtime_env_manager_.AddURIReference(actor_id.Hex(),
task.GetTaskSpecification().RuntimeEnvInfo());
}
}

View file

@ -28,7 +28,6 @@
#include "ray/object_manager/object_manager.h"
#include "ray/raylet/agent_manager.h"
#include "ray/raylet_client/raylet_client.h"
#include "ray/common/runtime_env_manager.h"
#include "ray/raylet/local_object_manager.h"
#include "ray/raylet/scheduling/scheduling_ids.h"
#include "ray/raylet/scheduling/cluster_resource_scheduler.h"
@ -747,9 +746,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// Managers all bundle-related operations.
std::shared_ptr<PlacementGroupResourceManager> placement_group_resource_manager_;
/// Manage all runtime env locally
RuntimeEnvManager runtime_env_manager_;
/// Next resource broadcast seq no. Non-incrementing sequence numbers
/// indicate network issues (dropped/duplicated/ooo packets, etc).
int64_t next_resource_seq_no_;

View file

@ -34,6 +34,10 @@ DEFINE_stats(worker_register_time_ms, "end to end latency of register a worker p
namespace {
// Add this prefix because the worker setup token is just a counter which is easy to
// duplicate with other ids.
static const std::string kWorkerSetupTokenPrefix = "worker_startup_token:";
// A helper function to get a worker from a list.
std::shared_ptr<ray::raylet::WorkerInterface> GetWorker(
const std::unordered_set<std::shared_ptr<ray::raylet::WorkerInterface>> &worker_pool,
@ -85,7 +89,14 @@ WorkerPool::WorkerPool(instrumented_io_context &io_service, const NodeID node_id
num_initial_python_workers_for_first_job, maximum_startup_concurrency)),
num_initial_python_workers_for_first_job_(num_initial_python_workers_for_first_job),
periodical_runner_(io_service),
get_time_(get_time) {
get_time_(get_time),
runtime_env_manager_(
/*deleter=*/[this](const std::string &uri, std::function<void(bool)> cb) {
if (RayConfig::instance().runtime_env_skip_local_gc()) {
return cb(true);
}
return agent_manager_->DeleteURIs({uri}, cb);
}) {
RAY_CHECK(maximum_startup_concurrency > 0);
// We need to record so that the metric exists. This way, we report that 0
// processes have started before a task runs on the node (as opposed to the
@ -182,11 +193,32 @@ void WorkerPool::update_worker_startup_token_counter() {
worker_startup_token_counter_ += 1;
}
void WorkerPool::AddStartingWorkerProcess(
State &state, const int workers_to_start, const rpc::WorkerType worker_type,
const Process &proc, const std::chrono::high_resolution_clock::time_point &start,
const rpc::RuntimeEnvInfo &runtime_env_info) {
state.starting_worker_processes.emplace(
worker_startup_token_counter_,
StartingWorkerProcessInfo{workers_to_start, workers_to_start, worker_type, proc,
start, runtime_env_info});
runtime_env_manager_.AddURIReference(
kWorkerSetupTokenPrefix + std::to_string(worker_startup_token_counter_),
runtime_env_info);
}
void WorkerPool::RemoveStartingWorkerProcess(State &state,
const StartupToken &proc_startup_token) {
state.starting_worker_processes.erase(proc_startup_token);
runtime_env_manager_.RemoveURIReference(kWorkerSetupTokenPrefix +
std::to_string(proc_startup_token));
}
std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
const Language &language, const rpc::WorkerType worker_type, const JobID &job_id,
PopWorkerStatus *status, const std::vector<std::string> &dynamic_options,
const int runtime_env_hash, const std::string &serialized_runtime_env_context,
const std::string &allocated_instances_serialized_json) {
const std::string &allocated_instances_serialized_json,
const rpc::RuntimeEnvInfo &runtime_env_info) {
rpc::JobConfig *job_config = nullptr;
if (!IsIOWorkerType(worker_type)) {
RAY_CHECK(!job_id.IsNil());
@ -415,13 +447,12 @@ std::tuple<Process, StartupToken> WorkerPool::StartWorkerProcess(
stats::ProcessStartupTimeMs.Record(duration.count());
stats::NumWorkersStarted.Record(1);
RAY_LOG(INFO) << "Started worker process of " << workers_to_start
<< " worker(s) with pid " << proc.GetId();
<< " worker(s) with pid " << proc.GetId() << ", the token "
<< worker_startup_token_counter_;
MonitorStartingWorkerProcess(proc, worker_startup_token_counter_, language,
worker_type);
state.starting_worker_processes.emplace(
worker_startup_token_counter_,
StartingWorkerProcessInfo{workers_to_start, workers_to_start, worker_type, proc,
start});
AddStartingWorkerProcess(state, workers_to_start, worker_type, proc, start,
runtime_env_info);
StartupToken worker_startup_token = worker_startup_token_counter_;
update_worker_startup_token_counter();
if (IsIOWorkerType(worker_type)) {
@ -471,7 +502,7 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc,
proc_startup_token, nullptr, status, &found,
&used, &task_id);
}
state.starting_worker_processes.erase(it);
RemoveStartingWorkerProcess(state, proc_startup_token);
if (IsIOWorkerType(worker_type)) {
// Mark the I/O worker as failed.
auto &io_worker_state = GetIOWorkerStateFromWorkerType(worker_type, state);
@ -560,34 +591,53 @@ void WorkerPool::MarkPortAsFree(int port) {
}
}
void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) {
all_jobs_[job_id] = job_config;
static bool NeedToEagerInstallRuntimeEnv(const rpc::JobConfig &job_config) {
if (job_config.has_runtime_env_info() &&
job_config.runtime_env_info().runtime_env_eager_install()) {
auto const &runtime_env = job_config.runtime_env_info().serialized_runtime_env();
if (runtime_env != "{}" && runtime_env != "") {
RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id
<< ". The runtime environment was " << runtime_env << ".";
CreateRuntimeEnv(
runtime_env, job_id,
[job_id](bool successful, const std::string &serialized_runtime_env_context) {
if (successful) {
RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job "
<< job_id << ". The result context was "
<< serialized_runtime_env_context << ".";
} else {
RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job "
<< job_id << ".";
}
});
return true;
}
}
return false;
}
void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) {
all_jobs_[job_id] = job_config;
if (NeedToEagerInstallRuntimeEnv(job_config)) {
auto const &runtime_env = job_config.runtime_env_info().serialized_runtime_env();
// NOTE: Technically `HandleJobStarted` isn't idempotent because we'll
// increment the ref count multiple times. This is fine because
// `HandleJobFinished` will also decrement the ref count multiple times.
runtime_env_manager_.AddURIReference(job_id.Hex(), job_config.runtime_env_info());
RAY_LOG(INFO) << "[Eagerly] Start install runtime environment for job " << job_id
<< ". The runtime environment was " << runtime_env << ".";
CreateRuntimeEnv(
runtime_env, job_id,
[job_id](bool successful, const std::string &serialized_runtime_env_context) {
if (successful) {
RAY_LOG(INFO) << "[Eagerly] Create runtime env successful for job " << job_id
<< ". The result context was " << serialized_runtime_env_context
<< ".";
} else {
RAY_LOG(ERROR) << "[Eagerly] Couldn't create a runtime environment for job "
<< job_id << ".";
}
});
}
}
void WorkerPool::HandleJobFinished(const JobID &job_id) {
// Currently we don't erase the job from `all_jobs_` , as a workaround for
// https://github.com/ray-project/ray/issues/11437.
// unfinished_jobs_.erase(job_id);
auto job_config = GetJobConfig(job_id);
RAY_CHECK(job_config);
// Check eager install here because we only add URI reference when runtime
// env install really happens.
if (NeedToEagerInstallRuntimeEnv(*job_config)) {
runtime_env_manager_.RemoveURIReference(job_id.Hex());
}
finished_jobs_.insert(job_id);
}
@ -608,7 +658,7 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker
if (it == state.starting_worker_processes.end()) {
RAY_LOG(WARNING)
<< "Received a register request from an unknown worker shim process: "
<< worker_shim_pid;
<< worker_shim_pid << ", token: " << worker_startup_token;
Status status = Status::Invalid("Unknown worker");
send_reply_callback(status, /*port=*/0);
return status;
@ -651,9 +701,11 @@ void WorkerPool::OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker)
auto it = state.starting_worker_processes.find(worker_startup_token);
if (it != state.starting_worker_processes.end()) {
runtime_env_manager_.AddURIReference(worker->WorkerId().Hex(),
it->second.runtime_env_info);
it->second.num_starting_workers--;
if (it->second.num_starting_workers == 0) {
state.starting_worker_processes.erase(it);
RemoveStartingWorkerProcess(state, worker_startup_token);
// We may have slots to start more workers now.
TryStartIOWorkers(worker->GetLanguage());
}
@ -1047,7 +1099,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec,
auto [proc, startup_token] = StartWorkerProcess(
task_spec.GetLanguage(), rpc::WorkerType::WORKER, task_spec.JobId(), &status,
dynamic_options, task_spec.GetRuntimeEnvHash(), serialized_runtime_env_context,
allocated_instances_serialized_json);
allocated_instances_serialized_json, task_spec.RuntimeEnvInfo());
if (status == PopWorkerStatus::OK) {
RAY_CHECK(proc.IsValid());
WarnAboutSize();
@ -1210,6 +1262,7 @@ void WorkerPool::PrestartWorkers(const TaskSpecification &task_spec, int64_t bac
bool WorkerPool::DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker,
rpc::WorkerExitType disconnect_type) {
runtime_env_manager_.RemoveURIReference(worker->WorkerId().Hex());
auto &state = GetStateForLanguage(worker->GetLanguage());
RAY_CHECK(RemoveWorker(state.registered_workers, worker));
RAY_UNUSED(RemoveWorker(state.pending_disconnection_workers, worker));

View file

@ -28,6 +28,7 @@
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/asio/periodical_runner.h"
#include "ray/common/client_connection.h"
#include "ray/common/runtime_env_manager.h"
#include "ray/common/task/task.h"
#include "ray/common/task/task_common.h"
#include "ray/gcs/gcs_client/gcs_client.h"
@ -405,6 +406,7 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
/// \param serialized_runtime_env_context The context of runtime env.
/// \param allocated_instances_serialized_json The allocated resource instances
// json string.
/// \param runtime_env_info The raw runtime env info.
/// \return The process that we started and a token. If the token is less than 0,
/// we didn't start a process.
std::tuple<Process, StartupToken> StartWorkerProcess(
@ -413,7 +415,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
const std::vector<std::string> &dynamic_options = {},
const int runtime_env_hash = 0,
const std::string &serialized_runtime_env_context = "{}",
const std::string &allocated_instances_serialized_json = "{}");
const std::string &allocated_instances_serialized_json = "{}",
const rpc::RuntimeEnvInfo &runtime_env_info = rpc::RuntimeEnvInfo());
/// The implementation of how to start a new worker process with command arguments.
/// The lifetime of the process is tied to that of the returned object,
@ -463,6 +466,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
Process proc;
/// The worker process start time.
std::chrono::high_resolution_clock::time_point start_time;
/// The runtime env Info.
rpc::RuntimeEnvInfo runtime_env_info;
};
struct TaskWaitingForWorkerInfo {
@ -620,6 +625,13 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
const std::function<void(bool, const std::string &)> &callback,
const std::string &serialized_allocated_resource_instances = "{}");
void AddStartingWorkerProcess(
State &state, const int workers_to_start, const rpc::WorkerType worker_type,
const Process &proc, const std::chrono::high_resolution_clock::time_point &start,
const rpc::RuntimeEnvInfo &runtime_env_info);
void RemoveStartingWorkerProcess(State &state, const StartupToken &proc_startup_token);
/// For Process class for managing subprocesses (e.g. reaping zombies).
instrumented_io_context *io_service_;
/// Node ID of the current node.
@ -684,6 +696,146 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
/// Agent manager.
std::shared_ptr<AgentManager> agent_manager_;
/// Manage all runtime env resources locally by URI reference. We add or remove URI
/// reference in the cases below:
/// For the job with an eager installed runtime env:
/// - Add URI reference when job started.
/// - Remove URI reference when job finished.
/// For the worker process with a valid runtime env:
/// - Add URI reference when worker process started.
/// - Remove URI reference when all the worker instance registration completed or any
/// worker instance registration times out.
/// - Add URI reference when a worker instance registered.
/// - Remove URI reference when a worker instance disconnected.
///
/// A normal state change flow is:
/// job level:
/// HandleJobStarted(ref + 1 = 1) -> HandleJobFinshed(ref - 1 = 0)
/// worker level:
/// StartWorkerProcess(ref + 1 = 1)
/// -> OnWorkerStarted * 3 (ref + 3 = 4)
/// -> All worker instances registered (ref - 1 = 3)
/// -> DisconnectWorker * 3 (ref - 3 = 0)
///
/// A state change flow for worker timeout case is:
/// StartWorkerProcess(ref + 1 = 1)
/// -> OnWorkerStarted * 2 (ref + 2 = 3)
/// -> One worker registration times out, kill worker process (ref - 1 = 2)
/// -> DisconnectWorker * 2 (ref - 2 = 0)
///
/// Note: "OnWorkerStarted * 3" means that three workers are started. And we assume
/// that the worker process has tree worker instances totally.
///
/// An example to show reference table changes:
///
/// Start a job with eager installed runtime env:
/// ray.init(runtime_env=
/// {
/// "conda": {
/// "dependencies": ["requests"]
/// },
/// "eager_install": True
/// })
/// Add URI reference and get the reference tables:
/// +---------------------------------------------+
/// | id | URIs |
/// +--------------------+------------------------+
/// | job-id-hex-A | conda://{hash-A} |
/// +---------------------------------------------+
/// +---------------------------------------------+
/// | URI | count |
/// +--------------------+------------------------+
/// | conda://{hash-A} | 1 |
/// +---------------------------------------------+
///
/// Create actor with runtime env:
/// @ray.remote
/// class actor:
/// def Foo():
/// import my_module
/// pass
/// actor.options(runtime_env=
/// {
/// "conda": {
/// "dependencies": ["requests"]
/// },
/// "py_modules": [
/// "s3://bucket/my_module.zip"
/// ]
/// }).remote()
/// First step when worker process started, add URI reference and get the reference
/// tables:
/// +-------------------------------------------------------------------+
/// | id | URIs |
/// +----------------------+--------------------------------------------+
/// | job-id-hex-A | conda://{hash-A} |
/// | worker-setup-token-A | conda://{hash-A},s3://bucket/my_module.zip |
/// +-------------------------------------------------------------------+
/// +------------------------------------------------------+
/// | URI | count |
/// +-----------------------------+------------------------+
/// | conda://{hash-A} | 2 |
/// | s3://bucket/my_module.zip | 1 |
/// +------------------------------------------------------+
/// Second step when worker instance registers, add URI reference for worker
/// instance and get the reference tables:
/// +-------------------------------------------------------------------+
/// | id | URIs |
/// +----------------------+--------------------------------------------+
/// | job-id-hex-A | conda://{hash-A} |
/// | worker-setup-token-A | conda://{hash-A},s3://bucket/my_module.zip |
/// | worker-id-hex-A | conda://{hash-A},s3://bucket/my_module.zip |
/// +-------------------------------------------------------------------+
/// +------------------------------------------------------+
/// | URI | count |
/// +-----------------------------+------------------------+
/// | conda://{hash-A} | 3 |
/// | s3://bucket/my_module.zip | 2 |
/// +------------------------------------------------------+
/// At the same time, we should remove URI reference for worker process because python
/// worker only contains one worker instance:
/// +-------------------------------------------------------------------+
/// | id | URIs |
/// +----------------------+--------------------------------------------+
/// | job-id-hex-A | conda://{hash-A} |
/// | worker-id-hex-A | conda://{hash-A},s3://bucket/my_module.zip |
/// +-------------------------------------------------------------------+
/// +------------------------------------------------------+
/// | URI | count |
/// +-----------------------------+------------------------+
/// | conda://{hash-A} | 2 |
/// | s3://bucket/my_module.zip | 1 |
/// +------------------------------------------------------+
///
/// Next, when the actor is killed, remove URI reference for worker instance:
/// +-------------------------------------------------------------------+
/// | id | URIs |
/// +----------------------+--------------------------------------------+
/// | job-id-hex-A | conda://{hash-A} |
/// +-------------------------------------------------------------------+
/// +------------------------------------------------------+
/// | URI | count |
/// +-----------------------------+------------------------+
/// | conda://{hash-A} | 1 |
/// +------------------------------------------------------+
/// Finally, when the job is finished, remove URI reference and got an empty reference
/// table:
/// +-------------------------------------------------------------------+
/// | id | URIs |
/// +----------------------+--------------------------------------------+
/// | | |
/// +-------------------------------------------------------------------+
/// +------------------------------------------------------+
/// | URI | count |
/// +-----------------------------+------------------------+
/// | | |
/// +------------------------------------------------------+
///
/// Now, we can delete the runtime env resources safely.
RuntimeEnvManager runtime_env_manager_;
/// Stats
int64_t process_failed_job_config_missing_ = 0;
int64_t process_failed_rate_limited_ = 0;

View file

@ -30,6 +30,7 @@ int NUM_WORKERS_PER_PROCESS_JAVA = 3;
int MAXIMUM_STARTUP_CONCURRENCY = 5;
int MAX_IO_WORKER_SIZE = 2;
int POOL_SIZE_SOFT_LIMIT = 5;
int WORKER_REGISTER_TIMEOUT_SECONDS = 3;
JobID JOB_ID = JobID::FromInt(1);
std::string BAD_RUNTIME_ENV = "bad runtime env";
@ -76,6 +77,8 @@ class MockWorkerClient : public rpc::CoreWorkerClientInterface {
instrumented_io_context &io_service_;
};
static std::unordered_set<std::string> valid_uris;
class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface {
public:
void CreateRuntimeEnv(const rpc::CreateRuntimeEnvRequest &request,
@ -84,6 +87,14 @@ class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface {
if (request.serialized_runtime_env() == BAD_RUNTIME_ENV) {
reply.set_status(rpc::AGENT_RPC_STATUS_FAILED);
} else {
rpc::RuntimeEnv runtime_env;
if (google::protobuf::util::JsonStringToMessage(request.serialized_runtime_env(),
&runtime_env)
.ok()) {
for (auto uri : runtime_env.uris().py_modules_uris()) {
valid_uris.emplace(uri);
}
}
reply.set_status(rpc::AGENT_RPC_STATUS_OK);
reply.set_serialized_runtime_env_context("{\"dummy\":\"dummy\"}");
}
@ -92,6 +103,9 @@ class MockRuntimeEnvAgentClient : public rpc::RuntimeEnvAgentClientInterface {
void DeleteURIs(const rpc::DeleteURIsRequest &request,
const rpc::ClientCallback<rpc::DeleteURIsReply> &callback) {
for (auto uri : request.uris()) {
valid_uris.erase(uri);
}
rpc::DeleteURIsReply reply;
reply.set_status(rpc::AGENT_RPC_STATUS_OK);
callback(Status::OK(), reply);
@ -203,7 +217,8 @@ class WorkerPoolMock : public WorkerPool {
Process proc, const Language &language = Language::PYTHON,
const JobID &job_id = JOB_ID,
const rpc::WorkerType worker_type = rpc::WorkerType::WORKER,
int runtime_env_hash = 0, StartupToken worker_startup_token = 0) {
int runtime_env_hash = 0, StartupToken worker_startup_token = 0,
bool set_process = true) {
std::function<void(ClientConnection &)> client_handler =
[this](ClientConnection &client) { HandleNewClient(client); };
std::function<void(std::shared_ptr<ClientConnection>, int64_t,
@ -225,7 +240,7 @@ class WorkerPoolMock : public WorkerPool {
auto rpc_client = std::make_shared<MockWorkerClient>(instrumented_io_service_);
worker->Connect(rpc_client);
mock_worker_rpc_clients_.emplace(worker->WorkerId(), rpc_client);
if (!proc.IsNull()) {
if (set_process && !proc.IsNull()) {
worker->SetProcess(proc);
worker->SetShimProcess(proc);
}
@ -245,13 +260,16 @@ class WorkerPoolMock : public WorkerPool {
}
// Create workers for processes and push them to worker pool.
void PushWorkers() {
// \param[in] timeout_worker_number Don't register some workers to simulate worker
// registration timeout.
void PushWorkers(int timeout_worker_number = 0) {
auto processes = GetProcesses();
for (auto it = processes.begin(); it != processes.end(); ++it) {
auto pushed_it = pushedProcesses_.find(it->first);
if (pushed_it == pushedProcesses_.end()) {
int runtime_env_hash = 0;
bool is_java = false;
bool has_dynamic_options = false;
// Parses runtime env hash to make sure the pushed workers can be popped out.
for (auto command_args : it->second) {
std::string runtime_env_key = "--runtime-env-hash=";
@ -264,14 +282,27 @@ class WorkerPoolMock : public WorkerPool {
if (pos != std::string::npos) {
is_java = true;
}
pos = command_args.find("-X");
if (pos != std::string::npos) {
has_dynamic_options = true;
}
}
// TODO(SongGuyang): support C++ language workers.
int num_workers = is_java ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
for (int i = 0; i < num_workers; i++) {
auto worker =
CreateWorker(it->first, is_java ? Language::JAVA : Language::PYTHON, JOB_ID,
rpc::WorkerType::WORKER, runtime_env_hash,
startup_tokens_by_proc_[it->first]);
int num_workers =
(is_java && !has_dynamic_options) ? NUM_WORKERS_PER_PROCESS_JAVA : 1;
RAY_CHECK(timeout_worker_number <= num_workers)
<< "The timeout worker number cannot exceed the total number of workers";
auto register_workers = num_workers - timeout_worker_number;
for (int i = 0; i < register_workers; i++) {
auto worker = CreateWorker(
it->first, is_java ? Language::JAVA : Language::PYTHON, JOB_ID,
rpc::WorkerType::WORKER, runtime_env_hash,
startup_tokens_by_proc_[it->first],
// Don't set process to ensure the `RegisterWorker` succeeds below.
false);
RAY_CHECK_OK(RegisterWorker(worker, it->first.GetId(), it->first.GetId(),
startup_tokens_by_proc_[it->first],
[](Status, int) {}));
OnWorkerStarted(worker);
PushAvailableWorker(worker);
}
@ -284,9 +315,10 @@ class WorkerPoolMock : public WorkerPool {
// worker synchronously.
// \param[in] push_workers If true, tries to push the workers from the started
// processes.
std::shared_ptr<WorkerInterface> PopWorkerSync(
const TaskSpecification &task_spec, bool push_workers = true,
PopWorkerStatus *worker_status = nullptr) {
std::shared_ptr<WorkerInterface> PopWorkerSync(const TaskSpecification &task_spec,
bool push_workers = true,
PopWorkerStatus *worker_status = nullptr,
int timeout_worker_number = 0) {
std::shared_ptr<WorkerInterface> popped_worker = nullptr;
std::promise<bool> promise;
this->PopWorker(task_spec,
@ -301,7 +333,7 @@ class WorkerPoolMock : public WorkerPool {
return true;
});
if (push_workers) {
PushWorkers();
PushWorkers(timeout_worker_number);
}
promise.get_future().get();
return popped_worker;
@ -328,7 +360,9 @@ class WorkerPoolTest : public ::testing::Test {
public:
WorkerPoolTest() {
RayConfig::instance().initialize(
R"({"worker_register_timeout_seconds": 3, "object_spilling_config": "dummy", "max_io_workers": )" +
R"({"worker_register_timeout_seconds": )" +
std::to_string(WORKER_REGISTER_TIMEOUT_SECONDS) +
R"(, "object_spilling_config": "dummy", "max_io_workers": )" +
std::to_string(MAX_IO_WORKER_SIZE) + "}");
SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA,
@ -344,7 +378,10 @@ class WorkerPoolTest : public ::testing::Test {
StartMockAgent();
}
virtual void TearDown() { AssertNoLeaks(); }
virtual void TearDown() {
AssertNoLeaks();
valid_uris.clear();
}
void AssertNoLeaks() { ASSERT_EQ(worker_pool_->pending_exit_idle_workers_.size(), 0); }
@ -448,12 +485,36 @@ class WorkerPoolTest : public ::testing::Test {
std::unique_ptr<WorkerPoolMock> worker_pool_;
};
static inline rpc::RuntimeEnvInfo ExampleRuntimeEnvInfo(
const std::vector<std::string> uris, bool eager_install = false) {
rpc::RuntimeEnv runtime_env;
for (auto &uri : uris) {
runtime_env.mutable_uris()->mutable_py_modules_uris()->Add(std::string(uri));
}
std::string runtime_env_string;
google::protobuf::util::MessageToJsonString(runtime_env, &runtime_env_string);
rpc::RuntimeEnvInfo runtime_env_info;
runtime_env_info.set_serialized_runtime_env(runtime_env_string);
for (auto &uri : uris) {
runtime_env_info.mutable_uris()->Add(std::string(uri));
}
runtime_env_info.set_runtime_env_eager_install(eager_install);
return runtime_env_info;
}
static inline rpc::RuntimeEnvInfo ExampleRuntimeEnvInfoFromString(
std::string serialized_runtime_env) {
rpc::RuntimeEnvInfo runtime_env_info;
runtime_env_info.set_serialized_runtime_env(serialized_runtime_env);
return runtime_env_info;
}
static inline TaskSpecification ExampleTaskSpec(
const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON,
const JobID &job_id = JOB_ID, const ActorID actor_creation_id = ActorID::Nil(),
const std::vector<std::string> &dynamic_worker_options = {},
const TaskID &task_id = TaskID::FromRandom(JobID::Nil()),
const std::string serialized_runtime_env = "") {
const rpc::RuntimeEnvInfo runtime_env_info = rpc::RuntimeEnvInfo()) {
rpc::TaskSpec message;
message.set_job_id(job_id.Binary());
message.set_language(language);
@ -472,7 +533,7 @@ static inline TaskSpecification ExampleTaskSpec(
} else {
message.set_type(TaskType::NORMAL_TASK);
}
message.mutable_runtime_env_info()->set_serialized_runtime_env(serialized_runtime_env);
message.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
return TaskSpecification(std::move(message));
}
@ -618,7 +679,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
rpc::JobConfig job_config = rpc::JobConfig();
job_config.add_code_search_path("/test/code_search_path");
job_config.set_num_java_workers_per_process(1);
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
job_config.add_jvm_options("-Xmx1g");
job_config.add_jvm_options("-Xms500m");
job_config.add_jvm_options("-Dmy-job.hello=world");
@ -1320,10 +1381,10 @@ TEST_F(WorkerPoolTest, PopWorkerWithRuntimeEnv) {
auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1);
const auto actor_creation_task_spec = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id, {"XXX=YYY"},
TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})");
TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"}));
const auto normal_task_spec = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {"XXX=YYY"},
TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})");
TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"}));
const auto normal_task_spec_without_runtime_env =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(), {});
// Pop worker for actor creation task again.
@ -1348,6 +1409,167 @@ TEST_F(WorkerPoolTest, PopWorkerWithRuntimeEnv) {
ASSERT_EQ(worker_pool_->GetProcessSize(), 3);
}
TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceJobLevel) {
// First part, test start job with eager installed runtime env.
{
auto job_id = JobID::FromInt(12345);
std::string uri = "s3://123";
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true);
rpc::JobConfig job_config;
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
// Start job.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(valid_uris.size(), 1);
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(valid_uris.size(), 0);
}
// Second part, test start job without eager installed runtime env.
{
auto job_id = JobID::FromInt(12345);
std::string uri = "s3://123";
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false);
rpc::JobConfig job_config;
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
// Start job.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(valid_uris.size(), 0);
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(valid_uris.size(), 0);
}
}
TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWorkerLevel) {
auto job_id = JobID::FromInt(12345);
std::string uri = "s3://123";
// First part, test URI reference with eager install.
{
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true);
rpc::JobConfig job_config;
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
// Start job with eager installed runtime env.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(valid_uris.size(), 1);
// Start actor with runtime env.
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id,
{"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info);
auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_EQ(valid_uris.size(), 1);
// Start task with runtime env.
const auto normal_task_spec =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(),
{"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info);
auto popped_normal_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_EQ(valid_uris.size(), 1);
// Disconnect actor worker.
worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(valid_uris.size(), 1);
// Disconnect task worker.
worker_pool_->DisconnectWorker(popped_normal_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(valid_uris.size(), 1);
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(valid_uris.size(), 0);
}
// Second part, test URI reference without eager install.
{
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, true);
auto runtime_env_info_without_eager_install = ExampleRuntimeEnvInfo({uri}, false);
rpc::JobConfig job_config;
job_config.mutable_runtime_env_info()->CopyFrom(
runtime_env_info_without_eager_install);
// Start job without eager installed runtime env.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(valid_uris.size(), 0);
// Start actor with runtime env.
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 2);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, actor_creation_id,
{"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), runtime_env_info);
auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_EQ(valid_uris.size(), 1);
// Start task with runtime env.
auto popped_normal_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_EQ(valid_uris.size(), 1);
// Disconnect actor worker.
worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(valid_uris.size(), 1);
// Disconnect task worker.
worker_pool_->DisconnectWorker(popped_normal_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(valid_uris.size(), 0);
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(valid_uris.size(), 0);
}
}
TEST_F(WorkerPoolTest, RuntimeEnvUriReferenceWithMultipleWorkers) {
auto job_id = JOB_ID;
std::string uri = "s3://567";
auto runtime_env_info = ExampleRuntimeEnvInfo({uri}, false);
rpc::JobConfig job_config;
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
job_config.mutable_runtime_env_info()->CopyFrom(runtime_env_info);
// Start job without eager installed runtime env.
worker_pool_->HandleJobStarted(job_id, job_config);
ASSERT_EQ(valid_uris.size(), 0);
// First part, test normal case with all worker registered.
{
// Start actors with runtime env. The Java actors will trigger a multi-worker process.
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), i + 1);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {},
TaskID::FromRandom(JobID::Nil()), runtime_env_info);
auto popped_actor_worker = worker_pool_->PopWorkerSync(actor_creation_task_spec);
ASSERT_NE(popped_actor_worker, nullptr);
workers.push_back(popped_actor_worker);
ASSERT_EQ(valid_uris.size(), 1);
}
// Make sure only one worker process has been started.
ASSERT_EQ(worker_pool_->GetProcessSize(), 1);
// Disconnect all actor workers.
for (auto &worker : workers) {
worker_pool_->DisconnectWorker(worker, rpc::WorkerExitType::IDLE_EXIT);
}
ASSERT_EQ(valid_uris.size(), 0);
}
// Second part, test corner case with some worker registration timeout.
{
// Start one actor with runtime env. The Java actor will trigger a multi-worker
// process.
auto actor_creation_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
const auto actor_creation_task_spec =
ExampleTaskSpec(ActorID::Nil(), Language::JAVA, job_id, actor_creation_id, {},
TaskID::FromRandom(JobID::Nil()), runtime_env_info);
PopWorkerStatus status;
// Only one worker registration. All the other worker registration times out.
auto popped_actor_worker = worker_pool_->PopWorkerSync(
actor_creation_task_spec, true, &status, NUM_WORKERS_PER_PROCESS_JAVA - 1);
ASSERT_EQ(valid_uris.size(), 1);
// Disconnect actor worker.
worker_pool_->DisconnectWorker(popped_actor_worker, rpc::WorkerExitType::IDLE_EXIT);
ASSERT_EQ(valid_uris.size(), 1);
// Sleep for a while to wait worker registration timeout.
std::this_thread::sleep_for(
std::chrono::seconds(WORKER_REGISTER_TIMEOUT_SECONDS + 1));
ASSERT_EQ(valid_uris.size(), 0);
}
// Finish the job.
worker_pool_->HandleJobFinished(job_id);
ASSERT_EQ(valid_uris.size(), 0);
}
TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) {
///
/// Check that a worker can be popped only if there is a
@ -1356,15 +1578,18 @@ TEST_F(WorkerPoolTest, CacheWorkersByRuntimeEnvHash) {
///
ASSERT_EQ(worker_pool_->GetProcessSize(), 0);
auto actor_creation_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1);
const auto actor_creation_task_spec_1 = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id,
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_1");
const auto task_spec_1 = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(),
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_1");
const auto task_spec_2 = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(),
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()), "mock_runtime_env_2");
const auto actor_creation_task_spec_1 =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, actor_creation_id,
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()),
ExampleRuntimeEnvInfoFromString("mock_runtime_env_1"));
const auto task_spec_1 =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(),
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()),
ExampleRuntimeEnvInfoFromString("mock_runtime_env_1"));
const auto task_spec_2 =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, JOB_ID, ActorID::Nil(),
/*dynamic_options=*/{}, TaskID::FromRandom(JobID::Nil()),
ExampleRuntimeEnvInfoFromString("mock_runtime_env_2"));
const WorkerCacheKey env1 = {"mock_runtime_env_1", {}};
const int runtime_env_hash_1 = env1.IntHash();
@ -1510,9 +1735,9 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) {
/* Test PopWorkerStatus RuntimeEnvCreationFailed */
// Create a task with bad runtime env.
const auto task_spec_with_bad_runtime_env =
ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(),
{"XXX=YYY"}, TaskID::FromRandom(JobID::Nil()), BAD_RUNTIME_ENV);
const auto task_spec_with_bad_runtime_env = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"},
TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfoFromString(BAD_RUNTIME_ENV));
popped_worker =
worker_pool_->PopWorkerSync(task_spec_with_bad_runtime_env, true, &status);
// PopWorker failed and the status is `RuntimeEnvCreationFailed`.
@ -1522,7 +1747,7 @@ TEST_F(WorkerPoolTest, PopWorkerStatus) {
// Create a task with available runtime env.
const auto task_spec_with_runtime_env = ExampleTaskSpec(
ActorID::Nil(), Language::PYTHON, job_id, ActorID::Nil(), {"XXX=YYY"},
TaskID::FromRandom(JobID::Nil()), R"({"uris": "XXX"})");
TaskID::FromRandom(JobID::Nil()), ExampleRuntimeEnvInfo({"XXX"}));
popped_worker = worker_pool_->PopWorkerSync(task_spec_with_runtime_env, true, &status);
// PopWorker success.
ASSERT_NE(popped_worker, nullptr);