diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 905958ddf..76c71cde4 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -61,8 +61,9 @@ public interface RayRuntime { * * @param objectIds The object ids to free. * @param localOnly Whether only free objects for local object store or not. + * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List objectIds, boolean localOnly); + void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); /** * Invoke a remote function. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 4ae1ee606..e91d4df7b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -205,8 +205,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { } @Override - public void free(List objectIds, boolean localOnly) { - rayletClient.freePlasmaObjects(objectIds, localOnly); + public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); } private List> splitIntoBatches(List objectIds) { diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index fcbe91231..0925574bc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -213,4 +213,22 @@ public final class RayNativeRuntime extends AbstractRayRuntime { return false; } + + /** + * Query whether the raylet task exists in Gcs. + */ + public boolean rayletTaskExistsInGcs(UniqueId taskId) { + byte[] key = ArrayUtils.addAll("RAYLET_TASK".getBytes(), taskId.getBytes()); + + // TODO(qwang): refactor this with `GlobalState` after this issue + // getting finished. https://github.com/ray-project/ray/issues/3933 + for (RedisClient client : redisClients) { + if (client.exists(key)) { + return true; + } + } + + return false; + } + } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index b3a77f0dc..385431c70 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -191,7 +191,8 @@ public class MockRayletClient implements RayletClient { } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly) { + public void freePlasmaObjects(List objectIds, boolean localOnly, + boolean deleteCreatingTasks) { for (UniqueId id : objectIds) { store.free(id); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 618784c28..fc6fc75b0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -24,7 +24,7 @@ public interface RayletClient { WaitResult wait(List> waitFor, int numReturns, int timeoutMs, UniqueId currentTaskId); - void freePlasmaObjects(List objectIds, boolean localOnly); + void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); UniqueId prepareCheckpoint(UniqueId actorId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 4682de26e..0ed1f9c86 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -123,9 +123,10 @@ public class RayletClientImpl implements RayletClient { } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly) { + public void freePlasmaObjects(List objectIds, boolean localOnly, + boolean deleteCreatingTasks) { byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); - nativeFreePlasmaObjects(client, objectIdsArray, localOnly); + nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); } @Override @@ -350,7 +351,7 @@ public class RayletClientImpl implements RayletClient { int taskIndex); private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, - boolean localOnly) throws RayException; + boolean localOnly, boolean deleteCreatingTasks) throws RayException; private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId); diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 9b1ea915b..9b3bbf233 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,15 +1,45 @@ package org.ray.api; +import java.util.function.Supplier; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; import org.testng.SkipException; public class TestUtils { + private static final int WAIT_INTERVAL_MS = 5; + public static void skipTestUnderSingleProcess() { AbstractRayRuntime runtime = (AbstractRayRuntime)Ray.internal(); if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { throw new SkipException("This test doesn't work under single-process mode."); } } + + /** + * Wait until the given condition is met. + * + * @param condition A function that predicts the condition. + * @param timeoutMs Timeout in milliseconds. + * @return True if the condition is met within the timeout, false otherwise. + */ + public static boolean waitForCondition(Supplier condition, int timeoutMs) { + int waitTime = 0; + while (true) { + if (condition.get()) { + return true; + } + + try { + java.util.concurrent.TimeUnit.MILLISECONDS.sleep(WAIT_INTERVAL_MS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + waitTime += WAIT_INTERVAL_MS; + if (waitTime > timeoutMs) { + break; + } + } + return false; + } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 876ab322d..c3dadd8f1 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -97,7 +97,7 @@ public class ActorTest extends BaseTest { RayObject value = Ray.call(Counter::getValue, counter); Assert.assertEquals(100, value.get()); // Delete the object from the object store. - Ray.internal().free(ImmutableList.of(value.getId()), false); + Ray.internal().free(ImmutableList.of(value.getId()), false, false); // Wait until the object is deleted, because the above free operation is async. while (true) { GetResult result = ((AbstractRayRuntime) diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 736150b8d..4737740d8 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -1,17 +1,16 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.WaitResult; +import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; -import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayNativeRuntime; +import org.ray.runtime.util.UniqueIdUtil; import org.testng.Assert; import org.testng.annotations.Test; - public class PlasmaFreeTest extends BaseTest { @RayRemote @@ -20,31 +19,27 @@ public class PlasmaFreeTest extends BaseTest { } @Test - public void test() { + public void testDeleteObjects() { RayObject helloId = Ray.call(PlasmaFreeTest::hello); String helloString = helloId.get(); Assert.assertEquals("hello", helloString); - List> waitFor = ImmutableList.of(helloId); - WaitResult waitResult = Ray.wait(waitFor, 1, 2 * 1000); - List> readyOnes = waitResult.getReady(); - List> unreadyOnes = waitResult.getUnready(); - Assert.assertEquals(1, readyOnes.size()); - Assert.assertEquals(0, unreadyOnes.size()); + Ray.internal().free(ImmutableList.of(helloId.getId()), true, false); - List freeList = new ArrayList<>(); - freeList.add(helloId.getId()); - Ray.internal().free(freeList, true); - // Flush: trigger the release function because Plasma Client has cache. - for (int i = 0; i < 128; i++) { - Ray.call(PlasmaFreeTest::hello).get(); - } - - // Check if the object has been evicted. Don't give ray.wait enough - // time to reconstruct the object. - waitResult = Ray.wait(waitFor, 1, 0); - readyOnes = waitResult.getReady(); - unreadyOnes = waitResult.getUnready(); - Assert.assertEquals(0, readyOnes.size()); - Assert.assertEquals(1, unreadyOnes.size()); + final boolean result = TestUtils.waitForCondition(() -> !((AbstractRayRuntime) Ray.internal()) + .getObjectStoreProxy().get(helloId.getId(), 0).exists, 50); + Assert.assertTrue(result); } + + @Test + public void testDeleteCreatingTasks() { + TestUtils.skipTestUnderSingleProcess(); + RayObject helloId = Ray.call(PlasmaFreeTest::hello); + Assert.assertEquals("hello", helloId.get()); + Ray.internal().free(ImmutableList.of(helloId.getId()), true, true); + + final boolean result = TestUtils.waitForCondition(() -> !((RayNativeRuntime) Ray.internal()) + .rayletTaskExistsInGcs(UniqueIdUtil.computeTaskId(helloId.getId())), 50); + Assert.assertTrue(result); + } + } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 453f66bb4..31937837c 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -349,9 +349,9 @@ cdef class RayletClient: check_status(self.client.get().PushProfileEvents(profile_info)) - def free_objects(self, object_ids, c_bool local_only): + def free_objects(self, object_ids, c_bool local_only, c_bool delete_creating_tasks): cdef c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids) - check_status(self.client.get().FreeObjects(free_ids, local_only)) + check_status(self.client.get().FreeObjects(free_ids, local_only, delete_creating_tasks)) def prepare_actor_checkpoint(self, ActorID actor_id): cdef CActorCheckpointID checkpoint_id diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index a496c5b83..be74b06e5 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -67,7 +67,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus PushProfileEvents( const GCSProfileTableDataT &profile_events) CRayStatus FreeObjects(const c_vector[CObjectID] &object_ids, - c_bool local_only) + c_bool local_only, c_bool delete_creating_tasks) CRayStatus PrepareActorCheckpoint(const CActorID &actor_id, CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 902728ec3..89499be5d 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -8,7 +8,7 @@ from ray import profiling __all__ = ["free"] -def free(object_ids, local_only=False): +def free(object_ids, local_only=False, delete_creating_tasks=False): """Free a list of IDs from object stores. This function is a low-level API which should be used in restricted @@ -25,6 +25,8 @@ def free(object_ids, local_only=False): object_ids (List[ObjectID]): List of object IDs to delete. local_only (bool): Whether only deleting the list of objects in local object store or all object stores. + delete_creating_tasks (bool): Whether also delete the object creating + tasks. """ worker = ray.worker.get_global_worker() @@ -46,4 +48,5 @@ def free(object_ids, local_only=False): if len(object_ids) == 0: return - worker.raylet_client.free_objects(object_ids, local_only) + worker.raylet_client.free_objects(object_ids, local_only, + delete_creating_tasks) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 38799d8ad..d535fa593 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1439,13 +1439,16 @@ def test_free_objects_multi_node(ray_start_cluster): assert len(l2) == 0 return (a, b, c) - def run_one_test(actors, local_only): + def run_one_test(actors, local_only, delete_creating_tasks): (a, b, c) = create(actors) # The three objects should be generated on different object stores. assert ray.get(a) != ray.get(b) assert ray.get(a) != ray.get(c) assert ray.get(c) != ray.get(b) - ray.internal.free([a, b, c], local_only=local_only) + ray.internal.free( + [a, b, c], + local_only=local_only, + delete_creating_tasks=delete_creating_tasks) # Wait for the objects to be deleted. time.sleep(0.1) return (a, b, c) @@ -1456,13 +1459,13 @@ def test_free_objects_multi_node(ray_start_cluster): ActorOnNode2.remote() ] # Case 1: run this local_only=False. All 3 objects will be deleted. - (a, b, c) = run_one_test(actors, False) + (a, b, c) = run_one_test(actors, False, False) (l1, l2) = ray.wait([a, b, c], timeout=0.01, num_returns=1) # All the objects are deleted. assert len(l1) == 0 assert len(l2) == 3 # Case 2: run this local_only=True. Only 1 object will be deleted. - (a, b, c) = run_one_test(actors, True) + (a, b, c) = run_one_test(actors, True, False) (l1, l2) = ray.wait([a, b, c], timeout=0.01, num_returns=3) # One object is deleted and 2 objects are not. assert len(l1) == 2 @@ -1472,6 +1475,17 @@ def test_free_objects_multi_node(ray_start_cluster): for object_id in l1: assert ray.get(object_id) != local_return + # Case3: These cases test the deleting creating tasks for the object. + (a, b, c) = run_one_test(actors, False, False) + task_table = ray.global_state.task_table() + for obj in [a, b, c]: + assert ray._raylet.compute_task_id(obj).hex() in task_table + + (a, b, c) = run_one_test(actors, False, True) + task_table = ray.global_state.task_table() + for obj in [a, b, c]: + assert ray._raylet.compute_task_id(obj).hex() not in task_table + def test_local_mode(shutdown_only): @ray.remote diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 9c6094b31..f673e2251 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -207,6 +207,8 @@ table FreeObjectsRequest { // Whether keep this request with local object store // or send it to all the object stores. local_only: bool; + // Whether also delete objects' creating tasks from GCS. + delete_creating_tasks: bool; // List of object ids we'll delete from object store. object_ids: [string]; } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index c55b2608b..c0fa6e105 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -247,11 +247,12 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFreePlasmaObjects - * Signature: ([[BZ)V + * Signature: (J[[BZZ)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( - JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly) { + JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly, + jboolean deleteCreatingTasks) { std::vector object_ids; auto len = env->GetArrayLength(objectIds); for (int i = 0; i < len; i++) { @@ -262,7 +263,7 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( env->DeleteLocalRef(object_id_bytes); } auto raylet_client = reinterpret_cast(client); - auto status = raylet_client->FreeObjects(object_ids, localOnly); + auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); ThrowRayExceptionIfNotOK(env, status); } diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index 8bf64e98c..c00c7c009 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -91,12 +91,12 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(JNIEnv *, jcla /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeFreePlasmaObjects - * Signature: (J[[BZ)V + * Signature: (J[[BZZ)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(JNIEnv *, jclass, jlong, jobjectArray, - jboolean); + jboolean, jboolean); /* * Class: org_ray_runtime_raylet_RayletClientImpl diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 8d7a04d64..7e68eafa1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -711,7 +711,16 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); std::vector object_ids = from_flatbuf(*message->object_ids()); + // Clean up objects from the object store. object_manager_.FreeObjects(object_ids, message->local_only()); + if (message->delete_creating_tasks()) { + // Clean up their creating tasks from GCS. + std::vector creating_task_ids; + for (const auto &object_id : object_ids) { + creating_task_ids.push_back(ComputeTaskId(object_id)); + } + gcs_client_->raylet_task_table().Delete(JobID::nil(), creating_task_ids); + } } break; case protocol::MessageType::PrepareActorCheckpointRequest: { ProcessPrepareActorCheckpointRequest(client, message_data); diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 41a7dfeeb..09e9b5fed 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -349,10 +349,10 @@ ray::Status RayletClient::PushProfileEvents(const ProfileTableDataT &profile_eve } ray::Status RayletClient::FreeObjects(const std::vector &object_ids, - bool local_only) { + bool local_only, bool delete_creating_tasks) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest(fbb, local_only, - to_flatbuf(fbb, object_ids)); + auto message = ray::protocol::CreateFreeObjectsRequest( + fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids)); fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 2e07becfc..d9cd63121 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -145,8 +145,10 @@ class RayletClient { /// \param object_ids A list of ObjectsIDs to be deleted. /// \param local_only Whether keep this request with local object store /// or send it to all the object stores. + /// \param delete_creating_tasks Whether also delete objects' creating tasks from GCS. /// \return ray::Status. - ray::Status FreeObjects(const std::vector &object_ids, bool local_only); + ray::Status FreeObjects(const std::vector &object_ids, bool local_only, + bool deleteCreatingTasks); /// Request raylet backend to prepare a checkpoint for an actor. ///