From 7097ba393b473b1bf2efc6cc432658d4c9242f26 Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Tue, 12 Feb 2019 00:39:38 +0800 Subject: [PATCH] protect raylet against bad messages (#4003) * protect raylet against bad messages * address comments * linting and regression test --- python/ray/includes/ray_config.pxd | 2 +- python/ray/includes/ray_config.pxi | 4 +- src/ray/common/client_connection.cc | 65 +++++++++++++++++++++--- src/ray/common/client_connection.h | 15 +++++- src/ray/ray_config_def.h | 8 ++- src/ray/raylet/client_connection_test.cc | 43 ++++++++++++++++ src/ray/raylet/format/node_manager.fbs | 7 +++ src/ray/raylet/node_manager.cc | 48 ++++++++++++----- src/ray/raylet/node_manager.h | 12 +++++ src/ray/raylet/raylet_client.cc | 10 ++-- test/runtest.py | 24 +++++++++ 11 files changed, 206 insertions(+), 32 deletions(-) diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 906addc99..809313479 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -8,7 +8,7 @@ cdef extern from "ray/ray_config.h" nogil: @staticmethod RayConfig &instance() - int64_t ray_protocol_version() const + int64_t ray_cookie() const int64_t handler_warning_timeout_ms() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 72b3e1e3d..cb7fa53c3 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -2,8 +2,8 @@ from ray.includes.ray_config cimport RayConfig cdef class Config: @staticmethod - def ray_protocol_version(): - return RayConfig.instance().ray_protocol_version() + def ray_cookie(): + return RayConfig.instance().ray_cookie() @staticmethod def handler_warning_timeout_ms(): diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 51d3859d5..2dfdbcd33 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -2,6 +2,7 @@ #include #include +#include #include "ray/ray_config.h" #include "ray/raylet/format/node_manager_generated.h" @@ -101,8 +102,8 @@ ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, bytes_written_ += length; std::vector message_buffers; - auto write_version = RayConfig::instance().ray_protocol_version(); - message_buffers.push_back(boost::asio::buffer(&write_version, sizeof(write_version))); + auto write_cookie = RayConfig::instance().ray_cookie(); + message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie))); message_buffers.push_back(boost::asio::buffer(&type, sizeof(type))); message_buffers.push_back(boost::asio::buffer(&length, sizeof(length))); message_buffers.push_back(boost::asio::buffer(message, length)); @@ -117,7 +118,7 @@ void ServerConnection::WriteMessageAsync( bytes_written_ += length; auto write_buffer = std::unique_ptr(new AsyncWriteBuffer()); - write_buffer->write_version = RayConfig::instance().ray_protocol_version(); + write_buffer->write_cookie = RayConfig::instance().ray_cookie(); write_buffer->write_type = type; write_buffer->write_length = length; write_buffer->write_message.resize(length); @@ -147,8 +148,8 @@ void ServerConnection::DoAsyncWrites() { std::vector message_buffers; int num_messages = 0; for (const auto &write_buffer : async_write_queue_) { - message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version, - sizeof(write_buffer->write_version))); + message_buffers.push_back(boost::asio::buffer(&write_buffer->write_cookie, + sizeof(write_buffer->write_cookie))); message_buffers.push_back( boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type))); message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length, @@ -202,6 +203,7 @@ ClientConnection::ClientConnection( const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type) : ServerConnection(std::move(socket)), + client_id_(ClientID::nil()), message_handler_(message_handler), debug_label_(debug_label), message_type_enum_names_(message_type_enum_names), @@ -222,7 +224,7 @@ void ClientConnection::ProcessMessages() { // Wait for a message header from the client. The message header includes the // protocol version, the message type, and the length of the message. std::vector header; - header.push_back(boost::asio::buffer(&read_version_, sizeof(read_version_))); + header.push_back(boost::asio::buffer(&read_cookie_, sizeof(read_cookie_))); header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_))); header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_))); boost::asio::async_read( @@ -241,8 +243,12 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & return; } - // If there was no error, make sure the protocol version matches. - RAY_CHECK(read_version_ == RayConfig::instance().ray_protocol_version()); + // If there was no error, make sure the ray cookie matches. + if (!CheckRayCookie()) { + ServerConnection::Close(); + return; + } + // Resize the message buffer to match the received length. read_message_.resize(read_length_); ServerConnection::bytes_read_ += read_length_; @@ -253,6 +259,49 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } +template +bool ClientConnection::CheckRayCookie() { + if (read_cookie_ == RayConfig::instance().ray_cookie()) { + return true; + } + + // Cookie is not matched. + // Only assert if the message is coming from a known remote endpoint, + // which is indicated by a non-nil client ID. This is to protect raylet + // against miscellaneous connections. We did see cases where bad data + // is received from local unknown program which crashes raylet. + std::ostringstream ss; + ss << " ray cookie mismatch for received message. " + << "received cookie: " << read_cookie_ << ", debug label: " << debug_label_ + << ", remote client ID: " << client_id_; + auto remote_endpoint_info = RemoteEndpointInfo(); + if (!remote_endpoint_info.empty()) { + ss << ", remote endpoint info: " << remote_endpoint_info; + } + + if (!client_id_.is_nil()) { + // This is from a known client, which indicates a bug. + RAY_LOG(FATAL) << ss.str(); + } else { + // It's not from a known client, log this message, and stop processing the connection. + RAY_LOG(WARNING) << ss.str(); + } + return false; +} + +template +std::string ClientConnection::RemoteEndpointInfo() { + return std::string(); +} + +template <> +std::string ClientConnection::RemoteEndpointInfo() { + const auto &remote_endpoint = + ServerConnection::socket_.remote_endpoint(); + return remote_endpoint.address().to_string() + ":" + + std::to_string(remote_endpoint.port()); +} + template void ClientConnection::ProcessMessage(const boost::system::error_code &error) { if (error) { diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 22d4a8ba0..b44026233 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -83,7 +83,7 @@ class ServerConnection : public std::enable_shared_from_this /// A message that is queued for writing asynchronously. struct AsyncWriteBuffer { - int64_t write_version; + int64_t write_cookie; int64_t write_type; uint64_t write_length; std::vector write_message; @@ -184,6 +184,17 @@ class ClientConnection : public ServerConnection { /// Process an error from reading the message header, then process the /// message from the client. void ProcessMessage(const boost::system::error_code &error); + /// Check if the ray cookie in a received message is correct. Note, if the cookie + /// is wrong and the remote endpoint is known, raylet process will crash. If the remote + /// endpoint is unknown, this method will only print a warning. + /// + /// \return If the cookie is correct. + bool CheckRayCookie(); + /// Return information about IP and port for the remote endpoint. For local connection + /// this returns an empty string. + /// + /// \return Information of remote endpoint. + std::string RemoteEndpointInfo(); /// The ClientID of the remote client. ClientID client_id_; @@ -197,7 +208,7 @@ class ClientConnection : public ServerConnection { /// The value for disconnect client message. int64_t error_message_type_; /// Buffers for the current message being read from the client. - int64_t read_version_; + int64_t read_cookie_; int64_t read_type_; uint64_t read_length_; std::vector read_message_; diff --git a/src/ray/ray_config_def.h b/src/ray/ray_config_def.h index 138d474bc..bfe09b5a5 100644 --- a/src/ray/ray_config_def.h +++ b/src/ray/ray_config_def.h @@ -9,8 +9,12 @@ // 1. You must update the file "ray/python/ray/includes/ray_config.pxd". // 2. You must update the file "ray/python/ray/includes/ray_config.pxi". -/// In theory, this is used to detect Ray version mismatches. -RAY_CONFIG(int64_t, ray_protocol_version, 0x0000000000000000); +/// In theory, this is used to detect Ray cookie mismatches. +/// This magic number (hex for "RAY") is used instead of zero, rationale is +/// that it could still be possible that some random program sends an int64_t +/// which is zero, but it's much less likely that a program sends this +/// particular magic number. +RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000); /// The duration that a single handler on the event loop can take before a /// warning is logged that the handler is taking too long. diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc index 4ef236bc1..361a04de1 100644 --- a/src/ray/raylet/client_connection_test.cc +++ b/src/ray/raylet/client_connection_test.cc @@ -18,6 +18,17 @@ class ClientConnectionTest : public ::testing::Test { boost::asio::local::connect_pair(in_, out_); } + ray::Status WriteBadMessage(std::shared_ptr conn, + int64_t type, int64_t length, const uint8_t *message) { + std::vector message_buffers; + auto write_cookie = 123456; // incorrect version. + message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie))); + message_buffers.push_back(boost::asio::buffer(&type, sizeof(type))); + message_buffers.push_back(boost::asio::buffer(&length, sizeof(length))); + message_buffers.push_back(boost::asio::buffer(message, length)); + return conn->WriteBuffer(message_buffers); + } + protected: boost::asio::io_service io_service_; boost::asio::local::stream_protocol::socket in_; @@ -147,6 +158,38 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { io_service_.run(); } +TEST_F(ClientConnectionTest, ProcessBadMessage) { + const uint8_t arr[5] = {1, 2, 3, 4, 5}; + int num_messages = 0; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler message_handler = + [&arr, &num_messages](std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + ASSERT_TRUE(!std::memcmp(arr, message, 5)); + num_messages += 1; + }; + + auto writer = LocalClientConnection::Create( + client_handler, message_handler, std::move(in_), "writer", {}, error_message_type_); + + auto reader = + LocalClientConnection::Create(client_handler, message_handler, std::move(out_), + "reader", {}, error_message_type_); + + // If client ID is set, bad message would crash the test. + // reader->SetClientID(UniqueID::from_random()); + + // Intentionally write a message with incorrect cookie. + // Verify it won't crash as long as client ID is not set. + RAY_CHECK_OK(WriteBadMessage(writer, 0, 5, arr)); + reader->ProcessMessages(); + io_service_.run(); + ASSERT_EQ(num_messages, 0); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 7725b9bc7..289d02170 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -71,6 +71,8 @@ enum MessageType:int { PushProfileEventsRequest, // Free the objects in objects store. FreeObjectsInObjectStoreRequest, + // A node manager requests to connect to another node manager. + ConnectClient, } table TaskExecutionSpecification { @@ -204,3 +206,8 @@ table FreeObjectsRequest { // List of object ids we'll delete from object store. object_ids: [string]; } + +table ConnectClient { + // ID of the connecting client. + client_id: string; +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d05f28dd0..adef01209 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -44,7 +44,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, std::shared_ptr gcs_client, std::shared_ptr object_directory, plasma::PlasmaClient &store_client) - : io_service_(io_service), + : client_id_(gcs_client->client_table().GetLocalClientId()), + io_service_(io_service), object_manager_(object_manager), store_client_(store_client), gcs_client_(std::move(gcs_client)), @@ -338,13 +339,8 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { } // Establish a new NodeManager connection to this GCS client. - RAY_LOG(DEBUG) << "[ClientAdded] Trying to connect to client " << client_id << " at " - << client_data.node_manager_address << ":" - << client_data.node_manager_port; - - boost::asio::ip::tcp::socket socket(io_service_); - auto status = - TcpConnect(socket, client_data.node_manager_address, client_data.node_manager_port); + auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address, + client_data.node_manager_port); // A disconnected client has 2 entries in the client table (one for being // inserted and one for being removed). When a new raylet starts, ClientAdded // will be called with the disconnected client's first entry, which will cause @@ -357,15 +353,38 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { return; } - // The client is connected. - auto server_conn = TcpServerConnection::Create(std::move(socket)); - remote_server_connections_.emplace(client_id, std::move(server_conn)); - ResourceSet resources_total(client_data.resources_total_label, client_data.resources_total_capacity); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } +ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id, + const std::string &client_address, + int32_t client_port) { + // Establish a new NodeManager connection to this GCS client. + RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at " + << client_address << ":" << client_port; + + boost::asio::ip::tcp::socket socket(io_service_); + RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port)); + + // The client is connected, now send a connect message to remote node manager. + auto server_conn = TcpServerConnection::Create(std::move(socket)); + + // Prepare client connection info buffer + flatbuffers::FlatBufferBuilder fbb; + auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_)); + fbb.Finish(message); + // Send synchronously. + // TODO(swang): Make this a WriteMessageAsync. + RAY_RETURN_NOT_OK(server_conn->WriteMessage( + static_cast(protocol::MessageType::ConnectClient), fbb.GetSize(), + fbb.GetBufferPointer())); + + remote_server_connections_.emplace(client_id, std::move(server_conn)); + return ray::Status::OK(); +} + void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. @@ -1007,6 +1026,11 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl << protocol::EnumNameMessageType(message_type_value) << "(" << message_type << ") from node manager"; switch (message_type_value) { + case protocol::MessageType::ConnectClient: { + auto message = flatbuffers::GetRoot(message_data); + auto client_id = from_flatbuf(*message->client_id()); + node_manager_client.SetClientID(client_id); + } break; case protocol::MessageType::ForwardTaskRequest: { auto message = flatbuffers::GetRoot(message_data); TaskID task_id = from_flatbuf(*message->task_id()); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 4d9de0c63..1b4daade2 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -397,6 +397,18 @@ class NodeManager { void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); + /// connect to a remote node manager. + /// + /// \param client_id The client ID for the remote node manager. + /// \param client_address The IP address for the remote node manager. + /// \param client_port The listening port for the remote node manager. + /// \return True if the connect succeeds. + ray::Status ConnectRemoteNodeManager(const ClientID &client_id, + const std::string &client_address, + int32_t client_port); + + // GCS client ID for this node. + ClientID client_id_; boost::asio::io_service &io_service_; ObjectManager &object_manager_; /// A Plasma object store client. This is used exclusively for creating new diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 884fc1f4f..d931051bc 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -139,12 +139,12 @@ ray::Status RayletConnection::Disconnect() { ray::Status RayletConnection::ReadMessage(MessageType type, std::unique_ptr &message) { - int64_t version; + int64_t cookie; int64_t type_field; int64_t length; - int closed = read_bytes(conn_, (uint8_t *)&version, sizeof(version)); + int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); if (closed) goto disconnected; - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); + RAY_CHECK(cookie == RayConfig::instance().ray_cookie()); closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); if (closed) goto disconnected; closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length)); @@ -175,13 +175,13 @@ ray::Status RayletConnection::ReadMessage(MessageType type, ray::Status RayletConnection::WriteMessage(MessageType type, flatbuffers::FlatBufferBuilder *fbb) { std::unique_lock guard(write_mutex_); - int64_t version = RayConfig::instance().ray_protocol_version(); + int64_t cookie = RayConfig::instance().ray_cookie(); int64_t length = fbb ? fbb->GetSize() : 0; uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr; int64_t type_field = static_cast(type); auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly."); int closed; - closed = write_bytes(conn_, (uint8_t *)&version, sizeof(version)); + closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); if (closed) return io_error; closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); if (closed) return io_error; diff --git a/test/runtest.py b/test/runtest.py index 2e83ba423..f744e487c 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -9,6 +9,7 @@ import random import re import setproctitle import shutil +import socket import string import subprocess import sys @@ -2718,3 +2719,26 @@ def test_socket_dir_not_existing(shutdown_only): temp_raylet_socket_name = os.path.join(temp_raylet_socket_dir, "raylet_socket") ray.init(num_cpus=1, raylet_socket_name=temp_raylet_socket_name) + + +def test_raylet_is_robust_to_random_messages(shutdown_only): + + ray.init(num_cpus=1) + node_manager_address = None + node_manager_port = None + for client in ray.global_state.client_table(): + if "NodeManagerAddress" in client: + node_manager_address = client["NodeManagerAddress"] + node_manager_port = client["NodeManagerPort"] + assert node_manager_address + assert node_manager_port + # Try to bring down the node manager: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect((node_manager_address, node_manager_port)) + s.send(1000 * b'asdf') + + @ray.remote + def f(): + return 1 + + assert ray.get(f.remote()) == 1