diff --git a/src/ray/gcs/asio.cc b/src/ray/gcs/asio.cc index 8e71fb838..e449a95bd 100644 --- a/src/ray/gcs/asio.cc +++ b/src/ray/gcs/asio.cc @@ -65,62 +65,49 @@ void RedisAsioClient::operate() { if (read_requested_ && !read_in_progress_) { read_in_progress_ = true; socket_.async_read_some(boost::asio::null_buffers(), - boost::bind(&RedisAsioClient::handle_read, this, - boost::asio::placeholders::error)); + boost::bind(&RedisAsioClient::handle_io, this, + boost::asio::placeholders::error, false)); } if (write_requested_ && !write_in_progress_) { write_in_progress_ = true; socket_.async_write_some(boost::asio::null_buffers(), - boost::bind(&RedisAsioClient::handle_write, this, - boost::asio::placeholders::error)); + boost::bind(&RedisAsioClient::handle_io, this, + boost::asio::placeholders::error, true)); } } -void RedisAsioClient::handle_read(boost::system::error_code error_code) { +void RedisAsioClient::handle_io(boost::system::error_code error_code, bool write) { RAY_CHECK(!error_code || error_code == boost::asio::error::would_block || - error_code == boost::asio::error::connection_reset) - << "handle_read(error_code = " << error_code << ")"; - read_in_progress_ = false; - redis_async_context_.RedisAsyncHandleRead(); + error_code == boost::asio::error::connection_reset || + error_code == boost::asio::error::operation_aborted) + << "handle_io(error_code = " << error_code << ")"; + (write ? write_in_progress_ : read_in_progress_) = false; + if (error_code != boost::asio::error::operation_aborted) { + if (!redis_async_context_.GetRawRedisAsyncContext()) { + RAY_LOG(FATAL) << "redis_async_context_ must not be NULL"; + } + write ? redis_async_context_.RedisAsyncHandleWrite() + : redis_async_context_.RedisAsyncHandleRead(); + } if (error_code == boost::asio::error::would_block) { operate(); } } -void RedisAsioClient::handle_write(boost::system::error_code error_code) { - RAY_CHECK(!error_code || error_code == boost::asio::error::would_block || - error_code == boost::asio::error::connection_reset) - << "handle_write(error_code = " << error_code << ")"; - write_in_progress_ = false; - redis_async_context_.RedisAsyncHandleWrite(); - - if (error_code == boost::asio::error::would_block) { - operate(); - } -} - -void RedisAsioClient::add_read() { +void RedisAsioClient::add_io(bool write) { // Because redis commands are non-thread safe, dispatch the operation to backend thread. - io_service_.dispatch([this]() { - read_requested_ = true; + io_service_.dispatch([this, write]() { + (write ? write_requested_ : read_requested_) = true; operate(); }); } -void RedisAsioClient::del_read() { read_requested_ = false; } - -void RedisAsioClient::add_write() { - // Because redis commands are non-thread safe, dispatch the operation to backend thread. - io_service_.dispatch([this]() { - write_requested_ = true; - operate(); - }); +void RedisAsioClient::del_io(bool write) { + (write ? write_requested_ : read_requested_) = false; } -void RedisAsioClient::del_write() { write_requested_ = false; } - void RedisAsioClient::cleanup() {} static inline RedisAsioClient *cast_to_client(void *private_data) { @@ -129,19 +116,19 @@ static inline RedisAsioClient *cast_to_client(void *private_data) { } extern "C" void call_C_addRead(void *private_data) { - cast_to_client(private_data)->add_read(); + cast_to_client(private_data)->add_io(false); } extern "C" void call_C_delRead(void *private_data) { - cast_to_client(private_data)->del_read(); + cast_to_client(private_data)->del_io(false); } extern "C" void call_C_addWrite(void *private_data) { - cast_to_client(private_data)->add_write(); + cast_to_client(private_data)->add_io(true); } extern "C" void call_C_delWrite(void *private_data) { - cast_to_client(private_data)->del_write(); + cast_to_client(private_data)->del_io(true); } extern "C" void call_C_cleanup(void *private_data) { diff --git a/src/ray/gcs/asio.h b/src/ray/gcs/asio.h index 7f5d30c69..0bac1d189 100644 --- a/src/ray/gcs/asio.h +++ b/src/ray/gcs/asio.h @@ -60,12 +60,9 @@ class RedisAsioClient { void operate(); - void handle_read(boost::system::error_code ec); - void handle_write(boost::system::error_code ec); - void add_read(); - void del_read(); - void add_write(); - void del_write(); + void handle_io(boost::system::error_code ec, bool write); + void add_io(bool write); + void del_io(bool write); void cleanup(); private: