mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:44:28 +08:00
[Java] Attach owner address for pass-by-reference task arguments (#9634)
This commit is contained in:
@@ -5,6 +5,7 @@ import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.generated.Common.TaskSpec;
|
||||
import io.ray.runtime.generated.Common.TaskType;
|
||||
import io.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
@@ -68,6 +69,11 @@ public class LocalModeWorkerContext implements WorkerContext {
|
||||
return TaskId.fromBytes(taskSpec.getTaskId().toByteArray());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getRpcAddress() {
|
||||
return Address.getDefaultInstance();
|
||||
}
|
||||
|
||||
public void setCurrentTask(TaskSpec taskSpec) {
|
||||
currentTask.set(taskSpec);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package io.ray.runtime.context;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.generated.Common.TaskType;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
@@ -51,6 +53,15 @@ public class NativeWorkerContext implements WorkerContext {
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getRpcAddress() {
|
||||
try {
|
||||
return Address.parseFrom(nativeGetRpcAddress());
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static native int nativeGetCurrentTaskType();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentTaskId();
|
||||
@@ -60,4 +71,6 @@ public class NativeWorkerContext implements WorkerContext {
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentActorId();
|
||||
|
||||
private static native byte[] nativeGetRpcAddress();
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
/**
|
||||
@@ -46,4 +47,6 @@ public interface WorkerContext {
|
||||
* ID of the current task.
|
||||
*/
|
||||
TaskId getCurrentTaskId();
|
||||
|
||||
Address getRpcAddress();
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -106,6 +107,11 @@ public class LocalModeObjectStore extends ObjectStore {
|
||||
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getOwnerAddress(ObjectId id) {
|
||||
return Address.getDefaultInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] promoteAndGetOwnershipInfo(ObjectId objectId) {
|
||||
return new byte[0];
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package io.ray.runtime.object;
|
||||
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -92,6 +94,15 @@ public class NativeObjectStore extends ObjectStore {
|
||||
return referenceCounts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Address getOwnerAddress(ObjectId id) {
|
||||
try {
|
||||
return Address.parseFrom(nativeGetOwnerAddress(id.getBytes()));
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
@@ -114,6 +125,8 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static native Map<byte[], long[]> nativeGetAllReferenceCounts();
|
||||
|
||||
private static native byte[] nativeGetOwnerAddress(byte[] objectId);
|
||||
|
||||
private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId);
|
||||
|
||||
private static native void nativeRegisterOwnershipInfoAndResolveFuture(byte[] objectId,
|
||||
|
||||
@@ -7,6 +7,7 @@ import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.context.WorkerContext;
|
||||
import io.ray.runtime.exception.RayException;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@@ -187,6 +188,8 @@ public abstract class ObjectStore {
|
||||
*/
|
||||
public abstract void removeLocalReference(UniqueId workerId, ObjectId objectId);
|
||||
|
||||
public abstract Address getOwnerAddress(ObjectId id);
|
||||
|
||||
/**
|
||||
* Promote the given object to the underlying object store, and get the ownership info.
|
||||
*
|
||||
|
||||
@@ -5,6 +5,7 @@ import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.RayRuntimeInternal;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
@@ -39,10 +40,12 @@ public class ArgumentsBuilder {
|
||||
List<FunctionArg> ret = new ArrayList<>();
|
||||
for (Object arg : args) {
|
||||
ObjectId id = null;
|
||||
Address address = null;
|
||||
NativeRayObject value = null;
|
||||
if (arg instanceof ObjectRef) {
|
||||
Preconditions.checkState(arg instanceof ObjectRefImpl);
|
||||
id = ((ObjectRefImpl<?>) arg).getId();
|
||||
address = ((RayRuntimeInternal) Ray.internal()).getObjectStore().getOwnerAddress(id);
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (language != Language.JAVA) {
|
||||
@@ -58,6 +61,7 @@ public class ArgumentsBuilder {
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
|
||||
.putRaw(value);
|
||||
address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress();
|
||||
value = null;
|
||||
}
|
||||
}
|
||||
@@ -65,7 +69,7 @@ public class ArgumentsBuilder {
|
||||
ret.add(FunctionArg.passByValue(PYTHON_DUMMY_TYPE));
|
||||
}
|
||||
if (id != null) {
|
||||
ret.add(FunctionArg.passByReference(id));
|
||||
ret.add(FunctionArg.passByReference(id, address));
|
||||
} else {
|
||||
ret.add(FunctionArg.passByValue(value));
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package io.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.object.NativeRayObject;
|
||||
|
||||
/**
|
||||
@@ -15,29 +16,44 @@ public class FunctionArg {
|
||||
* The id of this argument (passed by reference).
|
||||
*/
|
||||
public final ObjectId id;
|
||||
|
||||
/**
|
||||
* The owner address of this argument (passed by reference).
|
||||
*/
|
||||
public final Address ownerAddress;
|
||||
|
||||
/**
|
||||
* Serialized data of this argument (passed by value).
|
||||
*/
|
||||
public final NativeRayObject value;
|
||||
|
||||
private FunctionArg(ObjectId id, NativeRayObject value) {
|
||||
Preconditions.checkState((id == null) != (value == null));
|
||||
private FunctionArg(ObjectId id, Address ownerAddress) {
|
||||
Preconditions.checkNotNull(id);
|
||||
Preconditions.checkNotNull(ownerAddress);
|
||||
this.id = id;
|
||||
this.value = value;
|
||||
this.ownerAddress = ownerAddress;
|
||||
this.value = null;
|
||||
}
|
||||
|
||||
private FunctionArg(NativeRayObject nativeRayObject) {
|
||||
Preconditions.checkNotNull(nativeRayObject);
|
||||
this.id = null;
|
||||
this.ownerAddress = null;
|
||||
this.value = nativeRayObject;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a FunctionArg that will be passed by reference.
|
||||
*/
|
||||
public static FunctionArg passByReference(ObjectId id) {
|
||||
return new FunctionArg(id, null);
|
||||
public static FunctionArg passByReference(ObjectId id, Address ownerAddress) {
|
||||
return new FunctionArg(id, ownerAddress);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a FunctionArg that will be passed by value.
|
||||
*/
|
||||
public static FunctionArg passByValue(NativeRayObject value) {
|
||||
return new FunctionArg(null, value);
|
||||
return new FunctionArg(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -22,6 +22,7 @@ import io.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import io.ray.runtime.generated.Common;
|
||||
import io.ray.runtime.generated.Common.ActorCreationTaskSpec;
|
||||
import io.ray.runtime.generated.Common.ActorTaskSpec;
|
||||
import io.ray.runtime.generated.Common.Address;
|
||||
import io.ray.runtime.generated.Common.Language;
|
||||
import io.ray.runtime.generated.Common.ObjectReference;
|
||||
import io.ray.runtime.generated.Common.TaskArg;
|
||||
@@ -381,7 +382,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
TaskArg arg = taskSpec.getArgs(i);
|
||||
if (arg.getObjectRef().getObjectId() != ByteString.EMPTY) {
|
||||
functionArgs.add(FunctionArg
|
||||
.passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray())));
|
||||
.passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()),
|
||||
Address.getDefaultInstance()));
|
||||
} else {
|
||||
functionArgs.add(FunctionArg.passByValue(
|
||||
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
|
||||
|
||||
Reference in New Issue
Block a user