diff --git a/java/test.sh b/java/test.sh index 3c7da10bf..b2c99af29 100755 --- a/java/test.sh +++ b/java/test.sh @@ -1,4 +1,11 @@ #!/usr/bin/env bash + +# Cause the script to exit if a single command fails. +set -e + +# Show explicitly which commands are currently running. +set -x + ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) $ROOT_DIR/../build.sh -l java diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index f0cbddae0..0bdd38d63 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -58,7 +58,9 @@ table ObjectTableData { } table TaskReconstructionData { - num_executions: int; + // The number of times this task has been reconstructed so far. + num_reconstructions: int; + // The node manager that is trying to reconstruct the task. node_manager_id: string; } diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index c8825e266..124d39096 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -301,6 +301,24 @@ Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, return Status::OK(); } +Status RedisContext::RunArgvAsync(const std::vector &args) { + // Build the arguments. + std::vector argv; + std::vector argc; + for (size_t i = 0; i < args.size(); ++i) { + argv.push_back(args[i].data()); + argc.push_back(args[i].size()); + } + // Run the Redis command. + int status; + status = redisAsyncCommandArgv(async_context_, nullptr, nullptr, args.size(), + argv.data(), argc.data()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + return Status::OK(); +} + Status RedisContext::SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, const RedisCallback &redisCallback, diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index a9f988afe..fba91d068 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -73,6 +73,12 @@ class RedisContext { const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); + /// Run an arbitrary Redis command without a callback. + /// + /// \param args The vector of command args to pass to Redis. + /// \return Status. + Status RunArgvAsync(const std::vector &args); + /// Subscribe to a specific Pub-Sub channel. /// /// \param client_id The client ID that subscribe this message. diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index d5ccc7a99..528bcff4f 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -50,6 +50,20 @@ class PubsubInterface { virtual ~PubsubInterface(){}; }; +template +class LogInterface { + public: + using DataT = typename Data::NativeTableType; + using WriteCallback = + std::function; + virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; + virtual Status AppendAt(const JobID &job_id, const ID &task_id, + std::shared_ptr &data, const WriteCallback &done, + const WriteCallback &failure, int log_length) = 0; + virtual ~LogInterface(){}; +}; + /// \class Log /// /// A GCS table where every entry is an append-only log. This class is not @@ -63,14 +77,13 @@ class PubsubInterface { /// ClientTable: Stores a log of which GCS clients have been added or deleted /// from the system. template -class Log : virtual public PubsubInterface { +class Log : public LogInterface, virtual public PubsubInterface { public: using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; /// The callback to call when a write to a key succeeds. - using WriteCallback = - std::function; + using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to /// request and receive notifications. using SubscriptionCallback = std::function; @@ -371,6 +384,21 @@ class TaskLeaseTable : public Table { pubsub_channel_ = TablePubsub::TASK_LEASE; prefix_ = TablePrefix::TASK_LEASE; } + + Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr &data, + const WriteCallback &done) override { + RAY_RETURN_NOT_OK((Table::Add(job_id, id, data, done))); + // Mark the entry for expiration in Redis. It's okay if this command fails + // since the lease entry itself contains the expiration period. In the + // worst case, if the command fails, then a client that looks up the lease + // entry will overestimate the expiration time. + // TODO(swang): Use a common helper function to format the key instead of + // hardcoding it to match the Redis module. + std::vector args = {"PEXPIRE", + EnumNameTablePrefix(prefix_) + id.binary(), + std::to_string(data->timeout)}; + return context_->RunArgvAsync(args); + } }; namespace raylet { diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 07193d5f9..0b49adb88 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -459,6 +459,19 @@ void LineageCache::HandleEntryCommitted(const TaskID &task_id) { } } +const Task &LineageCache::GetTask(const TaskID &task_id) const { + const auto &entries = lineage_.GetEntries(); + auto it = entries.find(task_id); + RAY_CHECK(it != entries.end()); + return it->second.TaskData(); +} + +bool LineageCache::ContainsTask(const TaskID &task_id) const { + const auto &entries = lineage_.GetEntries(); + auto it = entries.find(task_id); + return it != entries.end(); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 402d49a67..97ee6dd61 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -228,6 +228,18 @@ class LineageCache { /// \param task_id The ID of the task entry that was committed. void HandleEntryCommitted(const TaskID &task_id); + /// Get a task. The task must be in the lineage cache. + /// + /// \param task_id The ID of the task to get. + /// \return A const reference to the task data. + const Task &GetTask(const TaskID &task_id) const; + + /// Get whether the lineage cache contains the task. + /// + /// \param task_id The ID of the task to get. + /// \return Whether the task is in the lineage cache. + bool ContainsTask(const TaskID &task_id) const; + private: /// Try to flush a task that is in UNCOMMITTED_READY state. If the task has /// parents that are not committed yet, then the child will be flushed once diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index bdfbacf7b..f2238b2fc 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -91,10 +91,12 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, local_queues_(SchedulingQueue()), scheduling_policy_(local_queues_), reconstruction_policy_( - io_service_, [this](const TaskID &task_id) { ResubmitTask(task_id); }, + io_service_, + [this](const TaskID &task_id) { HandleTaskReconstruction(task_id); }, RayConfig::instance().initial_reconstruction_timeout_milliseconds(), gcs_client_->client_table().GetLocalClientId(), gcs_client->task_lease_table(), - std::make_shared(gcs_client)), + std::make_shared(gcs_client), + gcs_client_->task_reconstruction_log()), task_dependency_manager_( object_manager, reconstruction_policy_, io_service, gcs_client_->client_table().GetLocalClientId(), @@ -1140,8 +1142,68 @@ void NodeManager::FinishAssignedTask(Worker &worker) { worker.AssignTaskId(TaskID::nil()); } -void NodeManager::ResubmitTask(const TaskID &task_id) { - RAY_LOG(WARNING) << "Task re-execution is not currently implemented"; +void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { + // Retrieve the task spec in order to re-execute the task. + RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup( + JobID::nil(), task_id, + /*success_callback=*/ + [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + const ray::protocol::TaskT &task_data) { + // The task was in the GCS task table. Use the stored task spec to + // re-execute the task. + const Task task(task_data); + ResubmitTask(task); + }, + /*failure_callback=*/ + [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id) { + // The task was not in the GCS task table. It must therefore be in the + // lineage cache. + if (!lineage_cache_.ContainsTask(task_id)) { + // The task was not in the lineage cache. + // TODO(swang): This should not ever happen, but Java TaskIDs are + // currently computed differently from Python TaskIDs, so + // reconstruction is currently broken for Java. Once the TaskID + // generation code matches for both frontends, we should be able to + // remove this warning and make it a fatal check. + RAY_LOG(WARNING) << "Task " << task_id << " to reconstruct was not found in " + "the GCS or the lineage cache. This " + "job may hang."; + } else { + // Use a copy of the cached task spec to re-execute the task. + const Task task = lineage_cache_.GetTask(task_id); + ResubmitTask(task); + } + })); +} + +void NodeManager::ResubmitTask(const Task &task) { + // Actor reconstruction is turned off by default right now. If this is an + // actor task, treat the task as failed and do not resubmit it. + if (task.GetTaskSpecification().IsActorTask()) { + TreatTaskAsFailed(task.GetTaskSpecification()); + return; + } + + // Driver tasks cannot be reconstructed. If this is a driver task, push an + // error to the driver and do not resubmit it. + if (task.GetTaskSpecification().IsDriverTask()) { + // TODO(rkn): Define this constant somewhere else. + std::string type = "put_reconstruction"; + std::ostringstream error_message; + error_message << "The task with ID " << task.GetTaskSpecification().TaskId() + << " is a driver task and so the object created by ray.put " + << "could not be reconstructed."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + task.GetTaskSpecification().DriverId(), type, error_message.str(), + current_time_ms())); + return; + } + + // The task may be reconstructed. Submit it with an empty lineage, since any + // uncommitted lineage must already be in the lineage cache. At this point, + // the task should not yet exist in the local scheduling queue. If it does, + // then this is a spurious reconstruction. + SubmitTask(task, Lineage()); } void NodeManager::HandleObjectLocal(const ObjectID &object_id) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 171aad929..49cadbc81 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -92,8 +92,11 @@ class NodeManager { void FinishAssignedTask(Worker &worker); /// Perform a placement decision on placeable tasks. void ScheduleTasks(); - /// Resubmit a task whose return value needs to be reconstructed. - void ResubmitTask(const TaskID &task_id); + /// Handle a task whose return value(s) must be reconstructed. + void HandleTaskReconstruction(const TaskID &task_id); + /// Resubmit a task for execution. This is a task that was previously already + /// submitted to a raylet but which must now be re-executed. + void ResubmitTask(const Task &task); /// Attempt to forward a task to a remote different node manager. If this /// fails, the task will be resubmit locally. /// diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 89632d546..c14dd34a0 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -9,13 +9,15 @@ ReconstructionPolicy::ReconstructionPolicy( std::function reconstruction_handler, int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, gcs::PubsubInterface &task_lease_pubsub, - std::shared_ptr object_directory) + std::shared_ptr object_directory, + gcs::LogInterface &task_reconstruction_log) : io_service_(io_service), reconstruction_handler_(reconstruction_handler), initial_reconstruction_timeout_ms_(initial_reconstruction_timeout_ms), client_id_(client_id), task_lease_pubsub_(task_lease_pubsub), - object_directory_(std::move(object_directory)) {} + object_directory_(std::move(object_directory)), + task_reconstruction_log_(task_reconstruction_log) {} void ReconstructionPolicy::SetTaskTimeout( std::unordered_map::iterator task_it, @@ -59,6 +61,23 @@ void ReconstructionPolicy::SetTaskTimeout( }); } +void ReconstructionPolicy::HandleReconstructionLogAppend(const TaskID &task_id, + bool success) { + auto it = listening_tasks_.find(task_id); + if (it == listening_tasks_.end()) { + return; + } + + // Reset the timer to wait for task lease notifications again. NOTE(swang): + // The timer should already be set here, but we extend it to give some time + // for the reconstructed task to propagate notifications. + SetTaskTimeout(it, initial_reconstruction_timeout_ms_); + + if (success) { + reconstruction_handler_(task_id); + } +} + void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, const ObjectID &required_object_id, int reconstruction_attempt) { @@ -81,17 +100,32 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // reconstruction_attempt many times. return; } - // Increment the number of times reconstruction has been attempted. This is - // used to suppress duplicate reconstructions of the same task. - it->second.reconstruction_attempt++; - // Reset the timer to wait for task lease notifications again. NOTE(swang): - // The timer should already be set here, but we extend it to give some time - // for the reconstructed task to propagate notifications. - SetTaskTimeout(it, initial_reconstruction_timeout_ms_); - // TODO(swang): Suppress simultaneous attempts to reconstruct the task using - // the task reconstruction log. - reconstruction_handler_(task_id); + // Attempt to reconstruct the task by inserting an entry into the task + // reconstruction log. This will fail if another node has already inserted + // an entry for this reconstruction. + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->num_reconstructions = reconstruction_attempt; + reconstruction_entry->node_manager_id = client_id_.binary(); + RAY_CHECK_OK(task_reconstruction_log_.AppendAt( + JobID::nil(), task_id, reconstruction_entry, + /*success_callback=*/ + [this](gcs::AsyncGcsClient *client, const TaskID &task_id, + const TaskReconstructionDataT &data) { + HandleReconstructionLogAppend(task_id, /*success=*/true); + }, + /*failure_callback=*/ + [this](gcs::AsyncGcsClient *client, const TaskID &task_id, + const TaskReconstructionDataT &data) { + HandleReconstructionLogAppend(task_id, /*success=*/false); + }, + reconstruction_attempt)); + + // Increment the number of times reconstruction has been attempted. This is + // used to suppress duplicate reconstructions of the same task. If + // reconstruction is attempted again, the next attempt will try to insert a + // task reconstruction entry at the next index in the log. + it->second.reconstruction_attempt++; } void ReconstructionPolicy::HandleTaskLeaseExpired(const TaskID &task_id) { diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cccdc7d24..dfa69ebf5 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -38,12 +38,13 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// the GCS. /// \param task_lease_pubsub The GCS pub-sub storage system to request task /// lease notifications from. - ReconstructionPolicy(boost::asio::io_service &io_service, - std::function reconstruction_handler, - int64_t initial_reconstruction_timeout_ms, - const ClientID &client_id, - gcs::PubsubInterface &task_lease_pubsub, - std::shared_ptr object_directory); + ReconstructionPolicy( + boost::asio::io_service &io_service, + std::function reconstruction_handler, + int64_t initial_reconstruction_timeout_ms, const ClientID &client_id, + gcs::PubsubInterface &task_lease_pubsub, + std::shared_ptr object_directory, + gcs::LogInterface &task_reconstruction_log); /// Listen for task lease notifications about an object that may require /// reconstruction. If no notifications are received within the initial @@ -114,6 +115,10 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { /// Handle expiration of a task lease. void HandleTaskLeaseExpired(const TaskID &task_id); + /// Handle the response for an attempt at adding an entry to the task + /// reconstruction log. + void HandleReconstructionLogAppend(const TaskID &task_id, bool success); + /// The event loop. boost::asio::io_service &io_service_; /// The handler to call for tasks that require reconstruction. @@ -127,6 +132,7 @@ class ReconstructionPolicy : public ReconstructionPolicyInterface { gcs::PubsubInterface &task_lease_pubsub_; /// The object directory used to lookup object locations. std::shared_ptr object_directory_; + gcs::LogInterface &task_reconstruction_log_; /// The tasks that we are currently subscribed to in the GCS. std::unordered_map listening_tasks_; }; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 828c6d4cb..a0a4334ad 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -53,7 +53,8 @@ class MockObjectDirectory : public ObjectDirectoryInterface { std::unordered_map> locations_; }; -class MockGcs : public gcs::PubsubInterface { +class MockGcs : public gcs::PubsubInterface, + public ray::gcs::LogInterface { public: MockGcs() : notification_callback_(nullptr), failure_callback_(nullptr){}; @@ -89,11 +90,40 @@ class MockGcs : public gcs::PubsubInterface { return ray::Status::OK(); } + Status AppendAt( + const JobID &job_id, const TaskID &task_id, + std::shared_ptr &task_data, + const ray::gcs::LogInterface::WriteCallback + &success_callback, + const ray::gcs::LogInterface::WriteCallback + &failure_callback, + int log_index) { + if (task_reconstruction_log_[task_id].size() == static_cast(log_index)) { + task_reconstruction_log_[task_id].push_back(*task_data); + if (success_callback != nullptr) { + success_callback(nullptr, task_id, *task_data); + } + } else { + if (failure_callback != nullptr) { + failure_callback(nullptr, task_id, *task_data); + } + } + return Status::OK(); + } + + MOCK_METHOD4( + Append, + ray::Status( + const JobID &, const TaskID &, std::shared_ptr &, + const ray::gcs::LogInterface::WriteCallback &)); + private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; + std::unordered_map> + task_reconstruction_log_; }; class ReconstructionPolicyTest : public ::testing::Test { @@ -107,7 +137,7 @@ class ReconstructionPolicyTest : public ::testing::Test { io_service_, [this](const TaskID &task_id) { TriggerReconstruction(task_id); }, reconstruction_timeout_ms_, ClientID::from_random(), mock_gcs_, - mock_object_directory_)), + mock_object_directory_, mock_gcs_)), timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, @@ -247,7 +277,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { // Acquire the task lease for a period longer than the test period. auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::from_random().hex(); + task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = 2 * test_period; mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); @@ -275,7 +305,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::from_random().hex(); + task_lease_data->node_manager_id = ClientID::from_random().binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = reconstruction_timeout_ms_; mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); @@ -320,6 +350,41 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { ASSERT_EQ(reconstructed_tasks_[task_id], 1); } +TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { + TaskID task_id = TaskID::from_random(); + task_id = FinishTaskId(task_id); + ObjectID object_id = ComputeReturnId(task_id, 1); + + // Log a reconstruction attempt to simulate a different node attempting the + // reconstruction first. This should suppress this node's first attempt at + // reconstruction. + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->node_manager_id = ClientID::from_random().binary(); + task_reconstruction_data->num_reconstructions = 0; + RAY_CHECK_OK( + mock_gcs_.AppendAt(DriverID::nil(), task_id, task_reconstruction_data, nullptr, + /*failure_callback=*/ + [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + /*log_index=*/0)); + + // Listen for an object. + reconstruction_policy_->ListenAndMaybeReconstruct(object_id); + // Run the test for longer than the reconstruction timeout. + Run(reconstruction_timeout_ms_ * 1.1); + // Check that reconstruction is suppressed by the reconstruction attempt + // logged by the other node. + ASSERT_TRUE(reconstructed_tasks_.empty()); + + // Run the test for longer than the reconstruction timeout again. + Run(reconstruction_timeout_ms_ * 1.1); + // Check that this time, reconstruction is triggered, since we did not + // receive a task lease notification from the other node yet and our next + // attempt to reconstruct adds an entry at the next index in the + // TaskReconstructionLog. + ASSERT_EQ(reconstructed_tasks_[task_id], 1); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index 7897d0200..472d23296 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -37,6 +37,13 @@ class Task { : task_execution_spec_(*task_flatbuffer.task_execution_spec()), task_spec_(*task_flatbuffer.task_specification()) {} + /// Create a task from a flatbuffer object. + /// + /// \param task_data The task flatbuffer object. + Task(const protocol::TaskT &task_data) + : task_execution_spec_(*task_data.task_execution_spec), + task_spec_(task_data.task_specification) {} + /// Destroy the task. virtual ~Task() {} diff --git a/src/ray/raylet/task_execution_spec.h b/src/ray/raylet/task_execution_spec.h index 6005a40e8..717160d2b 100644 --- a/src/ray/raylet/task_execution_spec.h +++ b/src/ray/raylet/task_execution_spec.h @@ -17,6 +17,9 @@ namespace raylet { /// TaskSpecification that is determined at submission time. class TaskExecutionSpecification { public: + TaskExecutionSpecification(const protocol::TaskExecutionSpecificationT &execution_spec) + : execution_spec_(execution_spec) {} + /// Create a task execution specification. /// /// \param dependencies The task's dependencies, determined at execution diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 40ba37f8f..456a8fc4a 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -126,7 +126,8 @@ int64_t TaskSpecification::ParentCounter() const { throw std::runtime_error("Method not implemented"); } FunctionID TaskSpecification::FunctionId() const { - throw std::runtime_error("Method not implemented"); + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->function_id()); } int64_t TaskSpecification::NumArgs() const { @@ -173,6 +174,11 @@ const ResourceSet TaskSpecification::GetRequiredResources() const { return ResourceSet(required_resources); } +bool TaskSpecification::IsDriverTask() const { + // Driver tasks are empty tasks that have no function ID set. + return FunctionId().is_nil(); +} + bool TaskSpecification::IsActorCreationTask() const { return !ActorCreationId().is_nil(); } diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 7214b4aa0..d21137e58 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -140,6 +140,7 @@ class TaskSpecification { size_t ArgValLength(int64_t arg_index) const; double GetRequiredResource(const std::string &resource_name) const; const ResourceSet GetRequiredResources() const; + bool IsDriverTask() const; // Methods specific to actor tasks. bool IsActorCreationTask() const; diff --git a/test/stress_tests.py b/test/stress_tests.py index 37ee465c6..59c155eb2 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -152,7 +152,9 @@ def ray_start_reconstruction(request): # Start the Redis global state store. node_ip_address = "127.0.0.1" - redis_address, redis_shards = ray.services.start_redis(node_ip_address) + use_raylet = os.environ.get("RAY_USE_XRAY") == "1" + redis_address, redis_shards = ray.services.start_redis( + node_ip_address, use_raylet=use_raylet) redis_ip_address = ray.services.get_ip_address(redis_address) redis_port = ray.services.get_port(redis_address) time.sleep(0.1) @@ -221,9 +223,6 @@ def ray_start_reconstruction(request): ray.shutdown() -@pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", - reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Failing with new GCS API on Linux.") @@ -232,7 +231,7 @@ def test_simple(ray_start_reconstruction): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' # combined allotted memory. - num_objects = 1000 + num_objects = 100 size = int(plasma_store_memory * 1.5 / (num_objects * 8)) # Define a remote task with no dependencies, which returns a numpy @@ -265,9 +264,6 @@ def test_simple(ray_start_reconstruction): del values -@pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", - reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Failing with new GCS API on Linux.") @@ -276,7 +272,7 @@ def test_recursive(ray_start_reconstruction): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' # combined allotted memory. - num_objects = 1000 + num_objects = 100 size = int(plasma_store_memory * 1.5 / (num_objects * 8)) # Define a root task with no dependencies, which returns a numpy array @@ -324,9 +320,6 @@ def test_recursive(ray_start_reconstruction): del values -@pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", - reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Failing with new GCS API on Linux.") @@ -335,7 +328,7 @@ def test_multiple_recursive(ray_start_reconstruction): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' # combined allotted memory. - num_objects = 1000 + num_objects = 100 size = plasma_store_memory * 2 // (num_objects * 8) # Define a root task with no dependencies, which returns a numpy array @@ -466,9 +459,6 @@ def test_nondeterministic_task(ray_start_reconstruction): for error in errors) -@pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", - reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="Failing with new GCS API on Linux.") @@ -477,7 +467,7 @@ def test_driver_put_errors(ray_start_reconstruction): # Define the size of one task's return argument so that the combined # sum of all objects' sizes is at least twice the plasma stores' # combined allotted memory. - num_objects = 1000 + num_objects = 100 size = plasma_store_memory * 2 // (num_objects * 8) # Define a task with a single dependency, a numpy array, that returns