diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index c0bd06926..fc992de3a 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -547,6 +547,9 @@ cdef class CoreWorker: check_status(self.core_worker.get().Objects().Delete( free_ids, local_only, delete_creating_tasks)) + def get_current_task_id(self): + return TaskID(self.core_worker.get().GetCurrentTaskId().Binary()) + def set_current_task_id(self, TaskID task_id): cdef: CTaskID c_task_id = task_id.native() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 3199ed451..f45af2bbf 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -78,4 +78,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: # TODO(edoakes): remove these once the Python core worker uses the task # interfaces void SetCurrentJobId(const CJobID &job_id) + CTaskID GetCurrentTaskId() void SetCurrentTaskId(const CTaskID &task_id) diff --git a/python/ray/worker.py b/python/ray/worker.py index 6047d0d7b..b8d6f6e85 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1941,68 +1941,7 @@ def connect(node, worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name) worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict) else: - raise Exception("This code should be unreachable.") - - # If this is a driver, set the current task ID, the task driver ID, and set - # the task index to 0. - if mode == SCRIPT_MODE: - # If the user provided an object_id_seed, then set the current task ID - # deterministically based on that seed (without altering the state of - # the user's random number generator). Otherwise, set the current task - # ID randomly to avoid object ID collisions. - numpy_state = np.random.get_state() - if node.object_id_seed is not None: - np.random.seed(node.object_id_seed) - else: - # Try to use true randomness. - np.random.seed(None) - # Reset the state of the numpy random number generator. - np.random.set_state(numpy_state) - - # Create an entry for the driver task in the task table. This task is - # added immediately with status RUNNING. This allows us to push errors - # related to this driver task back to the driver. For example, if the - # driver creates an object that is later evicted, we should notify the - # user that we're unable to reconstruct the object, since we cannot - # rerun the driver. - nil_actor_counter = 0 - - function_descriptor = FunctionDescriptor.for_driver_task() - driver_task_spec = ray._raylet.TaskSpec( - TaskID.for_driver_task(worker.current_job_id), - worker.current_job_id, - function_descriptor.get_function_descriptor_list(), - [], # arguments. - 0, # num_returns. - TaskID(worker.worker_id[:TaskID.size()]), # parent_task_id. - 0, # parent_counter. - ActorID.nil(), # actor_creation_id. - ObjectID.nil(), # actor_creation_dummy_object_id. - ObjectID.nil(), # previous_actor_task_dummy_object_id. - 0, # max_actor_reconstructions. - ActorID.nil(), # actor_id. - ActorHandleID.nil(), # actor_handle_id. - nil_actor_counter, # actor_counter. - [], # new_actor_handles. - {}, # resource_map. - {}, # placement_resource_map. - ) - task_table_data = ray._raylet.generate_gcs_task_table_data( - driver_task_spec) - - # Add the driver task to the task table. - ray.state.state._execute_command( - driver_task_spec.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), - ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), - driver_task_spec.task_id().binary(), - task_table_data, - ) - - # Set the driver's current task ID to the task ID assigned to the - # driver task. - worker.task_context.current_task_id = driver_task_spec.task_id() + raise ValueError("Invalid worker mode. Expected DRIVER or WORKER.") redis_address, redis_port = node.redis_address.split(":") gcs_options = ray._raylet.GcsClientOptions( @@ -2018,8 +1957,8 @@ def connect(node, gcs_options, node.get_logs_dir_path(), ) - worker.core_worker.set_current_job_id(worker.current_job_id) - worker.core_worker.set_current_task_id(worker.current_task_id) + worker.task_context.current_task_id = ( + worker.core_worker.get_current_task_id()) worker.raylet_client = ray._raylet.RayletClient(worker.core_worker) if driver_object_store_memory is not None: diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 42d6fb7e3..ec794897c 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -69,6 +69,28 @@ CoreWorker::CoreWorker( language_, rpc_server_port)); io_thread_ = std::thread(&CoreWorker::StartIOService, this); + + // Create an entry for the driver task in the task table. This task is + // added immediately with status RUNNING. This allows us to push errors + // related to this driver task back to the driver. For example, if the + // driver creates an object that is later evicted, we should notify the + // user that we're unable to reconstruct the object, since we cannot + // rerun the driver. + if (worker_type_ == WorkerType::DRIVER) { + TaskSpecBuilder builder; + std::vector empty_descriptor; + std::unordered_map empty_resources; + const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID()); + builder.SetCommonTaskSpec(task_id, language_, empty_descriptor, + worker_context_.GetCurrentJobID(), + TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), + 0, 0, empty_resources, empty_resources); + + std::shared_ptr data = std::make_shared(); + data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); + RAY_CHECK_OK(gcs_client_->raylet_task_table().Add(job_id, task_id, data, nullptr)); + worker_context_.SetCurrentTaskId(task_id); + } } CoreWorker::~CoreWorker() { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 3fcdbfdbb..829baac65 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -71,6 +71,8 @@ class CoreWorker { return *task_execution_interface_; } + const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); } + // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentJobId(const JobID &job_id) { worker_context_.SetCurrentJobId(job_id); }