[scheduler] Include depth and function descriptor in scheduling class (#20004)

This commit is contained in:
Alex Wu 2021-11-05 08:19:48 -07:00 committed by GitHub
parent 3d5cbc6e62
commit 146b3d6bcc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 200 additions and 42 deletions

View file

@ -44,12 +44,15 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation,
TaskSpecBuilder builder;
std::string task_name =
invocation.name.empty() ? functionDescriptor->DefaultTaskName() : invocation.name;
// TODO (Alex): Properly set the depth here?
builder.SetCommonTaskSpec(invocation.task_id, task_name, rpc::Language::CPP,
functionDescriptor, local_mode_ray_tuntime_.GetCurrentJobID(),
local_mode_ray_tuntime_.GetCurrentTaskId(), 0,
local_mode_ray_tuntime_.GetCurrentTaskId(), address, 1,
required_resources, required_placement_resources,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
std::make_pair(PlacementGroupID::Nil(), -1), true, "",
/*depth=*/0);
if (invocation.task_type == TaskType::NORMAL_TASK) {
} else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) {
invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID();

View file

@ -14,6 +14,7 @@ import ray
from ray.internal.internal_api import memory_summary
import ray.util.accelerators
import ray.cluster_utils
from ray._private.test_utils import fetch_prometheus
from ray._private.test_utils import (wait_for_condition, new_scheduler_enabled,
Semaphore, object_memory_usage,
@ -674,6 +675,57 @@ def test_gpu_scheduling_liveness(ray_start_cluster):
ray.get(trainer.train.remote(), timeout=30)
@pytest.mark.parametrize(
"ray_start_regular", [{
"_system_config": {
"metrics_report_interval_ms": 1000,
"complex_scheduling_class": True
}
}],
indirect=True)
def test_scheduling_class_depth(ray_start_regular):
node_info = ray.nodes()[0]
metrics_export_port = node_info["MetricsExportPort"]
addr = node_info["NodeManagerAddress"]
prom_addr = f"{addr}:{metrics_export_port}"
@ray.remote(num_cpus=1000)
def infeasible():
pass
@ray.remote(num_cpus=0)
def start_infeasible(n):
if n == 1:
ray.get(infeasible.remote())
ray.get(start_infeasible.remote(n - 1))
start_infeasible.remote(1)
infeasible.remote()
# We expect the 2 calls to `infeasible` to be separate scheduling classes
# because one has depth=1, and the other has depth=2.
metric_name = "ray_internal_num_infeasible_scheduling_classes"
def make_condition(n):
def condition():
_, metric_names, metric_samples = fetch_prometheus([prom_addr])
if metric_name in metric_names:
for sample in metric_samples:
if sample.name == metric_name and sample.value == n:
return True
return False
return condition
wait_for_condition(make_condition(2))
start_infeasible.remote(2)
wait_for_condition(make_condition(3))
start_infeasible.remote(4)
wait_for_condition(make_condition(4))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -103,6 +103,13 @@ RAY_CONFIG(bool, preallocate_plasma_memory, false)
/// even balancing of load. Low values (min 0.0) encourage more load spreading.
RAY_CONFIG(float, scheduler_spread_threshold, 0.5);
// TODO (Alex): Remove this feature flag once scheduling class capping is
// implemented.
/// Whether to include function descriptors, and depth in the
// scheduling class. / This causes tasks to be queued differently, so it may
// effect scheduling / behavior.
RAY_CONFIG(bool, complex_scheduling_class, false)
// The max allowed size in bytes of a return object from direct actor calls.
// Objects larger than this size will be spilled/promoted to plasma.
RAY_CONFIG(int64_t, max_direct_call_object_size, 100 * 1024)

View file

@ -37,7 +37,8 @@ SchedulingClassDescriptor &TaskSpecification::GetSchedulingClassDescriptor(
return it->second;
}
SchedulingClass TaskSpecification::GetSchedulingClass(const ResourceSet &sched_cls) {
SchedulingClass TaskSpecification::GetSchedulingClass(
const SchedulingClassDescriptor &sched_cls) {
SchedulingClass sched_cls_id;
absl::MutexLock lock(&mutex_);
auto it = sched_cls_to_id_.find(sched_cls);
@ -52,7 +53,7 @@ SchedulingClass TaskSpecification::GetSchedulingClass(const ResourceSet &sched_c
<< " types of tasks seen, this may reduce performance.";
}
sched_cls_to_id_[sched_cls] = sched_cls_id;
sched_id_to_cls_[sched_cls_id] = sched_cls;
sched_id_to_cls_.emplace(sched_cls_id, sched_cls);
} else {
sched_cls_id = it->second;
}
@ -90,12 +91,18 @@ void TaskSpecification::ComputeResources() {
}
if (!IsActorTask()) {
bool complex_scheduling_class = RayConfig::instance().complex_scheduling_class();
// There is no need to compute `SchedulingClass` for actor tasks since
// the actor tasks need not be scheduled.
const auto &resource_set = GetRequiredResources();
const auto &function_descriptor = complex_scheduling_class
? FunctionDescriptor()
: FunctionDescriptorBuilder::Empty();
auto depth = complex_scheduling_class ? GetDepth() : 0;
auto sched_cls_desc =
SchedulingClassDescriptor(resource_set, function_descriptor, depth);
// Map the scheduling class descriptor to an integer for performance.
auto sched_cls = GetRequiredPlacementResources();
sched_cls_id_ = GetSchedulingClass(sched_cls);
sched_cls_id_ = GetSchedulingClass(sched_cls_desc);
}
}
@ -240,6 +247,8 @@ std::string TaskSpecification::GetDebuggerBreakpoint() const {
return message_->debugger_breakpoint();
}
int64_t TaskSpecification::GetDepth() const { return message_->depth(); }
bool TaskSpecification::IsDriverTask() const {
return message_->type() == TaskType::DRIVER_TASK;
}
@ -360,7 +369,7 @@ std::string TaskSpecification::DebugString() const {
stream << ", task_id=" << TaskId() << ", task_name=" << GetName()
<< ", job_id=" << JobId() << ", num_args=" << NumArgs()
<< ", num_returns=" << NumReturns();
<< ", num_returns=" << NumReturns() << ", depth=" << GetDepth();
if (IsActorCreationTask()) {
// Print actor creation task spec.

View file

@ -29,10 +29,51 @@
extern "C" {
#include "ray/thirdparty/sha256.h"
}
namespace ray {
typedef int SchedulingClass;
struct SchedulingClassDescriptor {
public:
explicit SchedulingClassDescriptor(ResourceSet rs, FunctionDescriptor fd, int64_t d)
: resource_set(std::move(rs)), function_descriptor(std::move(fd)), depth(d) {}
ResourceSet resource_set;
FunctionDescriptor function_descriptor;
int64_t depth;
bool operator==(const SchedulingClassDescriptor &other) const {
return depth == other.depth && resource_set == other.resource_set &&
function_descriptor == other.function_descriptor;
}
std::string DebugString() const {
std::stringstream buffer;
buffer << "{"
<< "depth=" << depth << " "
<< "function_descriptor=" << function_descriptor->ToString() << " "
<< "resource_set="
<< "{";
for (const auto &pair : resource_set.GetResourceMap()) {
buffer << pair.first << " : " << pair.second << ", ";
}
buffer << "}}";
return buffer.str();
}
};
} // namespace ray
namespace std {
template <>
struct hash<ray::SchedulingClassDescriptor> {
size_t operator()(const ray::SchedulingClassDescriptor &sched_cls) const {
size_t hash = std::hash<ray::ResourceSet>()(sched_cls.resource_set);
hash ^= sched_cls.function_descriptor->Hash();
hash ^= sched_cls.depth;
return hash;
}
};
} // namespace std
namespace ray {
typedef ResourceSet SchedulingClassDescriptor;
typedef int SchedulingClass;
/// ConcurrencyGroup is a group of actor methods that shares
/// a executing thread pool.
@ -186,6 +227,11 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
std::string GetDebuggerBreakpoint() const;
/// Return the depth of this task. The depth of a graph, is the number of
/// `f.remote()` calls from the driver.
/// \return The depth.
int64_t GetDepth() const;
bool IsDriverTask() const;
Language GetLanguage() const;
@ -245,7 +291,7 @@ class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
static SchedulingClassDescriptor &GetSchedulingClassDescriptor(SchedulingClass id);
// Compute a static key that represents the given resource shape.
static SchedulingClass GetSchedulingClass(const ResourceSet &sched_cls);
static SchedulingClass GetSchedulingClass(const SchedulingClassDescriptor &sched_cls);
// Placement Group bundle that this task or actor creation is associated with.
const BundleID PlacementGroupBundleId() const;

View file

@ -104,7 +104,7 @@ class TaskSpecBuilder {
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const BundleID &bundle_id, bool placement_group_capture_child_tasks,
const std::string &debugger_breakpoint,
const std::string &debugger_breakpoint, int64_t depth,
const std::string &serialized_runtime_env = "{}",
const std::vector<std::string> &runtime_env_uris = {},
const std::string &concurrency_group_name = "") {
@ -128,6 +128,7 @@ class TaskSpecBuilder {
message_->set_placement_group_capture_child_tasks(
placement_group_capture_child_tasks);
message_->set_debugger_breakpoint(debugger_breakpoint);
message_->set_depth(depth);
message_->mutable_runtime_env()->set_serialized_runtime_env(serialized_runtime_env);
for (const std::string &uri : runtime_env_uris) {
message_->mutable_runtime_env()->add_uris(uri);

View file

@ -144,6 +144,14 @@ ObjectIDIndexType WorkerContext::GetNextPutIndex() {
return GetThreadContext().GetNextPutIndex();
}
int64_t WorkerContext::GetTaskDepth() const {
auto task_spec = GetCurrentTask();
if (task_spec) {
return task_spec->GetDepth();
}
return 0;
}
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id_; }
const TaskID &WorkerContext::GetCurrentTaskID() const {

View file

@ -82,6 +82,8 @@ class WorkerContext {
// Returns the next put object index; used to calculate ObjectIDs for puts.
ObjectIDIndexType GetNextPutIndex();
int64_t GetTaskDepth() const;
protected:
// allow unit test to set.
bool current_actor_is_direct_call_ = false;

View file

@ -1684,7 +1684,8 @@ void CoreWorker::BuildCommonTaskSpec(
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const BundleID &bundle_id, bool placement_group_capture_child_tasks,
const std::string &debugger_breakpoint, const std::string &serialized_runtime_env,
const std::string &debugger_breakpoint, int64_t depth,
const std::string &serialized_runtime_env,
const std::vector<std::string> &runtime_env_uris,
const std::string &concurrency_group_name) {
// Build common task spec.
@ -1692,7 +1693,7 @@ void CoreWorker::BuildCommonTaskSpec(
task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id,
current_task_id, task_index, caller_id, address, num_returns, required_resources,
required_placement_resources, bundle_id, placement_group_capture_child_tasks,
debugger_breakpoint,
debugger_breakpoint, depth,
// TODO(SongGuyang): Move the logic of `prepare_runtime_env` from Python to Core
// Worker. A common process is needed.
// If runtime env is not provided, use job config. Only for Java and C++ because it
@ -1724,12 +1725,13 @@ std::vector<rpc::ObjectReference> CoreWorker::SubmitTask(
auto task_name = task_options.name.empty()
? function.GetFunctionDescriptor()->DefaultTaskName()
: task_options.name;
int64_t depth = worker_context_.GetTaskDepth() + 1;
// TODO(ekl) offload task building onto a thread pool for performance
BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, task_options.num_returns,
constrained_resources, required_resources, placement_options,
placement_group_capture_child_tasks, debugger_breakpoint,
placement_group_capture_child_tasks, debugger_breakpoint, depth,
task_options.serialized_runtime_env, task_options.runtime_env_uris);
builder.SetNormalTaskSpec(max_retries, retry_exceptions);
TaskSpecification task_spec = builder.Build();
@ -1780,12 +1782,13 @@ Status CoreWorker::CreateActor(const RayFunction &function,
actor_name.empty()
? function.GetFunctionDescriptor()->DefaultTaskName()
: actor_name + ":" + function.GetFunctionDescriptor()->CallString();
int64_t depth = worker_context_.GetTaskDepth() + 1;
BuildCommonTaskSpec(builder, job_id, actor_creation_task_id, task_name,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, 1, new_resource,
new_placement_resources, actor_creation_options.placement_options,
actor_creation_options.placement_group_capture_child_tasks,
"", /* debugger_breakpoint */
"" /* debugger_breakpoint */, depth,
actor_creation_options.serialized_runtime_env,
actor_creation_options.runtime_env_uris);
@ -1967,14 +1970,19 @@ std::vector<rpc::ObjectReference> CoreWorker::SubmitActorTask(
const auto task_name = task_options.name.empty()
? function.GetFunctionDescriptor()->DefaultTaskName()
: task_options.name;
// Depth shouldn't matter for an actor task, but for consistency it should be
// the same as the actor creation task's depth.
int64_t depth = worker_context_.GetTaskDepth();
BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
rpc_address_, function, args, num_returns, task_options.resources,
required_resources, std::make_pair(PlacementGroupID::Nil(), -1),
true, /* placement_group_capture_child_tasks */
"", /* debugger_breakpoint */
"{}", /* serialized_runtime_env */
{}, /* runtime_env_uris */
true, /* placement_group_capture_child_tasks */
"", /* debugger_breakpoint */
depth, /*depth*/
"{}", /* serialized_runtime_env */
{}, /* runtime_env_uris */
task_options.concurrency_group_name);
// NOTE: placement_group_capture_child_tasks and runtime_env will
// be ignored in the actor because we should always follow the actor's option.

View file

@ -1064,7 +1064,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources,
const BundleID &bundle_id, bool placement_group_capture_child_tasks,
const std::string &debugger_breakpoint, const std::string &serialized_runtime_env,
const std::string &debugger_breakpoint, int64_t depth,
const std::string &serialized_runtime_env,
const std::vector<std::string> &runtime_env_uris,
const std::string &concurrency_group_name = "");
void SetCurrentTaskId(const TaskID &task_id);

View file

@ -531,7 +531,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
builder.SetCommonTaskSpec(RandomTaskId(), options.name, function.GetLanguage(),
function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0,
RandomTaskId(), address, num_returns, resources, resources,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
std::make_pair(PlacementGroupID::Nil(), -1), true, "", 0);
// Set task arguments.
for (const auto &arg : args) {
builder.AddArg(*arg);

View file

@ -29,8 +29,19 @@ namespace core {
// be better to use a mock clock or lease manager interface, but that's high
// overhead for the very simple timeout logic we currently have.
int64_t kLongTimeout = 1024 * 1024 * 1024;
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
const FunctionDescriptor &function_descriptor);
const FunctionDescriptor &function_descriptor,
int64_t depth = 0,
std::string serialized_runtime_env = "") {
TaskSpecBuilder builder;
rpc::Address empty_address;
builder.SetCommonTaskSpec(
TaskID::Nil(), "dummy_task", Language::PYTHON, function_descriptor, JobID::Nil(),
TaskID::Nil(), 0, TaskID::Nil(), empty_address, 1, resources, resources,
std::make_pair(PlacementGroupID::Nil(), -1), true, serialized_runtime_env, depth);
return builder.Build();
}
// Calls BuildTaskSpec with empty resources map and empty function descriptor
TaskSpecification BuildEmptyTaskSpec();
@ -492,17 +503,6 @@ TEST(LocalDependencyResolverTest, TestInlinedObjectIds) {
ASSERT_EQ(task_finisher->num_contained_ids, 2);
}
TaskSpecification BuildTaskSpec(const std::unordered_map<std::string, double> &resources,
const FunctionDescriptor &function_descriptor) {
TaskSpecBuilder builder;
rpc::Address empty_address;
builder.SetCommonTaskSpec(TaskID::Nil(), "dummy_task", Language::PYTHON,
function_descriptor, JobID::Nil(), TaskID::Nil(), 0,
TaskID::Nil(), empty_address, 1, resources, resources,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
return builder.Build();
}
TaskSpecification BuildEmptyTaskSpec() {
std::unordered_map<std::string, double> empty_resources;
FunctionDescriptor empty_descriptor =
@ -1215,6 +1215,7 @@ void TestSchedulingKey(const std::shared_ptr<CoreWorkerMemoryStore> store,
}
TEST(DirectTaskTransportTest, TestSchedulingKeys) {
RayConfig::instance().complex_scheduling_class() = true;
auto store = std::make_shared<CoreWorkerMemoryStore>();
std::unordered_map<std::string, double> resources1({{"a", 1.0}});
@ -1230,11 +1231,23 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) {
BuildTaskSpec(resources1, descriptor1),
BuildTaskSpec(resources2, descriptor1));
// Tasks with different function descriptors do not request different worker leases.
RAY_LOG(INFO) << "Test different descriptors";
// Tasks with different functions should request different worker leases.
RAY_LOG(INFO) << "Test different functions";
TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1),
BuildTaskSpec(resources1, descriptor2),
BuildTaskSpec(resources2, descriptor1));
BuildTaskSpec(resources1, descriptor1),
BuildTaskSpec(resources1, descriptor2));
// Tasks with different depths should request different worker leases.
RAY_LOG(INFO) << "Test different depths";
TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1, 0),
BuildTaskSpec(resources1, descriptor1, 0),
BuildTaskSpec(resources1, descriptor1, 1));
// Tasks with different runtime envs do not request different workers.
RAY_LOG(INFO) << "Test different runtimes";
TestSchedulingKey(store, BuildTaskSpec(resources1, descriptor1, 0, "a"),
BuildTaskSpec(resources1, descriptor1, 0, "b"),
BuildTaskSpec(resources1, descriptor1, 1, "a"));
ObjectID direct1 = ObjectID::FromRandom();
ObjectID direct2 = ObjectID::FromRandom();

View file

@ -45,7 +45,7 @@ struct Mocker {
Language::PYTHON, function_descriptor, job_id,
TaskID::Nil(), 0, TaskID::Nil(), owner_address, 1,
required_resources, required_placement_resources,
std::make_pair(PlacementGroupID::Nil(), -1), true, "");
std::make_pair(PlacementGroupID::Nil(), -1), true, "", 0);
builder.SetActorCreationTaskSpec(actor_id, {}, max_restarts,
/*max_task_retries=*/0, {}, 1, detached, name);
return builder.Build();

View file

@ -225,6 +225,9 @@ message TaskSpec {
string concurrency_group_name = 24;
// Whether application-level errors (exceptions) should be retried.
bool retry_exceptions = 25;
// The depth of the task. The driver has depth 0, anything it calls has depth
// 1, etc.
int64 depth = 26;
}
message Bundle {

View file

@ -726,7 +726,7 @@ void ClusterTaskManager::FillResourceUsage(
}
const auto &resources =
TaskSpecification::GetSchedulingClassDescriptor(scheduling_class)
.GetResourceMap();
.resource_set.GetResourceMap();
const auto &queue = pair.second;
const auto &count = queue.size();
@ -761,7 +761,7 @@ void ClusterTaskManager::FillResourceUsage(
}
const auto &resources =
TaskSpecification::GetSchedulingClassDescriptor(scheduling_class)
.GetResourceMap();
.resource_set.GetResourceMap();
const auto &queue = pair.second;
const auto &count = queue.size();
@ -792,7 +792,7 @@ void ClusterTaskManager::FillResourceUsage(
}
const auto &resources =
TaskSpecification::GetSchedulingClassDescriptor(scheduling_class)
.GetResourceMap();
.resource_set.GetResourceMap();
const auto &queue = pair.second;
const auto &count = queue.size();
@ -1017,6 +1017,7 @@ void ClusterTaskManager::RecordMetrics() {
stats::NumReceivedTasks.Record(metric_tasks_queued_);
stats::NumDispatchedTasks.Record(metric_tasks_dispatched_);
stats::NumSpilledTasks.Record(metric_tasks_spilled_);
stats::NumInfeasibleSchedulingClasses.Record(infeasible_tasks_.size());
metric_tasks_queued_ = 0;
metric_tasks_dispatched_ = 0;

View file

@ -134,7 +134,7 @@ RayTask CreateTask(const std::unordered_map<std::string, double> &required_resou
FunctionDescriptorBuilder::BuildPython("", "", "", ""),
job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0,
required_resources, {},
std::make_pair(PlacementGroupID::Nil(), -1), true, "",
std::make_pair(PlacementGroupID::Nil(), -1), true, "", 0,
serialized_runtime_env, runtime_env_uris);
if (!args.empty()) {

View file

@ -134,6 +134,10 @@ static Gauge NumInfeasibleTasks(
"internal_num_infeasible_tasks",
"The number of tasks in the scheduler that are in the 'infeasible' state.", "tasks");
static Gauge NumInfeasibleSchedulingClasses(
"internal_num_infeasible_scheduling_classes",
"The number of unique scheduling classes that are infeasible.", "tasks");
static Gauge SpillingBandwidthMB("object_spilling_bandwidth_mb",
"Bandwidth of object spilling.", "MB");