diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index b0b690776..436512e02 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -188,12 +188,12 @@ class JobInfoAccessor { virtual Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) = 0; - /// Subscribe to finished jobs. + /// Subscribe to job updates. /// - /// \param subscribe Callback that will be called each time when a job finishes. + /// \param subscribe Callback that will be called each time when a job updates. /// \param done Callback that will be called when subscription is complete. /// \return Status - virtual Status AsyncSubscribeToFinishedJobs( + virtual Status AsyncSubscribeAll( const SubscribeCallback &subscribe, const StatusCallback &done) = 0; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index 6f0efebca..9c27d1dee 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -59,20 +59,32 @@ Status ServiceBasedJobInfoAccessor::AsyncMarkFinished(const JobID &job_id, return Status::OK(); } -Status ServiceBasedJobInfoAccessor::AsyncSubscribeToFinishedJobs( +Status ServiceBasedJobInfoAccessor::AsyncSubscribeAll( const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); + fetch_all_data_operation_ = [this, subscribe](const StatusCallback &done) { + auto callback = [subscribe, done]( + const Status &status, + const std::vector &job_info_list) { + for (auto &job_info : job_info_list) { + subscribe(JobID::FromBinary(job_info.job_id()), job_info); + } + if (done) { + done(status); + } + }; + RAY_CHECK_OK(AsyncGetAll(callback)); + }; subscribe_operation_ = [this, subscribe](const StatusCallback &done) { auto on_subscribe = [subscribe](const std::string &id, const std::string &data) { JobTableData job_data; job_data.ParseFromString(data); - if (job_data.is_dead()) { - subscribe(JobID::FromBinary(id), job_data); - } + subscribe(JobID::FromBinary(id), job_data); }; return client_impl_->GetGcsPubSub().SubscribeAll(JOB_CHANNEL, on_subscribe, done); }; - return subscribe_operation_(done); + return subscribe_operation_( + [this, done](const Status &status) { fetch_all_data_operation_(done); }); } void ServiceBasedJobInfoAccessor::AsyncResubscribe(bool is_pubsub_server_restarted) { diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 08479f0ed..8cda97f50 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -43,15 +43,18 @@ class ServiceBasedJobInfoAccessor : public JobInfoAccessor { Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; - Status AsyncSubscribeToFinishedJobs( - const SubscribeCallback &subscribe, - const StatusCallback &done) override; + Status AsyncSubscribeAll(const SubscribeCallback &subscribe, + const StatusCallback &done) override; Status AsyncGetAll(const MultiItemCallback &callback) override; void AsyncResubscribe(bool is_pubsub_server_restarted) override; private: + /// Save the fetch data operation in this function, so we can call it again when GCS + /// server restarts from a failure. + FetchDataOperation fetch_all_data_operation_; + /// Save the subscribe operation in this function, so we can call it again when PubSub /// server restarts from a failure. SubscribeOperation subscribe_operation_; diff --git a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc index ce15ffc7e..47f247309 100644 --- a/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc +++ b/src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc @@ -101,10 +101,10 @@ class ServiceBasedGcsClientTest : public ::testing::Test { RAY_LOG(INFO) << "GCS service restarted, port = " << gcs_server_->GetPort(); } - bool SubscribeToFinishedJobs( + bool SubscribeToAllJobs( const gcs::SubscribeCallback &subscribe) { std::promise promise; - RAY_CHECK_OK(gcs_client_->Jobs().AsyncSubscribeToFinishedJobs( + RAY_CHECK_OK(gcs_client_->Jobs().AsyncSubscribeAll( subscribe, [&promise](Status status) { promise.set_value(status.ok()); })); return WaitReady(promise.get_future(), timeout_ms_); } @@ -515,17 +515,16 @@ TEST_F(ServiceBasedGcsClientTest, TestJobInfo) { JobID add_job_id = JobID::FromInt(1); auto job_table_data = Mocker::GenJobTableData(add_job_id); - // Subscribe to finished jobs. - std::atomic finished_job_count(0); - auto on_subscribe = [&finished_job_count](const JobID &job_id, - const gcs::JobTableData &data) { - finished_job_count++; + // Subscribe to all jobs. + std::atomic job_updates(0); + auto on_subscribe = [&job_updates](const JobID &job_id, const gcs::JobTableData &data) { + job_updates++; }; - ASSERT_TRUE(SubscribeToFinishedJobs(on_subscribe)); + ASSERT_TRUE(SubscribeToAllJobs(on_subscribe)); ASSERT_TRUE(AddJob(job_table_data)); ASSERT_TRUE(MarkJobFinished(add_job_id)); - WaitPendingDone(finished_job_count, 1); + WaitPendingDone(job_updates, 2); } TEST_F(ServiceBasedGcsClientTest, TestActorInfo) { @@ -862,18 +861,18 @@ TEST_F(ServiceBasedGcsClientTest, TestJobTableResubscribe) { JobID job_id = JobID::FromInt(1); auto job_table_data = Mocker::GenJobTableData(job_id); - // Subscribe to finished jobs. + // Subscribe to all jobs. std::atomic job_update_count(0); auto subscribe = [&job_update_count](const JobID &id, const rpc::JobTableData &result) { ++job_update_count; }; - ASSERT_TRUE(SubscribeToFinishedJobs(subscribe)); + ASSERT_TRUE(SubscribeToAllJobs(subscribe)); RestartGcsServer(); ASSERT_TRUE(AddJob(job_table_data)); ASSERT_TRUE(MarkJobFinished(job_id)); - WaitPendingDone(job_update_count, 1); + WaitPendingDone(job_update_count, 2); } TEST_F(ServiceBasedGcsClientTest, TestActorTableResubscribe) { diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index b70b2ad91..b5d60f555 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -25,11 +25,14 @@ void GcsJobManager::HandleAddJob(const rpc::AddJobRequest &request, JobID job_id = JobID::FromBinary(request.data().job_id()); RAY_LOG(INFO) << "Adding job, job id = " << job_id << ", driver pid = " << request.data().driver_pid(); - auto on_done = [job_id, request, reply, send_reply_callback](const Status &status) { + auto on_done = [this, job_id, request, reply, + send_reply_callback](const Status &status) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add job, job id = " << job_id << ", driver pid = " << request.data().driver_pid(); } else { + RAY_CHECK_OK(gcs_pub_sub_->Publish(JOB_CHANNEL, job_id.Binary(), + request.data().SerializeAsString(), nullptr)); RAY_LOG(INFO) << "Finished adding job, job id = " << job_id << ", driver pid = " << request.data().driver_pid(); } diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index 564623fa5..988276bc2 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -34,13 +34,15 @@ namespace gcs { /// \return The job table data created by this method. inline std::shared_ptr CreateJobTableData( const ray::JobID &job_id, bool is_dead, int64_t timestamp, - const std::string &driver_ip_address, int64_t driver_pid) { + const std::string &driver_ip_address, int64_t driver_pid, + const ray::rpc::JobConfigs &job_configs = {}) { auto job_info_ptr = std::make_shared(); job_info_ptr->set_job_id(job_id.Binary()); job_info_ptr->set_is_dead(is_dead); job_info_ptr->set_timestamp(timestamp); job_info_ptr->set_driver_ip_address(driver_ip_address); job_info_ptr->set_driver_pid(driver_pid); + *job_info_ptr->mutable_configs() = job_configs; return job_info_ptr; } diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index 800f90e0f..3ca27345d 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -349,15 +349,10 @@ Status RedisJobInfoAccessor::DoAsyncAppend(const std::shared_ptr & return client_impl_->job_table().Append(job_id, job_id, data_ptr, on_done); } -Status RedisJobInfoAccessor::AsyncSubscribeToFinishedJobs( +Status RedisJobInfoAccessor::AsyncSubscribeAll( const SubscribeCallback &subscribe, const StatusCallback &done) { RAY_CHECK(subscribe != nullptr); - auto on_subscribe = [subscribe](const JobID &job_id, const JobTableData &job_data) { - if (job_data.is_dead()) { - subscribe(job_id, job_data); - } - }; - return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), on_subscribe, done); + return job_sub_executor_.AsyncSubscribeAll(ClientID::Nil(), subscribe, done); } RedisTaskInfoAccessor::RedisTaskInfoAccessor(RedisGcsClient *client_impl) diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 8d1878d23..b02b464bd 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -175,9 +175,8 @@ class RedisJobInfoAccessor : public JobInfoAccessor { Status AsyncMarkFinished(const JobID &job_id, const StatusCallback &callback) override; - Status AsyncSubscribeToFinishedJobs( - const SubscribeCallback &subscribe, - const StatusCallback &done) override; + Status AsyncSubscribeAll(const SubscribeCallback &subscribe, + const StatusCallback &done) override; Status AsyncGetAll(const MultiItemCallback &callback) override { return Status::NotImplemented("AsyncGetAll not implemented"); diff --git a/src/ray/gcs/test/redis_job_info_accessor_test.cc b/src/ray/gcs/test/redis_job_info_accessor_test.cc index ddf1830dc..31dc69393 100644 --- a/src/ray/gcs/test/redis_job_info_accessor_test.cc +++ b/src/ray/gcs/test/redis_job_info_accessor_test.cc @@ -45,8 +45,9 @@ TEST_F(RedisJobInfoAccessorTest, AddAndSubscribe) { auto on_subscribe = [this](const JobID &job_id, const JobTableData &data) { const auto it = id_to_data_.find(job_id); RAY_CHECK(it != id_to_data_.end()); - ASSERT_TRUE(data.is_dead()); - --subscribe_pending_count_; + if (data.is_dead()) { + --subscribe_pending_count_; + } }; auto on_done = [this](Status status) { @@ -55,7 +56,7 @@ TEST_F(RedisJobInfoAccessorTest, AddAndSubscribe) { }; ++pending_count_; - RAY_CHECK_OK(job_accessor.AsyncSubscribeToFinishedJobs(on_subscribe, on_done)); + RAY_CHECK_OK(job_accessor.AsyncSubscribeAll(on_subscribe, on_done)); WaitPendingDone(wait_pending_timeout_); WaitPendingDone(subscribe_pending_count_, wait_pending_timeout_); diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index e765cce28..7633f6559 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -283,6 +283,21 @@ message TaskLeaseData { uint64 timeout = 4; } +message JobConfigs { + // The initial Python workers to start per node. If a negative value is specified, it + // fallbacks to `num_cpus`. + int32 num_initial_python_workers = 1; + // The initial Java workers to start per node. If a negative value is specified, it + // fallbacks to `num_cpus`. + int32 num_initial_java_workers = 2; + // Environment variables to be set on worker processes. + map worker_env = 3; + // The number of java workers per worker process. + uint32 num_java_workers_per_process = 4; + // The jvm options for java workers of the job. + repeated string jvm_options = 5; +} + message JobTableData { // The job ID. bytes job_id = 1; @@ -294,6 +309,8 @@ message JobTableData { string driver_ip_address = 4; // Process ID of the driver running this job. int64 driver_pid = 5; + // The configs of this job. + JobConfigs configs = 6; } // This table stores the actor checkpoint data. An actor checkpoint diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d67090730..0c867e971 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -282,10 +282,14 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to job updates. const auto job_subscribe_handler = [this](const JobID &job_id, const JobTableData &job_data) { - HandleJobFinished(job_id, job_data); + if (!job_data.is_dead()) { + HandleJobStarted(job_id, job_data); + } else { + HandleJobFinished(job_id, job_data); + } }; RAY_RETURN_NOT_OK( - gcs_client_->Jobs().AsyncSubscribeToFinishedJobs(job_subscribe_handler, nullptr)); + gcs_client_->Jobs().AsyncSubscribeAll(job_subscribe_handler, nullptr)); // Start sending heartbeats to the GCS. last_heartbeat_at_ms_ = current_time_ms(); @@ -320,6 +324,13 @@ void NodeManager::KillWorker(std::shared_ptr worker) { }); } +void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_data) { + RAY_LOG(DEBUG) << "HandleJobStarted " << job_id; + RAY_CHECK(!job_data.is_dead()); + + // TODO(kfstorm): Spawn job initial workers in a later PR. +} + void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job_data) { RAY_LOG(DEBUG) << "HandleJobFinished " << job_id; RAY_CHECK(job_data.is_dead()); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 8ac626fe2..58fb51a7e 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -441,6 +441,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void HandleObjectMissing(const ObjectID &object_id); + /// Handles the event that a job is started. + /// + /// \param job_id ID of the started job. + /// \param job_data Data associated with the started job. + /// \return Void + void HandleJobStarted(const JobID &job_id, const JobTableData &job_data); + /// Handles the event that a job is finished. /// /// \param job_id ID of the finished job.