diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 85b02de62..e9db97ba5 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -38,7 +38,6 @@ cdef class BaseID: cdef class ObjectID(BaseID): cdef: CObjectID data - object buffer_ref # Flag indicating whether or not this object ID was added to the set # of active IDs in the core worker so we know whether we should clean # it up. diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 86d78dee9..c602fda9a 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -844,7 +844,8 @@ cdef class CoreWorker: if object_id is None: with nogil: check_status(self.core_worker.get().Create( - metadata, data_size, c_object_id, data)) + metadata, data_size, + c_object_id, data)) else: c_object_id[0] = object_id.native() with nogil: @@ -869,20 +870,29 @@ cdef class CoreWorker: return data.get() == NULL def put_serialized_object(self, serialized_object, - ObjectID object_id=None): + ObjectID object_id=None, + c_bool pin_object=True): cdef: CObjectID c_object_id shared_ptr[CBuffer] data shared_ptr[CBuffer] metadata + # The object won't be pinned if an ObjectID is provided by the + # user (because we can't track its lifetime to unpin). Note that + # the API to do this isn't supported as a public API. + c_bool owns_object = object_id is None + metadata = string_to_buffer(serialized_object.metadata) total_bytes = serialized_object.total_bytes object_already_exists = self._create_put_buffer( - metadata, total_bytes, object_id, &c_object_id, &data) + metadata, total_bytes, object_id, + &c_object_id, &data) if not object_already_exists: write_serialized_object(serialized_object, data) with nogil: check_status( - self.core_worker.get().Seal(c_object_id)) + self.core_worker.get().Seal( + c_object_id, owns_object, pin_object)) + return ObjectID(c_object_id.Binary()) def wait(self, object_ids, int num_returns, int64_t timeout_ms, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 136666b65..82c252a51 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -133,12 +133,13 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus Put(const CRayObject &object, CObjectID *object_id) CRayStatus Put(const CRayObject &object, const CObjectID &object_id) CRayStatus Create(const shared_ptr[CBuffer] &metadata, - const size_t data_size, CObjectID *object_id, - shared_ptr[CBuffer] *data) + const size_t data_size, + CObjectID *object_id, shared_ptr[CBuffer] *data) CRayStatus Create(const shared_ptr[CBuffer] &metadata, const size_t data_size, const CObjectID &object_id, shared_ptr[CBuffer] *data) - CRayStatus Seal(const CObjectID &object_id) + CRayStatus Seal(const CObjectID &object_id, c_bool owns_object, + c_bool pin_object) CRayStatus Get(const c_vector[CObjectID] &ids, int64_t timeout_ms, c_vector[shared_ptr[CRayObject]] *results) CRayStatus Contains(const CObjectID &object_id, c_bool *has_object) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 5b596bf44..6a54b7846 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -176,12 +176,6 @@ cdef class ObjectID(BaseID): def task_id(self): return TaskID(self.data.TaskId().Binary()) - def set_buffer_ref(self, ref): - self.buffer_ref = ref - - def get_buffer_ref(self): - return self.buffer_ref - cdef size_t hash(self): return self.data.Hash() diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 6a6fc627a..0cc437ed9 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -5,41 +5,7 @@ from __future__ import print_function import ray.worker from ray import profiling -__all__ = ["free", "pin_object_data"] - - -def pin_object_data(object_id): - """Pin the object data referenced by this object id in memory. - - The object data cannot be evicted while there exists a Python reference to - the object id passed to this function. In order to pin the object, we will - also download the object to the current node (this overhead is unavoidable - for now without a distributed ref counting solution). - - Examples: - >>> x_id = f.remote() - >>> x_id = pin_object_id(x_id) # x pinned, cannot be evicted - >>> del x_id # x can be evicted again - - Note that ray will automatically do this for objects created with - ray.put() already, unless you ray.put with weakref=True. - """ - worker = ray.worker.get_global_worker() - - object_id.set_buffer_ref( - worker.core_worker.get_objects([object_id], worker.current_task_id)) - - -def unpin_object_data(object_id): - """Unpin an object pinned by pin_object_id. - - Examples: - >>> x_id = f.remote() - >>> pin_object_id(x_id) - >>> unpin_object_id(x_id) # as if the pin didn't happen - """ - - object_id.set_buffer_ref(None) +__all__ = ["free"] def free(object_ids, local_only=False, delete_creating_tasks=False): @@ -81,7 +47,6 @@ def free(object_ids, local_only=False, delete_creating_tasks=False): if not isinstance(object_id, ray.ObjectID): raise TypeError("Attempting to call `free` on the value {}, " "which is not an ray.ObjectID.".format(object_id)) - unpin_object_data(object_id) if ray.worker._mode() == ray.worker.LOCAL_MODE: worker.local_mode_manager.free(object_ids) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index c38ed1138..94b94933c 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -188,9 +188,9 @@ py_test( ) py_test( - name = "test_garbage_collection", - size = "small", - srcs = ["test_garbage_collection.py"], + name = "test_reference_counting", + size = "medium", + srcs = ["test_reference_counting.py"], tags = ["exclusive"], deps = ["//:ray_lib"], ) diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index f7d56f735..79c7f5a9e 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -8,6 +8,7 @@ import logging import os import setproctitle import shutil +import json import sys import socket import subprocess @@ -418,7 +419,12 @@ def test_initialized_local_mode(shutdown_only_with_initialization_check): def test_wait_reconstruction(shutdown_only): - ray.init(num_cpus=1, object_store_memory=int(10**8)) + ray.init( + num_cpus=1, + object_store_memory=int(10**8), + _internal_config=json.dumps({ + "object_pinning_enabled": 0 + })) @ray.remote def f(): @@ -577,21 +583,21 @@ def test_shutdown_disconnect_global_state(): "ray_start_object_store_memory", [150 * 1024 * 1024], indirect=True) def test_put_pins_object(ray_start_object_store_memory): x_id = ray.put("HI") - x_copy = ray.ObjectID(x_id.binary()) - assert ray.get(x_copy) == "HI" + x_binary = x_id.binary() + assert ray.get(ray.ObjectID(x_binary)) == "HI" # x cannot be evicted since x_id pins it for _ in range(10): ray.put(np.zeros(10 * 1024 * 1024)) assert ray.get(x_id) == "HI" - assert ray.get(x_copy) == "HI" + assert ray.get(ray.ObjectID(x_binary)) == "HI" - # now it can be evicted since x_id pins it but x_copy does not + # now it can be evicted since x_id pins it but x_binary does not del x_id for _ in range(10): ray.put(np.zeros(10 * 1024 * 1024)) with pytest.raises(ray.exceptions.UnreconstructableError): - ray.get(x_copy) + ray.get(ray.ObjectID(x_binary)) # weakref put y_id = ray.put("HI", weakref=True) @@ -600,14 +606,6 @@ def test_put_pins_object(ray_start_object_store_memory): with pytest.raises(ray.exceptions.UnreconstructableError): ray.get(y_id) - @ray.remote - def check_no_buffer_ref(x): - assert x[0].get_buffer_ref() is None - - z_id = ray.put("HI") - assert z_id.get_buffer_ref() is not None - ray.get(check_no_buffer_ref.remote([z_id])) - @pytest.mark.parametrize( "ray_start_object_store_memory", [150 * 1024 * 1024], indirect=True) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 0669e0a98..b6d08f308 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -936,11 +936,10 @@ def test_direct_call_serialized_id_eviction(ray_start_cluster): @ray.remote def get(obj_ids): - print("get", obj_ids) obj_id = obj_ids[0] assert (isinstance(ray.get(obj_id), np.ndarray)) - # Evict the object. - ray.internal.free(obj_ids) + # Wait for the object to be evicted. + ray.internal.free(obj_id) while ray.worker.global_worker.core_worker.object_exists(obj_id): time.sleep(1) with pytest.raises(ray.exceptions.UnreconstructableError): @@ -948,7 +947,9 @@ def test_direct_call_serialized_id_eviction(ray_start_cluster): print("get done", obj_ids) obj = large_object.remote() - ray.get(get.remote([obj])) + result = get.remote([obj]) + ray.internal.free(obj) + ray.get(result) @pytest.mark.parametrize( diff --git a/python/ray/tests/test_garbage_collection.py b/python/ray/tests/test_garbage_collection.py deleted file mode 100644 index b6ce12501..000000000 --- a/python/ray/tests/test_garbage_collection.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding: utf-8 -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import numpy as np -import time -import logging -import pytest - -import ray -import ray.cluster_utils -import ray.test_utils - -logger = logging.getLogger(__name__) - - -def test_basic_gc(shutdown_only): - ray.init( - object_store_memory=100 * 1024 * 1024, - use_pickle=True, - _internal_config=json.dumps({ - "worker_heartbeat_timeout_milliseconds": 500, - "raylet_max_active_object_ids": 1000 - })) - - @ray.remote - def shuffle(input): - return np.random.shuffle(input) - - @ray.remote - class Actor: - def __init__(self): - # Hold a long-lived reference to a ray.put object. This should not - # be garbage collected while the actor is alive. - self.large_object = ray.put( - np.zeros(25 * 1024 * 1024, dtype=np.uint8), weakref=True) - - def get_large_object(self): - return ray.get(self.large_object) - - actor = Actor.remote() - - # Fill up the object store with short-lived objects. These should be - # evicted before the long-lived object whose reference is held by - # the actor. - for batch in range(10): - intermediate_result = shuffle.remote( - np.zeros(10 * 1024 * 1024, dtype=np.uint8)) - ray.get(intermediate_result) - - # The ray.get below would fail with only LRU eviction, as the object - # that was ray.put by the actor would have been evicted. - ray.get(actor.get_large_object.remote()) - - -@pytest.mark.skip(reason="This test currently fails on Travis.") -def test_pending_task_dependency(shutdown_only): - ray.init(object_store_memory=100 * 1024 * 1024, use_pickle=True) - - @ray.remote - def pending(input1, input2): - return - - @ray.remote - def slow(): - time.sleep(5) - - # The object that is ray.put here will go out of scope immediately, so if - # pending task dependencies aren't considered, it will be evicted before - # the ray.get below due to the subsequent ray.puts that fill up the object - # store. - np_array = np.zeros(40 * 1024 * 1024, dtype=np.uint8) - oid = pending.remote(ray.put(np_array), slow.remote()) - - for _ in range(2): - ray.put(np_array) - - ray.get(oid) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_memory_limits.py b/python/ray/tests/test_memory_limits.py index 03a1c3861..091f1b4bd 100644 --- a/python/ray/tests/test_memory_limits.py +++ b/python/ray/tests/test_memory_limits.py @@ -15,7 +15,7 @@ class LightActor: pass def sample(self): - return np.zeros(1 * MB, dtype=np.uint8) + return np.zeros(5 * MB, dtype=np.uint8) @ray.remote @@ -29,9 +29,8 @@ class GreedyActor: class TestMemoryLimits(unittest.TestCase): def testWithoutQuota(self): + self._run(100 * MB, None, None) self.assertRaises(OBJECT_EVICTED, lambda: self._run(None, None, None)) - self.assertRaises(OBJECT_EVICTED, - lambda: self._run(100 * MB, None, None)) self.assertRaises(OBJECT_EVICTED, lambda: self._run(None, 100 * MB, None)) @@ -66,14 +65,11 @@ class TestMemoryLimits(unittest.TestCase): z = ray.put("hi", weakref=True) a = LightActor._remote(object_store_memory=a_quota) b = GreedyActor._remote(object_store_memory=b_quota) - oids = [z] for _ in range(5): r_a = a.sample.remote() for _ in range(20): new_oid = b.sample.remote() - oids.append(new_oid) ray.get(new_oid) - oids.append(r_a) ray.get(r_a) ray.get(z) except Exception as e: diff --git a/python/ray/tests/test_reference_counting.py b/python/ray/tests/test_reference_counting.py index b5367afdc..c8258848d 100644 --- a/python/ray/tests/test_reference_counting.py +++ b/python/ray/tests/test_reference_counting.py @@ -4,10 +4,12 @@ from __future__ import division from __future__ import print_function import os +import json import copy import tempfile import numpy as np import time +import pytest import logging import uuid @@ -27,7 +29,7 @@ def _check_refcounts(expected): assert submitted == actual[object_id]["submitted"] -def check_refcounts(expected, timeout=1): +def check_refcounts(expected, timeout=10): start = time.time() while True: try: @@ -156,7 +158,99 @@ def test_dependency_refcounts(ray_start_regular): check_refcounts({}) +def test_basic_pinning(shutdown_only): + ray.init(object_store_memory=100 * 1024 * 1024) + + @ray.remote + def f(array): + return np.sum(array) + + @ray.remote + class Actor(object): + def __init__(self): + # Hold a long-lived reference to a ray.put object's ID. The object + # should not be garbage collected while the actor is alive because + # the object is pinned by the raylet. + self.large_object = ray.put( + np.zeros(25 * 1024 * 1024, dtype=np.uint8)) + + def get_large_object(self): + return ray.get(self.large_object) + + actor = Actor.remote() + + # Fill up the object store with short-lived objects. These should be + # evicted before the long-lived object whose reference is held by + # the actor. + for batch in range(10): + intermediate_result = f.remote( + np.zeros(10 * 1024 * 1024, dtype=np.uint8)) + ray.get(intermediate_result) + + # The ray.get below would fail with only LRU eviction, as the object + # that was ray.put by the actor would have been evicted. + ray.get(actor.get_large_object.remote()) + + +def test_pending_task_dependency_pinning(shutdown_only): + ray.init(object_store_memory=100 * 1024 * 1024, use_pickle=True) + + @ray.remote + def pending(input1, input2): + return + + @ray.remote + def slow(dep): + pass + + # The object that is ray.put here will go out of scope immediately, so if + # pending task dependencies aren't considered, it will be evicted before + # the ray.get below due to the subsequent ray.puts that fill up the object + # store. + np_array = np.zeros(40 * 1024 * 1024, dtype=np.uint8) + random_id = ray.ObjectID.from_random() + oid = pending.remote(np_array, slow.remote(random_id)) + + for _ in range(2): + ray.put(np_array) + + ray.worker.global_worker.put_object(None, object_id=random_id) + ray.get(oid) + + +def test_feature_flag(shutdown_only): + ray.init( + object_store_memory=100 * 1024 * 1024, + _internal_config=json.dumps({ + "object_pinning_enabled": 0 + })) + + @ray.remote + def f(array): + return np.sum(array) + + @ray.remote + class Actor(object): + def __init__(self): + self.large_object = ray.put( + np.zeros(25 * 1024 * 1024, dtype=np.uint8)) + + def get_large_object(self): + return ray.get(self.large_object) + + actor = Actor.remote() + + for batch in range(10): + intermediate_result = f.remote( + np.zeros(10 * 1024 * 1024, dtype=np.uint8)) + ray.get(intermediate_result) + + # The ray.get below fails with only LRU eviction, as the object + # that was ray.put by the actor should have been evicted. + with pytest.raises(ray.exceptions.RayTimeoutError): + ray.get(actor.get_large_object.remote(), timeout=1) + + if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7a6382767..07c338127 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -242,7 +242,7 @@ class Worker: """ self.mode = mode - def put_object(self, value, object_id=None): + def put_object(self, value, object_id=None, pin_object=True): """Put value in the local object store with object id `objectid`. This assumes that the value for `objectid` has not yet been placed in @@ -256,6 +256,7 @@ class Worker: value: The value to put in the object store. object_id (object_id.ObjectID): The object ID of the value to be put. If None, one will be generated. + pin_object: If set, the object will be pinned at the raylet. Returns: object_id.ObjectID: The object ID the object was put under. @@ -276,7 +277,7 @@ class Worker: serialized_value = self.get_serialization_context().serialize(value) return self.core_worker.put_serialized_object( - serialized_value, object_id=object_id) + serialized_value, object_id=object_id, pin_object=pin_object) def deserialize_objects(self, data_metadata_pairs, @@ -1519,7 +1520,7 @@ def put(value, weakref=False): object_id = worker.local_mode_manager.put_object(value) else: try: - object_id = worker.put_object(value) + object_id = worker.put_object(value, pin_object=not weakref) except ObjectStoreFullError: logger.info( "Put failed since the value was either too large or the " @@ -1528,16 +1529,6 @@ def put(value, weakref=False): "ray.put(value, weakref=True) to allow object data to " "be evicted early.") raise - # Pin the object buffer with the returned id. This avoids put returns - # from getting evicted out from under the id. - # TODO(edoakes): we should be able to avoid this extra IPC by holding - # a reference to the buffer created when putting the object, but the - # buffer returned by the plasma store create method doesn't prevent - # the object from being evicted. - if not weakref and not worker.mode == LOCAL_MODE: - object_id.set_buffer_ref( - worker.core_worker.get_objects([object_id], - worker.current_task_id)) return object_id diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 8538fb44e..70b9ba667 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -40,6 +40,10 @@ RAY_CONFIG(int64_t, debug_dump_period_milliseconds, 10000) /// type of task from starving other types (see issue #3664). RAY_CONFIG(bool, fair_queueing_enabled, true) +/// Whether to enable object pinning for plasma objects. When this is +/// enabled, objects in scope in the cluster will not be LRU evicted. +RAY_CONFIG(bool, object_pinning_enabled, true) + /// Whether to enable the new scheduler. The new scheduler is designed /// only to work with direct calls. Once direct calls afre becoming /// the default, this scheduler will also become the default. diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index ef72538c3..0ae140469 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -131,9 +131,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, raylet_task_receiver_ = std::unique_ptr(new CoreWorkerRayletTaskReceiver( worker_context_.GetWorkerID(), local_raylet_client_, execute_task, exit)); - direct_task_receiver_ = - std::unique_ptr(new CoreWorkerDirectTaskReceiver( - worker_context_, task_execution_service_, execute_task, exit)); + direct_task_receiver_ = std::unique_ptr( + new CoreWorkerDirectTaskReceiver(worker_context_, local_raylet_client_, + task_execution_service_, execute_task, exit)); } // Start RPC server after all the task receivers are properly initialized. @@ -149,8 +149,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, node_ip_address, node_manager_port, *client_call_manager_); ClientID local_raylet_id; local_raylet_client_ = std::shared_ptr(new raylet::RayletClient( - std::move(grpc_client), raylet_socket, - WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), + std::move(grpc_client), raylet_socket, worker_context_.GetWorkerID(), (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_, &local_raylet_id, core_worker_server_.GetPort())); connected_ = true; @@ -160,6 +159,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, rpc_address_.set_ip_address(node_ip_address); rpc_address_.set_port(core_worker_server_.GetPort()); rpc_address_.set_raylet_id(local_raylet_id.Binary()); + rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary()); // Set timer to periodically send heartbeats containing active object IDs to the raylet. // If the heartbeat timeout is < 0, the heartbeats are disabled. @@ -227,12 +227,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, new rpc::CoreWorkerClient(ip_address, port, *client_call_manager_)); }; direct_actor_submitter_ = std::unique_ptr( - new CoreWorkerDirectActorTaskSubmitter(client_factory, memory_store_, + new CoreWorkerDirectActorTaskSubmitter(rpc_address_, client_factory, memory_store_, task_manager_)); direct_task_submitter_ = std::unique_ptr(new CoreWorkerDirectTaskSubmitter( - local_raylet_client_, client_factory, + rpc_address_, local_raylet_client_, client_factory, [this](const std::string ip_address, int port) { auto grpc_client = rpc::NodeManagerWorkerClient::make(ip_address, port, *client_call_manager_); @@ -244,7 +244,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, future_resolver_.reset(new FutureResolver(memory_store_, client_factory)); // Unfortunately the raylet client has to be constructed after the receivers. if (direct_task_receiver_ != nullptr) { - direct_task_receiver_->Init(*local_raylet_client_, client_factory, rpc_address_); + direct_task_receiver_->Init(client_factory, rpc_address_); } } @@ -385,7 +385,10 @@ Status CoreWorker::Put(const RayObject &object, ObjectID *object_id) { worker_context_.GetNextPutIndex(), static_cast(TaskTransportType::RAYLET)); reference_counter_->AddOwnedObject(*object_id, GetCallerId(), rpc_address_); - return Put(object, *object_id); + RAY_RETURN_NOT_OK(Put(object, *object_id)); + // Tell the raylet to pin the object **after** it is created. + RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {*object_id})); + return Status::OK(); } Status CoreWorker::Put(const RayObject &object, const ObjectID &object_id) { @@ -408,8 +411,16 @@ Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t return plasma_store_provider_->Create(metadata, data_size, object_id, data); } -Status CoreWorker::Seal(const ObjectID &object_id) { - return plasma_store_provider_->Seal(object_id); +Status CoreWorker::Seal(const ObjectID &object_id, bool owns_object, bool pin_object) { + RAY_RETURN_NOT_OK(plasma_store_provider_->Seal(object_id)); + if (owns_object) { + reference_counter_->AddOwnedObject(object_id, GetCallerId(), rpc_address_); + if (pin_object) { + // Tell the raylet to pin the object **after** it is created. + RAY_CHECK_OK(local_raylet_client_->PinObjectIDs(rpc_address_, {object_id})); + } + } + return Status::OK(); } Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_ms, @@ -608,6 +619,9 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on absl::flat_hash_set memory_object_ids; GroupObjectIdsByStoreProvider(object_ids, &plasma_object_ids, &memory_object_ids); + // TODO(edoakes): what are the desired semantics for deleting from a non-owner? + // Should we just delete locally or ping the owner and delete globally? + reference_counter_->DeleteReferences(object_ids); memory_store_->Delete(memory_object_ids, &plasma_object_ids); RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, delete_creating_tasks)); @@ -932,7 +946,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, continue; } if (return_objects->at(i)->GetData()->IsPlasmaBuffer()) { - if (!Seal(return_ids[i]).ok()) { + if (!Seal(return_ids[i], /*owns_object=*/false, /*pin_object=*/false).ok()) { RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to seal object " << return_ids[i] << " in store: " << status.message(); } @@ -1101,6 +1115,30 @@ void CoreWorker::HandleGetObjectStatus(const rpc::GetObjectStatusRequest &reques } } +void CoreWorker::HandleWaitForObjectEviction( + const rpc::WaitForObjectEvictionRequest &request, + rpc::WaitForObjectEvictionReply *reply, rpc::SendReplyCallback send_reply_callback) { + if (HandleWrongRecipient(WorkerID::FromBinary(request.intended_worker_id()), + send_reply_callback)) { + return; + } + + // Send a response to trigger unpinning the object when it is no longer in scope. + auto respond = [send_reply_callback](const ObjectID &object_id) { + RAY_LOG(DEBUG) << "Replying to HandleWaitForObjectEviction for " << object_id; + send_reply_callback(Status::OK(), nullptr, nullptr); + }; + + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + // Returns true if the object was present and the callback was added. It might have + // already been evicted by the time we get this request, in which case we should + // respond immediately so the raylet unpins the object. + if (!reference_counter_->SetDeleteCallback(object_id, respond)) { + RAY_LOG(DEBUG) << "ObjectID reference already gone for " << object_id; + respond(object_id); + } +} + void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 115aff752..f78e13a80 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -33,6 +33,7 @@ RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \ RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) \ RAY_CORE_WORKER_RPC_HANDLER(GetObjectStatus, 9999) \ + RAY_CORE_WORKER_RPC_HANDLER(WaitForObjectEviction, 9999) \ RAY_CORE_WORKER_RPC_HANDLER(KillActor, 9999) \ RAY_CORE_WORKER_RPC_HANDLER(GetCoreWorkerStats, 100) @@ -219,8 +220,14 @@ class CoreWorker { /// a corresponding `Create()` call and then writing into the returned buffer. /// /// \param[in] object_id Object ID corresponding to the object. + /// \param[in] owns_object Whether or not this worker owns the object. If true, + /// the object will be added as owned to the reference counter as an + /// owned object and this worker will be responsible for managing its + /// lifetime. + /// \param[in] pin_object Whether or not to pin the object at the local raylet. This + /// only applies when owns_object is true. /// \return Status. - Status Seal(const ObjectID &object_id); + Status Seal(const ObjectID &object_id, bool owns_object, bool pin_object); /// Get a list of objects from the object store. Objects that failed to be retrieved /// will be returned as nullptrs. @@ -410,6 +417,11 @@ class CoreWorker { rpc::GetObjectStatusReply *reply, rpc::SendReplyCallback send_reply_callback); + /// Implements gRPC server handler. + void HandleWaitForObjectEviction(const rpc::WaitForObjectEvictionRequest &request, + rpc::WaitForObjectEvictionReply *reply, + rpc::SendReplyCallback send_reply_callback); + /// Implements gRPC server handler. void HandleKillActor(const rpc::KillActorRequest &request, rpc::KillActorReply *reply, rpc::SendReplyCallback send_reply_callback); diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index f799d67e2..0e81a48ac 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -27,28 +27,25 @@ void ReferenceCounter::AddOwnedObject(const ObjectID &object_id, const TaskID &o void ReferenceCounter::AddLocalReference(const ObjectID &object_id) { absl::MutexLock lock(&mutex_); - auto entry = object_id_refs_.find(object_id); - if (entry == object_id_refs_.end()) { - // TODO: Once ref counting is implemented, we should always know how the - // ObjectID was created, so there should always be an entry. - entry = object_id_refs_.emplace(object_id, Reference()).first; + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + // NOTE: ownership info for these objects must be added later via AddBorrowedObject. + it = object_id_refs_.emplace(object_id, Reference()).first; } - entry->second.local_ref_count++; + it->second.local_ref_count++; } void ReferenceCounter::RemoveLocalReference(const ObjectID &object_id, std::vector *deleted) { absl::MutexLock lock(&mutex_); - auto entry = object_id_refs_.find(object_id); - if (entry == object_id_refs_.end()) { + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " << object_id; return; } - if (--entry->second.local_ref_count == 0 && - entry->second.submitted_task_ref_count == 0) { - object_id_refs_.erase(entry); - deleted->push_back(object_id); + if (--it->second.local_ref_count == 0 && it->second.submitted_task_ref_count == 0) { + DeleteReferenceInternal(it, deleted); } } @@ -56,13 +53,13 @@ void ReferenceCounter::AddSubmittedTaskReferences( const std::vector &object_ids) { absl::MutexLock lock(&mutex_); for (const ObjectID &object_id : object_ids) { - auto entry = object_id_refs_.find(object_id); - if (entry == object_id_refs_.end()) { - // TODO: Once ref counting is implemented, we should always know how the - // ObjectID was created, so there should always be an entry. - entry = object_id_refs_.emplace(object_id, Reference()).first; + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + // This happens if a large argument is transparently passed by reference + // because we don't hold a Python reference to its ObjectID. + it = object_id_refs_.emplace(object_id, Reference()).first; } - entry->second.submitted_task_ref_count++; + it->second.submitted_task_ref_count++; } } @@ -70,16 +67,14 @@ void ReferenceCounter::RemoveSubmittedTaskReferences( const std::vector &object_ids, std::vector *deleted) { absl::MutexLock lock(&mutex_); for (const ObjectID &object_id : object_ids) { - auto entry = object_id_refs_.find(object_id); - if (entry == object_id_refs_.end()) { + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { RAY_LOG(WARNING) << "Tried to decrease ref count for nonexistent object ID: " << object_id; return; } - if (--entry->second.submitted_task_ref_count == 0 && - entry->second.local_ref_count == 0) { - object_id_refs_.erase(entry); - deleted->push_back(object_id); + if (--it->second.submitted_task_ref_count == 0 && it->second.local_ref_count == 0) { + DeleteReferenceInternal(it, deleted); } } } @@ -101,6 +96,41 @@ bool ReferenceCounter::GetOwner(const ObjectID &object_id, TaskID *owner_id, } } +void ReferenceCounter::DeleteReferences(const std::vector &object_ids) { + absl::MutexLock lock(&mutex_); + for (const ObjectID &object_id : object_ids) { + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + return; + } + DeleteReferenceInternal(it, nullptr); + } +} + +void ReferenceCounter::DeleteReferenceInternal( + absl::flat_hash_map::iterator it, + std::vector *deleted) { + if (it->second.on_delete) { + it->second.on_delete(it->first); + } + if (deleted) { + deleted->push_back(it->first); + } + object_id_refs_.erase(it); +} + +bool ReferenceCounter::SetDeleteCallback( + const ObjectID &object_id, const std::function callback) { + absl::MutexLock lock(&mutex_); + auto it = object_id_refs_.find(object_id); + if (it == object_id_refs_.end()) { + return false; + } + RAY_CHECK(!it->second.on_delete); + it->second.on_delete = callback; + return true; +} + bool ReferenceCounter::HasReference(const ObjectID &object_id) const { absl::MutexLock lock(&mutex_); return object_id_refs_.find(object_id) != object_id_refs_.end(); @@ -134,21 +164,4 @@ ReferenceCounter::GetAllReferenceCounts() const { return all_ref_counts; } -void ReferenceCounter::LogDebugString() const { - absl::MutexLock lock(&mutex_); - - RAY_LOG(DEBUG) << "ReferenceCounter state:"; - if (object_id_refs_.empty()) { - RAY_LOG(DEBUG) << "\tEMPTY"; - return; - } - - for (const auto &entry : object_id_refs_) { - RAY_LOG(DEBUG) << "\t" << entry.first.Hex(); - RAY_LOG(DEBUG) << "\t\tlocal refcount: " << entry.second.local_ref_count; - RAY_LOG(DEBUG) << "\t\tsubmitted task refcount: " - << entry.second.submitted_task_ref_count; - } -} - } // namespace ray diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index d18cc6be6..f17199718 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -36,7 +36,8 @@ class ReferenceCounter { /// dependencies to a submitted task. /// /// \param[in] object_ids The object IDs to add references for. - void AddSubmittedTaskReferences(const std::vector &object_ids); + void AddSubmittedTaskReferences(const std::vector &object_ids) + LOCKS_EXCLUDED(mutex_); /// Remove references for the provided object IDs that correspond to them being /// dependencies to a submitted task. This should be called when inlined @@ -45,7 +46,8 @@ class ReferenceCounter { /// \param[in] object_ids The object IDs to remove references for. /// \param[out] deleted The object IDs whos reference counts reached zero. void RemoveSubmittedTaskReferences(const std::vector &object_ids, - std::vector *deleted); + std::vector *deleted) + LOCKS_EXCLUDED(mutex_); /// Add an object that we own. The object may depend on other objects. /// Dependencies for each ObjectID must be set at most once. The local @@ -73,9 +75,23 @@ class ReferenceCounter { void AddBorrowedObject(const ObjectID &object_id, const TaskID &owner_id, const rpc::Address &owner_address) LOCKS_EXCLUDED(mutex_); + /// Get the owner ID and address of the given object. + /// + /// \param[in] object_id The ID of the object to look up. + /// \param[out] owner_id The TaskID of the object owner. + /// \param[out] owner_address The address of the object owner. bool GetOwner(const ObjectID &object_id, TaskID *owner_id, rpc::Address *owner_address) const LOCKS_EXCLUDED(mutex_); + /// Manually delete the objects from the reference counter. + void DeleteReferences(const std::vector &object_ids) LOCKS_EXCLUDED(mutex_); + + /// Sets the callback that will be run when the object goes out of scope. + /// Returns true if the object was in scope and the callback was added, else false. + bool SetDeleteCallback(const ObjectID &object_id, + const std::function callback) + LOCKS_EXCLUDED(mutex_); + /// Returns the total number of ObjectIDs currently in scope. size_t NumObjectIDsInScope() const LOCKS_EXCLUDED(mutex_); @@ -90,9 +106,6 @@ class ReferenceCounter { std::unordered_map> GetAllReferenceCounts() const LOCKS_EXCLUDED(mutex_); - /// Dumps information about all currently tracked references to RAY_LOG(DEBUG). - void LogDebugString() const LOCKS_EXCLUDED(mutex_); - private: /// Metadata for an ObjectID reference in the language frontend. struct Reference { @@ -113,20 +126,15 @@ class ReferenceCounter { /// if we do not know the object's owner (because distributed ref counting /// is not yet implemented). absl::optional> owner; + /// Callback that will be called when this ObjectID no longer has references. + std::function on_delete; }; - /// Helper function with the same semantics as AddReference to allow adding a reference - /// while already holding mutex_. - void AddLocalReferenceInternal(const ObjectID &object_id) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - /// Recursive helper function for decreasing reference counts. Will recursively call - /// itself on any dependencies whose reference count reaches zero as a result of - /// removing the reference. - /// - /// \param[in] object_id The object to to decrement the count for. - /// \param[in] deleted List to store objects that hit zero ref count. - void RemoveReferenceRecursive(const ObjectID &object_id, std::vector *deleted) + /// Helper method to delete an entry from the reference map and run any necessary + /// callbacks. Assumes that the entry is in object_id_refs_ and invalidates the + /// iterator. + void DeleteReferenceInternal(absl::flat_hash_map::iterator entry, + std::vector *deleted) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Protects access to the reference counting state. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 69a41302d..079e441c2 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -1,29 +1,26 @@ -#include -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "ray/common/buffer.h" -#include "ray/common/ray_object.h" -#include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" -#include "ray/core_worker/transport/direct_actor_transport.h" - -#include "ray/core_worker/store_provider/memory_store/memory_store.h" - -#include "ray/raylet/raylet_client.h" -#include "src/ray/protobuf/core_worker.pb.h" -#include "src/ray/protobuf/gcs.pb.h" -#include "src/ray/util/test_util.h" #include #include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" #include "hiredis/async.h" #include "hiredis/hiredis.h" +#include "ray/common/buffer.h" +#include "ray/common/ray_object.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/transport/direct_actor_transport.h" +#include "ray/raylet/raylet_client.h" #include "ray/util/test_util.h" +#include "src/ray/protobuf/core_worker.pb.h" +#include "src/ray/protobuf/gcs.pb.h" +#include "src/ray/util/test_util.h" namespace { @@ -954,10 +951,13 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); - ASSERT_TRUE(worker1.Get(ids, 0, &results).IsTimedOut()); - ASSERT_EQ(results.size(), 2); - ASSERT_TRUE(!results[0]); - ASSERT_TRUE(!results[1]); + // TODO(edoakes): this currently fails because the object is pinned on the + // creating node. Should be fixed or removed once we decide the semantics + // for Delete() with pinning. + // ASSERT_TRUE(worker1.Get(ids, 0, &results).IsTimedOut()); + // ASSERT_EQ(results.size(), 2); + // ASSERT_TRUE(!results[0]); + // ASSERT_TRUE(!results[1]); } TEST_F(SingleNodeTest, TestNormalTaskLocal) { diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index 5af72d3c0..38bff92bc 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -63,9 +63,11 @@ class DirectActorTransportTest : public ::testing::Test { : worker_client_(std::shared_ptr(new MockWorkerClient())), store_(std::shared_ptr(new CoreWorkerMemoryStore())), task_finisher_(std::make_shared()), - submitter_([&](const std::string ip, int port) { return worker_client_; }, store_, + submitter_(address_, + [&](const std::string ip, int port) { return worker_client_; }, store_, task_finisher_) {} + rpc::Address address_; std::shared_ptr worker_client_; std::shared_ptr store_; std::shared_ptr task_finisher_; diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 983eb78e6..e70116ed3 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -251,12 +251,13 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r } TEST(DirectTaskTransportTest, TestSubmitOneTask) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; @@ -281,12 +282,13 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { } TEST(DirectTaskTransportTest, TestHandleTaskFailure) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; @@ -304,12 +306,13 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { } TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; @@ -348,12 +351,13 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { } TEST(DirectTaskTransportTest, TestReuseWorkerLease) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; @@ -395,12 +399,13 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { } TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; @@ -432,12 +437,13 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { } TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; @@ -459,6 +465,7 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { } TEST(DirectTaskTransportTest, TestSpillback) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); @@ -473,9 +480,9 @@ TEST(DirectTaskTransportTest, TestSpillback) { return client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory, - store, task_finisher, ClientID::Nil(), - kLongTimeout); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, + lease_client_factory, store, task_finisher, + ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); @@ -507,6 +514,7 @@ TEST(DirectTaskTransportTest, TestSpillback) { } TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); @@ -522,9 +530,9 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { }; auto task_finisher = std::make_shared(); auto local_raylet_id = ClientID::FromRandom(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, lease_client_factory, - store, task_finisher, local_raylet_id, - kLongTimeout); + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, + lease_client_factory, store, task_finisher, + local_raylet_id, kLongTimeout); std::unordered_map empty_resources; std::vector empty_descriptor; TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); @@ -565,11 +573,12 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { void TestSchedulingKey(const std::shared_ptr store, const TaskSpecification &same1, const TaskSpecification &same2, const TaskSpecification &different) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); ASSERT_TRUE(submitter.SubmitTask(same1).ok()); @@ -663,12 +672,13 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { } TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { + rpc::Address address; auto raylet_client = std::make_shared(); auto worker_client = std::make_shared(); auto store = std::make_shared(); auto factory = [&](const std::string &addr, int port) { return worker_client; }; auto task_finisher = std::make_shared(); - CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, nullptr, store, + CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), /*lease_timeout_ms=*/5); std::unordered_map empty_resources; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 0da279185..7c9b85cf0 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -42,6 +42,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spe const auto &actor_id = task_spec.ActorId(); auto request = std::unique_ptr(new rpc::PushTaskRequest); + request->mutable_caller_address()->CopyFrom(rpc_address_); // NOTE(swang): CopyFrom is needed because if we use Swap here and the task // fails, then the task data will be gone when the TaskManager attempts to // access the task. @@ -77,6 +78,7 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectActor(const ActorID &actor_id, // Update the mapping so new RPCs go out with the right intended worker id. worker_ids_[actor_id] = address.worker_id(); // Create a new connection to the actor. + // TODO(edoakes): are these clients cleaned up properly? if (rpc_clients_.count(actor_id) == 0) { rpc_clients_[actor_id] = std::shared_ptr( client_factory_(address.ip_address(), address.port())); @@ -168,13 +170,11 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c return (iter != rpc_clients_.end()); } -void CoreWorkerDirectTaskReceiver::Init(raylet::RayletClient &raylet_client, - rpc::ClientFactoryFn client_factory, +void CoreWorkerDirectTaskReceiver::Init(rpc::ClientFactoryFn client_factory, rpc::Address rpc_address) { - waiter_.reset(new DependencyWaiterImpl(raylet_client)); + waiter_.reset(new DependencyWaiterImpl(*local_raylet_client_)); rpc_address_ = rpc_address; client_factory_ = client_factory; - local_raylet_client_ = raylet_client; } void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) { @@ -252,7 +252,9 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( } } - auto accept_callback = [this, reply, send_reply_callback, task_spec, resource_ids]() { + const rpc::Address &caller_address = request.caller_address(); + auto accept_callback = [this, caller_address, reply, send_reply_callback, task_spec, + resource_ids]() { // We have posted an exit task onto the main event loop, // so shouldn't bother executing any further work. if (exiting_) return; @@ -268,6 +270,7 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( auto status = task_handler_(task_spec, resource_ids, &return_objects); bool objects_valid = return_objects.size() == num_returns; if (objects_valid) { + std::vector plasma_return_ids; for (size_t i = 0; i < return_objects.size(); i++) { auto return_object = reply->add_return_objects(); ObjectID id = ObjectID::ForTaskReturn( @@ -279,6 +282,7 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( const auto &result = return_objects[i]; if (result == nullptr || result->GetData()->IsPlasmaBuffer()) { return_object->set_in_plasma(true); + plasma_return_ids.push_back(id); } else { if (result->GetData() != nullptr) { return_object->set_data(result->GetData()->Data(), result->GetData()->Size()); @@ -289,7 +293,15 @@ void CoreWorkerDirectTaskReceiver::HandlePushTask( } } } - + // If we spilled any return objects to plasma, notify the raylet to pin them. + // The raylet will then coordinate with the caller to manage the objects' + // lifetimes. + // TODO(edoakes): the plasma objects could be evicted between creating them + // here and when raylet pins them. + if (!plasma_return_ids.empty()) { + RAY_CHECK_OK( + local_raylet_client_->PinObjectIDs(caller_address, plasma_return_ids)); + } if (task_spec.IsActorCreationTask()) { RAY_LOG(INFO) << "Actor creation task finished, task_id: " << task_spec.TaskId() << ", actor_id: " << task_spec.ActorCreationId(); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 56e0d1401..bb6f1d39a 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -36,10 +36,12 @@ const int kMaxReorderWaitSeconds = 30; // This class is thread-safe. class CoreWorkerDirectActorTaskSubmitter { public: - CoreWorkerDirectActorTaskSubmitter(rpc::ClientFactoryFn client_factory, + CoreWorkerDirectActorTaskSubmitter(rpc::Address rpc_address, + rpc::ClientFactoryFn client_factory, std::shared_ptr store, std::shared_ptr task_finisher) - : client_factory_(client_factory), + : rpc_address_(rpc_address), + client_factory_(client_factory), resolver_(store, task_finisher), task_finisher_(task_finisher) {} @@ -102,6 +104,9 @@ class CoreWorkerDirectActorTaskSubmitter { /// Mutex to proect the various maps below. mutable absl::Mutex mu_; + /// Address of our RPC server. + rpc::Address rpc_address_; + /// Map from actor id to rpc client. This only includes actors that we send tasks to. /// We use shared_ptr to enable shared_from_this for pending client callbacks. /// @@ -171,14 +176,14 @@ class DependencyWaiter { class DependencyWaiterImpl : public DependencyWaiter { public: - DependencyWaiterImpl(raylet::RayletClient &raylet_client) - : raylet_client_(raylet_client) {} + DependencyWaiterImpl(raylet::RayletClient &local_raylet_client) + : local_raylet_client_(local_raylet_client) {} void Wait(const std::vector &dependencies, std::function on_dependencies_available) override { auto tag = next_request_id_++; requests_[tag] = on_dependencies_available; - raylet_client_.WaitForDirectActorCallArgs(dependencies, tag); + local_raylet_client_.WaitForDirectActorCallArgs(dependencies, tag); } /// Fulfills the callback stored by Wait(). @@ -192,7 +197,7 @@ class DependencyWaiterImpl : public DependencyWaiter { private: int64_t next_request_id_ = 0; std::unordered_map> requests_; - raylet::RayletClient &raylet_client_; + raylet::RayletClient &local_raylet_client_; }; /// Wraps a thread-pool to block posts until the pool has free slots. This is used @@ -431,10 +436,12 @@ class CoreWorkerDirectTaskReceiver { std::vector> *return_objects)>; CoreWorkerDirectTaskReceiver(WorkerContext &worker_context, + std::shared_ptr &local_raylet_client, boost::asio::io_service &main_io_service, const TaskHandler &task_handler, const std::function &exit_handler) : worker_context_(worker_context), + local_raylet_client_(local_raylet_client), task_handler_(task_handler), exit_handler_(exit_handler), task_main_io_service_(main_io_service) {} @@ -448,8 +455,7 @@ class CoreWorkerDirectTaskReceiver { } /// Initialize this receiver. This must be called prior to use. - void Init(raylet::RayletClient &client, rpc::ClientFactoryFn client_factory, - rpc::Address rpc_address); + void Init(rpc::ClientFactoryFn client_factory, rpc::Address rpc_address); /// Handle a `PushTask` request. /// @@ -488,6 +494,9 @@ class CoreWorkerDirectTaskReceiver { rpc::ClientFactoryFn client_factory_; /// Address of our RPC server. rpc::Address rpc_address_; + /// Reference to the core worker's raylet client. This is a pointer ref so that it + /// can be initialized by core worker after this class is constructed. + std::shared_ptr &local_raylet_client_; /// Shared waiter for dependencies required by incoming tasks. std::unique_ptr waiter_; /// Queue of pending requests per actor handle. @@ -510,8 +519,6 @@ class CoreWorkerDirectTaskReceiver { /// The fiber semaphore used to limit the number of concurrent fibers /// running at once. std::shared_ptr fiber_rate_limiter_; - - boost::optional local_raylet_client_; }; } // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc index fd130b414..7fc563cac 100644 --- a/src/ray/core_worker/transport/direct_task_transport.cc +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -1,4 +1,5 @@ #include "ray/core_worker/transport/direct_task_transport.h" + #include "ray/core_worker/transport/dependency_resolver.h" #include "ray/core_worker/transport/direct_actor_transport.h" @@ -169,6 +170,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask( // NOTE(swang): CopyFrom is needed because if we use Swap here and the task // fails, then the task data will be gone when the TaskManager attempts to // access the task. + request->mutable_caller_address()->CopyFrom(rpc_address_); request->mutable_task_spec()->CopyFrom(task_spec.GetMessage()); request->mutable_resource_mapping()->CopyFrom(assigned_resources); request->set_intended_worker_id(addr.worker_id.Binary()); diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h index aec5918e9..3a2193a7b 100644 --- a/src/ray/core_worker/transport/direct_task_transport.h +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -35,13 +35,15 @@ using SchedulingKey = std::tuple, ActorID // This class is thread-safe. class CoreWorkerDirectTaskSubmitter { public: - CoreWorkerDirectTaskSubmitter(std::shared_ptr lease_client, + CoreWorkerDirectTaskSubmitter(rpc::Address rpc_address, + std::shared_ptr lease_client, rpc::ClientFactoryFn client_factory, LeaseClientFactoryFn lease_client_factory, std::shared_ptr store, std::shared_ptr task_finisher, ClientID local_raylet_id, int64_t lease_timeout_ms) - : local_lease_client_(lease_client), + : rpc_address_(rpc_address), + local_lease_client_(lease_client), client_factory_(client_factory), lease_client_factory_(lease_client_factory), resolver_(store, task_finisher), @@ -101,6 +103,9 @@ class CoreWorkerDirectTaskSubmitter { const google::protobuf::RepeatedPtrField &assigned_resources); + /// Address of our RPC server. + rpc::Address rpc_address_; + // Client that can be used to lease and return workers from the local raylet. std::shared_ptr local_lease_client_; diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 36e610dfd..7c77857ac 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -65,22 +65,24 @@ message ReturnObject { message PushTaskRequest { // The ID of the worker this message is intended for. bytes intended_worker_id = 1; + // Address of the caller. + Address caller_address = 2; // The task to be pushed. - TaskSpec task_spec = 2; + TaskSpec task_spec = 3; // The sequence number of the task for this client. This must increase // sequentially starting from zero for each actor handle. The server // will guarantee tasks execute in this sequence, waiting for any // out-of-order request messages to arrive as necessary. // If set to -1, ordering is disabled and the task executes immediately. // This mode of behaviour is used for direct task submission only. - int64 sequence_number = 3; + int64 sequence_number = 4; // The max sequence number the client has processed responses for. This // is a performance optimization that allows the client to tell the server // to cancel any PushTaskRequests with seqno <= this value, rather than // waiting for the server to time out waiting for missing messages. - int64 client_processed_up_to = 4; + int64 client_processed_up_to = 5; // Resource mapping ids assigned to the worker executing the task. - repeated ResourceMapEntry resource_mapping = 5; + repeated ResourceMapEntry resource_mapping = 6; } message PushTaskReply { @@ -117,6 +119,16 @@ message GetObjectStatusReply { ObjectStatus status = 1; } +message WaitForObjectEvictionRequest { + // The ID of the worker this message is intended for. + bytes intended_worker_id = 1; + // ObjectID of the pinned object. + bytes object_id = 2; +} + +message WaitForObjectEvictionReply { +} + message KillActorRequest { // ID of the actor that is intended to be killed. bytes intended_actor_id = 1; @@ -145,6 +157,10 @@ service CoreWorkerService { returns (DirectActorCallArgWaitCompleteReply); // Ask the object's owner about the object's current status. rpc GetObjectStatus(GetObjectStatusRequest) returns (GetObjectStatusReply); + // Notify the object's owner that it has been pinned by a raylet. Replying + // to this message indicates that the raylet should unpin the object. + rpc WaitForObjectEviction(WaitForObjectEvictionRequest) + returns (WaitForObjectEvictionReply); // Request that the worker shut down without completing outstanding work. rpc KillActor(KillActorRequest) returns (KillActorReply); // Get metrics from core workers. diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index e98cfd8a7..225a9f90b 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -45,6 +45,16 @@ message ForwardTaskRequest { message ForwardTaskReply { } +message PinObjectIDsRequest { + // Address of the owner to ask when to unpin the objects. + Address owner_address = 1; + // ObjectIDs to pin. + repeated bytes object_ids = 2; +} + +message PinObjectIDsReply { +} + message GetNodeStatsRequest { } @@ -72,6 +82,8 @@ service NodeManagerService { rpc ReturnWorker(ReturnWorkerRequest) returns (ReturnWorkerReply); // Forward a task and its uncommitted lineage to the remote node manager. rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); + // Pin the provided object IDs. + rpc PinObjectIDs(PinObjectIDsRequest) returns (PinObjectIDsReply); // Get the current node stats. rpc GetNodeStats(GetNodeStatsRequest) returns (GetNodeStatsReply); } diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index efdc548e4..82912b59d 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -1,5 +1,6 @@ #include +#include "gflags/gflags.h" #include "ray/common/id.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" @@ -7,8 +8,6 @@ #include "ray/raylet/raylet.h" #include "ray/stats/stats.h" -#include "gflags/gflags.h" - DEFINE_string(raylet_socket_name, "", "The socket name of raylet."); DEFINE_string(store_socket_name, "", "The socket name of object store."); DEFINE_int32(object_manager_port, -1, "The port of object manager."); @@ -125,6 +124,8 @@ int main(int argc, char *argv[]) { RayConfig::instance().debug_dump_period_milliseconds(); node_manager_config.fair_queueing_enabled = RayConfig::instance().fair_queueing_enabled(); + node_manager_config.object_pinning_enabled = + RayConfig::instance().object_pinning_enabled(); node_manager_config.max_lineage_size = RayConfig::instance().max_lineage_size(); node_manager_config.store_socket_name = store_socket_name; node_manager_config.temp_dir = temp_dir; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a24f3244d..f418a9b28 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -4,6 +4,7 @@ #include #include +#include "ray/common/buffer.h" #include "ray/common/common_protocol.h" #include "ray/common/id.h" #include "ray/common/status.h" @@ -83,6 +84,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, heartbeat_period_(std::chrono::milliseconds(config.heartbeat_period_ms)), debug_dump_period_(config.debug_dump_period_ms), fair_queueing_enabled_(config.fair_queueing_enabled), + object_pinning_enabled_(config.object_pinning_enabled), temp_dir_(config.temp_dir), object_manager_profile_timer_(io_service), initial_config_(config), @@ -2966,6 +2968,82 @@ std::string compact_tag_string(const opencensus::stats::ViewDescriptor &view, return result.str(); } +void NodeManager::HandlePinObjectIDsRequest(const rpc::PinObjectIDsRequest &request, + rpc::PinObjectIDsReply *reply, + rpc::SendReplyCallback send_reply_callback) { + if (!object_pinning_enabled_) { + send_reply_callback(Status::OK(), nullptr, nullptr); + return; + } + WorkerID worker_id = WorkerID::FromBinary(request.owner_address().worker_id()); + auto it = worker_rpc_clients_.find(worker_id); + if (it == worker_rpc_clients_.end()) { + auto client = std::unique_ptr( + new rpc::CoreWorkerClient(request.owner_address().ip_address(), + request.owner_address().port(), client_call_manager_)); + it = worker_rpc_clients_ + .emplace(worker_id, + std::make_pair, size_t>( + std::move(client), 0)) + .first; + } + + // Pin the objects in plasma by getting them and holding a reference to + // the returned buffer. + // NOTE: the caller must ensure that the objects already exist in plamsa before + // sending a PinObjectIDs request. + std::vector plasma_ids; + plasma_ids.reserve(request.object_ids_size()); + for (const auto &object_id_binary : request.object_ids()) { + plasma_ids.push_back(plasma::ObjectID::from_binary(object_id_binary)); + } + std::vector plasma_results; + if (!store_client_.Get(plasma_ids, /*timeout_ms=*/0, &plasma_results).ok()) { + RAY_LOG(WARNING) << "Failed to get objects to be pinned from object store."; + send_reply_callback(Status::Invalid("Failed to get objects."), nullptr, nullptr); + return; + } + + // Pin the requested objects until the owner notifies us that the objects can be + // unpinned by responding to the WaitForObjectEviction message. + // TODO(edoakes): we should be batching these requests instead of sending one per + // pinned object. + size_t i = 0; + for (const auto &object_id_binary : request.object_ids()) { + ObjectID object_id = ObjectID::FromBinary(object_id_binary); + + RAY_LOG(DEBUG) << "Pinning object " << object_id; + pinned_objects_.emplace( + object_id, std::unique_ptr(new RayObject( + std::make_shared(plasma_results[i].data), + std::make_shared(plasma_results[i].metadata)))); + i++; + + // Send a long-running RPC request to the owner for each object. When we get a + // response or the RPC fails (due to the owner crashing), unpin the object. + rpc::WaitForObjectEvictionRequest wait_request; + wait_request.set_object_id(object_id_binary); + wait_request.set_intended_worker_id(request.owner_address().worker_id()); + worker_rpc_clients_[worker_id].second++; + RAY_CHECK_OK(it->second.first->WaitForObjectEviction( + wait_request, [this, worker_id, object_id]( + Status status, const rpc::WaitForObjectEvictionReply &reply) { + if (!status.ok()) { + RAY_LOG(WARNING) << "Worker " << worker_id << " failed. Unpinning object " + << object_id; + } + RAY_LOG(DEBUG) << "Unpinning object " << object_id; + pinned_objects_.erase(object_id); + + // Remove the cached worker client if there are no more pending requests. + if (--worker_rpc_clients_[worker_id].second == 0) { + worker_rpc_clients_.erase(worker_id); + } + })); + } + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void NodeManager::HandleNodeStatsRequest(const rpc::GetNodeStatsRequest &request, rpc::GetNodeStatsReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index c3fdd130c..ddfda825f 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -8,6 +8,7 @@ #include "ray/rpc/node_manager/node_manager_server.h" #include "ray/rpc/node_manager/node_manager_client.h" #include "ray/common/task/task.h" +#include "ray/common/ray_object.h" #include "ray/common/client_connection.h" #include "ray/common/task/task_common.h" #include "ray/common/task/scheduling_resources.h" @@ -56,6 +57,8 @@ struct NodeManagerConfig { uint64_t debug_dump_period_ms; /// Whether to enable fair queueing between task classes in raylet. bool fair_queueing_enabled; + /// Whether to enable pinning for plasma objects. + bool object_pinning_enabled; /// the maximum lineage size. uint64_t max_lineage_size; /// The store socket name. @@ -545,6 +548,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::ForwardTaskReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `PinObjectIDs` request. + void HandlePinObjectIDsRequest(const rpc::PinObjectIDsRequest &request, + rpc::PinObjectIDsReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `NodeStats` request. void HandleNodeStatsRequest(const rpc::GetNodeStatsRequest &request, rpc::GetNodeStatsReply *reply, @@ -567,9 +575,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler { ClientID self_node_id_; boost::asio::io_service &io_service_; ObjectManager &object_manager_; - /// A Plasma object store client. This is used exclusively for creating new - /// objects in the object store (e.g., for actor tasks that can't be run - /// because the actor died). + /// A Plasma object store client. This is used for creating new objects in + /// the object store (e.g., for actor tasks that can't be run because the + /// actor died) and to pin objects that are in scope in the cluster. plasma::PlasmaClient store_client_; /// A client connection to the GCS. std::shared_ptr gcs_client_; @@ -583,6 +591,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler { int64_t debug_dump_period_; /// Whether to enable fair queueing between task classes in raylet. bool fair_queueing_enabled_; + /// Whether to enable pinning for plasma objects. + bool object_pinning_enabled_; /// Whether we have printed out a resource deadlock warning. bool resource_deadlock_warned_ = false; /// Whether we have recorded any metrics yet. @@ -658,6 +668,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Queue of lease requests that should be scheduled onto workers. std::deque> tasks_to_dispatch_; + /// Cache of gRPC clients to workers (not necessarily running on this node). + /// Also includes the number of inflight requests to each worker - when this + /// reaches zero, the client will be deleted and a new one will need to be created + /// for any subsequent requests. + absl::flat_hash_map, size_t>> + worker_rpc_clients_; + + absl::flat_hash_map> pinned_objects_; + /// XXX void WaitForTaskArgsRequests(std::pair &work); }; diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 6429d2ef0..176c72fae 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -428,4 +428,14 @@ Status raylet::RayletClient::ReturnWorker(int worker_port, const WorkerID &worke }); } +Status raylet::RayletClient::PinObjectIDs(const rpc::Address &caller_address, + const std::vector &object_ids) { + rpc::PinObjectIDsRequest request; + request.mutable_owner_address()->CopyFrom(caller_address); + for (const ObjectID &object_id : object_ids) { + request.add_object_ids(object_id.Binary()); + } + return grpc_client_->PinObjectIDs(request, nullptr); +} + } // namespace ray diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index ebf65bb4d..e5618c42f 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -3,12 +3,12 @@ #include #include + +#include #include #include #include -#include - #include "ray/common/status.h" #include "ray/common/task/task_spec.h" #include "ray/rpc/node_manager/node_manager_client.h" @@ -249,6 +249,9 @@ class RayletClient : public WorkerLeaseInterface { ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, bool disconnect_worker) override; + ray::Status PinObjectIDs(const rpc::Address &caller_address, + const std::vector &object_ids); + WorkerID GetWorkerID() const { return worker_id_; } JobID GetJobID() const { return job_id_; } diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index aae2d8473..93e438148 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -1,10 +1,10 @@ #ifndef RAY_RPC_NODE_MANAGER_CLIENT_H #define RAY_RPC_NODE_MANAGER_CLIENT_H -#include - #include +#include + #include "ray/common/status.h" #include "ray/rpc/grpc_client.h" #include "ray/util/logging.h" @@ -73,6 +73,9 @@ class NodeManagerWorkerClient /// Return a worker lease. RPC_CLIENT_METHOD(NodeManagerService, ReturnWorker, grpc_client_, ) + /// Notify the raylet to pin the provided object IDs. + RPC_CLIENT_METHOD(NodeManagerService, PinObjectIDs, grpc_client_, ) + private: /// Constructor. /// diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index c1caf0ffb..9c292884d 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -3,7 +3,6 @@ #include "ray/rpc/grpc_server.h" #include "ray/rpc/server_call.h" - #include "src/ray/protobuf/node_manager.grpc.pb.h" #include "src/ray/protobuf/node_manager.pb.h" @@ -36,6 +35,10 @@ class NodeManagerServiceHandler { ForwardTaskReply *reply, SendReplyCallback send_reply_callback) = 0; + virtual void HandlePinObjectIDsRequest(const PinObjectIDsRequest &request, + PinObjectIDsReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleNodeStatsRequest(const GetNodeStatsRequest &request, GetNodeStatsReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -81,6 +84,13 @@ class NodeManagerGrpcService : public GrpcService { service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, main_service_)); + std::unique_ptr pin_object_ids_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestPinObjectIDs, + service_handler_, &NodeManagerServiceHandler::HandlePinObjectIDsRequest, cq, + main_service_)); + std::unique_ptr node_stats_call_factory( new ServerCallFactoryImpl( @@ -95,6 +105,8 @@ class NodeManagerGrpcService : public GrpcService { std::move(release_worker_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( std::move(forward_task_call_factory), 100); + server_call_factories_and_concurrencies->emplace_back( + std::move(pin_object_ids_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( std::move(node_stats_call_factory), 1); } diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 7fd0f8f69..d52b10584 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -121,6 +121,13 @@ class CoreWorkerClientInterface { return Status::NotImplemented(""); } + /// Notify the owner of an object that the object has been pinned. + virtual ray::Status WaitForObjectEviction( + const WaitForObjectEvictionRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + /// Tell this actor to exit immediately. virtual ray::Status KillActor(const KillActorRequest &request, const ClientCallback &callback) { @@ -161,6 +168,8 @@ class CoreWorkerClient : public std::enable_shared_from_this, RPC_CLIENT_METHOD(CoreWorkerService, KillActor, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, WaitForObjectEviction, grpc_client_, override) + RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override) ray::Status PushActorTask(std::unique_ptr request,