diff --git a/python/ray/actor.py b/python/ray/actor.py index 7fa2bad41..4d0c6dd0a 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import copy import inspect import logging @@ -389,7 +390,7 @@ class ActorClass(object): if kwargs is None: kwargs = {} if is_direct_call is None: - is_direct_call = False + is_direct_call = bool(os.environ.get("RAY_FORCE_DIRECT")) if max_concurrency is None: max_concurrency = 1 diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 67d23901b..4cfaac6c1 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import logging from functools import wraps @@ -87,6 +88,7 @@ class RemoteFunction(object): return self._remote(args=args, kwargs=kwargs) self.remote = _remote_proxy + self.direct_call_enabled = bool(os.environ.get("RAY_FORCE_DIRECT")) def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -157,7 +159,7 @@ class RemoteFunction(object): if num_return_vals is None: num_return_vals = self._num_return_vals if is_direct_call is None: - is_direct_call = False + is_direct_call = self.direct_call_enabled resources = ray.utils.resources_from_resource_arguments( self._num_cpus, self._num_gpus, self._memory, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index df9007cfc..209c056d6 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -426,30 +426,38 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, plasma_store_provider_->Wait(plasma_object_ids, num_objects, /*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready)); } - if (memory_object_ids.size() > 0) { + RAY_CHECK(static_cast(ready.size()) <= num_objects); + if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { // TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should // consider waiting on them in plasma as well to ensure they are local. RAY_RETURN_NOT_OK(memory_store_provider_.Wait( - memory_object_ids, std::max(0, static_cast(ready.size()) - num_objects), + memory_object_ids, num_objects - static_cast(ready.size()), /*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready)); } + RAY_CHECK(static_cast(ready.size()) <= num_objects); + + if (timeout_ms != 0 && static_cast(ready.size()) < num_objects) { + // Clear the ready set and retry. We clear it so that we can compute the number of + // objects to fetch from the memory store easily below. + ready.clear(); - if (static_cast(ready.size()) < num_objects && timeout_ms != 0) { int64_t start_time = current_time_ms(); if (plasma_object_ids.size() > 0) { RAY_RETURN_NOT_OK( plasma_store_provider_->Wait(plasma_object_ids, num_objects, timeout_ms, worker_context_.GetCurrentTaskID(), &ready)); } + RAY_CHECK(static_cast(ready.size()) <= num_objects); if (timeout_ms > 0) { timeout_ms = std::max(0, static_cast(timeout_ms - (current_time_ms() - start_time))); } - if (memory_object_ids.size() > 0) { - RAY_RETURN_NOT_OK( - memory_store_provider_.Wait(memory_object_ids, num_objects, timeout_ms, - worker_context_.GetCurrentTaskID(), &ready)); + if (static_cast(ready.size()) < num_objects && memory_object_ids.size() > 0) { + RAY_RETURN_NOT_OK(memory_store_provider_.Wait( + memory_object_ids, num_objects - static_cast(ready.size()), timeout_ms, + worker_context_.GetCurrentTaskID(), &ready)); } + RAY_CHECK(static_cast(ready.size()) <= num_objects); } for (size_t i = 0; i < ids.size(); i++) { diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 10afc3e58..4b59451c4 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -87,6 +87,9 @@ void GetRequest::Wait() { void GetRequest::Set(const ObjectID &object_id, std::shared_ptr object) { std::unique_lock lock(mutex_); + if (is_ready_) { + return; // We have already hit the number of objects to return limit. + } objects_.emplace(object_id, object); if (objects_.size() == num_objects_) { is_ready_ = true; @@ -176,6 +179,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, (*results).resize(object_ids.size(), nullptr); std::shared_ptr get_request; + int count = 0; { absl::flat_hash_set remaining_ids; @@ -183,7 +187,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, absl::MutexLock lock(&mu_); // Check for existing objects and see if this get request can be fullfilled. - for (size_t i = 0; i < object_ids.size(); i++) { + for (size_t i = 0; i < object_ids.size() && count < num_objects; i++) { const auto &object_id = object_ids[i]; auto iter = objects_.find(object_id); if (iter != objects_.end()) { @@ -193,22 +197,19 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, // because `object_ids` might have duplicate ids. ids_to_remove.insert(object_id); } + count += 1; } else { remaining_ids.insert(object_id); } } + RAY_CHECK(count <= num_objects); for (const auto &object_id : ids_to_remove) { objects_.erase(object_id); } // Return if all the objects are obtained. - if (remaining_ids.empty()) { - return Status::OK(); - } - - if (object_ids.size() - remaining_ids.size() >= static_cast(num_objects)) { - // Already get enough objects. + if (remaining_ids.empty() || count >= num_objects) { return Status::OK(); }