[Java] New Java actor API (#7414)

This commit is contained in:
Hao Chen
2020-03-04 22:39:23 +08:00
committed by GitHub
parent 4198db5038
commit fe7820fec9
46 changed files with 1576 additions and 753 deletions
@@ -6,7 +6,6 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.msgpack.core.Preconditions;
import org.ray.api.annotation.RayRemote;
import org.ray.streaming.api.context.StreamingContext;
import org.ray.streaming.python.PythonFunction;
import org.ray.streaming.python.PythonPartition;
@@ -24,7 +23,6 @@ import org.slf4j.LoggerFactory;
* `streaming/python/runtime/gateway_client.py`
*/
@SuppressWarnings("unchecked")
@RayRemote
public class PythonGateway {
private static final Logger LOG = LoggerFactory.getLogger(PythonGateway.class);
private static final String REFERENCE_ID_PREFIX = "__gateway_reference_id__";
@@ -61,7 +61,7 @@ public class JobSchedulerImpl implements JobScheduler {
switch (executionNode.getLanguage()) {
case JAVA:
RayActor<JobWorker> jobWorker = (RayActor<JobWorker>) worker;
waits.add(Ray.call(JobWorker::init, jobWorker,
waits.add(jobWorker.call(JobWorker::init,
new WorkerContext(taskId, executionGraph, jobConfig)));
break;
case PYTHON:
@@ -4,7 +4,6 @@ import java.io.Serializable;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.annotation.RayRemote;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.streaming.runtime.core.graph.ExecutionGraph;
@@ -28,7 +27,6 @@ import org.slf4j.LoggerFactory;
/**
* The stream job worker, it is a ray actor.
*/
@RayRemote
public class JobWorker implements Serializable {
private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
@@ -101,11 +101,11 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
builder.createActorCreationOptions());
LOGGER.info("call getName on writerActor: {}",
Ray.call(WriterWorker::getName, writerActor).get());
writerActor.call(WriterWorker::getName).get());
LOGGER.info("call getName on readerActor: {}",
Ray.call(ReaderWorker::getName, readerActor).get());
readerActor.call(ReaderWorker::getName).get());
// LOGGER.info(Ray.call(WriterWorker::testCallReader, writerActor, readerActor).get());
// LOGGER.info(writerActor.call(WriterWorker::testCallReader, readerActor).get());
List<String> outputQueueList = new ArrayList<>();
List<String> inputQueueList = new ArrayList<>();
int queueNum = 2;
@@ -118,17 +118,17 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
}
final int msgCount = 100;
Ray.call(ReaderWorker::init, readerActor, inputQueueList, writerActor, msgCount);
readerActor.call(ReaderWorker::init, inputQueueList, writerActor, msgCount);
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
Ray.call(WriterWorker::init, writerActor, outputQueueList, readerActor, msgCount);
writerActor.call(WriterWorker::init, outputQueueList, readerActor, msgCount);
long time = 0;
while (time < 20000 &&
Ray.call(ReaderWorker::getTotalMsg, readerActor).get() < msgCount * queueNum) {
readerActor.call(ReaderWorker::getTotalMsg).get() < msgCount * queueNum) {
try {
Thread.sleep(1000);
time += 1000;
@@ -138,7 +138,7 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
}
Assert.assertEquals(
Ray.call(ReaderWorker::getTotalMsg, readerActor).get().intValue(),
readerActor.call(ReaderWorker::getTotalMsg).get().intValue(),
msgCount * queueNum);
}
@@ -9,7 +9,6 @@ import java.util.Map;
import java.util.Random;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.ActorId;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.actor.NativeRayActor;
@@ -59,7 +58,6 @@ public class Worker {
}
}
@RayRemote
class ReaderWorker extends Worker {
private static final Logger LOGGER = LoggerFactory.getLogger(ReaderWorker.class);
@@ -68,7 +66,7 @@ class ReaderWorker extends Worker {
private List<ActorId> inputActorIds = new ArrayList<>();
private DataReader dataReader = null;
private long handler = 0;
private RayActor peerActor = null;
private RayActor<WriterWorker> peerActor = null;
private int msgCount = 0;
private int totalMsg = 0;
@@ -90,7 +88,7 @@ class ReaderWorker extends Worker {
return "testRayCall";
}
public boolean init(List<String> inputQueueList, RayActor peer, int msgCount) {
public boolean init(List<String> inputQueueList, RayActor<WriterWorker> peer, int msgCount) {
this.inputQueueList = inputQueueList;
this.peerActor = peer;
@@ -176,7 +174,6 @@ class ReaderWorker extends Worker {
}
}
@RayRemote
class WriterWorker extends Worker {
private static final Logger LOGGER = LoggerFactory.getLogger(WriterWorker.class);
@@ -184,7 +181,7 @@ class WriterWorker extends Worker {
private List<String> outputQueueList = null;
private List<ActorId> outputActorIds = new ArrayList<>();
DataWriter dataWriter = null;
RayActor peerActor = null;
RayActor<ReaderWorker> peerActor = null;
int msgCount = 0;
public WriterWorker(String name) {
@@ -199,13 +196,13 @@ class WriterWorker extends Worker {
return name;
}
public String testCallReader(RayActor readerActor) {
String name = (String) Ray.call(ReaderWorker::getName, readerActor).get();
public String testCallReader(RayActor<ReaderWorker> readerActor) {
String name = readerActor.call(ReaderWorker::getName).get();
LOGGER.info("testCallReader: {}", name);
return name;
}
public boolean init(List<String> outputQueueList, RayActor peer, int msgCount) {
public boolean init(List<String> outputQueueList, RayActor<ReaderWorker> peer, int msgCount) {
this.outputQueueList = outputQueueList;
this.peerActor = peer;
@@ -221,7 +218,7 @@ class WriterWorker extends Worker {
LOGGER.info("Peer isDirectActorCall: {}", ((NativeRayActor) peer).isDirectCallActor());
int count = 3;
while (count-- != 0) {
Ray.call(ReaderWorker::testRayCall, peer).get();
peer.call(ReaderWorker::testRayCall).get();
}
try {
@@ -277,4 +274,4 @@ class WriterWorker extends Worker {
e.printStackTrace();
}
}
}
}