[Java] Attach owner address for pass-by-reference task arguments (#9634)

This commit is contained in:
Kai Yang
2020-09-14 11:46:59 +08:00
committed by GitHub
parent 9795356ac0
commit a43817f34b
20 changed files with 173 additions and 22 deletions
@@ -5,6 +5,7 @@ import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.task.LocalModeTaskSubmitter;
@@ -68,6 +69,11 @@ public class LocalModeWorkerContext implements WorkerContext {
return TaskId.fromBytes(taskSpec.getTaskId().toByteArray());
}
@Override
public Address getRpcAddress() {
return Address.getDefaultInstance();
}
public void setCurrentTask(TaskSpec taskSpec) {
currentTask.set(taskSpec);
}
@@ -1,9 +1,11 @@
package io.ray.runtime.context;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskType;
import java.nio.ByteBuffer;
@@ -51,6 +53,15 @@ public class NativeWorkerContext implements WorkerContext {
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
}
@Override
public Address getRpcAddress() {
try {
return Address.parseFrom(nativeGetRpcAddress());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
private static native int nativeGetCurrentTaskType();
private static native ByteBuffer nativeGetCurrentTaskId();
@@ -60,4 +71,6 @@ public class NativeWorkerContext implements WorkerContext {
private static native ByteBuffer nativeGetCurrentWorkerId();
private static native ByteBuffer nativeGetCurrentActorId();
private static native byte[] nativeGetRpcAddress();
}
@@ -4,6 +4,7 @@ import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskType;
/**
@@ -46,4 +47,6 @@ public interface WorkerContext {
* ID of the current task.
*/
TaskId getCurrentTaskId();
Address getRpcAddress();
}
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.generated.Common.Address;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -106,6 +107,11 @@ public class LocalModeObjectStore extends ObjectStore {
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
}
@Override
public Address getOwnerAddress(ObjectId id) {
return Address.getDefaultInstance();
}
@Override
public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) {
return new byte[0];
@@ -1,9 +1,11 @@
package io.ray.runtime.object;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.BaseId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.generated.Common.Address;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -92,6 +94,15 @@ public class NativeObjectStore extends ObjectStore {
return referenceCounts;
}
@Override
public Address getOwnerAddress(ObjectId id) {
try {
return Address.parseFrom(nativeGetOwnerAddress(id.getBytes()));
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
}
@@ -114,6 +125,8 @@ public class NativeObjectStore extends ObjectStore {
private static native Map<byte[], long[]> nativeGetAllReferenceCounts();
private static native byte[] nativeGetOwnerAddress(byte[] objectId);
private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId);
private static native void nativeRegisterOwnershipInfoAndResolveFuture(byte[] objectId,
@@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId;
import io.ray.api.id.UniqueId;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.exception.RayException;
import io.ray.runtime.generated.Common.Address;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -187,6 +188,8 @@ public abstract class ObjectStore {
*/
public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId);
public abstract Address getOwnerAddress(ObjectId id);
/**
* Promote the given object to the underlying object store, and get the ownership info.
*
@@ -5,6 +5,7 @@ import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectRefImpl;
@@ -39,10 +40,12 @@ public class ArgumentsBuilder {
List<FunctionArg> ret = new ArrayList<>();
for (Object arg : args) {
ObjectId id = null;
Address address = null;
NativeRayObject value = null;
if (arg instanceof ObjectRef) {
Preconditions.checkState(arg instanceof ObjectRefImpl);
id = ((ObjectRefImpl<?>) arg).getId();
address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id);
} else {
value = ObjectSerializer.serialize(arg);
if (language != Language.JAVA) {
@@ -58,6 +61,7 @@ public class ArgumentsBuilder {
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
.putRaw(value);
address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress();
value = null;
}
}
@@ -65,7 +69,7 @@ public class ArgumentsBuilder {
ret.add(FunctionArg.passByValue(PYTHON_DUMMY_TYPE));
}
if (id != null) {
ret.add(FunctionArg.passByReference(id));
ret.add(FunctionArg.passByReference(id, address));
} else {
ret.add(FunctionArg.passByValue(value));
}
@@ -2,6 +2,7 @@ package io.ray.runtime.task;
import com.google.common.base.Preconditions;
import io.ray.api.id.ObjectId;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.object.NativeRayObject;
/**
@@ -15,29 +16,44 @@ public class FunctionArg {
* The id of this argument (passed by reference).
*/
public final ObjectId id;
/**
* The owner address of this argument (passed by reference).
*/
public final Address ownerAddress;
/**
* Serialized data of this argument (passed by value).
*/
public final NativeRayObject value;
private FunctionArg(ObjectId id, NativeRayObject value) {
Preconditions.checkState((id == null) != (value == null));
private FunctionArg(ObjectId id, Address ownerAddress) {
Preconditions.checkNotNull(id);
Preconditions.checkNotNull(ownerAddress);
this.id = id;
this.value = value;
this.ownerAddress = ownerAddress;
this.value = null;
}
private FunctionArg(NativeRayObject nativeRayObject) {
Preconditions.checkNotNull(nativeRayObject);
this.id = null;
this.ownerAddress = null;
this.value = nativeRayObject;
}
/**
* Create a FunctionArg that will be passed by reference.
*/
public static FunctionArg passByReference(ObjectId id) {
return new FunctionArg(id, null);
public static FunctionArg passByReference(ObjectId id, Address ownerAddress) {
return new FunctionArg(id, ownerAddress);
}
/**
* Create a FunctionArg that will be passed by value.
*/
public static FunctionArg passByValue(NativeRayObject value) {
return new FunctionArg(null, value);
return new FunctionArg(value);
}
@Override
@@ -22,6 +22,7 @@ import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
import io.ray.runtime.generated.Common;
import io.ray.runtime.generated.Common.ActorCreationTaskSpec;
import io.ray.runtime.generated.Common.ActorTaskSpec;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.generated.Common.ObjectReference;
import io.ray.runtime.generated.Common.TaskArg;
@@ -381,7 +382,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
TaskArg arg = taskSpec.getArgs(i);
if (arg.getObjectRef().getObjectId() != ByteString.EMPTY) {
functionArgs.add(FunctionArg
.passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray())));
.passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()),
Address.getDefaultInstance()));
} else {
functionArgs.add(FunctionArg.passByValue(
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));