mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Push driver task in core worker (#5752)
This commit is contained in:
parent
62bc30c1cf
commit
61e5d674be
5 changed files with 31 additions and 64 deletions
|
@ -547,6 +547,9 @@ cdef class CoreWorker:
|
|||
check_status(self.core_worker.get().Objects().Delete(
|
||||
free_ids, local_only, delete_creating_tasks))
|
||||
|
||||
def get_current_task_id(self):
|
||||
return TaskID(self.core_worker.get().GetCurrentTaskId().Binary())
|
||||
|
||||
def set_current_task_id(self, TaskID task_id):
|
||||
cdef:
|
||||
CTaskID c_task_id = task_id.native()
|
||||
|
|
|
@ -78,4 +78,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||
# TODO(edoakes): remove these once the Python core worker uses the task
|
||||
# interfaces
|
||||
void SetCurrentJobId(const CJobID &job_id)
|
||||
CTaskID GetCurrentTaskId()
|
||||
void SetCurrentTaskId(const CTaskID &task_id)
|
||||
|
|
|
@ -1941,68 +1941,7 @@ def connect(node,
|
|||
worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name)
|
||||
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# If this is a driver, set the current task ID, the task driver ID, and set
|
||||
# the task index to 0.
|
||||
if mode == SCRIPT_MODE:
|
||||
# If the user provided an object_id_seed, then set the current task ID
|
||||
# deterministically based on that seed (without altering the state of
|
||||
# the user's random number generator). Otherwise, set the current task
|
||||
# ID randomly to avoid object ID collisions.
|
||||
numpy_state = np.random.get_state()
|
||||
if node.object_id_seed is not None:
|
||||
np.random.seed(node.object_id_seed)
|
||||
else:
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
# Reset the state of the numpy random number generator.
|
||||
np.random.set_state(numpy_state)
|
||||
|
||||
# Create an entry for the driver task in the task table. This task is
|
||||
# added immediately with status RUNNING. This allows us to push errors
|
||||
# related to this driver task back to the driver. For example, if the
|
||||
# driver creates an object that is later evicted, we should notify the
|
||||
# user that we're unable to reconstruct the object, since we cannot
|
||||
# rerun the driver.
|
||||
nil_actor_counter = 0
|
||||
|
||||
function_descriptor = FunctionDescriptor.for_driver_task()
|
||||
driver_task_spec = ray._raylet.TaskSpec(
|
||||
TaskID.for_driver_task(worker.current_job_id),
|
||||
worker.current_job_id,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
TaskID(worker.worker_id[:TaskID.size()]), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ActorID.nil(), # actor_creation_id.
|
||||
ObjectID.nil(), # actor_creation_dummy_object_id.
|
||||
ObjectID.nil(), # previous_actor_task_dummy_object_id.
|
||||
0, # max_actor_reconstructions.
|
||||
ActorID.nil(), # actor_id.
|
||||
ActorHandleID.nil(), # actor_handle_id.
|
||||
nil_actor_counter, # actor_counter.
|
||||
[], # new_actor_handles.
|
||||
{}, # resource_map.
|
||||
{}, # placement_resource_map.
|
||||
)
|
||||
task_table_data = ray._raylet.generate_gcs_task_table_data(
|
||||
driver_task_spec)
|
||||
|
||||
# Add the driver task to the task table.
|
||||
ray.state.state._execute_command(
|
||||
driver_task_spec.task_id(),
|
||||
"RAY.TABLE_ADD",
|
||||
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
|
||||
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
|
||||
driver_task_spec.task_id().binary(),
|
||||
task_table_data,
|
||||
)
|
||||
|
||||
# Set the driver's current task ID to the task ID assigned to the
|
||||
# driver task.
|
||||
worker.task_context.current_task_id = driver_task_spec.task_id()
|
||||
raise ValueError("Invalid worker mode. Expected DRIVER or WORKER.")
|
||||
|
||||
redis_address, redis_port = node.redis_address.split(":")
|
||||
gcs_options = ray._raylet.GcsClientOptions(
|
||||
|
@ -2018,8 +1957,8 @@ def connect(node,
|
|||
gcs_options,
|
||||
node.get_logs_dir_path(),
|
||||
)
|
||||
worker.core_worker.set_current_job_id(worker.current_job_id)
|
||||
worker.core_worker.set_current_task_id(worker.current_task_id)
|
||||
worker.task_context.current_task_id = (
|
||||
worker.core_worker.get_current_task_id())
|
||||
worker.raylet_client = ray._raylet.RayletClient(worker.core_worker)
|
||||
|
||||
if driver_object_store_memory is not None:
|
||||
|
|
|
@ -69,6 +69,28 @@ CoreWorker::CoreWorker(
|
|||
language_, rpc_server_port));
|
||||
|
||||
io_thread_ = std::thread(&CoreWorker::StartIOService, this);
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
// added immediately with status RUNNING. This allows us to push errors
|
||||
// related to this driver task back to the driver. For example, if the
|
||||
// driver creates an object that is later evicted, we should notify the
|
||||
// user that we're unable to reconstruct the object, since we cannot
|
||||
// rerun the driver.
|
||||
if (worker_type_ == WorkerType::DRIVER) {
|
||||
TaskSpecBuilder builder;
|
||||
std::vector<std::string> empty_descriptor;
|
||||
std::unordered_map<std::string, double> empty_resources;
|
||||
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
|
||||
builder.SetCommonTaskSpec(task_id, language_, empty_descriptor,
|
||||
worker_context_.GetCurrentJobID(),
|
||||
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
|
||||
0, 0, empty_resources, empty_resources);
|
||||
|
||||
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
|
||||
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
|
||||
RAY_CHECK_OK(gcs_client_->raylet_task_table().Add(job_id, task_id, data, nullptr));
|
||||
worker_context_.SetCurrentTaskId(task_id);
|
||||
}
|
||||
}
|
||||
|
||||
CoreWorker::~CoreWorker() {
|
||||
|
|
|
@ -71,6 +71,8 @@ class CoreWorker {
|
|||
return *task_execution_interface_;
|
||||
}
|
||||
|
||||
const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); }
|
||||
|
||||
// TODO(edoakes): remove this once Python core worker uses the task interfaces.
|
||||
void SetCurrentJobId(const JobID &job_id) { worker_context_.SetCurrentJobId(job_id); }
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue