[Java] Make both RayActor and RayPyActor inheriting from BaseActor (#7462)

This commit is contained in:
Kai Yang
2020-03-17 21:45:56 +08:00
committed by GitHub
parent dfa5d9b8e9
commit 6b888b0247
21 changed files with 83 additions and 69 deletions
@@ -0,0 +1,26 @@
package org.ray.api;
import org.ray.api.id.ActorId;
/**
* A handle to an actor. <p>
*
* A handle can be used to invoke a remote actor method.
*/
public interface BaseActor {
/**
* @return The id of this actor.
*/
ActorId getId();
/**
* Kill the actor immediately. This will cause any outstanding tasks submitted to the actor to
* fail and the actor to exit in the same way as if it crashed.
*
* @param noReconstruction If set to true, the killed actor will not be reconstructed anymore.
*/
default void kill(boolean noReconstruction) {
Ray.internal().killActor(this, noReconstruction);
}
}
@@ -1,9 +1,7 @@
package org.ray.api;
import org.ray.api.id.ActorId;
/**
* A handle to an actor. <p>
* A handle to a Java actor. <p>
*
* A handle can be used to invoke a remote actor method, with the {@code "call"} method. For
* example:
@@ -14,7 +12,7 @@ import org.ray.api.id.ActorId;
* }
* }
* // Create an actor, and get a handle.
* RayActor<MyActor> myActor = Ray.createActor(RayActor::new);
* RayActor<MyActor> myActor = Ray.createActor(MyActor::new);
* // Call the `echo` method remotely.
* RayObject<Integer> result = myActor.call(MyActor::echo, 1);
* // Get the result of the remote `echo` method.
@@ -26,20 +24,6 @@ import org.ray.api.id.ActorId;
*
* @param <A> The type of the concrete actor class.
*/
public interface RayActor<A> extends ActorCall<A> {
public interface RayActor<A> extends BaseActor, ActorCall<A> {
/**
* @return The id of this actor.
*/
ActorId getId();
/**
* Kill the actor immediately. This will cause any outstanding tasks submitted to the actor to
* fail and the actor to exit in the same way as if it crashed.
*
* @param noReconstruction If set to true, the killed actor will not be reconstructed anymore.
*/
default void kill(boolean noReconstruction) {
Ray.internal().killActor(this, noReconstruction);
}
}
@@ -3,7 +3,7 @@ package org.ray.api;
/**
* Handle of a Python actor.
*/
public interface RayPyActor extends RayActor, PyActorCall {
public interface RayPyActor extends BaseActor, PyActorCall {
/**
* @return Module name of the Python actor class.
@@ -2,6 +2,7 @@ package org.ray.api.runtime;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -82,7 +83,7 @@ public interface RayRuntime {
* @param actor The actor to be killed.
* @param noReconstruction If set to true, the killed actor will not be reconstructed anymore.
*/
void killActor(RayActor<?> actor, boolean noReconstruction);
void killActor(BaseActor actor, boolean noReconstruction);
/**
* Invoke a remote function.
@@ -5,6 +5,7 @@ import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -176,7 +177,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
}
private RayObject callActorFunction(RayActor rayActor,
private RayObject callActorFunction(BaseActor rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
@@ -190,14 +191,14 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
}
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
private BaseActor createActorImpl(FunctionDescriptor functionDescriptor,
Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
RayActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
BaseActor actor = taskSubmitter.createActor(functionDescriptor, functionArgs, options);
return actor;
}
@@ -1,7 +1,7 @@
package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
@@ -49,7 +49,7 @@ public class RayDevRuntime extends AbstractRayRuntime {
}
@Override
public void killActor(RayActor<?> actor, boolean noReconstruction) {
public void killActor(BaseActor actor, boolean noReconstruction) {
throw new UnsupportedOperationException();
}
@@ -3,6 +3,7 @@ package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -139,7 +140,7 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
}
@Override
public void killActor(RayActor<?> actor, boolean noReconstruction) {
public void killActor(BaseActor actor, boolean noReconstruction) {
getCurrentRuntime().killActor(actor, noReconstruction);
}
@@ -6,7 +6,7 @@ import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
@@ -135,7 +135,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
@Override
public void killActor(RayActor<?> actor, boolean noReconstruction) {
public void killActor(BaseActor actor, boolean noReconstruction) {
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
}
@@ -10,7 +10,7 @@ import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
/**
* RayActor implementation for local mode.
* Implementation of actor handle for local mode.
*/
public class LocalModeRayActor implements RayActor, Externalizable {
@@ -6,8 +6,8 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
@@ -15,10 +15,10 @@ import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor abstract language-independent implementation for cluster mode. This is a wrapper class
* for C++ ActorHandle.
* Abstract and language-independent implementation of actor handle for cluster mode. This is a
* wrapper class for C++ ActorHandle.
*/
public abstract class NativeRayActor implements RayActor, Externalizable {
public abstract class NativeRayActor implements BaseActor, Externalizable {
/**
* Address of core worker.
@@ -3,12 +3,13 @@ package org.ray.runtime.actor;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.io.ObjectInput;
import org.ray.api.RayActor;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor Java implementation for cluster mode.
* Java implementation of actor handle for cluster mode.
*/
public class NativeRayJavaActor extends NativeRayActor {
public class NativeRayJavaActor extends NativeRayActor implements RayActor {
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId);
@@ -7,7 +7,7 @@ import org.ray.api.RayPyActor;
import org.ray.runtime.generated.Common.Language;
/**
* RayActor Python implementation for cluster mode.
* Python actor handle implementation for cluster mode.
*/
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
@@ -17,7 +17,7 @@ import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
@@ -188,7 +188,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
@Override
public RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
ActorId actorId = ActorId.fromRandom();
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args)
@@ -203,7 +203,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitActorTask(
RayActor actor, FunctionDescriptor functionDescriptor,
BaseActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(numReturns <= 1);
TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args);
@@ -3,7 +3,7 @@ package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
@@ -33,7 +33,7 @@ public class NativeTaskSubmitter implements TaskSubmitter {
}
@Override
public RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
options);
@@ -43,7 +43,7 @@ public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitActorTask(
RayActor actor, FunctionDescriptor functionDescriptor,
BaseActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(actor instanceof NativeRayActor);
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
@@ -1,7 +1,7 @@
package org.ray.runtime.task;
import java.util.List;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
@@ -30,7 +30,7 @@ public interface TaskSubmitter {
* @param options Options for this actor creation task.
* @return Handle to the actor.
*/
RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options);
/**
@@ -42,6 +42,6 @@ public interface TaskSubmitter {
* @param options Options for this task.
* @return Ids of the return objects.
*/
List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
List<ObjectId> submitActorTask(BaseActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options);
}
@@ -4,11 +4,10 @@ import java.io.File;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.apache.commons.io.FileUtils;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
@@ -141,7 +140,7 @@ public class ClassLoaderTest extends BaseTest {
private <T> RayObject<T> callActorFunction(RayActor<?> rayActor,
FunctionDescriptor functionDescriptor, Object[] args, int numReturns) throws Exception {
Method callActorFunctionMethod = AbstractRayRuntime.class.getDeclaredMethod("callActorFunction",
RayActor.class, FunctionDescriptor.class, Object[].class, int.class);
BaseActor.class, FunctionDescriptor.class, Object[].class, int.class);
callActorFunctionMethod.setAccessible(true);
return (RayObject<T>) callActorFunctionMethod
.invoke(TestUtils.getRuntime(), rayActor, functionDescriptor, args, numReturns);
@@ -6,7 +6,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
/**
* Physical execution graph.
@@ -18,19 +18,19 @@ import org.ray.api.RayActor;
public class ExecutionGraph implements Serializable {
private long buildTime;
private List<ExecutionNode> executionNodeList;
private List<RayActor> sourceWorkers = new ArrayList<>();
private List<RayActor> sinkWorkers = new ArrayList<>();
private List<BaseActor> sourceWorkers = new ArrayList<>();
private List<BaseActor> sinkWorkers = new ArrayList<>();
public ExecutionGraph(List<ExecutionNode> executionNodes) {
this.executionNodeList = executionNodes;
for (ExecutionNode executionNode : executionNodeList) {
if (executionNode.getNodeType() == ExecutionNode.NodeType.SOURCE) {
List<RayActor> actors = executionNode.getExecutionTasks().stream()
List<BaseActor> actors = executionNode.getExecutionTasks().stream()
.map(ExecutionTask::getWorker).collect(Collectors.toList());
sourceWorkers.addAll(actors);
}
if (executionNode.getNodeType() == ExecutionNode.NodeType.SINK) {
List<RayActor> actors = executionNode.getExecutionTasks().stream()
List<BaseActor> actors = executionNode.getExecutionTasks().stream()
.map(ExecutionTask::getWorker).collect(Collectors.toList());
sinkWorkers.addAll(actors);
}
@@ -38,11 +38,11 @@ public class ExecutionGraph implements Serializable {
buildTime = System.currentTimeMillis();
}
public List<RayActor> getSourceWorkers() {
public List<BaseActor> getSourceWorkers() {
return sourceWorkers;
}
public List<RayActor> getSinkWorkers() {
public List<BaseActor> getSinkWorkers() {
return sinkWorkers;
}
@@ -81,10 +81,10 @@ public class ExecutionGraph implements Serializable {
throw new RuntimeException("Task " + taskId + " does not exist!");
}
public Map<Integer, RayActor> getTaskId2WorkerByNodeId(int nodeId) {
public Map<Integer, BaseActor> getTaskId2WorkerByNodeId(int nodeId) {
for (ExecutionNode executionNode : executionNodeList) {
if (executionNode.getNodeId() == nodeId) {
Map<Integer, RayActor> taskId2Worker = new HashMap<>();
Map<Integer, BaseActor> taskId2Worker = new HashMap<>();
for (ExecutionTask executionTask : executionNode.getExecutionTasks()) {
taskId2Worker.put(executionTask.getTaskId(), executionTask.getWorker());
}
@@ -1,7 +1,7 @@
package org.ray.streaming.runtime.core.graph;
import java.io.Serializable;
import org.ray.api.RayActor;
import org.ray.api.BaseActor;
/**
* ExecutionTask is minimal execution unit.
@@ -11,9 +11,9 @@ import org.ray.api.RayActor;
public class ExecutionTask implements Serializable {
private int taskId;
private int taskIndex;
private RayActor worker;
private BaseActor worker;
public ExecutionTask(int taskId, int taskIndex, RayActor worker) {
public ExecutionTask(int taskId, int taskIndex, BaseActor worker) {
this.taskId = taskId;
this.taskIndex = taskIndex;
this.worker = worker;
@@ -35,11 +35,11 @@ public class ExecutionTask implements Serializable {
this.taskIndex = taskIndex;
}
public RayActor getWorker() {
public BaseActor getWorker() {
return worker;
}
public void setWorker(RayActor worker) {
public void setWorker(BaseActor worker) {
this.worker = worker;
}
}
@@ -3,6 +3,7 @@ package org.ray.streaming.runtime.schedule;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
@@ -57,7 +58,7 @@ public class JobSchedulerImpl implements JobScheduler {
List<ExecutionTask> executionTasks = executionNode.getExecutionTasks();
for (ExecutionTask executionTask : executionTasks) {
int taskId = executionTask.getTaskId();
RayActor worker = executionTask.getWorker();
BaseActor worker = executionTask.getWorker();
switch (executionNode.getLanguage()) {
case JAVA:
RayActor<JobWorker> jobWorker = (RayActor<JobWorker>) worker;
@@ -4,8 +4,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.streaming.jobgraph.JobEdge;
import org.ray.streaming.jobgraph.JobGraph;
import org.ray.streaming.jobgraph.JobVertex;
@@ -58,7 +58,7 @@ public class TaskAssignerImpl implements TaskAssigner {
return new ExecutionGraph(executionNodes);
}
private RayActor createWorker(JobVertex jobVertex) {
private BaseActor createWorker(JobVertex jobVertex) {
switch (jobVertex.getLanguage()) {
case PYTHON:
return Ray.createPyActor(
@@ -4,8 +4,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.id.ActorId;
import org.ray.streaming.api.collector.Collector;
import org.ray.streaming.api.context.RuntimeContext;
@@ -65,7 +65,7 @@ public abstract class StreamTask implements Runnable {
List<Collector> collectors = new ArrayList<>();
for (ExecutionEdge edge : outputEdges) {
Map<String, ActorId> outputActorIds = new HashMap<>();
Map<Integer, RayActor> taskId2Worker = executionGraph
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getTargetNodeId());
taskId2Worker.forEach((targetTaskId, targetActor) -> {
String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
@@ -91,7 +91,7 @@ public abstract class StreamTask implements Runnable {
List<ExecutionEdge> inputEdges = executionNode.getInputsEdges();
Map<String, ActorId> inputActorIds = new HashMap<>();
for (ExecutionEdge edge : inputEdges) {
Map<Integer, RayActor> taskId2Worker = executionGraph
Map<Integer, BaseActor> taskId2Worker = executionGraph
.getTaskId2WorkerByNodeId(edge.getSrcNodeId());
taskId2Worker.forEach((srcTaskId, srcActor) -> {
String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());