Push driver task in core worker (#5752)

This commit is contained in:
Edward Oakes 2019-09-23 10:53:55 -05:00 committed by GitHub
parent 62bc30c1cf
commit 61e5d674be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 64 deletions

View file

@ -547,6 +547,9 @@ cdef class CoreWorker:
check_status(self.core_worker.get().Objects().Delete(
free_ids, local_only, delete_creating_tasks))
def get_current_task_id(self):
return TaskID(self.core_worker.get().GetCurrentTaskId().Binary())
def set_current_task_id(self, TaskID task_id):
cdef:
CTaskID c_task_id = task_id.native()

View file

@ -78,4 +78,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
# TODO(edoakes): remove these once the Python core worker uses the task
# interfaces
void SetCurrentJobId(const CJobID &job_id)
CTaskID GetCurrentTaskId()
void SetCurrentTaskId(const CTaskID &task_id)

View file

@ -1941,68 +1941,7 @@ def connect(node,
worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name)
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
else:
raise Exception("This code should be unreachable.")
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
if mode == SCRIPT_MODE:
# If the user provided an object_id_seed, then set the current task ID
# deterministically based on that seed (without altering the state of
# the user's random number generator). Otherwise, set the current task
# ID randomly to avoid object ID collisions.
numpy_state = np.random.get_state()
if node.object_id_seed is not None:
np.random.seed(node.object_id_seed)
else:
# Try to use true randomness.
np.random.seed(None)
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
# Create an entry for the driver task in the task table. This task is
# added immediately with status RUNNING. This allows us to push errors
# related to this driver task back to the driver. For example, if the
# driver creates an object that is later evicted, we should notify the
# user that we're unable to reconstruct the object, since we cannot
# rerun the driver.
nil_actor_counter = 0
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task_spec = ray._raylet.TaskSpec(
TaskID.for_driver_task(worker.current_job_id),
worker.current_job_id,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
TaskID(worker.worker_id[:TaskID.size()]), # parent_task_id.
0, # parent_counter.
ActorID.nil(), # actor_creation_id.
ObjectID.nil(), # actor_creation_dummy_object_id.
ObjectID.nil(), # previous_actor_task_dummy_object_id.
0, # max_actor_reconstructions.
ActorID.nil(), # actor_id.
ActorHandleID.nil(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
{}, # resource_map.
{}, # placement_resource_map.
)
task_table_data = ray._raylet.generate_gcs_task_table_data(
driver_task_spec)
# Add the driver task to the task table.
ray.state.state._execute_command(
driver_task_spec.task_id(),
"RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
driver_task_spec.task_id().binary(),
task_table_data,
)
# Set the driver's current task ID to the task ID assigned to the
# driver task.
worker.task_context.current_task_id = driver_task_spec.task_id()
raise ValueError("Invalid worker mode. Expected DRIVER or WORKER.")
redis_address, redis_port = node.redis_address.split(":")
gcs_options = ray._raylet.GcsClientOptions(
@ -2018,8 +1957,8 @@ def connect(node,
gcs_options,
node.get_logs_dir_path(),
)
worker.core_worker.set_current_job_id(worker.current_job_id)
worker.core_worker.set_current_task_id(worker.current_task_id)
worker.task_context.current_task_id = (
worker.core_worker.get_current_task_id())
worker.raylet_client = ray._raylet.RayletClient(worker.core_worker)
if driver_object_store_memory is not None:

View file

@ -69,6 +69,28 @@ CoreWorker::CoreWorker(
language_, rpc_server_port));
io_thread_ = std::thread(&CoreWorker::StartIOService, this);
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
// related to this driver task back to the driver. For example, if the
// driver creates an object that is later evicted, we should notify the
// user that we're unable to reconstruct the object, since we cannot
// rerun the driver.
if (worker_type_ == WorkerType::DRIVER) {
TaskSpecBuilder builder;
std::vector<std::string> empty_descriptor;
std::unordered_map<std::string, double> empty_resources;
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
builder.SetCommonTaskSpec(task_id, language_, empty_descriptor,
worker_context_.GetCurrentJobID(),
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
0, 0, empty_resources, empty_resources);
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
RAY_CHECK_OK(gcs_client_->raylet_task_table().Add(job_id, task_id, data, nullptr));
worker_context_.SetCurrentTaskId(task_id);
}
}
CoreWorker::~CoreWorker() {

View file

@ -71,6 +71,8 @@ class CoreWorker {
return *task_execution_interface_;
}
const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); }
// TODO(edoakes): remove this once Python core worker uses the task interfaces.
void SetCurrentJobId(const JobID &job_id) { worker_context_.SetCurrentJobId(job_id); }