diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index aa50a8898..1383f5edc 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -75,6 +75,7 @@ from ray.includes.ray_config cimport RayConfig import ray import ray.experimental.signal as ray_signal +import ray.memory_monitor as memory_monitor import ray.ray_constants as ray_constants from ray import profiling from ray.exceptions import ( @@ -178,13 +179,6 @@ cdef c_vector[CObjectID] ObjectIDsToVector(object_ids): return result -def compute_put_id(TaskID task_id, int64_t put_index): - if put_index < 1 or put_index > CObjectID.MaxObjectIndex(): - raise ValueError("The range of 'put_index' should be [1, %d]" - % CObjectID.MaxObjectIndex()) - return ObjectID(CObjectID.ForPut(task_id.native(), put_index, 0).Binary()) - - def compute_task_id(ObjectID object_id): return TaskID(object_id.native().TaskId().Binary()) @@ -460,26 +454,6 @@ cdef deserialize_args( return ray.signature.recover_args(args) -cdef _check_worker_state(worker, CTaskType task_type, JobID job_id): - assert worker.current_task_id.is_nil() - assert worker.task_context.task_index == 0 - assert worker.task_context.put_index == 1 - - # If this worker is not an actor, check that `current_job_id` - # was reset when the worker finished the previous task. - if task_type in [TASK_TYPE_NORMAL_TASK, - TASK_TYPE_ACTOR_CREATION_TASK]: - assert worker.current_job_id.is_nil() - # Set the driver ID of the current running task. This is - # needed so that if the task throws an exception, we propagate - # the error message to the correct driver. - worker.current_job_id = job_id - else: - # If this worker is an actor, current_job_id wasn't reset. - # Check that current task's driver ID equals the previous - # one. - assert worker.current_job_id == job_id - cdef _store_task_outputs(worker, return_ids, outputs): for i in range(len(return_ids)): @@ -494,14 +468,12 @@ cdef _store_task_outputs(worker, return_ids, outputs): "from a remote function, but the corresponding " "ObjectID does not exist in the local object store.") else: - worker.put_object(return_id, output) + worker.put_object(output, object_id=return_id) cdef execute_task( CTaskType task_type, const CRayFunction &ray_function, - const CJobID &c_job_id, - const CActorID &c_actor_id, const unordered_map[c_string, double] &c_resources, const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, @@ -510,14 +482,10 @@ cdef execute_task( worker = ray.worker.global_worker - actor_id = ActorID(c_actor_id.Binary()) - job_id = JobID(c_job_id.Binary()) + actor_id = worker.core_worker.get_actor_id() + job_id = worker.core_worker.get_current_job_id() task_id = worker.core_worker.get_current_task_id() - # Check that the worker is in the expected state to execute the task. - _check_worker_state(worker, task_type, job_id) - worker.task_context.current_task_id = task_id - # Automatically restrict the GPUs available to this task. ray.utils.set_cuda_visible_devices(ray.get_gpu_ids()) @@ -525,7 +493,6 @@ cdef execute_task( ray_function.GetFunctionDescriptor()) if task_type == TASK_TYPE_ACTOR_CREATION_TASK: - worker.actor_id = actor_id actor_class = worker.function_actor_manager.load_actor_class( job_id, function_descriptor) worker.actors[actor_id] = actor_class.__new__(actor_class) @@ -556,7 +523,7 @@ cdef execute_task( ray_constants.from_memory_units( dereference(c_resources.find(b"memory")).second)) if c_resources.find(b"object_store_memory") != c_resources.end(): - worker._set_object_store_client_options( + worker.core_worker.set_object_store_client_options( worker_name, int(ray_constants.from_memory_units( dereference( @@ -613,19 +580,10 @@ cdef execute_task( # Send signal with the error. ray_signal.send(ray_signal.ErrorSignal(str(failure_object))) - # Reset the state fields so the next task can run. - worker.task_context.current_task_id = TaskID.nil() - worker.core_worker.set_current_task_id(TaskID.nil()) - worker.task_context.task_index = 0 - worker.task_context.put_index = 1 - # Don't need to reset `current_job_id` if the worker is an # actor. Because the following tasks should all have the # same driver id. if task_type == TASK_TYPE_NORMAL_TASK: - worker.current_job_id = JobID.nil() - worker.core_worker.set_current_job_id(JobID.nil()) - # Reset signal counters so that the next task can get # all past signals. ray_signal.reset() @@ -646,8 +604,6 @@ cdef execute_task( cdef CRayStatus task_execution_handler( CTaskType task_type, const CRayFunction &ray_function, - const CJobID &c_job_id, - const CActorID &c_actor_id, const unordered_map[c_string, double] &c_resources, const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, @@ -658,8 +614,7 @@ cdef CRayStatus task_execution_handler( try: # The call to execute_task should never raise an exception. If it # does, that indicates that there was an unexpected internal error. - execute_task(task_type, ray_function, c_job_id, - c_actor_id, c_resources, c_args, + execute_task(task_type, ray_function, c_resources, c_args, c_arg_reference_ids, c_return_ids, returns) except Exception: traceback_str = traceback.format_exc() + ( @@ -715,22 +670,11 @@ cdef class CoreWorker: def get_current_task_id(self): return TaskID(self.core_worker.get().GetCurrentTaskId().Binary()) - def set_current_task_id(self, TaskID task_id): - cdef: - CTaskID c_task_id = task_id.native() - - with nogil: - self.core_worker.get().SetCurrentTaskId(c_task_id) - def get_current_job_id(self): return JobID(self.core_worker.get().GetCurrentJobId().Binary()) - def set_current_job_id(self, JobID job_id): - cdef: - CJobID c_job_id = job_id.native() - - with nogil: - self.core_worker.get().SetCurrentJobId(c_job_id) + def get_actor_id(self): + return ActorID(self.core_worker.get().GetActorId().Binary()) def get_objects(self, object_ids, TaskID current_task_id, int64_t timeout_ms=-1): @@ -756,87 +700,108 @@ cdef class CoreWorker: return has_object - def put_serialized_object(self, serialized_object, ObjectID object_id, - int memcopy_threads=6): - cdef: - shared_ptr[CBuffer] data - shared_ptr[CBuffer] metadata - CObjectID c_object_id = object_id.native() - size_t data_size - - data_size = serialized_object.total_bytes - - with nogil: - check_status(self.core_worker.get().Objects().Create( - metadata, data_size, c_object_id, &data)) + cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, + size_t data_size, ObjectID object_id, + CObjectID *c_object_id, shared_ptr[CBuffer] *data): + delay = ray_constants.DEFAULT_PUT_OBJECT_DELAY + for attempt in reversed( + range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)): + try: + if object_id is None: + with nogil: + check_status(self.core_worker.get().Objects().Create( + metadata, data_size, c_object_id, data)) + else: + c_object_id[0] = object_id.native() + with nogil: + check_status(self.core_worker.get().Objects().Create( + metadata, data_size, c_object_id[0], data)) + break + except ObjectStoreFullError as e: + if attempt: + logger.warning("Waiting {} seconds for space to free up " + "in the object store.".format(delay)) + time.sleep(delay) + delay *= 2 + else: + self.dump_object_store_memory_usage() + raise e # If data is nullptr, that means the ObjectID already existed, # which we ignore. # TODO(edoakes): this is hacky, we should return the error instead # and deal with it here. - if not data: - return + return data.get() == NULL - stream = pyarrow.FixedSizeBufferWriter( - pyarrow.py_buffer(Buffer.make(data))) - stream.set_memcopy_threads(memcopy_threads) - serialized_object.write_to(stream) + def put_serialized_object(self, serialized_object, ObjectID object_id=None, + int memcopy_threads=6): + cdef: + CObjectID c_object_id + shared_ptr[CBuffer] data + shared_ptr[CBuffer] metadata - with nogil: - check_status(self.core_worker.get().Objects().Seal(c_object_id)) + object_already_exists = self._create_put_buffer( + metadata, serialized_object.total_bytes, + object_id, &c_object_id, &data) + if not object_already_exists: + stream = pyarrow.FixedSizeBufferWriter( + pyarrow.py_buffer(Buffer.make(data))) + stream.set_memcopy_threads(memcopy_threads) + serialized_object.write_to(stream) - def put_raw_buffer(self, c_string value, ObjectID object_id, + with nogil: + check_status( + self.core_worker.get().Objects().Seal(c_object_id)) + + return ObjectID(c_object_id.Binary()) + + def put_raw_buffer(self, c_string value, ObjectID object_id=None, int memcopy_threads=6): cdef: c_string metadata_str = RAW_BUFFER_METADATA - CObjectID c_object_id = object_id.native() + CObjectID c_object_id shared_ptr[CBuffer] data shared_ptr[CBuffer] metadata = dynamic_pointer_cast[ CBuffer, LocalMemoryBuffer]( make_shared[LocalMemoryBuffer]( (metadata_str.data()), metadata_str.size())) - with nogil: - check_status(self.core_worker.get().Objects().Create( - metadata, value.size(), c_object_id, &data)) + object_already_exists = self._create_put_buffer( + metadata, value.size(), object_id, &c_object_id, &data) + if not object_already_exists: + stream = pyarrow.FixedSizeBufferWriter( + pyarrow.py_buffer(Buffer.make(data))) + stream.set_memcopy_threads(memcopy_threads) + stream.write(pyarrow.py_buffer(value)) - stream = pyarrow.FixedSizeBufferWriter( - pyarrow.py_buffer(Buffer.make(data))) - stream.set_memcopy_threads(memcopy_threads) - stream.write(pyarrow.py_buffer(value)) + with nogil: + check_status( + self.core_worker.get().Objects().Seal(c_object_id)) - with nogil: - check_status(self.core_worker.get().Objects().Seal(c_object_id)) + return ObjectID(c_object_id.Binary()) - def put_pickle5_buffers(self, ObjectID object_id, c_string inband, - Pickle5Writer writer, - int memcopy_threads): + def put_pickle5_buffers(self, c_string inband, + Pickle5Writer writer, ObjectID object_id=None, + int memcopy_threads=6): cdef: - shared_ptr[CBuffer] data + CObjectID c_object_id c_string metadata_str = PICKLE5_BUFFER_METADATA + shared_ptr[CBuffer] data shared_ptr[CBuffer] metadata = dynamic_pointer_cast[ CBuffer, LocalMemoryBuffer]( make_shared[LocalMemoryBuffer]( (metadata_str.data()), metadata_str.size())) - CObjectID c_object_id = object_id.native() - size_t data_size - data_size = writer.get_total_bytes(inband) + object_already_exists = self._create_put_buffer( + metadata, writer.get_total_bytes(inband), + object_id, &c_object_id, &data) + if not object_already_exists: + writer.write_to(inband, data, memcopy_threads) + with nogil: + check_status( + self.core_worker.get().Objects().Seal(c_object_id)) - with nogil: - check_status(self.core_worker.get().Objects().Create( - metadata, data_size, c_object_id, &data)) - - # If data is nullptr, that means the ObjectID already existed, - # which we ignore. - # TODO(edoakes): this is hacky, we should return the error instead - # and deal with it here. - if not data: - return - - writer.write_to(inband, data, memcopy_threads) - with nogil: - check_status(self.core_worker.get().Objects().Seal(c_object_id)) + return ObjectID(c_object_id.Binary()) def wait(self, object_ids, int num_returns, int64_t timeout_ms, TaskID current_task_id): @@ -860,7 +825,7 @@ cdef class CoreWorker: else: not_ready.append(object_id) - return (ready, not_ready) + return ready, not_ready def free_objects(self, object_ids, c_bool local_only, c_bool delete_creating_tasks): @@ -871,20 +836,27 @@ cdef class CoreWorker: check_status(self.core_worker.get().Objects().Delete( free_ids, local_only, delete_creating_tasks)) - def set_object_store_client_options(self, c_string client_name, + def set_object_store_client_options(self, client_name, int64_t limit_bytes): - with nogil: + try: + logger.debug("Setting plasma memory limit to {} for {}".format( + limit_bytes, client_name)) check_status(self.core_worker.get().Objects().SetClientOptions( - client_name, limit_bytes)) + client_name.encode("ascii"), limit_bytes)) + except RayError as e: + self.dump_object_store_memory_usage() + raise memory_monitor.RayOutOfMemoryError( + "Failed to set object_store_memory={} for {}. The " + "plasma store may have insufficient memory remaining " + "to satisfy this limit (30% of object store memory is " + "permanently reserved for shared usage). The current " + "object store memory status is:\n\n{}".format( + limit_bytes, client_name, e)) - def object_store_memory_usage_string(self): - cdef: - c_string message - - with nogil: - message = self.core_worker.get().Objects().MemoryUsageString() - - return message.decode("utf-8") + def dump_object_store_memory_usage(self): + message = self.core_worker.get().Objects().MemoryUsageString() + logger.warning("Local object store memory usage:\n{}\n".format( + message.decode("utf-8"))) def submit_task(self, function_descriptor, diff --git a/python/ray/experimental/serve/queues.py b/python/ray/experimental/serve/queues.py index f72d87f32..997d61ab4 100644 --- a/python/ray/experimental/serve/queues.py +++ b/python/ray/experimental/serve/queues.py @@ -3,7 +3,7 @@ from collections import defaultdict, deque import numpy as np import ray -from ray.experimental.serve.utils import get_custom_object_id, logger +from ray.experimental.serve.utils import logger class Query: @@ -17,7 +17,7 @@ class Query: self.request_context = request_context if result_object_id is None: - self.result_object_id = get_custom_object_id() + self.result_object_id = ray.ObjectID.from_random() else: self.result_object_id = result_object_id @@ -25,7 +25,7 @@ class Query: class WorkIntent: def __init__(self, work_object_id=None): if work_object_id is None: - self.work_object_id = get_custom_object_id() + self.work_object_id = ray.ObjectID.from_random() else: self.work_object_id = work_object_id @@ -160,7 +160,7 @@ class CentralizedQueues: work_queue.popleft(), ) ray.worker.global_worker.put_object( - work.work_object_id, request) + request, work.work_object_id) @ray.remote diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index af072e131..eff3a9c5c 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -123,12 +123,12 @@ class RayServeMixin: start_timestamp = time.time() try: result = self.__call__(*args, **kwargs) - ray.worker.global_worker.put_object(result_object_id, result) + ray.worker.global_worker.put_object(result, result_object_id) except Exception as e: wrapped_exception = wrap_to_ray_error(e) self._serve_metric_error_counter += 1 - ray.worker.global_worker.put_object(result_object_id, - wrapped_exception) + ray.worker.global_worker.put_object(wrapped_exception, + result_object_id) self._serve_metric_latency_list.append(time.time() - start_timestamp) serve_context.web = False diff --git a/python/ray/experimental/serve/tests/test_queue.py b/python/ray/experimental/serve/tests/test_queue.py index 49bd05103..61d7d336c 100644 --- a/python/ray/experimental/serve/tests/test_queue.py +++ b/python/ray/experimental/serve/tests/test_queue.py @@ -12,7 +12,7 @@ def test_single_prod_cons_queue(serve_instance): assert got_work.request_args == 1 assert got_work.request_kwargs == "kwargs" - ray.worker.global_worker.put_object(got_work.result_object_id, 2) + ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2 @@ -24,7 +24,7 @@ def test_alter_backend(serve_instance): work_object_id = q.dequeue_request("backend-1") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_args == 1 - ray.worker.global_worker.put_object(got_work.result_object_id, 2) + ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2 q.set_traffic("svc", {"backend-2": 1}) @@ -32,7 +32,7 @@ def test_alter_backend(serve_instance): work_object_id = q.dequeue_request("backend-2") got_work = ray.get(ray.ObjectID(work_object_id)) assert got_work.request_args == 1 - ray.worker.global_worker.put_object(got_work.result_object_id, 2) + ray.worker.global_worker.put_object(2, got_work.result_object_id) assert ray.get(ray.ObjectID(result_object_id)) == 2 diff --git a/python/ray/experimental/serve/utils.py b/python/ray/experimental/serve/utils.py index 62fd663e6..60a0679fc 100644 --- a/python/ray/experimental/serve/utils.py +++ b/python/ray/experimental/serve/utils.py @@ -3,8 +3,6 @@ import logging from pygments import formatters, highlight, lexers -import ray - def _get_logger(): logger = logging.getLogger("ray.serve") @@ -32,15 +30,6 @@ class BytesEncoder(json.JSONEncoder): return super().default(o) -def get_custom_object_id(): - """Use ray worker API to get computed ObjectID""" - worker = ray.worker.global_worker - object_id = ray._raylet.compute_put_id(worker.current_task_id, - worker.task_context.put_index) - worker.task_context.put_index += 1 - return object_id - - def pformat_color_json(d): """Use pygments to pretty format and colroize dictionary""" formatted_json = json.dumps(d, sort_keys=True, indent=4) diff --git a/python/ray/experimental/streaming/batched_queue.py b/python/ray/experimental/streaming/batched_queue.py index 13cbf5673..9ba0c2795 100644 --- a/python/ray/experimental/streaming/batched_queue.py +++ b/python/ray/experimental/streaming/batched_queue.py @@ -141,8 +141,8 @@ class BatchedQueue(object): if not self.write_buffer: return batch_id = self._batch_id(self.write_batch_offset) - ray.worker.global_worker.put_object( - ray.ObjectID(batch_id), self.write_buffer) + ray.worker.global_worker.put_object(self.write_buffer, + ray.ObjectID(batch_id)) logger.debug("[writer] Flush batch {} offset {} size {}".format( self.write_batch_offset, self.write_item_offset, len(self.write_buffer))) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index ddead813e..3e8754e68 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -58,6 +58,9 @@ cdef extern from "ray/core_worker/object_interface.h" nogil: CRayStatus SetClientOptions(c_string client_name, int64_t limit) CRayStatus Put(const CRayObject &object, CObjectID *object_id) CRayStatus Put(const CRayObject &object, const CObjectID &object_id) + CRayStatus Create(const shared_ptr[CBuffer] &metadata, + const size_t data_size, CObjectID *object_id, + shared_ptr[CBuffer] *data) CRayStatus Create(const shared_ptr[CBuffer] &metadata, const size_t data_size, const CObjectID &object_id, shared_ptr[CBuffer] *data) @@ -81,8 +84,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus ( CTaskType task_type, const CRayFunction &ray_function, - const CJobID &job_id, - const CActorID &actor_id, const unordered_map[c_string, double] &resources, const c_vector[shared_ptr[CRayObject]] &args, const c_vector[CObjectID] &arg_reference_ids, @@ -113,12 +114,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: # TODO(edoakes): remove this once the raylet client is no longer used # directly. CRayletClient &GetRayletClient() - # TODO(edoakes): remove these once the Python core worker uses the task - # interfaces CJobID GetCurrentJobId() - void SetCurrentJobId(const CJobID &job_id) CTaskID GetCurrentTaskId() - void SetCurrentTaskId(const CTaskID &task_id) const CActorID &GetActorId() CTaskID GetCallerId() const ResourceMappingType &GetResourceIDs() const diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 6d515efa8..7672cbbd8 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -2281,7 +2281,7 @@ def test_actor_reconstruction_without_task(ray_start_regular): # put a new object in plasma store. global_worker = ray.worker.global_worker if not global_worker.core_worker.object_exists(obj_id): - global_worker.put_object(obj_id, 1) + global_worker.put_object(1, obj_id) break def get_pid(self): diff --git a/python/ray/tests/test_component_failures.py b/python/ray/tests/test_component_failures.py index b7e5e0a62..b71b30bfb 100644 --- a/python/ray/tests/test_component_failures.py +++ b/python/ray/tests/test_component_failures.py @@ -59,7 +59,7 @@ def test_dying_worker_get(ray_start_2_cpus): assert len(ready_ids) == 0 # Seal the object so the store attempts to notify the worker that the # get has been fulfilled. - ray.worker.global_worker.put_object(x_id, 1) + ray.worker.global_worker.put_object(1, x_id) time.sleep(0.1) # Make sure that nothing has died. @@ -102,7 +102,7 @@ ray.get(ray.ObjectID(ray.utils.hex_to_binary("{}"))) assert len(ready_ids) == 0 # Seal the object so the store attempts to notify the worker that the # get has been fulfilled. - ray.worker.global_worker.put_object(x_id, 1) + ray.worker.global_worker.put_object(1, x_id) time.sleep(0.1) # Make sure that nothing has died. @@ -142,7 +142,7 @@ def test_dying_worker_wait(ray_start_2_cpus): time.sleep(0.1) # Create the object. - ray.worker.global_worker.put_object(x_id, 1) + ray.worker.global_worker.put_object(1, x_id) time.sleep(0.1) # Make sure that nothing has died. @@ -185,7 +185,7 @@ ray.wait([ray.ObjectID(ray.utils.hex_to_binary("{}"))]) assert len(ready_ids) == 0 # Seal the object so the store attempts to notify the worker that the # wait can return. - ray.worker.global_worker.put_object(x_id, 1) + ray.worker.global_worker.put_object(1, x_id) time.sleep(0.1) # Make sure that nothing has died. diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py index 54cf186a3..e3c73189b 100644 --- a/python/ray/tests/test_dynres.py +++ b/python/ray/tests/test_dynres.py @@ -331,7 +331,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster): @ray.remote def wait_func(running_oid, wait_oid): # Signal that the task is running - ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) # Make the task wait till signalled by driver ray.get(ray.ObjectID(wait_oid)) @@ -351,7 +351,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster): ray.get(set_res.remote(res_name, updated_capacity, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) ray.get(task) # Check if scheduler state is consistent by launching a task requiring @@ -410,7 +410,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): @ray.remote def wait_func(running_oid, wait_oid): # Signal that the task is running - ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) # Make the task wait till signalled by driver ray.get(ray.ObjectID(wait_oid)) @@ -430,7 +430,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): ray.get(set_res.remote(res_name, updated_capacity, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) ray.get(task) # Check if scheduler state is consistent by launching a task requiring @@ -492,7 +492,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster): @ray.remote def wait_func(running_oid, wait_oid): # Signal that the task is running - ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(running_oid)) # Make the task wait till signalled by driver ray.get(ray.ObjectID(wait_oid)) @@ -512,7 +512,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster): ray.get(delete_res.remote(res_name, target_node_id)) # Signal task to complete - ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.worker.global_worker.put_object(1, ray.ObjectID(WAIT_OBJECT_ID_STR)) ray.get(task) # Check if scheduler state is consistent by launching a task requiring diff --git a/python/ray/tests/test_memory_limits.py b/python/ray/tests/test_memory_limits.py index 75d40b661..f4a32a37f 100644 --- a/python/ray/tests/test_memory_limits.py +++ b/python/ray/tests/test_memory_limits.py @@ -76,7 +76,8 @@ class TestMemoryLimits(unittest.TestCase): print("Raised exception", type(e), e) raise e finally: - print(ray.worker.global_worker.dump_object_store_memory_usage()) + print(ray.worker.global_worker.core_worker. + dump_object_store_memory_usage()) ray.shutdown() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index e2695f158..8324d4ab2 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -8,7 +8,6 @@ import subprocess import time import ray -from ray.utils import _random_string from ray.tests.utils import ( RayTestTimeoutException, run_string_as_driver, @@ -556,14 +555,12 @@ print("success") # Create some drivers and let them exit and make sure everything is # still alive. for _ in range(3): - nonexistent_id_bytes = _random_string() - nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) + nonexistent_id = ray.ObjectID.from_random() driver_script = driver_script_template.format(address, - nonexistent_id_hex) + nonexistent_id.hex()) out = run_string_as_driver(driver_script) # Simulate the nonexistent dependency becoming available. - ray.worker.global_worker.put_object( - ray.ObjectID(nonexistent_id_bytes), None) + ray.worker.global_worker.put_object(None, nonexistent_id) # Make sure the first driver ran to completion. assert "success" in out @@ -583,14 +580,12 @@ print("success") # Create some drivers and let them exit and make sure everything is # still alive. for _ in range(3): - nonexistent_id_bytes = _random_string() - nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) + nonexistent_id = ray.ObjectID.from_random() driver_script = driver_script_template.format(address, - nonexistent_id_hex) + nonexistent_id.hex()) out = run_string_as_driver(driver_script) # Simulate the nonexistent dependency becoming available. - ray.worker.global_worker.put_object( - ray.ObjectID(nonexistent_id_bytes), None) + ray.worker.global_worker.put_object(None, nonexistent_id) # Make sure the first driver ran to completion. assert "success" in out diff --git a/python/ray/worker.py b/python/ray/worker.py index c87954d8a..bc7b3edfc 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -42,7 +42,6 @@ from ray import ( ActorID, JobID, ObjectID, - TaskID, ) from ray import import_thread from ray import profiling @@ -141,10 +140,6 @@ class Worker(object): # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} self.function_actor_manager = FunctionActorManager(self) - # Identity of the job that this worker is processing. - # It is a JobID. - self.current_job_id = JobID.nil() - self._task_context = threading.local() # This event is checked regularly by all of the threads so that they # know when to exit. self.threads_stopped = threading.Event() @@ -175,46 +170,20 @@ class Worker(object): return self.node.use_pickle @property - def task_context(self): - """A thread-local that contains the following attributes. + def current_job_id(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_current_job_id() + return JobID.nil() - current_task_id: For the main thread, this field is the ID of this - worker's current running task; for other threads, this field is a - fake random ID. - task_index: The number of tasks that have been submitted from the - current task. - put_index: The number of objects that have been put from the current - task. - """ - if not hasattr(self._task_context, "initialized"): - # Initialize task_context for the current thread. - if ray.utils.is_main_thread(): - # If this is running on the main thread, initialize it to - # NIL. The actual value will set when the worker receives - # a task from raylet backend. - self._task_context.current_task_id = TaskID.nil() - else: - # If this is running on a separate thread, then the mapping - # to the current task ID may not be correct. Generate a - # random task ID so that the backend can differentiate - # between different threads. - self._task_context.current_task_id = TaskID.for_fake_task() - if getattr(self, "_multithreading_warned", False) is not True: - logger.warning( - "Calling ray.get or ray.wait in a separate thread " - "may lead to deadlock if the main thread blocks on " - "this thread and there are not enough resources to " - "execute more tasks") - self._multithreading_warned = True - - self._task_context.task_index = 0 - self._task_context.put_index = 1 - self._task_context.initialized = True - return self._task_context + @property + def actor_id(self): + if hasattr(self, "core_worker"): + return self.core_worker.get_actor_id() + return ActorID.nil() @property def current_task_id(self): - return self.task_context.current_task_id + return self.core_worker.get_current_task_id() @property def current_session_and_job(self): @@ -283,19 +252,111 @@ class Worker(object): """ self.mode = mode - def store_and_register(self, object_id, value, depth=100): + def put_object(self, value, object_id=None): + """Put value in the local object store with object id `objectid`. + + This assumes that the value for `objectid` has not yet been placed in + the local object store. If the plasma store is full, the worker will + automatically retry up to DEFAULT_PUT_OBJECT_RETRIES times. Each + retry will delay for an exponentially doubling amount of time, + starting with DEFAULT_PUT_OBJECT_DELAY. After this, exception + will be raised. + + Args: + value: The value to put in the object store. + object_id (object_id.ObjectID): The object ID of the value to be + put. If None, one will be generated. + + Returns: + object_id.ObjectID: The object ID the object was put under. + + Raises: + ray.exceptions.ObjectStoreFullError: This is raised if the attempt + to store the object fails because the object store is full even + after multiple retries. + """ + # Make sure that the value is not an object ID. + if isinstance(value, ObjectID): + raise TypeError( + "Calling 'put' on an ray.ObjectID is not allowed " + "(similarly, returning an ray.ObjectID from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ray.ObjectID in a list and " + "call 'put' on it (or return it).") + + if isinstance(value, bytes): + # If the object is a byte array, skip serializing it and + # use a special metadata to indicate it's raw binary. So + # that this object can also be read by Java. + return self.core_worker.put_raw_buffer( + value, + object_id=object_id, + memcopy_threads=self.memcopy_threads) + + if self.use_pickle: + return self._serialize_and_put_pickle5(value, object_id=object_id) + else: + return self._serialize_and_put_pyarrow(value, object_id=object_id) + + def _serialize_and_put_pickle5(self, value, object_id=None): + """Serialize an object using pickle5 and store it in the object store. + + Args: + value: The value to put in the object store. + object_id: The ID of the object to store. If none, one will be + generated. + + Raises: + Exception: An exception is raised if the attempt to store the + object fails. This can happen if the object store is full. + """ + writer = Pickle5Writer() + if ray.cloudpickle.FAST_CLOUDPICKLE_USED: + inband = pickle.dumps( + value, protocol=5, buffer_callback=writer.buffer_callback) + else: + inband = pickle.dumps(value) + return self.core_worker.put_pickle5_buffers( + inband, + writer, + object_id=object_id, + memcopy_threads=self.memcopy_threads) + + def _serialize_and_put_pyarrow(self, value, object_id=None): + """Wraps `store_and_register` with cases for existence and pickling. + + Args: + object_id (object_id.ObjectID): The object ID of the value to be + put. + value: The value to put in the object store. + """ + try: + serialized_value = self._serialize_with_pyarrow(value) + except TypeError: + # TypeError can happen because one of the members of the object + # may not be serializable for cloudpickle. So we need + # these extra fallbacks here to start from the beginning. + # Hopefully the object could have a `__reduce__` method. + _register_custom_serializer(type(value), use_pickle=True) + logger.warning("WARNING: Serializing the class {} failed, " + "falling back to cloudpickle.".format(type(value))) + serialized_value = self._serialize_with_pyarrow(value) + + return self.core_worker.put_serialized_object( + serialized_value, + object_id=object_id, + memcopy_threads=self.memcopy_threads) + + def _serialize_with_pyarrow(self, value, depth=100): """Store an object and attempt to register its class if needed. Args: - object_id: The ID of the object to store. value: The value to put in the object store. depth: The maximum number of classes to recursively register. Raises: - Exception: An exception is raised if the attempt to store the - object fails. This can happen if there is already an object - with the same ID in the object store or if the object store is - full. + Exception: An exception is raised if the attempt to serialize the + object fails. """ counter = 0 while True: @@ -306,20 +367,9 @@ class Worker(object): "type {}.".format(type(value))) counter += 1 try: - if isinstance(value, bytes): - # If the object is a byte array, skip serializing it and - # use a special metadata to indicate it's raw binary. So - # that this object can also be read by Java. - self.core_worker.put_raw_buffer( - value, object_id, memcopy_threads=self.memcopy_threads) - else: - serialization_context = self.get_serialization_context( - self.current_job_id) - self.core_worker.put_serialized_object( - pyarrow.serialize(value, serialization_context), - object_id, - memcopy_threads=self.memcopy_threads) - break + serialization_context = self.get_serialization_context( + self.current_job_id) + return pyarrow.serialize(value, serialization_context) except pyarrow.SerializationCallbackError as e: cls_type = type(e.example_object) try: @@ -352,121 +402,6 @@ class Worker(object): "locally.".format(cls_type)) logger.warning(warning_message) - def put_object(self, object_id, value): - """Put value in the local object store with object id `objectid`. - - This assumes that the value for `objectid` has not yet been placed in - the local object store. If the plasma store is full, the worker will - automatically retry up to DEFAULT_PUT_OBJECT_RETRIES times. Each - retry will delay for an exponentially doubling amount of time, - starting with DEFAULT_PUT_OBJECT_DELAY. After this, exception - will be raised. - - Args: - object_id (object_id.ObjectID): The object ID of the value to be - put. - value: The value to put in the object store. - - Raises: - ray.exceptions.ObjectStoreFullError: This is raised if the attempt - to store the object fails because the object store is full even - after multiple retries. - """ - # Make sure that the value is not an object ID. - if isinstance(value, ObjectID): - raise TypeError( - "Calling 'put' on an ray.ObjectID is not allowed " - "(similarly, returning an ray.ObjectID from a remote " - "function is not allowed). If you really want to " - "do this, you can wrap the ray.ObjectID in a list and " - "call 'put' on it (or return it).") - - delay = ray_constants.DEFAULT_PUT_OBJECT_DELAY - for attempt in reversed( - range(ray_constants.DEFAULT_PUT_OBJECT_RETRIES)): - try: - if self.use_pickle: - self.store_with_plasma(object_id, value) - else: - self._try_store_and_register(object_id, value) - break - except ObjectStoreFullError as e: - if attempt: - logger.warning("Waiting {} seconds for space to free up " - "in the object store.".format(delay)) - time.sleep(delay) - delay *= 2 - else: - self.dump_object_store_memory_usage() - raise e - - def dump_object_store_memory_usage(self): - """Prints object store debug string to stdout.""" - logger.warning("Local object store memory usage:\n{}\n".format( - self.core_worker.object_store_memory_usage_string())) - - def store_with_plasma(self, object_id, value): - """Serialize and store an object. - - Args: - object_id: The ID of the object to store. - value: The value to put in the object store. - - Raises: - Exception: An exception is raised if the attempt to store the - object fails. This can happen if there is already an object - with the same ID in the object store or if the object store is - full. - """ - try: - if isinstance(value, bytes): - # If the object is a byte array, skip serializing it and - # use a special metadata to indicate it's raw binary. So - # that this object can also be read by Java. - self.core_worker.put_raw_buffer( - value, object_id, memcopy_threads=self.memcopy_threads) - else: - writer = Pickle5Writer() - if ray.cloudpickle.FAST_CLOUDPICKLE_USED: - inband = pickle.dumps( - value, - protocol=5, - buffer_callback=writer.buffer_callback) - else: - inband = pickle.dumps(value) - self.core_worker.put_pickle5_buffers(object_id, inband, writer, - self.memcopy_threads) - except pyarrow.plasma.PlasmaObjectExists: - # The object already exists in the object store, so there is no - # need to add it again. TODO(rkn): We need to compare hashes - # and make sure that the objects are in fact the same. We also - # should return an error code to caller instead of printing a - # message. - logger.info("The object with ID {} already exists " - "in the object store.".format(object_id)) - - def _try_store_and_register(self, object_id, value): - """Wraps `store_and_register` with cases for existence and pickling. - - Args: - object_id (object_id.ObjectID): The object ID of the value to be - put. - value: The value to put in the object store. - """ - try: - self.store_and_register(object_id, value) - except TypeError: - # TypeError can happen because one of the members of the object - # may not be serializable for cloudpickle. So we need - # these extra fallbacks here to start from the beginning. - # Hopefully the object could have a `__reduce__` method. - _register_custom_serializer(type(value), use_pickle=True) - warning_message = ("WARNING: Serializing the class {} failed, " - "falling back to cloudpickle.".format( - type(value))) - logger.warning(warning_message) - self.store_and_register(object_id, value) - def deserialize_objects(self, data_metadata_pairs, object_ids, @@ -674,22 +609,6 @@ class Worker(object): return ray.signature.recover_args(arguments) - def _set_object_store_client_options(self, name, object_store_memory): - try: - logger.debug("Setting plasma memory limit to {} for {}".format( - object_store_memory, name)) - self.core_worker.set_object_store_client_options( - name.encode("ascii"), object_store_memory) - except RayError as e: - self.dump_object_store_memory_usage() - raise memory_monitor.RayOutOfMemoryError( - "Failed to set object_store_memory={} for {}. The " - "plasma store may have insufficient memory remaining " - "to satisfy this limit (30% of object store memory is " - "permanently reserved for shared usage). The current " - "object store memory status is:\n\n{}".format( - object_store_memory, name, e)) - def main_loop(self): """The main loop a worker runs to receive and execute tasks.""" @@ -1461,11 +1380,9 @@ def connect(node, if not isinstance(job_id, JobID): raise TypeError("The type of given job id must be JobID.") - worker.current_job_id = job_id # All workers start out as non-actors. A worker can be turned into an actor # after it is created. - worker.actor_id = ActorID.nil() worker.node = node worker.set_mode(mode) @@ -1560,24 +1477,22 @@ def connect(node, (mode == SCRIPT_MODE), node.plasma_store_socket_name, node.raylet_socket_name, - worker.current_job_id, + job_id, gcs_options, node.get_logs_dir_path(), node.node_ip_address, ) - worker.task_context.current_task_id = ( - worker.core_worker.get_current_task_id()) worker.raylet_client = ray._raylet.RayletClient(worker.core_worker) if driver_object_store_memory is not None: - worker._set_object_store_client_options( + worker.core_worker.set_object_store_client_options( "ray_driver_{}".format(os.getpid()), driver_object_store_memory) # Put something in the plasma store so that subsequent plasma store # accesses will be faster. Currently the first access is always slow, and # we don't want the user to experience this. temporary_object_id = ray.ObjectID(np.random.bytes(20)) - worker.put_object(temporary_object_id, 1) + worker.put_object(1, object_id=temporary_object_id) ray.internal.free([temporary_object_id]) # Start the import thread @@ -1944,7 +1859,7 @@ def get(object_ids): if isinstance(value, RayError): last_task_error_raise_time = time.time() if isinstance(value, ray.exceptions.UnreconstructableError): - worker.dump_object_store_memory_usage() + worker.core_worker.dump_object_store_memory_usage() if isinstance(value, RayTaskError): raise value.as_instanceof_cause() else: @@ -1981,12 +1896,8 @@ def put(value, weakref=False): if worker.mode == LOCAL_MODE: object_id = worker.local_mode_manager.put_object(value) else: - object_id = ray._raylet.compute_put_id( - worker.current_task_id, - worker.task_context.put_index, - ) try: - worker.put_object(object_id, value) + object_id = worker.put_object(value) except ObjectStoreFullError: logger.info( "Put failed since the value was either too large or the " @@ -1995,7 +1906,6 @@ def put(value, weakref=False): "ray.put(value, weakref=True) to allow object data to " "be evicted early.") raise - worker.task_context.put_index += 1 # Pin the object buffer with the returned id. This avoids put returns # from getting evicted out from under the id. # TODO(edoakes): we should be able to avoid this extra IPC by holding diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 863dc8c37..6d3ab7e8a 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -18,17 +18,22 @@ struct WorkerThreadContext { return current_task_; } - void SetCurrentTaskId(const TaskID &task_id) { - current_task_id_ = task_id; - task_index_ = 0; - put_index_ = 0; - } + void SetCurrentTaskId(const TaskID &task_id) { current_task_id_ = task_id; } void SetCurrentTask(const TaskSpecification &task_spec) { + RAY_CHECK(current_task_id_.IsNil()); + RAY_CHECK(task_index_ == 0); + RAY_CHECK(put_index_ == 0); SetCurrentTaskId(task_spec.TaskId()); current_task_ = std::make_shared(task_spec); } + void ResetCurrentTask(const TaskSpecification &task_spec) { + SetCurrentTaskId(TaskID::Nil()); + task_index_ = 0; + put_index_ = 0; + } + private: /// The task ID for current task. TaskID current_task_id_; @@ -55,9 +60,9 @@ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. // (For other threads it's set to random ID via GetThreadContext). - GetThreadContext().SetCurrentTaskId((worker_type_ == WorkerType::DRIVER) - ? TaskID::ForDriverTask(job_id) - : TaskID::Nil()); + GetThreadContext(true).SetCurrentTaskId((worker_type_ == WorkerType::DRIVER) + ? TaskID::ForDriverTask(job_id) + : TaskID::Nil()); } const WorkerType WorkerContext::GetWorkerType() const { return worker_type_; } @@ -74,26 +79,38 @@ const TaskID &WorkerContext::GetCurrentTaskID() const { return GetThreadContext().GetCurrentTaskID(); } -// TODO(edoakes): remove this once Python core worker uses the task interfaces. void WorkerContext::SetCurrentJobId(const JobID &job_id) { current_job_id_ = job_id; } -// TODO(edoakes): remove this once Python core worker uses the task interfaces. void WorkerContext::SetCurrentTaskId(const TaskID &task_id) { GetThreadContext().SetCurrentTaskId(task_id); } void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { - SetCurrentJobId(task_spec.JobId()); GetThreadContext().SetCurrentTask(task_spec); - if (task_spec.IsActorCreationTask()) { + if (task_spec.IsNormalTask()) { + RAY_CHECK(current_job_id_.IsNil()); + SetCurrentJobId(task_spec.JobId()); + } else if (task_spec.IsActorCreationTask()) { + RAY_CHECK(current_job_id_.IsNil()); + SetCurrentJobId(task_spec.JobId()); RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); current_actor_use_direct_call_ = task_spec.IsDirectCall(); - } - if (task_spec.IsActorTask()) { + } else if (task_spec.IsActorTask()) { + RAY_CHECK(current_job_id_ == task_spec.JobId()); RAY_CHECK(current_actor_id_ == task_spec.ActorId()); + } else { + RAY_CHECK(false); } } + +void WorkerContext::ResetCurrentTask(const TaskSpecification &task_spec) { + GetThreadContext().ResetCurrentTask(task_spec); + if (task_spec.IsNormalTask()) { + SetCurrentJobId(JobID::Nil()); + } +} + std::shared_ptr WorkerContext::GetCurrentTask() const { return GetThreadContext().GetCurrentTask(); } @@ -104,9 +121,21 @@ bool WorkerContext::CurrentActorUseDirectCall() const { return current_actor_use_direct_call_; } -WorkerThreadContext &WorkerContext::GetThreadContext() { +WorkerThreadContext &WorkerContext::GetThreadContext(bool for_main_thread) { + // Flag used to ensure that we only print a warning about multithreading once per + // process. + static bool multithreading_warning_printed = false; + if (thread_context_ == nullptr) { thread_context_ = std::unique_ptr(new WorkerThreadContext()); + if (!for_main_thread && !multithreading_warning_printed) { + std::cout << "WARNING: " + << "Calling ray.get or ray.wait in a separate thread " + << "may lead to deadlock if the main thread blocks on " + << "this thread and there are not enough resources to " + << "execute more tasks." << std::endl; + multithreading_warning_printed = true; + } } return *thread_context_; diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 19cb3b81f..aadd1e447 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -28,6 +28,8 @@ class WorkerContext { void SetCurrentTask(const TaskSpecification &task_spec); + void ResetCurrentTask(const TaskSpecification &task_spec); + std::shared_ptr GetCurrentTask() const; const ActorID &GetCurrentActorID() const; @@ -46,7 +48,7 @@ class WorkerContext { bool current_actor_use_direct_call_; private: - static WorkerThreadContext &GetThreadContext(); + static WorkerThreadContext &GetThreadContext(bool for_main_thread = false); /// Per-thread worker context. static thread_local std::unique_ptr thread_context_; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index a7d5b9490..f8ef88a65 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -81,18 +81,12 @@ class CoreWorker { // Get the resource IDs available to this worker (as assigned by the raylet). const ResourceMappingType GetResourceIDs() const; - // TODO(edoakes): remove this once Python core worker uses the task interfaces. const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); } - // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentTaskId(const TaskID &task_id); - // TODO(edoakes): remove this once Python core worker uses the task interfaces. const JobID &GetCurrentJobId() const { return worker_context_.GetCurrentJobID(); } - // TODO(edoakes): remove this once Python core worker uses the task interfaces. - void SetCurrentJobId(const JobID &job_id) { worker_context_.SetCurrentJobId(job_id); } - void SetActorId(const ActorID &actor_id) { RAY_CHECK(actor_id_.IsNil()); actor_id_ = actor_id; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index a530f59e8..5465b0e97 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -39,7 +39,6 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork auto task_execution_callback = [](ray::TaskType task_type, const ray::RayFunction &ray_function, - const JobID &job_id, const ActorID &actor_id, const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_reference_ids, diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index 586e67393..e5582442d 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -56,11 +56,10 @@ Status CoreWorkerObjectInterface::SetClientOptions(std::string name, } Status CoreWorkerObjectInterface::Put(const RayObject &object, ObjectID *object_id) { - ObjectID put_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), - worker_context_.GetNextPutIndex(), - static_cast(TaskTransportType::RAYLET)); - *object_id = put_id; - return Put(object, put_id); + *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), + worker_context_.GetNextPutIndex(), + static_cast(TaskTransportType::RAYLET)); + return Put(object, *object_id); } Status CoreWorkerObjectInterface::Put(const RayObject &object, @@ -71,6 +70,15 @@ Status CoreWorkerObjectInterface::Put(const RayObject &object, return store_providers_[StoreProviderType::PLASMA]->Put(object, object_id); } +Status CoreWorkerObjectInterface::Create(const std::shared_ptr &metadata, + const size_t data_size, ObjectID *object_id, + std::shared_ptr *data) { + *object_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(), + worker_context_.GetNextPutIndex(), + static_cast(TaskTransportType::RAYLET)); + return Create(metadata, data_size, *object_id, data); +} + Status CoreWorkerObjectInterface::Create(const std::shared_ptr &metadata, const size_t data_size, const ObjectID &object_id, diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 050400357..da5900810 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -48,6 +48,19 @@ class CoreWorkerObjectInterface { /// \return Status. Status Put(const RayObject &object, const ObjectID &object_id); + /// Create and return a buffer in the object store that can be directly written + /// into. After writing to the buffer, the caller must call `Seal()` to finalize + /// the object. The `Create()` and `Seal()` combination is an alternative interface + /// to `Put()` that allows frontends to avoid an extra copy when possible. + /// + /// \param[in] metadata Metadata of the object to be written. + /// \param[in] data_size Size of the object to be written. + /// \param[out] object_id Object ID generated for the put. + /// \param[out] data Buffer for the user to write the object into. + /// \return Status. + Status Create(const std::shared_ptr &metadata, const size_t data_size, + ObjectID *object_id, std::shared_ptr *data); + /// Create and return a buffer in the object store that can be directly written /// into. After writing to the buffer, the caller must call `Seal()` to finalize /// the object. The `Create()` and `Seal()` combination is an alternative interface diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index 5e027f3c3..bf2bee181 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -62,24 +62,24 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask( } Status status; - ActorID actor_id = ActorID::Nil(); TaskType task_type = TaskType::NORMAL_TASK; if (task_spec.IsActorCreationTask()) { RAY_CHECK(return_ids.size() > 0); return_ids.pop_back(); - actor_id = task_spec.ActorCreationId(); task_type = TaskType::ACTOR_CREATION_TASK; - core_worker_.SetActorId(actor_id); + core_worker_.SetActorId(task_spec.ActorCreationId()); } else if (task_spec.IsActorTask()) { RAY_CHECK(return_ids.size() > 0); return_ids.pop_back(); - actor_id = task_spec.ActorId(); task_type = TaskType::ACTOR_TASK; } - status = task_execution_callback_(task_type, func, task_spec.JobId(), actor_id, + status = task_execution_callback_(task_type, func, task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids, return_ids, results); + core_worker_.SetCurrentTaskId(TaskID::Nil()); + worker_context_.ResetCurrentTask(task_spec); + // TODO(zhijunfu): // 1. Check and handle failure. // 2. Save or load checkpoint. diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index 20e39c405..ad07fdd61 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -27,8 +27,7 @@ class CoreWorkerTaskExecutionInterface { // Callback that must be implemented and provided by the language-specific worker // frontend to execute tasks and return their results. using TaskExecutionCallback = std::function &required_resources, const std::vector> &args, const std::vector &arg_reference_ids, diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index f60a8e221..a88593b0d 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -25,8 +25,7 @@ class MockWorker { : worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, JobID::FromInt(1), gcs_options, /*log_dir=*/"", /*node_id_address=*/"127.0.0.1", - std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, - _9)) {} + std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {} void Run() { // Start executing tasks. @@ -35,7 +34,6 @@ class MockWorker { private: Status ExecuteTask(TaskType task_type, const RayFunction &ray_function, - const JobID &job_id, const ActorID &actor_id, const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_reference_ids,