mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:48:31 +08:00
[Core Worker] implement ObjectInterface and add test framework (#4899)
This commit is contained in:
@@ -148,6 +148,9 @@ install:
|
||||
- ./ci/suppress_output bazel build //:stats_test -c opt
|
||||
- ./bazel-bin/stats_test
|
||||
|
||||
# core worker test.
|
||||
- ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh
|
||||
|
||||
# Raylet tests.
|
||||
- ./ci/suppress_output bash src/ray/test/run_object_manager_tests.sh
|
||||
- ./ci/suppress_output bazel test --build_tests_only --test_lang_filters=cc //:all
|
||||
|
||||
+6
-1
@@ -77,6 +77,7 @@ cc_library(
|
||||
"src/ray/raylet/mock_gcs_client.cc",
|
||||
"src/ray/raylet/monitor_main.cc",
|
||||
"src/ray/raylet/*_test.cc",
|
||||
"src/ray/raylet/main.cc",
|
||||
],
|
||||
),
|
||||
hdrs = glob([
|
||||
@@ -122,15 +123,18 @@ cc_library(
|
||||
deps = [
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
":raylet_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
# This test is run by src/ray/test/run_core_worker_tests.sh
|
||||
cc_binary(
|
||||
name = "core_worker_test",
|
||||
srcs = ["src/ray/core_worker/core_worker_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":core_worker_lib",
|
||||
":gcs",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
@@ -320,6 +324,7 @@ cc_library(
|
||||
":node_manager_fbs",
|
||||
":ray_util",
|
||||
"@boost//:asio",
|
||||
"@plasma//:plasma_client",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
+21
-1
@@ -3,6 +3,11 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "plasma/client.h"
|
||||
|
||||
namespace arrow {
|
||||
class Buffer;
|
||||
}
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -15,7 +20,7 @@ class Buffer {
|
||||
/// Size of this buffer.
|
||||
virtual size_t Size() const = 0;
|
||||
|
||||
virtual ~Buffer() {}
|
||||
virtual ~Buffer(){};
|
||||
|
||||
bool operator==(const Buffer &rhs) const {
|
||||
return this->Data() == rhs.Data() && this->Size() == rhs.Size();
|
||||
@@ -40,6 +45,21 @@ class LocalMemoryBuffer : public Buffer {
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
/// Represents a byte buffer for plasma object.
|
||||
class PlasmaBuffer : public Buffer {
|
||||
public:
|
||||
PlasmaBuffer(std::shared_ptr<arrow::Buffer> buffer) : buffer_(buffer) {}
|
||||
|
||||
uint8_t *Data() const override { return const_cast<uint8_t *>(buffer_->data()); }
|
||||
|
||||
size_t Size() const override { return buffer_->size(); }
|
||||
|
||||
private:
|
||||
/// shared_ptr to arrow buffer which can potentially hold a reference
|
||||
/// for the object (when it's a plasma::PlasmaBuffer).
|
||||
std::shared_ptr<arrow::Buffer> buffer_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_COMMON_BUFFER_H
|
||||
|
||||
@@ -45,13 +45,13 @@ class TaskArg {
|
||||
bool IsPassedByReference() const { return id_ != nullptr; }
|
||||
|
||||
/// Get the reference object ID.
|
||||
ObjectID &GetReference() {
|
||||
const ObjectID &GetReference() const {
|
||||
RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference.";
|
||||
return *id_;
|
||||
}
|
||||
|
||||
/// Get the value.
|
||||
std::shared_ptr<Buffer> GetValue() {
|
||||
std::shared_ptr<Buffer> GetValue() const {
|
||||
RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value.";
|
||||
return data_;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
|
||||
#include "context.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// per-thread context for core worker.
|
||||
struct WorkerThreadContext {
|
||||
WorkerThreadContext()
|
||||
: current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}
|
||||
|
||||
int GetNextTaskIndex() { return ++task_index; }
|
||||
|
||||
int GetNextPutIndex() { return ++put_index; }
|
||||
|
||||
const TaskID &GetCurrentTaskID() const { return current_task_id; }
|
||||
|
||||
void SetCurrentTask(const TaskID &task_id) {
|
||||
current_task_id = task_id;
|
||||
task_index = 0;
|
||||
put_index = 0;
|
||||
}
|
||||
|
||||
void SetCurrentTask(const raylet::TaskSpecification &spec) {
|
||||
SetCurrentTask(spec.TaskId());
|
||||
}
|
||||
|
||||
private:
|
||||
/// The task ID for current task.
|
||||
TaskID current_task_id;
|
||||
|
||||
/// Number of tasks that have been submitted from current task.
|
||||
int task_index;
|
||||
|
||||
/// Number of objects that have been put from current task.
|
||||
int put_index;
|
||||
};
|
||||
|
||||
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
|
||||
nullptr;
|
||||
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id)
|
||||
: worker_type(worker_type),
|
||||
worker_id(worker_type == WorkerType::DRIVER
|
||||
? ClientID::FromBinary(driver_id.Binary())
|
||||
: ClientID::FromRandom()),
|
||||
current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) {
|
||||
// For worker main thread which initializes the WorkerContext,
|
||||
// set task_id according to whether current worker is a driver.
|
||||
// (For other threads it's set to randmom ID via GetThreadContext).
|
||||
GetThreadContext().SetCurrentTask(
|
||||
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
|
||||
}
|
||||
|
||||
const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
|
||||
|
||||
const ClientID &WorkerContext::GetWorkerID() const { return worker_id; }
|
||||
|
||||
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
|
||||
|
||||
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
|
||||
|
||||
const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; }
|
||||
|
||||
const TaskID &WorkerContext::GetCurrentTaskID() const {
|
||||
return GetThreadContext().GetCurrentTaskID();
|
||||
}
|
||||
|
||||
void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) {
|
||||
current_driver_id = spec.DriverId();
|
||||
GetThreadContext().SetCurrentTask(spec);
|
||||
}
|
||||
|
||||
WorkerThreadContext &WorkerContext::GetThreadContext() {
|
||||
if (thread_context_ == nullptr) {
|
||||
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());
|
||||
}
|
||||
|
||||
return *thread_context_;
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,48 @@
|
||||
#ifndef RAY_CORE_WORKER_CONTEXT_H
|
||||
#define RAY_CORE_WORKER_CONTEXT_H
|
||||
|
||||
#include "common.h"
|
||||
#include "ray/raylet/task_spec.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct WorkerThreadContext;
|
||||
|
||||
class WorkerContext {
|
||||
public:
|
||||
WorkerContext(WorkerType worker_type, const DriverID &driver_id);
|
||||
|
||||
const WorkerType GetWorkerType() const;
|
||||
|
||||
const ClientID &GetWorkerID() const;
|
||||
|
||||
const DriverID &GetCurrentDriverID() const;
|
||||
|
||||
const TaskID &GetCurrentTaskID() const;
|
||||
|
||||
void SetCurrentTask(const raylet::TaskSpecification &spec);
|
||||
|
||||
int GetNextTaskIndex();
|
||||
|
||||
int GetNextPutIndex();
|
||||
|
||||
private:
|
||||
/// Type of the worker.
|
||||
const WorkerType worker_type;
|
||||
|
||||
/// ID for this worker.
|
||||
const ClientID worker_id;
|
||||
|
||||
/// Driver ID for this worker.
|
||||
DriverID current_driver_id;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext();
|
||||
|
||||
/// Per-thread worker context.
|
||||
static thread_local std::unique_ptr<WorkerThreadContext> thread_context_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_CONTEXT_H
|
||||
@@ -0,0 +1,39 @@
|
||||
#include "core_worker.h"
|
||||
#include "context.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
DriverID driver_id)
|
||||
: worker_type_(worker_type),
|
||||
language_(language),
|
||||
worker_context_(worker_type, driver_id),
|
||||
store_socket_(store_socket),
|
||||
raylet_socket_(raylet_socket),
|
||||
task_interface_(*this),
|
||||
object_interface_(*this),
|
||||
task_execution_interface_(*this) {}
|
||||
|
||||
Status CoreWorker::Connect() {
|
||||
// connect to plasma.
|
||||
RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_));
|
||||
|
||||
// connect to raylet.
|
||||
::Language lang = ::Language::PYTHON;
|
||||
if (language_ == ray::Language::JAVA) {
|
||||
lang = ::Language::JAVA;
|
||||
}
|
||||
|
||||
// TODO: currently RayletClient would crash in its constructor if it cannot
|
||||
// connect to Raylet after a number of retries, this needs to be changed
|
||||
// so that the worker (java/python .etc) can retrieve and handle the error
|
||||
// instead of crashing.
|
||||
raylet_client_ = std::unique_ptr<RayletClient>(
|
||||
new RayletClient(raylet_socket_, worker_context_.GetWorkerID(),
|
||||
(worker_type_ == ray::WorkerType::WORKER),
|
||||
worker_context_.GetCurrentDriverID(), lang));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
@@ -2,8 +2,10 @@
|
||||
#define RAY_CORE_WORKER_CORE_WORKER_H
|
||||
|
||||
#include "common.h"
|
||||
#include "context.h"
|
||||
#include "object_interface.h"
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "task_execution.h"
|
||||
#include "task_interface.h"
|
||||
|
||||
@@ -18,15 +20,12 @@ class CoreWorker {
|
||||
///
|
||||
/// \param[in] worker_type Type of this worker.
|
||||
/// \param[in] langauge Language of this worker.
|
||||
CoreWorker(const WorkerType worker_type, const Language language)
|
||||
: worker_type_(worker_type),
|
||||
language_(language),
|
||||
task_interface_(*this),
|
||||
object_interface_(*this),
|
||||
task_execution_interface_(*this) {}
|
||||
CoreWorker(const WorkerType worker_type, const Language language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
DriverID driver_id = DriverID::Nil());
|
||||
|
||||
/// Connect this worker to Raylet.
|
||||
Status Connect() { return Status::OK(); }
|
||||
/// Connect to raylet.
|
||||
Status Connect();
|
||||
|
||||
/// Type of this worker.
|
||||
enum WorkerType WorkerType() const { return worker_type_; }
|
||||
@@ -53,6 +52,21 @@ class CoreWorker {
|
||||
/// Language of this worker.
|
||||
const enum Language language_;
|
||||
|
||||
/// Worker context per thread.
|
||||
WorkerContext worker_context_;
|
||||
|
||||
/// Plasma store socket name.
|
||||
std::string store_socket_;
|
||||
|
||||
/// raylet socket name.
|
||||
std::string raylet_socket_;
|
||||
|
||||
/// Plasma store client.
|
||||
plasma::PlasmaClient store_client_;
|
||||
|
||||
/// Raylet client.
|
||||
std::unique_ptr<RayletClient> raylet_client_;
|
||||
|
||||
/// The `CoreWorkerTaskInterface` instance.
|
||||
CoreWorkerTaskInterface task_interface_;
|
||||
|
||||
@@ -61,6 +75,10 @@ class CoreWorker {
|
||||
|
||||
/// The `CoreWorkerTaskExecutionInterface` instance.
|
||||
CoreWorkerTaskExecutionInterface task_execution_interface_;
|
||||
|
||||
friend class CoreWorkerTaskInterface;
|
||||
friend class CoreWorkerObjectInterface;
|
||||
friend class CoreWorkerTaskExecutionInterface;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -1,20 +1,137 @@
|
||||
#include <thread>
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "context.h"
|
||||
#include "core_worker.h"
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "ray/thirdparty/hiredis/async.h"
|
||||
#include "ray/thirdparty/hiredis/hiredis.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
std::string store_executable;
|
||||
std::string raylet_executable;
|
||||
|
||||
ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); }
|
||||
|
||||
static void flushall_redis(void) {
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
freeReplyObject(redisCommand(context, "SET NumRedisShards 1"));
|
||||
freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380"));
|
||||
redisFree(context);
|
||||
}
|
||||
|
||||
class CoreWorkerTest : public ::testing::Test {
|
||||
public:
|
||||
CoreWorkerTest() : core_worker_(WorkerType::WORKER, Language::PYTHON) {}
|
||||
CoreWorkerTest(int num_nodes) {
|
||||
RAY_CHECK(num_nodes >= 0);
|
||||
if (num_nodes > 0) {
|
||||
raylet_socket_names_.resize(num_nodes);
|
||||
raylet_store_socket_names_.resize(num_nodes);
|
||||
}
|
||||
|
||||
// start plasma store.
|
||||
for (auto &store_socket : raylet_store_socket_names_) {
|
||||
store_socket = StartStore();
|
||||
}
|
||||
|
||||
// start raylet on each node
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
raylet_socket_names_[i] = StartRaylet(raylet_store_socket_names_[i], "127.0.0.1",
|
||||
"127.0.0.1", "\"CPU,4.0\"");
|
||||
}
|
||||
}
|
||||
|
||||
~CoreWorkerTest() {
|
||||
for (const auto &raylet_socket : raylet_socket_names_) {
|
||||
StopRaylet(raylet_socket);
|
||||
}
|
||||
|
||||
for (const auto &store_socket : raylet_store_socket_names_) {
|
||||
StopStore(store_socket);
|
||||
}
|
||||
}
|
||||
|
||||
std::string StartStore() {
|
||||
std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex();
|
||||
std::string store_pid = store_socket_name + ".pid";
|
||||
std::string plasma_command = store_executable + " -m 10000000 -s " +
|
||||
store_socket_name +
|
||||
" 1> /dev/null 2> /dev/null & echo $! > " + store_pid;
|
||||
RAY_LOG(INFO) << plasma_command;
|
||||
RAY_CHECK(system(plasma_command.c_str()) == 0);
|
||||
usleep(200 * 1000);
|
||||
return store_socket_name;
|
||||
}
|
||||
|
||||
void StopStore(std::string store_socket_name) {
|
||||
std::string store_pid = store_socket_name + ".pid";
|
||||
std::string kill_9 = "kill -9 `cat " + store_pid + "`";
|
||||
RAY_LOG(INFO) << kill_9;
|
||||
ASSERT_TRUE(system(kill_9.c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + store_socket_name).c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + store_socket_name + ".pid").c_str()) == 0);
|
||||
}
|
||||
|
||||
std::string StartRaylet(std::string store_socket_name, std::string node_ip_address,
|
||||
std::string redis_address, std::string resource) {
|
||||
std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex();
|
||||
std::string ray_start_cmd = raylet_executable;
|
||||
ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name)
|
||||
.append(" --store_socket_name=" + store_socket_name)
|
||||
.append(" --object_manager_port=0 --node_manager_port=0")
|
||||
.append(" --node_ip_address=" + node_ip_address)
|
||||
.append(" --redis_address=" + redis_address)
|
||||
.append(" --redis_port=6379")
|
||||
.append(" --num_initial_workers=0")
|
||||
.append(" --maximum_startup_concurrency=10")
|
||||
.append(" --static_resource_list=" + resource)
|
||||
.append(" --python_worker_command=NoneCmd")
|
||||
.append(" & echo $! > " + raylet_socket_name + ".pid");
|
||||
|
||||
RAY_LOG(INFO) << "Ray Start command: " << ray_start_cmd;
|
||||
RAY_CHECK(system(ray_start_cmd.c_str()) == 0);
|
||||
usleep(200 * 1000);
|
||||
return raylet_socket_name;
|
||||
}
|
||||
|
||||
void StopRaylet(std::string raylet_socket_name) {
|
||||
std::string raylet_pid = raylet_socket_name + ".pid";
|
||||
std::string kill_9 = "kill -9 `cat " + raylet_pid + "`";
|
||||
RAY_LOG(INFO) << kill_9;
|
||||
ASSERT_TRUE(system(kill_9.c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0);
|
||||
ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0);
|
||||
}
|
||||
|
||||
void SetUp() { flushall_redis(); }
|
||||
|
||||
void TearDown() {}
|
||||
|
||||
protected:
|
||||
CoreWorker core_worker_;
|
||||
std::vector<std::string> raylet_socket_names_;
|
||||
std::vector<std::string> raylet_store_socket_names_;
|
||||
};
|
||||
|
||||
TEST_F(CoreWorkerTest, TestTaskArg) {
|
||||
class ZeroNodeTest : public CoreWorkerTest {
|
||||
public:
|
||||
ZeroNodeTest() : CoreWorkerTest(0) {}
|
||||
};
|
||||
|
||||
class SingleNodeTest : public CoreWorkerTest {
|
||||
public:
|
||||
SingleNodeTest() : CoreWorkerTest(1) {}
|
||||
};
|
||||
|
||||
TEST_F(ZeroNodeTest, TestTaskArg) {
|
||||
// Test by-reference argument.
|
||||
ObjectID id = ObjectID::FromRandom();
|
||||
TaskArg by_ref = TaskArg::PassByReference(id);
|
||||
@@ -30,9 +147,100 @@ TEST_F(CoreWorkerTest, TestTaskArg) {
|
||||
ASSERT_EQ(*data, *buffer);
|
||||
}
|
||||
|
||||
TEST_F(CoreWorkerTest, TestAttributeGetters) {
|
||||
ASSERT_EQ(core_worker_.WorkerType(), WorkerType::WORKER);
|
||||
ASSERT_EQ(core_worker_.Language(), Language::PYTHON);
|
||||
TEST_F(ZeroNodeTest, TestAttributeGetters) {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", "",
|
||||
DriverID::FromRandom());
|
||||
ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER);
|
||||
ASSERT_EQ(core_worker.Language(), Language::PYTHON);
|
||||
}
|
||||
|
||||
TEST_F(ZeroNodeTest, TestWorkerContext) {
|
||||
auto driver_id = DriverID::FromRandom();
|
||||
|
||||
WorkerContext context(WorkerType::WORKER, driver_id);
|
||||
ASSERT_TRUE(context.GetCurrentTaskID().IsNil());
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 1);
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 2);
|
||||
ASSERT_EQ(context.GetNextPutIndex(), 1);
|
||||
ASSERT_EQ(context.GetNextPutIndex(), 2);
|
||||
|
||||
auto thread_func = [&context]() {
|
||||
// Verify that task_index, put_index are thread-local.
|
||||
ASSERT_TRUE(!context.GetCurrentTaskID().IsNil());
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 1);
|
||||
ASSERT_EQ(context.GetNextPutIndex(), 1);
|
||||
};
|
||||
|
||||
std::thread async_thread(thread_func);
|
||||
async_thread.join();
|
||||
|
||||
// Verify that these fields are thread-local.
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 3);
|
||||
ASSERT_EQ(context.GetNextPutIndex(), 3);
|
||||
}
|
||||
|
||||
TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
DriverID::FromRandom());
|
||||
RAY_CHECK_OK(core_worker.Connect());
|
||||
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
|
||||
std::vector<LocalMemoryBuffer> buffers;
|
||||
buffers.emplace_back(array1, sizeof(array1));
|
||||
buffers.emplace_back(array2, sizeof(array2));
|
||||
|
||||
std::vector<ObjectID> ids(buffers.size());
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
core_worker.Objects().Put(buffers[i], &ids[i]);
|
||||
}
|
||||
|
||||
// Test Get().
|
||||
std::vector<std::shared_ptr<Buffer>> results;
|
||||
core_worker.Objects().Get(ids, 0, &results);
|
||||
|
||||
ASSERT_EQ(results.size(), 2);
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
ASSERT_EQ(results[i]->Size(), buffers[i].Size());
|
||||
ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0);
|
||||
}
|
||||
|
||||
// Test Wait().
|
||||
ObjectID non_existent_id = ObjectID::FromRandom();
|
||||
std::vector<ObjectID> all_ids(ids);
|
||||
all_ids.push_back(non_existent_id);
|
||||
|
||||
std::vector<bool> wait_results;
|
||||
core_worker.Objects().Wait(all_ids, 2, -1, &wait_results);
|
||||
ASSERT_EQ(wait_results.size(), 3);
|
||||
ASSERT_EQ(wait_results, std::vector<bool>({true, true, false}));
|
||||
|
||||
core_worker.Objects().Wait(all_ids, 3, 100, &wait_results);
|
||||
ASSERT_EQ(wait_results.size(), 3);
|
||||
ASSERT_EQ(wait_results, std::vector<bool>({true, true, false}));
|
||||
|
||||
// Test Delete().
|
||||
// clear the reference held by PlasmaBuffer.
|
||||
results.clear();
|
||||
core_worker.Objects().Delete(ids, true, false);
|
||||
|
||||
// Note that Delete() calls RayletClient::FreeObjects and would not
|
||||
// wait for objects being deleted, so wait a while for plasma store
|
||||
// to process the command.
|
||||
usleep(200 * 1000);
|
||||
core_worker.Objects().Get(ids, 0, &results);
|
||||
ASSERT_EQ(results.size(), 2);
|
||||
ASSERT_TRUE(!results[0]);
|
||||
ASSERT_TRUE(!results[1]);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::store_executable = std::string(argv[1]);
|
||||
ray::raylet_executable = std::string(argv[2]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
@@ -1,25 +1,128 @@
|
||||
#include "object_interface.h"
|
||||
#include "context.h"
|
||||
#include "core_worker.h"
|
||||
#include "ray/ray_config.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID *object_id) {
|
||||
CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker)
|
||||
: core_worker_(core_worker) {}
|
||||
|
||||
Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) {
|
||||
ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(),
|
||||
core_worker_.worker_context_.GetNextPutIndex());
|
||||
*object_id = put_id;
|
||||
|
||||
auto plasma_id = put_id.ToPlasmaId();
|
||||
std::shared_ptr<arrow::Buffer> data;
|
||||
RAY_ARROW_RETURN_NOT_OK(
|
||||
core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data));
|
||||
memcpy(data->mutable_data(), buffer.Data(), buffer.Size());
|
||||
RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id));
|
||||
RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CoreWorkerObjectInterface::Get(const std::vector<ObjectID> &ids,
|
||||
int64_t timeout_ms, std::vector<Buffer> *results) {
|
||||
int64_t timeout_ms,
|
||||
std::vector<std::shared_ptr<Buffer>> *results) {
|
||||
(*results).resize(ids.size(), nullptr);
|
||||
|
||||
bool was_blocked = false;
|
||||
|
||||
std::unordered_map<ObjectID, int> unready;
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
unready.insert({ids[i], i});
|
||||
}
|
||||
|
||||
int num_attempts = 0;
|
||||
bool should_break = false;
|
||||
int64_t remaining_timeout = timeout_ms;
|
||||
// Repeat until we get all objects.
|
||||
while (!unready.empty() && !should_break) {
|
||||
std::vector<ObjectID> unready_ids;
|
||||
for (const auto &entry : unready) {
|
||||
unready_ids.push_back(entry.first);
|
||||
}
|
||||
|
||||
// For the initial fetch, we only fetch the objects, do not reconstruct them.
|
||||
bool fetch_only = num_attempts == 0;
|
||||
if (!fetch_only) {
|
||||
// If fetch_only is false, this worker will be blocked.
|
||||
was_blocked = true;
|
||||
}
|
||||
|
||||
// TODO: can call `fetchOrReconstruct` in batches as an optimization.
|
||||
RAY_CHECK_OK(core_worker_.raylet_client_->FetchOrReconstruct(
|
||||
unready_ids, fetch_only, core_worker_.worker_context_.GetCurrentTaskID()));
|
||||
|
||||
// Get the objects from the object store, and parse the result.
|
||||
int64_t get_timeout;
|
||||
if (remaining_timeout >= 0) {
|
||||
get_timeout =
|
||||
std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds());
|
||||
remaining_timeout -= get_timeout;
|
||||
should_break = remaining_timeout <= 0;
|
||||
} else {
|
||||
get_timeout = RayConfig::instance().get_timeout_milliseconds();
|
||||
}
|
||||
|
||||
std::vector<plasma::ObjectID> plasma_ids;
|
||||
for (const auto &id : unready_ids) {
|
||||
plasma_ids.push_back(id.ToPlasmaId());
|
||||
}
|
||||
|
||||
std::vector<plasma::ObjectBuffer> object_buffers;
|
||||
auto status =
|
||||
core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers);
|
||||
|
||||
for (int i = 0; i < object_buffers.size(); i++) {
|
||||
if (object_buffers[i].data != nullptr) {
|
||||
const auto &object_id = unready_ids[i];
|
||||
(*results)[unready[object_id]] =
|
||||
std::make_shared<PlasmaBuffer>(object_buffers[i].data);
|
||||
unready.erase(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
num_attempts += 1;
|
||||
// TODO: log a message if attempted too many times.
|
||||
}
|
||||
|
||||
if (was_blocked) {
|
||||
RAY_CHECK_OK(core_worker_.raylet_client_->NotifyUnblocked(
|
||||
core_worker_.worker_context_.GetCurrentTaskID()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CoreWorkerObjectInterface::Wait(const std::vector<ObjectID> &object_ids,
|
||||
int num_objects, int64_t timeout_ms,
|
||||
std::vector<bool> *results) {
|
||||
return Status::OK();
|
||||
WaitResultPair result_pair;
|
||||
auto status = core_worker_.raylet_client_->Wait(
|
||||
object_ids, num_objects, timeout_ms, false,
|
||||
core_worker_.worker_context_.GetCurrentTaskID(), &result_pair);
|
||||
std::unordered_set<ObjectID> ready_ids;
|
||||
for (const auto &entry : result_pair.first) {
|
||||
ready_ids.insert(entry);
|
||||
}
|
||||
|
||||
// TODO: change RayletClient::Wait() to return a bit set, so that we don't need
|
||||
// to do this translation.
|
||||
(*results).resize(object_ids.size());
|
||||
for (int i = 0; i < object_ids.size(); i++) {
|
||||
(*results)[i] = ready_ids.count(object_ids[i]) > 0;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
Status CoreWorkerObjectInterface::Delete(const std::vector<ObjectID> &object_ids,
|
||||
bool local_only, bool delete_creating_tasks) {
|
||||
return Status::OK();
|
||||
return core_worker_.raylet_client_->FreeObjects(object_ids, local_only,
|
||||
delete_creating_tasks);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define RAY_CORE_WORKER_OBJECT_INTERFACE_H
|
||||
|
||||
#include "common.h"
|
||||
#include "plasma/client.h"
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
@@ -13,14 +14,14 @@ class CoreWorker;
|
||||
/// The interface that contains all `CoreWorker` methods that are related to object store.
|
||||
class CoreWorkerObjectInterface {
|
||||
public:
|
||||
CoreWorkerObjectInterface(CoreWorker &core_worker) : core_worker_(core_worker) {}
|
||||
CoreWorkerObjectInterface(CoreWorker &core_worker);
|
||||
|
||||
/// Put an object into object store.
|
||||
///
|
||||
/// \param[in] buffer Data buffer of the object.
|
||||
/// \param[out] object_id Generated ID of the object.
|
||||
/// \return Status.
|
||||
Status Put(const Buffer &buffer, const ObjectID *object_id);
|
||||
Status Put(const Buffer &buffer, ObjectID *object_id);
|
||||
|
||||
/// Get a list of objects from the object store.
|
||||
///
|
||||
@@ -29,7 +30,7 @@ class CoreWorkerObjectInterface {
|
||||
/// \param[out] results Result list of objects data.
|
||||
/// \return Status.
|
||||
Status Get(const std::vector<ObjectID> &ids, int64_t timeout_ms,
|
||||
std::vector<Buffer> *results);
|
||||
std::vector<std::shared_ptr<Buffer>> *results);
|
||||
|
||||
/// Wait for a list of objects to appear in the object store.
|
||||
///
|
||||
|
||||
@@ -791,7 +791,7 @@ int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleString
|
||||
return REDISMODULE_OK;
|
||||
}
|
||||
|
||||
Status is_nil(bool *out, const std::string &data) {
|
||||
Status IsNil(bool *out, const std::string &data) {
|
||||
if (data.size() != kUniqueIDSize) {
|
||||
return Status::RedisError("Size of data doesn't match size of UniqueID");
|
||||
}
|
||||
@@ -836,7 +836,7 @@ int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
|
||||
static_cast<int>(update->test_state_bitmask());
|
||||
|
||||
bool is_nil_result;
|
||||
REPLY_AND_RETURN_IF_NOT_OK(is_nil(&is_nil_result, update->test_raylet_id()->str()));
|
||||
REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str()));
|
||||
if (!is_nil_result) {
|
||||
do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str();
|
||||
}
|
||||
|
||||
@@ -54,6 +54,17 @@
|
||||
// This macro is used to replace the "ARROW_CHECK_OK" macro.
|
||||
#define RAY_ARROW_CHECK_OK(s) RAY_ARROW_CHECK_OK_PREPEND(s, "Bad status")
|
||||
|
||||
// If arrow status is not ok, return a ray IOError status
|
||||
// with the error message.
|
||||
#define RAY_ARROW_RETURN_NOT_OK(s) \
|
||||
do { \
|
||||
::arrow::Status _s = (s); \
|
||||
if (RAY_PREDICT_FALSE(!_s.ok())) { \
|
||||
return ray::Status::IOError(_s.message()); \
|
||||
; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace ray {
|
||||
|
||||
enum class StatusCode : char {
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the root directory.
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
set -x
|
||||
|
||||
bazel build "//:core_worker_test" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server"
|
||||
|
||||
# Get the directory in which this script is executing.
|
||||
SCRIPT_DIR="`dirname \"$0\"`"
|
||||
RAY_ROOT="$SCRIPT_DIR/../../.."
|
||||
# Makes $RAY_ROOT an absolute path.
|
||||
RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`"
|
||||
if [ -z "$RAY_ROOT" ] ; then
|
||||
exit 1
|
||||
fi
|
||||
# Ensure we're in the right directory.
|
||||
if [ ! -d "$RAY_ROOT/python" ]; then
|
||||
echo "Unable to find root Ray directory. Has this script moved?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
REDIS_MODULE="./bazel-bin/libray_redis_module.so"
|
||||
LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}"
|
||||
STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server"
|
||||
RAYLET_EXEC="./bazel-bin/raylet"
|
||||
|
||||
# Allow cleanup commands to fail.
|
||||
bazel run //:redis-cli -- -p 6379 shutdown || true
|
||||
sleep 1s
|
||||
bazel run //:redis-cli -- -p 6380 shutdown || true
|
||||
sleep 1s
|
||||
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 &
|
||||
sleep 2s
|
||||
bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 &
|
||||
sleep 2s
|
||||
# Run tests.
|
||||
./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC
|
||||
sleep 1s
|
||||
bazel run //:redis-cli -- -p 6379 shutdown
|
||||
bazel run //:redis-cli -- -p 6380 shutdown
|
||||
sleep 1s
|
||||
|
||||
# Include raylet integration test once it's ready.
|
||||
# ./bazel-bin/object_manager_integration_test $STORE_EXEC
|
||||
Reference in New Issue
Block a user