diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 2ce07b305..4dfcb5210 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1148,6 +1148,25 @@ def ray_start_cluster(): cluster.shutdown() +def test_wait_cluster(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=1, resources={"RemoteResource": 1}) + cluster.add_node(num_cpus=1, resources={"RemoteResource": 1}) + ray.init(redis_address=cluster.redis_address) + + @ray.remote(resources={"RemoteResource": 1}) + def f(): + return + + # Submit some tasks that can only be executed on the remote nodes. + tasks = [f.remote() for _ in range(10)] + # Sleep for a bit to let the tasks finish. + time.sleep(1) + _, unready = ray.wait(tasks, num_returns=len(tasks), timeout=0) + # All remote tasks should have finished. + assert len(unready) == 0 + + def test_object_transfer_dump(ray_start_cluster): cluster = ray_start_cluster diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 7692ce505..51cb2600b 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -204,7 +204,20 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const OnLocationsFound &callback) { ray::Status status; auto it = listeners_.find(object_id); - if (it == listeners_.end()) { + if (it != listeners_.end() && it->second.has_been_created) { + // If we have locations cached due to a concurrent SubscribeObjectLocations + // call, and we have received at least one notification from the GCS about + // the object's creation, then call the callback immediately with the + // cached locations. + auto &locations = it->second.current_object_locations; + bool has_been_created = it->second.has_been_created; + io_service_.post([callback, object_id, locations, has_been_created]() { + callback(object_id, locations, has_been_created); + }); + } else { + // We do not have any locations cached due to a concurrent + // SubscribeObjectLocations call, so look up the object's locations + // directly from the GCS. status = gcs_client_->object_table().Lookup( JobID::nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, @@ -218,14 +231,6 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, // in the GCS client's lookup callback stack. callback(object_id, client_ids, has_been_created); }); - } else { - // If we have locations cached due to a concurrent SubscribeObjectLocations - // call, call the callback immediately with the cached locations. - auto &locations = it->second.current_object_locations; - bool has_been_created = it->second.has_been_created; - io_service_.post([callback, object_id, locations, has_been_created]() { - callback(object_id, locations, has_been_created); - }); } return status; }