Push driver task in core worker (#5752)

This commit is contained in:
Edward Oakes
2019-09-23 10:53:55 -05:00
committed by GitHub
parent 62bc30c1cf
commit 61e5d674be
5 changed files with 31 additions and 64 deletions
+3
View File
@@ -547,6 +547,9 @@ cdef class CoreWorker:
check_status(self.core_worker.get().Objects().Delete(
free_ids, local_only, delete_creating_tasks))
def get_current_task_id(self):
return TaskID(self.core_worker.get().GetCurrentTaskId().Binary())
def set_current_task_id(self, TaskID task_id):
cdef:
CTaskID c_task_id = task_id.native()
+1
View File
@@ -78,4 +78,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
# TODO(edoakes): remove these once the Python core worker uses the task
# interfaces
void SetCurrentJobId(const CJobID &job_id)
CTaskID GetCurrentTaskId()
void SetCurrentTaskId(const CTaskID &task_id)
+3 -64
View File
@@ -1941,68 +1941,7 @@ def connect(node,
worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name)
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
else:
raise Exception("This code should be unreachable.")
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
if mode == SCRIPT_MODE:
# If the user provided an object_id_seed, then set the current task ID
# deterministically based on that seed (without altering the state of
# the user's random number generator). Otherwise, set the current task
# ID randomly to avoid object ID collisions.
numpy_state = np.random.get_state()
if node.object_id_seed is not None:
np.random.seed(node.object_id_seed)
else:
# Try to use true randomness.
np.random.seed(None)
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
# Create an entry for the driver task in the task table. This task is
# added immediately with status RUNNING. This allows us to push errors
# related to this driver task back to the driver. For example, if the
# driver creates an object that is later evicted, we should notify the
# user that we're unable to reconstruct the object, since we cannot
# rerun the driver.
nil_actor_counter = 0
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task_spec = ray._raylet.TaskSpec(
TaskID.for_driver_task(worker.current_job_id),
worker.current_job_id,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
TaskID(worker.worker_id[:TaskID.size()]), # parent_task_id.
0, # parent_counter.
ActorID.nil(), # actor_creation_id.
ObjectID.nil(), # actor_creation_dummy_object_id.
ObjectID.nil(), # previous_actor_task_dummy_object_id.
0, # max_actor_reconstructions.
ActorID.nil(), # actor_id.
ActorHandleID.nil(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
{}, # resource_map.
{}, # placement_resource_map.
)
task_table_data = ray._raylet.generate_gcs_task_table_data(
driver_task_spec)
# Add the driver task to the task table.
ray.state.state._execute_command(
driver_task_spec.task_id(),
"RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
driver_task_spec.task_id().binary(),
task_table_data,
)
# Set the driver's current task ID to the task ID assigned to the
# driver task.
worker.task_context.current_task_id = driver_task_spec.task_id()
raise ValueError("Invalid worker mode. Expected DRIVER or WORKER.")
redis_address, redis_port = node.redis_address.split(":")
gcs_options = ray._raylet.GcsClientOptions(
@@ -2018,8 +1957,8 @@ def connect(node,
gcs_options,
node.get_logs_dir_path(),
)
worker.core_worker.set_current_job_id(worker.current_job_id)
worker.core_worker.set_current_task_id(worker.current_task_id)
worker.task_context.current_task_id = (
worker.core_worker.get_current_task_id())
worker.raylet_client = ray._raylet.RayletClient(worker.core_worker)
if driver_object_store_memory is not None: