mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:48:31 +08:00
[direct call] In memory store (#5303)
This commit is contained in:
committed by
Philipp Moritz
parent
25b5bd1530
commit
134c6bd128
@@ -387,6 +387,7 @@ cc_library(
|
||||
[
|
||||
"src/ray/core_worker/*.cc",
|
||||
"src/ray/core_worker/store_provider/*.cc",
|
||||
"src/ray/core_worker/store_provider/memory_store/*.cc",
|
||||
"src/ray/core_worker/transport/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
@@ -397,6 +398,7 @@ cc_library(
|
||||
hdrs = glob([
|
||||
"src/ray/core_worker/*.h",
|
||||
"src/ray/core_worker/store_provider/*.h",
|
||||
"src/ray/core_worker/store_provider/memory_store/*.h",
|
||||
"src/ray/core_worker/transport/*.h",
|
||||
]),
|
||||
copts = COPTS,
|
||||
|
||||
+26
-3
@@ -20,6 +20,9 @@ class Buffer {
|
||||
/// Size of this buffer.
|
||||
virtual size_t Size() const = 0;
|
||||
|
||||
/// Whether this buffer owns the data.
|
||||
virtual bool OwnsData() const = 0;
|
||||
|
||||
virtual ~Buffer(){};
|
||||
|
||||
bool operator==(const Buffer &rhs) const {
|
||||
@@ -34,12 +37,21 @@ class Buffer {
|
||||
/// Represents a byte buffer in local memory.
|
||||
class LocalMemoryBuffer : public Buffer {
|
||||
public:
|
||||
LocalMemoryBuffer(uint8_t *data, size_t size, bool should_copy = false)
|
||||
: data_(data), size_(size) {
|
||||
if (should_copy) {
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param data The data pointer to the passed-in buffer.
|
||||
/// \param size The size of the passed in buffer.
|
||||
/// \param copy_data If true, data will be copied and owned by this buffer,
|
||||
/// otherwise the buffer only points to the given address.
|
||||
LocalMemoryBuffer(uint8_t *data, size_t size, bool copy_data = false)
|
||||
: has_data_copy_(copy_data) {
|
||||
if (copy_data) {
|
||||
buffer_.insert(buffer_.end(), data, data + size);
|
||||
data_ = buffer_.data();
|
||||
size_ = buffer_.size();
|
||||
} else {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,13 +59,22 @@ class LocalMemoryBuffer : public Buffer {
|
||||
|
||||
size_t Size() const override { return size_; }
|
||||
|
||||
bool OwnsData() const override { return has_data_copy_; }
|
||||
|
||||
~LocalMemoryBuffer() {}
|
||||
|
||||
private:
|
||||
/// Disable copy constructor and assignment, as default copy will
|
||||
/// cause invalid data_.
|
||||
LocalMemoryBuffer &operator=(const LocalMemoryBuffer &) = delete;
|
||||
LocalMemoryBuffer(const LocalMemoryBuffer &) = delete;
|
||||
|
||||
/// Pointer to the data.
|
||||
uint8_t *data_;
|
||||
/// Size of the buffer.
|
||||
size_t size_;
|
||||
/// Whether this buffer holds a copy of data.
|
||||
bool has_data_copy_;
|
||||
/// This is only valid when `should_copy` is true.
|
||||
std::vector<uint8_t> buffer_;
|
||||
};
|
||||
@@ -68,6 +89,8 @@ class PlasmaBuffer : public Buffer {
|
||||
|
||||
size_t Size() const override { return buffer_->size(); }
|
||||
|
||||
bool OwnsData() const override { return true; }
|
||||
|
||||
private:
|
||||
/// shared_ptr to arrow buffer which can potentially hold a reference
|
||||
/// for the object (when it's a plasma::PlasmaBuffer).
|
||||
|
||||
@@ -74,7 +74,7 @@ struct TaskInfo {
|
||||
const TaskType task_type;
|
||||
};
|
||||
|
||||
enum class StoreProviderType { LOCAL_PLASMA, PLASMA };
|
||||
enum class StoreProviderType { LOCAL_PLASMA, PLASMA, MEMORY };
|
||||
|
||||
enum class TaskTransportType { RAYLET, DIRECT_ACTOR };
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ std::unique_ptr<CoreWorkerStoreProvider> CoreWorkerObjectInterface::CreateStoreP
|
||||
RAY_LOG(FATAL) << "unknown store provider type " << static_cast<int>(type);
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -76,7 +76,7 @@ Status CoreWorkerLocalPlasmaStoreProvider::Wait(const std::vector<ObjectID> &obj
|
||||
|
||||
(*results).resize(object_ids.size());
|
||||
for (size_t i = 0; i < object_ids.size(); i++) {
|
||||
(*results)[i] = objects[i]->GetData() != nullptr;
|
||||
(*results)[i] = (objects[i] != nullptr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
#include <condition_variable>
|
||||
#include "ray/common/ray_config.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// A class that represents a `Get` request.
|
||||
class GetRequest {
|
||||
public:
|
||||
GetRequest(std::unordered_set<ObjectID> object_ids, bool remove_after_get);
|
||||
|
||||
const std::unordered_set<ObjectID> &ObjectIds() const;
|
||||
|
||||
/// Wait until all requested objects are available, or timeout happens.
|
||||
///
|
||||
/// \param timeout_ms The maximum time in milliseconds to wait for.
|
||||
/// \return Whether all requested objects are available.
|
||||
bool Wait(int64_t timeout_ms);
|
||||
/// Set the object content for the specific object id.
|
||||
void Set(const ObjectID &object_id, std::shared_ptr<RayObject> buffer);
|
||||
/// Get the object content for the specific object id.
|
||||
std::shared_ptr<RayObject> Get(const ObjectID &object_id) const;
|
||||
/// Whether this is a `get` request.
|
||||
bool ShouldRemoveObjects() const;
|
||||
|
||||
private:
|
||||
/// Wait until all requested objects are available.
|
||||
void Wait();
|
||||
|
||||
/// The object IDs involved in this request.
|
||||
std::unordered_set<ObjectID> object_ids_;
|
||||
/// The object information for the objects in this request.
|
||||
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> objects_;
|
||||
|
||||
// Whether the requested objects should be removed from store
|
||||
// after `get` returns.
|
||||
const bool remove_after_get_;
|
||||
// Whether all the requested objects are available.
|
||||
bool is_ready_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
};
|
||||
|
||||
GetRequest::GetRequest(std::unordered_set<ObjectID> object_ids, bool remove_after_get)
|
||||
: object_ids_(std::move(object_ids)), remove_after_get_(remove_after_get) {}
|
||||
|
||||
const std::unordered_set<ObjectID> &GetRequest::ObjectIds() const { return object_ids_; }
|
||||
|
||||
bool GetRequest::ShouldRemoveObjects() const { return remove_after_get_; }
|
||||
|
||||
bool GetRequest::Wait(int64_t timeout_ms) {
|
||||
RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1);
|
||||
if (timeout_ms == -1) {
|
||||
// Wait forever until all objects are ready.
|
||||
Wait();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Wait until all objects are ready, or the timeout expires.
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (!is_ready_) {
|
||||
auto status = cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms));
|
||||
if (status == std::cv_status::timeout) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetRequest::Wait() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (!is_ready_) {
|
||||
cv_.wait(lock);
|
||||
}
|
||||
}
|
||||
|
||||
void GetRequest::Set(const ObjectID &object_id, std::shared_ptr<RayObject> object) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
objects_.emplace(object_id, object);
|
||||
if (objects_.size() == object_ids_.size()) {
|
||||
is_ready_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CoreWorkerMemoryStore::CoreWorkerMemoryStore() {}
|
||||
|
||||
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
return Status::KeyError("object already exists");
|
||||
}
|
||||
|
||||
auto object_entry =
|
||||
std::make_shared<RayObject>(object.GetData(), object.GetMetadata(), true);
|
||||
|
||||
bool should_add_entry = true;
|
||||
auto object_request_iter = object_get_requests_.find(object_id);
|
||||
if (object_request_iter != object_get_requests_.end()) {
|
||||
auto &get_requests = object_request_iter->second;
|
||||
for (auto &get_request : get_requests) {
|
||||
get_request->Set(object_id, object_entry);
|
||||
if (get_request->ShouldRemoveObjects()) {
|
||||
should_add_entry = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (should_add_entry) {
|
||||
// If there is no existing get request, then add the `RayObject` to map.
|
||||
objects_.emplace(object_id, object_entry);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
int64_t timeout_ms, bool remove_after_get,
|
||||
std::vector<std::shared_ptr<RayObject>> *results) {
|
||||
(*results).resize(object_ids.size(), nullptr);
|
||||
|
||||
std::shared_ptr<GetRequest> get_request;
|
||||
|
||||
{
|
||||
std::unordered_set<ObjectID> remaining_ids;
|
||||
std::unordered_set<ObjectID> ids_to_remove;
|
||||
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
// Check for existing objects and see if this get request can be fullfilled.
|
||||
for (int i = 0; i < object_ids.size(); i++) {
|
||||
const auto &object_id = object_ids[i];
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
(*results)[i] = iter->second;
|
||||
if (remove_after_get) {
|
||||
// Note that we cannot remove the object_id from `objects_` now,
|
||||
// because `object_ids` might have duplicate ids.
|
||||
ids_to_remove.insert(object_id);
|
||||
}
|
||||
} else {
|
||||
remaining_ids.insert(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &object_id : ids_to_remove) {
|
||||
objects_.erase(object_id);
|
||||
}
|
||||
|
||||
// Return if all the objects are obtained.
|
||||
if (remaining_ids.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Otherwise, create a GetRequest to track remaining objects.
|
||||
get_request =
|
||||
std::make_shared<GetRequest>(std::move(remaining_ids), remove_after_get);
|
||||
for (const auto &object_id : get_request->ObjectIds()) {
|
||||
object_get_requests_[object_id].push_back(get_request);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for remaining objects (or timeout).
|
||||
get_request->Wait(timeout_ms);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
// Populate results.
|
||||
for (int i = 0; i < object_ids.size(); i++) {
|
||||
const auto &object_id = object_ids[i];
|
||||
if ((*results)[i] == nullptr) {
|
||||
(*results)[i] = get_request->Get(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove get request.
|
||||
for (const auto &object_id : get_request->ObjectIds()) {
|
||||
auto object_request_iter = object_get_requests_.find(object_id);
|
||||
if (object_request_iter != object_get_requests_.end()) {
|
||||
auto &get_requests = object_request_iter->second;
|
||||
// Erase get_request from the vector.
|
||||
auto it = std::find(get_requests.begin(), get_requests.end(), get_request);
|
||||
if (it != get_requests.end()) {
|
||||
get_requests.erase(it);
|
||||
// If the vector is empty, remove the object ID from the map.
|
||||
if (get_requests.empty()) {
|
||||
object_get_requests_.erase(object_request_iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
for (const auto &object_id : object_ids) {
|
||||
objects_.erase(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef RAY_CORE_WORKER_MEMORY_STORE_H
|
||||
#define RAY_CORE_WORKER_MEMORY_STORE_H
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class GetRequest;
|
||||
class CoreWorkerMemoryStore;
|
||||
|
||||
/// The class provides implementations for local process memory store.
|
||||
/// An example usage for this is to retrieve the returned objects from direct
|
||||
/// actor call (see direct_actor_transport.cc).
|
||||
class CoreWorkerMemoryStore {
|
||||
public:
|
||||
CoreWorkerMemoryStore();
|
||||
~CoreWorkerMemoryStore(){};
|
||||
|
||||
/// Put an object with specified ID into object store.
|
||||
///
|
||||
/// \param[in] object_id Object ID specified by user.
|
||||
/// \param[in] object The ray object.
|
||||
/// \return Status.
|
||||
Status Put(const ObjectID &object_id, const RayObject &object);
|
||||
|
||||
/// Get a list of objects from the object store.
|
||||
///
|
||||
/// \param[in] object_ids IDs of the objects to get. Duplicates are allowed.
|
||||
/// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative.
|
||||
/// \param[in] remove_after_get When to remove the objects from store after `Get`
|
||||
/// finishes.
|
||||
/// \param[out] results Result list of objects data.
|
||||
/// \return Status.
|
||||
Status Get(const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
bool remove_after_get, std::vector<std::shared_ptr<RayObject>> *results);
|
||||
|
||||
/// Delete a list of objects from the object store.
|
||||
///
|
||||
/// \param[in] object_ids IDs of the objects to delete.
|
||||
/// \return Void.
|
||||
void Delete(const std::vector<ObjectID> &object_ids);
|
||||
|
||||
private:
|
||||
/// Map from object ID to `RayObject`.
|
||||
std::unordered_map<ObjectID, std::shared_ptr<RayObject>> objects_;
|
||||
|
||||
/// Map from object ID to its get requests.
|
||||
std::unordered_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
|
||||
object_get_requests_;
|
||||
|
||||
/// Protect the two maps above.
|
||||
std::mutex lock_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_MEMORY_STORE_H
|
||||
@@ -0,0 +1,59 @@
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include <condition_variable>
|
||||
#include "ray/common/ray_config.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
//
|
||||
// CoreWorkerMemoryStoreProvider functions
|
||||
//
|
||||
CoreWorkerMemoryStoreProvider::CoreWorkerMemoryStoreProvider(
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store)
|
||||
: store_(store) {
|
||||
RAY_CHECK(store != nullptr);
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Put(const RayObject &object,
|
||||
const ObjectID &object_id) {
|
||||
return store_->Put(object_id, object);
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Get(
|
||||
const std::vector<ObjectID> &object_ids, int64_t timeout_ms, const TaskID &task_id,
|
||||
std::vector<std::shared_ptr<RayObject>> *results) {
|
||||
return store_->Get(object_ids, timeout_ms, true, results);
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Wait(const std::vector<ObjectID> &object_ids,
|
||||
int num_objects, int64_t timeout_ms,
|
||||
const TaskID &task_id,
|
||||
std::vector<bool> *results) {
|
||||
if (num_objects != object_ids.size()) {
|
||||
return Status::Invalid("num_objects should equal to number of items in object_ids");
|
||||
}
|
||||
|
||||
(*results).resize(object_ids.size(), false);
|
||||
|
||||
std::vector<std::shared_ptr<RayObject>> result_objects;
|
||||
auto status = store_->Get(object_ids, timeout_ms, false, &result_objects);
|
||||
if (status.ok()) {
|
||||
RAY_CHECK(result_objects.size() == object_ids.size());
|
||||
for (int i = 0; i < object_ids.size(); i++) {
|
||||
(*results)[i] = (result_objects[i] != nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStoreProvider::Delete(const std::vector<ObjectID> &object_ids,
|
||||
bool local_only,
|
||||
bool delete_creating_tasks) {
|
||||
store_->Delete(object_ids);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,47 @@
|
||||
#ifndef RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H
|
||||
#define RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
/// The class provides implementations for accessing local process memory store.
|
||||
/// An example usage for this is to retrieve the returned objects from direct
|
||||
/// actor call (see direct_actor_transport.cc).
|
||||
class CoreWorkerMemoryStoreProvider : public CoreWorkerStoreProvider {
|
||||
public:
|
||||
CoreWorkerMemoryStoreProvider(std::shared_ptr<CoreWorkerMemoryStore> store);
|
||||
|
||||
/// See `CoreWorkerStoreProvider::Put` for semantics.
|
||||
Status Put(const RayObject &object, const ObjectID &object_id) override;
|
||||
|
||||
/// See `CoreWorkerStoreProvider::Get` for semantics.
|
||||
Status Get(const std::vector<ObjectID> &ids, int64_t timeout_ms, const TaskID &task_id,
|
||||
std::vector<std::shared_ptr<RayObject>> *results) override;
|
||||
|
||||
/// See `CoreWorkerStoreProvider::Wait` for semantics.
|
||||
/// Note that `num_objects` must equal to number of items in `object_ids`.
|
||||
Status Wait(const std::vector<ObjectID> &object_ids, int num_objects,
|
||||
int64_t timeout_ms, const TaskID &task_id,
|
||||
std::vector<bool> *results) override;
|
||||
|
||||
/// See `CoreWorkerStoreProvider::Delete` for semantics.
|
||||
/// Note that `local_only` must be true, and `delete_creating_tasks` must be false here.
|
||||
Status Delete(const std::vector<ObjectID> &object_ids, bool local_only = true,
|
||||
bool delete_creating_tasks = false) override;
|
||||
|
||||
private:
|
||||
/// Implementation.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> store_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_MEMORY_STORE_PROVIDER_H
|
||||
@@ -15,8 +15,23 @@ class RayObject {
|
||||
///
|
||||
/// \param[in] data Data of the ray object.
|
||||
/// \param[in] metadata Metadata of the ray object.
|
||||
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata)
|
||||
: data_(data), metadata_(metadata) {}
|
||||
/// \param[in] copy_data Whether this class should hold a copy of data.
|
||||
RayObject(const std::shared_ptr<Buffer> &data, const std::shared_ptr<Buffer> &metadata,
|
||||
bool copy_data = false)
|
||||
: data_(data), metadata_(metadata), has_data_copy_(copy_data) {
|
||||
if (has_data_copy_) {
|
||||
// If this object is required to hold a copy of the data,
|
||||
// make a copy if the passed in buffers don't already have a copy.
|
||||
if (data_ && !data_->OwnsData()) {
|
||||
data_ = std::make_shared<LocalMemoryBuffer>(data_->Data(), data_->Size(), true);
|
||||
}
|
||||
|
||||
if (metadata_ && !metadata_->OwnsData()) {
|
||||
metadata_ = std::make_shared<LocalMemoryBuffer>(metadata_->Data(),
|
||||
metadata_->Size(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the data of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetData() const { return data_; };
|
||||
@@ -24,14 +39,23 @@ class RayObject {
|
||||
/// Return the metadata of the ray object.
|
||||
const std::shared_ptr<Buffer> &GetMetadata() const { return metadata_; };
|
||||
|
||||
uint64_t GetSize() const {
|
||||
uint64_t size = 0;
|
||||
size += (data_ != nullptr) ? data_->Size() : 0;
|
||||
size += (metadata_ != nullptr) ? metadata_->Size() : 0;
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Whether this object has metadata.
|
||||
bool HasMetadata() const { return metadata_ != nullptr && metadata_->Size() > 0; }
|
||||
|
||||
private:
|
||||
/// Data of the ray object.
|
||||
const std::shared_ptr<Buffer> data_;
|
||||
std::shared_ptr<Buffer> data_;
|
||||
/// Metadata of the ray object.
|
||||
const std::shared_ptr<Buffer> metadata_;
|
||||
std::shared_ptr<Buffer> metadata_;
|
||||
/// Whether this class holds a data copy.
|
||||
bool has_data_copy_;
|
||||
};
|
||||
|
||||
/// Provider interface for store access. Store provider should inherit from this class and
|
||||
|
||||
@@ -6,11 +6,14 @@
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
#include "ray/core_worker/store_provider/local_plasma_provider.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
@@ -158,6 +161,9 @@ class CoreWorkerTest : public ::testing::Test {
|
||||
|
||||
void TearDown() {}
|
||||
|
||||
// Test tore provider.
|
||||
void TestStoreProvider(StoreProviderType type);
|
||||
|
||||
// Test normal tasks.
|
||||
void TestNormalTask(const std::unordered_map<std::string, double> &resources);
|
||||
|
||||
@@ -446,6 +452,87 @@ void CoreWorkerTest::TestActorFailure(
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerTest::TestStoreProvider(StoreProviderType type) {
|
||||
std::unique_ptr<CoreWorkerStoreProvider> provider_ptr;
|
||||
std::shared_ptr<CoreWorkerMemoryStore> memory_store;
|
||||
|
||||
switch (type) {
|
||||
case StoreProviderType::LOCAL_PLASMA:
|
||||
provider_ptr = std::unique_ptr<CoreWorkerStoreProvider>(
|
||||
new CoreWorkerLocalPlasmaStoreProvider(raylet_store_socket_names_[0]));
|
||||
break;
|
||||
case StoreProviderType::MEMORY:
|
||||
memory_store = std::make_shared<CoreWorkerMemoryStore>();
|
||||
provider_ptr = std::unique_ptr<CoreWorkerStoreProvider>(
|
||||
new CoreWorkerMemoryStoreProvider(memory_store));
|
||||
break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "unspported store provider type " << static_cast<int>(type);
|
||||
break;
|
||||
}
|
||||
|
||||
auto &provider = *provider_ptr;
|
||||
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
|
||||
std::vector<RayObject> buffers;
|
||||
buffers.emplace_back(std::make_shared<LocalMemoryBuffer>(array1, sizeof(array1)),
|
||||
std::make_shared<LocalMemoryBuffer>(array1, sizeof(array1) / 2));
|
||||
buffers.emplace_back(std::make_shared<LocalMemoryBuffer>(array2, sizeof(array2)),
|
||||
std::make_shared<LocalMemoryBuffer>(array2, sizeof(array2) / 2));
|
||||
|
||||
std::vector<ObjectID> ids(buffers.size());
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
ids[i] = ObjectID::FromRandom();
|
||||
RAY_CHECK_OK(provider.Put(buffers[i], ids[i]));
|
||||
}
|
||||
|
||||
// Test Wait().
|
||||
std::vector<ObjectID> ids_with_duplicate;
|
||||
ids_with_duplicate.insert(ids_with_duplicate.end(), ids.begin(), ids.end());
|
||||
// add the same ids again to test `Get` with duplicate object ids.
|
||||
ids_with_duplicate.insert(ids_with_duplicate.end(), ids.begin(), ids.end());
|
||||
|
||||
std::vector<ObjectID> wait_ids(ids_with_duplicate);
|
||||
ObjectID non_existent_id = ObjectID::FromRandom();
|
||||
wait_ids.push_back(non_existent_id);
|
||||
|
||||
std::vector<bool> wait_results;
|
||||
RAY_CHECK_OK(provider.Wait(wait_ids, 5, 100, TaskID::FromRandom(), &wait_results));
|
||||
ASSERT_EQ(wait_results.size(), 5);
|
||||
ASSERT_EQ(wait_results, std::vector<bool>({true, true, true, true, false}));
|
||||
|
||||
// Test Get().
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
RAY_CHECK_OK(provider.Get(ids_with_duplicate, -1, TaskID::FromRandom(), &results));
|
||||
|
||||
ASSERT_EQ(results.size(), ids_with_duplicate.size());
|
||||
for (size_t i = 0; i < ids_with_duplicate.size(); i++) {
|
||||
const auto &expected = buffers[i % ids.size()];
|
||||
ASSERT_EQ(results[i]->GetData()->Size(), expected.GetData()->Size());
|
||||
ASSERT_EQ(memcmp(results[i]->GetData()->Data(), expected.GetData()->Data(),
|
||||
expected.GetData()->Size()),
|
||||
0);
|
||||
ASSERT_EQ(results[i]->GetMetadata()->Size(), expected.GetMetadata()->Size());
|
||||
ASSERT_EQ(memcmp(results[i]->GetMetadata()->Data(), expected.GetMetadata()->Data(),
|
||||
expected.GetMetadata()->Size()),
|
||||
0);
|
||||
}
|
||||
|
||||
// Test Delete().
|
||||
// clear the reference held.
|
||||
results.clear();
|
||||
|
||||
RAY_CHECK_OK(provider.Delete(ids, true, false));
|
||||
|
||||
usleep(200 * 1000);
|
||||
RAY_CHECK_OK(provider.Get(ids, 0, TaskID::FromRandom(), &results));
|
||||
ASSERT_EQ(results.size(), 2);
|
||||
ASSERT_TRUE(!results[0]);
|
||||
ASSERT_TRUE(!results[1]);
|
||||
}
|
||||
|
||||
class ZeroNodeTest : public CoreWorkerTest {
|
||||
public:
|
||||
ZeroNodeTest() : CoreWorkerTest(0) {}
|
||||
@@ -637,6 +724,10 @@ TEST_F(ZeroNodeTest, TestActorHandle) {
|
||||
ASSERT_EQ(handle1.NumForks(), handle2.NumForks());
|
||||
}
|
||||
|
||||
TEST_F(SingleNodeTest, TestMemoryStoreProvider) {
|
||||
TestStoreProvider(StoreProviderType::MEMORY);
|
||||
}
|
||||
|
||||
TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
@@ -705,14 +796,13 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
|
||||
std::vector<LocalMemoryBuffer> buffers;
|
||||
buffers.emplace_back(array1, sizeof(array1));
|
||||
buffers.emplace_back(array2, sizeof(array2));
|
||||
std::vector<std::shared_ptr<LocalMemoryBuffer>> buffers;
|
||||
buffers.emplace_back(std::make_shared<LocalMemoryBuffer>(array1, sizeof(array1)));
|
||||
buffers.emplace_back(std::make_shared<LocalMemoryBuffer>(array2, sizeof(array2)));
|
||||
|
||||
std::vector<ObjectID> ids(buffers.size());
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
RAY_CHECK_OK(worker1.Objects().Put(
|
||||
RayObject(std::make_shared<LocalMemoryBuffer>(buffers[i]), nullptr), &ids[i]));
|
||||
RAY_CHECK_OK(worker1.Objects().Put(RayObject(buffers[i], nullptr), &ids[i]));
|
||||
}
|
||||
|
||||
// Test Get() from remote node.
|
||||
@@ -721,8 +811,8 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
|
||||
|
||||
ASSERT_EQ(results.size(), 2);
|
||||
for (size_t i = 0; i < ids.size(); i++) {
|
||||
ASSERT_EQ(results[i]->GetData()->Size(), buffers[i].Size());
|
||||
ASSERT_EQ(*(results[i]->GetData()), buffers[i]);
|
||||
ASSERT_EQ(results[i]->GetData()->Size(), buffers[i]->Size());
|
||||
ASSERT_EQ(*(results[i]->GetData()), *buffers[i]);
|
||||
}
|
||||
|
||||
// Test Wait() from remote node.
|
||||
|
||||
Reference in New Issue
Block a user