From fb34928a2a477d97fc61f61302589695d307b22d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 1 Nov 2019 14:41:14 -0700 Subject: [PATCH] [minor] Perf optimizations for direct actor task submission (#6044) * merge optimizations * fix * fix memory err * optimize * fix tests * fix serialization of method handles * document weakref * fix check * bazel format * disable on 2 --- BUILD.bazel | 3 + python/ray/_raylet.pyx | 40 ++++++--- python/ray/actor.py | 89 +++++++++++-------- python/ray/ray_perf.py | 10 +-- python/ray/tests/test_actor.py | 5 ++ python/ray/worker.py | 4 +- src/ray/common/ray_object.h | 5 -- src/ray/core_worker/core_worker.cc | 20 ++--- src/ray/core_worker/core_worker.h | 3 +- .../memory_store/memory_store.cc | 16 ++-- .../memory_store/memory_store.h | 6 +- .../store_provider/memory_store_provider.cc | 13 ++- .../store_provider/memory_store_provider.h | 12 +-- .../store_provider/plasma_store_provider.cc | 21 +++-- .../store_provider/plasma_store_provider.h | 18 ++-- src/ray/core_worker/test/core_worker_test.cc | 10 ++- 16 files changed, 157 insertions(+), 118 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index e87a25fe1..f39dfadef 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -371,6 +371,7 @@ cc_library( ]), copts = COPTS, deps = [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ":core_worker_cc_proto", ":ray_common", @@ -413,6 +414,8 @@ cc_library( deps = [ ":core_worker_lib", ":gcs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest_main", ], ) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 83d185282..baed93a85 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -683,21 +683,38 @@ cdef void push_objects_into_return_vector( c_vector[shared_ptr[CRayObject]] *returns): cdef: + c_string metadata_str = RAW_BUFFER_METADATA + c_string raw_data_str shared_ptr[CBuffer] data shared_ptr[CBuffer] metadata shared_ptr[CRayObject] ray_object int64_t data_size for serialized_object in py_objects: - data_size = serialized_object.total_bytes - data = dynamic_pointer_cast[ - CBuffer, LocalMemoryBuffer]( - make_shared[LocalMemoryBuffer](data_size)) - stream = pyarrow.FixedSizeBufferWriter( - pyarrow.py_buffer(Buffer.make(data))) - serialized_object.write_to(stream) - ray_object = make_shared[CRayObject](data, metadata) - returns.push_back(ray_object) + if isinstance(serialized_object, bytes): + data_size = len(serialized_object) + raw_data_str = serialized_object + data = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (raw_data_str.data()), raw_data_str.size())) + metadata = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (metadata_str.data()), metadata_str.size())) + ray_object = make_shared[CRayObject](data, metadata, True) + returns.push_back(ray_object) + else: + data_size = serialized_object.total_bytes + data = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer](data_size)) + metadata.reset() + stream = pyarrow.FixedSizeBufferWriter( + pyarrow.py_buffer(Buffer.make(data))) + serialized_object.write_to(stream) + ray_object = make_shared[CRayObject](data, metadata) + returns.push_back(ray_object) cdef class CoreWorker: @@ -981,7 +998,7 @@ cdef class CoreWorker: function_descriptor, args, int num_return_vals, - resources): + double num_method_cpus): cdef: CActorID c_actor_id = actor_id.native() @@ -992,7 +1009,8 @@ cdef class CoreWorker: c_vector[CObjectID] return_ids with self.profile_event(b"submit_task"): - prepare_resources(resources, &c_resources) + if num_method_cpus > 0: + c_resources[b"CPU"] = num_method_cpus task_options = CTaskOptions(num_return_vals, c_resources) ray_function = CRayFunction( LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) diff --git a/python/ray/actor.py b/python/ray/actor.py index 9f487d281..6accf35c0 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -7,6 +7,7 @@ import inspect import logging import six import sys +import weakref from abc import ABCMeta, abstractmethod from collections import namedtuple @@ -57,9 +58,8 @@ def method(*args, **kwargs): class ActorMethod(object): """A class used to invoke an actor method. - Note: This class is instantiated only while the actor method is being - invoked (so that it doesn't keep a reference to the actor handle and - prevent it from going out of scope). + Note: This class only keeps a weak ref to the actor, unless it has been + passed to a remote function. This avoids delays in GC of the actor. Attributes: _actor: A handle to the actor. @@ -75,8 +75,13 @@ class ActorMethod(object): "test_decorated_method" in "python/ray/tests/test_actor.py". """ - def __init__(self, actor, method_name, num_return_vals, decorator=None): - self._actor = actor + def __init__(self, + actor, + method_name, + num_return_vals, + decorator=None, + hardref=False): + self._actor_ref = weakref.ref(actor) self._method_name = method_name self._num_return_vals = num_return_vals # This is a decorator that is used to wrap the function invocation (as @@ -86,6 +91,11 @@ class ActorMethod(object): # and return the resulting ObjectIDs. self._decorator = decorator + # Acquire a hard ref to the actor, this is useful mainly when passing + # actor method handles to remote functions. + if hardref: + self._actor_hard_ref = actor + def __call__(self, *args, **kwargs): raise Exception("Actor methods cannot be called directly. Instead " "of running 'object.{}()', try " @@ -96,15 +106,14 @@ class ActorMethod(object): return self._remote(args, kwargs) def _remote(self, args=None, kwargs=None, num_return_vals=None): - if args is None: - args = [] - if kwargs is None: - kwargs = {} if num_return_vals is None: num_return_vals = self._num_return_vals def invocation(args, kwargs): - return self._actor._actor_method_call( + actor = self._actor_ref() + if actor is None: + raise RuntimeError("Lost reference to actor") + return actor._actor_method_call( self._method_name, args=args, kwargs=kwargs, @@ -116,6 +125,22 @@ class ActorMethod(object): return invocation(args, kwargs) + def __getstate__(self): + return { + "actor": self._actor_ref(), + "method_name": self._method_name, + "num_return_vals": self._num_return_vals, + "decorator": self._decorator, + } + + def __setstate__(self, state): + self.__init__( + state["actor"], + state["method_name"], + state["num_return_vals"], + state["decorator"], + hardref=True) + class ActorClassMetadata(object): """Metadata for an actor class. @@ -502,6 +527,14 @@ class ActorHandle(object): for method_name in self._ray_method_signatures.keys() } + for method_name in actor_method_names: + method = ActorMethod( + self, + method_name, + self._ray_method_num_return_vals[method_name], + decorator=self._ray_method_decorators.get(method_name)) + setattr(self, method_name, method) + def _actor_method_call(self, method_name, args=None, @@ -526,13 +559,15 @@ class ActorHandle(object): """ worker = ray.worker.get_global_worker() - worker.check_connected() - - function_signature = self._ray_method_signatures[method_name] args = args or [] kwargs = kwargs or {} + function_signature = self._ray_method_signatures[method_name] - list_args = signature.flatten_args(function_signature, args, kwargs) + if not args and not kwargs and not function_signature: + list_args = [] + else: + list_args = signature.flatten_args(function_signature, args, + kwargs) if worker.mode == ray.LOCAL_MODE: function = getattr(worker.actors[self._actor_id], method_name) object_ids = worker.local_mode_manager.execute( @@ -541,7 +576,7 @@ class ActorHandle(object): object_ids = worker.core_worker.submit_actor_task( self._ray_actor_id, self._ray_function_descriptor_lists[method_name], list_args, - num_return_vals, {"CPU": self._ray_actor_method_cpus}) + num_return_vals, self._ray_actor_method_cpus) if len(object_ids) == 1: object_ids = object_ids[0] @@ -554,30 +589,6 @@ class ActorHandle(object): def __dir__(self): return self._ray_actor_method_names - def __getattribute__(self, attr): - try: - # Check whether this is an actor method. - actor_method_names = object.__getattribute__( - self, "_ray_actor_method_names") - if attr in actor_method_names: - # We create the ActorMethod on the fly here so that the - # ActorHandle doesn't need a reference to the ActorMethod. - # The ActorMethod has a reference to the ActorHandle and - # this was causing cyclic references which were prevent - # object deallocation from behaving in a predictable - # manner. - return ActorMethod( - self, - attr, - self._ray_method_num_return_vals[attr], - decorator=self._ray_method_decorators.get(attr)) - except AttributeError: - pass - - # If the requested attribute is not a registered method, fall back - # to default __getattribute__. - return object.__getattribute__(self, attr) - def __repr__(self): return "Actor({}, {})".format(self._ray_class_name, self._actor_id.hex()) diff --git a/python/ray/ray_perf.py b/python/ray/ray_perf.py index f3ec0d85b..1bdaefc5a 100644 --- a/python/ray/ray_perf.py +++ b/python/ray/ray_perf.py @@ -10,19 +10,19 @@ import ray filter_pattern = os.environ.get("TESTS_TO_RUN", "") -@ray.remote +@ray.remote(num_cpus=0) class Actor(object): def small_value(self): - return 0 + return b"ok" def small_value_arg(self, x): - return 0 + return b"ok" def small_value_batch(self, n): ray.get([small_value.remote() for _ in range(n)]) -@ray.remote +@ray.remote(num_cpus=0) class Client(object): def __init__(self, servers): if not isinstance(servers, list): @@ -45,7 +45,7 @@ class Client(object): @ray.remote def small_value(): - return 0 + return b"ok" @ray.remote diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index e20a6a79f..daf9833a6 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -494,11 +494,16 @@ def test_actor_deletion(ray_start_regular): actors = None [ray.tests.utils.wait_for_pid_to_exit(pid) for pid in pids] + +@pytest.mark.skipif( + sys.version_info < (3, 0), reason="This test requires Python 3.") +def test_actor_method_deletion(ray_start_regular): @ray.remote class Actor(object): def method(self): return 1 + # TODO(ekl) this doesn't work in Python 2 after the weak ref method change. # Make sure that if we create an actor and call a method on it # immediately, the actor doesn't get killed before the method is # called. diff --git a/python/ray/worker.py b/python/ray/worker.py index 00c32906d..f96e8e3ba 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -292,8 +292,8 @@ class Worker(object): if isinstance(value, bytes): if return_buffer is not None: - raise NotImplementedError( - "returning raw buffers from direct actor calls") + return_buffer.append(value) + return # If the object is a byte array, skip serializing it and # use a special metadata to indicate it's raw binary. So # that this object can also be read by Java. diff --git a/src/ray/common/ray_object.h b/src/ray/common/ray_object.h index 163a338d1..04ff0e9e3 100644 --- a/src/ray/common/ray_object.h +++ b/src/ray/common/ray_object.h @@ -25,11 +25,6 @@ class RayObject { RayObject(const std::shared_ptr &data, const std::shared_ptr &metadata, bool copy_data = false) : data_(data), metadata_(metadata), has_data_copy_(copy_data) { - RAY_CHECK(!data || data_->Size()) - << "Zero-length buffers are not allowed when constructing a RayObject."; - RAY_CHECK(!metadata || metadata->Size()) - << "Zero-length buffers are not allowed when constructing a RayObject."; - if (has_data_copy_) { // If this object is required to hold a copy of the data, // make a copy if the passed in buffers don't already have a copy. diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 297d4c70b..8930865ec 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -39,8 +39,8 @@ void BuildCommonTaskSpec( // Group object ids according the the corresponding store providers. void GroupObjectIdsByStoreProvider(const std::vector &object_ids, - std::unordered_set *plasma_object_ids, - std::unordered_set *memory_object_ids) { + absl::flat_hash_set *plasma_object_ids, + absl::flat_hash_set *memory_object_ids) { // There are two cases: // - for task return objects from direct actor call, use memory store provider; // - all the others use plasma store provider. @@ -312,12 +312,12 @@ Status CoreWorker::Get(const std::vector &ids, int64_t timeout_ms, std::vector> *results) { results->resize(ids.size(), nullptr); - std::unordered_set plasma_object_ids; - std::unordered_set memory_object_ids; + absl::flat_hash_set plasma_object_ids; + absl::flat_hash_set memory_object_ids; GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids); bool got_exception = false; - std::unordered_map> result_map; + absl::flat_hash_map> result_map; auto start_time = current_time_ms(); RAY_RETURN_NOT_OK(plasma_store_provider_->Get(plasma_object_ids, timeout_ms, worker_context_.GetCurrentTaskID(), @@ -360,8 +360,8 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, "Number of objects to wait for must be between 1 and the number of ids."); } - std::unordered_set plasma_object_ids; - std::unordered_set memory_object_ids; + absl::flat_hash_set plasma_object_ids; + absl::flat_hash_set memory_object_ids; GroupObjectIdsByStoreProvider(ids, &plasma_object_ids, &memory_object_ids); if (plasma_object_ids.size() + memory_object_ids.size() != ids.size()) { @@ -377,7 +377,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, // a timeout of 0, but that does not address the situation where objects // become available on the second store provider while waiting on the first. - std::unordered_set ready; + absl::flat_hash_set ready; // Wait from both store providers with timeout set to 0. This is to avoid the case // where we might use up the entire timeout on trying to get objects from one store // provider before even trying another (which might have all of the objects available). @@ -421,8 +421,8 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, Status CoreWorker::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - std::unordered_set plasma_object_ids; - std::unordered_set memory_object_ids; + absl::flat_hash_set plasma_object_ids; + absl::flat_hash_set memory_object_ids; GroupObjectIdsByStoreProvider(object_ids, &plasma_object_ids, &memory_object_ids); RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index bf78f8e38..c7fd6a959 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -2,6 +2,7 @@ #define RAY_CORE_WORKER_CORE_WORKER_H #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/synchronization/mutex.h" @@ -421,7 +422,7 @@ class CoreWorker { std::unique_ptr direct_actor_submitter_; /// Map from actor ID to a handle to that actor. - std::unordered_map> actor_handles_; + absl::flat_hash_map> actor_handles_; /* Fields related to task execution. */ diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 0b645be75..ca21b81ec 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -9,10 +9,10 @@ namespace ray { /// A class that represents a `Get` request. class GetRequest { public: - GetRequest(std::unordered_set object_ids, size_t num_objects, + GetRequest(absl::flat_hash_set object_ids, size_t num_objects, bool remove_after_get); - const std::unordered_set &ObjectIds() const; + const absl::flat_hash_set &ObjectIds() const; /// Wait until all requested objects are available, or timeout happens. /// @@ -31,9 +31,9 @@ class GetRequest { void Wait(); /// The object IDs involved in this request. - const std::unordered_set object_ids_; + const absl::flat_hash_set object_ids_; /// The object information for the objects in this request. - std::unordered_map> objects_; + absl::flat_hash_map> objects_; /// Number of objects required. const size_t num_objects_; @@ -46,7 +46,7 @@ class GetRequest { std::condition_variable cv_; }; -GetRequest::GetRequest(std::unordered_set object_ids, size_t num_objects, +GetRequest::GetRequest(absl::flat_hash_set object_ids, size_t num_objects, bool remove_after_get) : object_ids_(std::move(object_ids)), num_objects_(num_objects), @@ -55,7 +55,7 @@ GetRequest::GetRequest(std::unordered_set object_ids, size_t num_objec RAY_CHECK(num_objects_ <= object_ids_.size()); } -const std::unordered_set &GetRequest::ObjectIds() const { return object_ids_; } +const absl::flat_hash_set &GetRequest::ObjectIds() const { return object_ids_; } bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; } @@ -144,8 +144,8 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, std::shared_ptr get_request; { - std::unordered_set remaining_ids; - std::unordered_set ids_to_remove; + absl::flat_hash_set remaining_ids; + absl::flat_hash_set ids_to_remove; std::unique_lock lock(lock_); // Check for existing objects and see if this get request can be fullfilled. diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index 3c4905db0..6cad239b9 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -1,6 +1,8 @@ #ifndef RAY_CORE_WORKER_MEMORY_STORE_H #define RAY_CORE_WORKER_MEMORY_STORE_H +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" @@ -45,10 +47,10 @@ class CoreWorkerMemoryStore { private: /// Map from object ID to `RayObject`. - std::unordered_map> objects_; + absl::flat_hash_map> objects_; /// Map from object ID to its get requests. - std::unordered_map>> + absl::flat_hash_map>> object_get_requests_; /// Protect the two maps above. diff --git a/src/ray/core_worker/store_provider/memory_store_provider.cc b/src/ray/core_worker/store_provider/memory_store_provider.cc index 17290a28a..549c7e317 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.cc +++ b/src/ray/core_worker/store_provider/memory_store_provider.cc @@ -23,9 +23,9 @@ Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object, } Status CoreWorkerMemoryStoreProvider::Get( - const std::unordered_set &object_ids, int64_t timeout_ms, + const absl::flat_hash_set &object_ids, int64_t timeout_ms, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception) { const std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; @@ -43,10 +43,9 @@ Status CoreWorkerMemoryStoreProvider::Get( return Status::OK(); } -Status CoreWorkerMemoryStoreProvider::Wait(const std::unordered_set &object_ids, - int num_objects, int64_t timeout_ms, - const TaskID &task_id, - std::unordered_set *ready) { +Status CoreWorkerMemoryStoreProvider::Wait( + const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, + const TaskID &task_id, absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); std::vector> result_objects; RAY_CHECK(object_ids.size() == id_vector.size()); @@ -63,7 +62,7 @@ Status CoreWorkerMemoryStoreProvider::Wait(const std::unordered_set &o } Status CoreWorkerMemoryStoreProvider::Delete( - const std::unordered_set &object_ids) { + const absl::flat_hash_set &object_ids) { std::vector object_id_vector(object_ids.begin(), object_ids.end()); store_->Delete(object_id_vector); return Status::OK(); diff --git a/src/ray/core_worker/store_provider/memory_store_provider.h b/src/ray/core_worker/store_provider/memory_store_provider.h index 68050472d..32ee88509 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.h +++ b/src/ray/core_worker/store_provider/memory_store_provider.h @@ -1,6 +1,8 @@ #ifndef RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H #define RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/status.h" @@ -21,18 +23,18 @@ class CoreWorkerMemoryStoreProvider { Status Put(const RayObject &object, const ObjectID &object_id); - Status Get(const std::unordered_set &object_ids, int64_t timeout_ms, + Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception); /// Note that `num_objects` must equal to number of items in `object_ids`. - Status Wait(const std::unordered_set &object_ids, int num_objects, + Status Wait(const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, const TaskID &task_id, - std::unordered_set *ready); + absl::flat_hash_set *ready); /// Note that `local_only` must be true, and `delete_creating_tasks` must be false here. - Status Delete(const std::unordered_set &object_ids); + Status Delete(const absl::flat_hash_set &object_ids); private: /// Implementation. diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index 0a3d8c3bc..2c6c9f96f 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -81,9 +81,9 @@ Status CoreWorkerPlasmaStoreProvider::Seal(const ObjectID &object_id) { } Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( - std::unordered_set &remaining, const std::vector &batch_ids, + absl::flat_hash_set &remaining, const std::vector &batch_ids, int64_t timeout_ms, bool fetch_only, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception) { RAY_RETURN_NOT_OK(raylet_client_->FetchOrReconstruct(batch_ids, fetch_only, task_id)); @@ -125,13 +125,13 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( } Status CoreWorkerPlasmaStoreProvider::Get( - const std::unordered_set &object_ids, int64_t timeout_ms, + const absl::flat_hash_set &object_ids, int64_t timeout_ms, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception) { int64_t batch_size = RayConfig::instance().worker_fetch_request_size(); std::vector batch_ids; - std::unordered_set remaining(object_ids.begin(), object_ids.end()); + absl::flat_hash_set remaining(object_ids.begin(), object_ids.end()); // First, attempt to fetch all of the required objects once without reconstructing. std::vector id_vector(object_ids.begin(), object_ids.end()); @@ -206,10 +206,9 @@ Status CoreWorkerPlasmaStoreProvider::Contains(const ObjectID &object_id, return Status::OK(); } -Status CoreWorkerPlasmaStoreProvider::Wait(const std::unordered_set &object_ids, - int num_objects, int64_t timeout_ms, - const TaskID &task_id, - std::unordered_set *ready) { +Status CoreWorkerPlasmaStoreProvider::Wait( + const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, + const TaskID &task_id, absl::flat_hash_set *ready) { std::vector id_vector(object_ids.begin(), object_ids.end()); bool should_break = false; @@ -240,7 +239,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(const std::unordered_set &o } Status CoreWorkerPlasmaStoreProvider::Delete( - const std::unordered_set &object_ids, bool local_only, + const absl::flat_hash_set &object_ids, bool local_only, bool delete_creating_tasks) { std::vector object_id_vector(object_ids.begin(), object_ids.end()); return raylet_client_->FreeObjects(object_id_vector, local_only, delete_creating_tasks); @@ -252,7 +251,7 @@ std::string CoreWorkerPlasmaStoreProvider::MemoryUsageString() { } void CoreWorkerPlasmaStoreProvider::WarnIfAttemptedTooManyTimes( - int num_attempts, const std::unordered_set &remaining) { + int num_attempts, const absl::flat_hash_set &remaining) { if (num_attempts % RayConfig::instance().object_store_get_warn_per_num_attempts() == 0) { std::ostringstream oss; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index d6c2de84d..a65f803a1 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -1,6 +1,8 @@ #ifndef RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H #define RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "plasma/client.h" #include "ray/common/buffer.h" #include "ray/common/id.h" @@ -33,18 +35,18 @@ class CoreWorkerPlasmaStoreProvider { Status Seal(const ObjectID &object_id); - Status Get(const std::unordered_set &object_ids, int64_t timeout_ms, + Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception); Status Contains(const ObjectID &object_id, bool *has_object); - Status Wait(const std::unordered_set &object_ids, int num_objects, + Status Wait(const absl::flat_hash_set &object_ids, int num_objects, int64_t timeout_ms, const TaskID &task_id, - std::unordered_set *ready); + absl::flat_hash_set *ready); - Status Delete(const std::unordered_set &object_ids, bool local_only, + Status Delete(const absl::flat_hash_set &object_ids, bool local_only, bool delete_creating_tasks); std::string MemoryUsageString(); @@ -67,9 +69,9 @@ class CoreWorkerPlasmaStoreProvider { /// exception. /// \return Status. Status FetchAndGetFromPlasmaStore( - std::unordered_set &remaining, const std::vector &batch_ids, + absl::flat_hash_set &remaining, const std::vector &batch_ids, int64_t timeout_ms, bool fetch_only, const TaskID &task_id, - std::unordered_map> *results, + absl::flat_hash_map> *results, bool *got_exception); /// Print a warning if we've attempted too many times, but some objects are still @@ -78,7 +80,7 @@ class CoreWorkerPlasmaStoreProvider { /// \param[in] num_attemps The number of attempted times. /// \param[in] remaining The remaining objects. static void WarnIfAttemptedTooManyTimes(int num_attempts, - const std::unordered_set &remaining); + const absl::flat_hash_set &remaining); const std::unique_ptr &raylet_client_; plasma::PlasmaClient store_client_; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 2c1a5edb0..246c933c8 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -2,6 +2,8 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "ray/common/buffer.h" #include "ray/common/ray_object.h" #include "ray/core_worker/context.h" @@ -644,8 +646,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { RAY_CHECK_OK(provider.Put(buffers[i], ids[i])); } - std::unordered_set wait_ids(ids.begin(), ids.end()); - std::unordered_set wait_results; + absl::flat_hash_set wait_ids(ids.begin(), ids.end()); + absl::flat_hash_set wait_results; ObjectID nonexistent_id = ObjectID::FromRandom(); wait_ids.insert(nonexistent_id); @@ -662,8 +664,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { // Test Get(). bool got_exception = false; - std::unordered_map> results; - std::unordered_set ids_set(ids.begin(), ids.end()); + absl::flat_hash_map> results; + absl::flat_hash_set ids_set(ids.begin(), ids.end()); RAY_CHECK_OK(provider.Get(ids_set, -1, RandomTaskId(), &results, &got_exception)); ASSERT_TRUE(!got_exception);