[JavaWorker] Enable java worker support (#2094)

* Enable java worker support
--------------------------
This commit includes a tailored version of the Java worker implementation from Ant Financial.
The changes for build system, python module, src module and arrow are in other commits, this commit consists of the following modules:
 - java/api: Ray API definition
 - java/common: utilities
 - java/hook: binary rewrite of the Java byte-code for remote execution
 - java/runtime-common: common implementation of the runtime in worker
 - java/runtime-dev: a pure-java mock implementation of the runtime for fast development
 - java/runtime-native: a native implementation of the runtime
 - java/test: various tests

Contributors for this work:
 Guyang Song, Peng Cao, Senlin Zhu,Xiaoying Chu, Yiming Yu, Yujie Liu, Zhenyu Guo

* change the format of java help document from markdown to RST

* update the vesion of Arrow for java worker

* adapt the new version of plasma java client from arrow which use byte[] instead of custom type

* add java worker test to ci

* add the example module for better usage guide
This commit is contained in:
Yujie Liu
2018-05-27 05:38:50 +08:00
committed by Philipp Moritz
parent 74cca3b284
commit a8d3c057c1
193 changed files with 22675 additions and 5 deletions
@@ -0,0 +1,72 @@
package org.ray.api.benchmark;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.test.MyRunner;
@RunWith(MyRunner.class)
public class ActorPressTest extends RayBenchmarkTest {
@Test
public void singleLatencyTest() {
int times = 10;
RayActor<ActorPressTest.Adder> adder = Ray.create(ActorPressTest.Adder.class);
super.singleLatencyTest(times, adder);
}
@Test
public void maxTest() {
int clientNum = 2;
int totalNum = 20;
RayActor<ActorPressTest.Adder> adder = Ray.create(ActorPressTest.Adder.class);
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalNum(totalNum);
pressureTestParameter.setRayBenchmarkTest(this);
pressureTestParameter.setRayActor(adder);
super.maxPressureTest(pressureTestParameter);
}
@Test
public void rateLimiterTest() {
int clientNum = 2;
int totalQps = 2;
int duration = 3;
RayActor<ActorPressTest.Adder> adder = Ray.create(ActorPressTest.Adder.class);
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalQps(totalQps);
pressureTestParameter.setDuration(duration);
pressureTestParameter.setRayBenchmarkTest(this);
pressureTestParameter.setRayActor(adder);
super.rateLimiterPressureTest(pressureTestParameter);
}
@RayRemote
public static class Adder {
public RemoteResult<Integer> add(Integer n) {
RemoteResult<Integer> remoteResult = new RemoteResult<>();
remoteResult.setResult(sum += n);
remoteResult.setFinishTime(System.nanoTime());
return remoteResult;
}
private Integer sum = 0;
}
@Override
public RayObject<RemoteResult<Integer>> rayCall(RayActor rayActor) {
return Ray.call(Adder::add, (RayActor<Adder>) rayActor, 10);
}
@Override
public boolean checkResult(Object o) {
return true;
}
}
@@ -0,0 +1,48 @@
package org.ray.api.benchmark;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.test.MyRunner;
@RunWith(MyRunner.class)
public class MaxPressureTest extends RayBenchmarkTest {
private static final long serialVersionUID = -1684518885171395952L;
public static final int clientNum = 2;
public static final int totalNum = 10;
@Test
public void Test() {
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalNum(totalNum);
pressureTestParameter.setRayBenchmarkTest(this);
super.maxPressureTest(pressureTestParameter);
}
@RayRemote
public static RemoteResult<Integer> currentTime() {
RemoteResult<Integer> remoteResult = new RemoteResult<>();
remoteResult.setFinishTime(System.nanoTime());
remoteResult.setResult(0);
return remoteResult;
}
@Override
public boolean checkResult(Object o) {
return (int) o == 0;
}
@Override
public RayObject<RemoteResult<Integer>> rayCall(RayActor rayActor) {
return Ray.call(MaxPressureTest::currentTime);
}
}
@@ -0,0 +1,79 @@
package org.ray.api.benchmark;
import java.io.Serializable;
import org.ray.api.RayActor;
public class PressureTestParameter implements Serializable {
private static final long serialVersionUID = -52054601722982473L;
private Integer clientNum = 1; //number of test client
private PressureTestType pressureTestType = PressureTestType.RATE_LIMITER; //pressure test type
private Integer totalNum = 1; //total number of task under the mode of MAX
private Integer totalQps = 1; //total qps of task under the mode of RATE_LIMITER
private Integer duration = 1; //duration of the pressure test under the mode of RATE_LIMITER
private RayBenchmarkTest rayBenchmarkTest; //reference of current test case instance
private RayActor rayActor; // reference of the Actor, if only test remote funtion it could be null
public Integer getClientNum() {
return clientNum;
}
public void setClientNum(Integer clientNum) {
this.clientNum = clientNum;
}
public PressureTestType getPressureTestType() {
return pressureTestType;
}
public void setPressureTestType(PressureTestType pressureTestType) {
this.pressureTestType = pressureTestType;
}
public Integer getTotalNum() {
return totalNum;
}
public void setTotalNum(Integer totalNum) {
this.totalNum = totalNum;
}
public Integer getTotalQps() {
return totalQps;
}
public void setTotalQps(Integer totalQps) {
this.totalQps = totalQps;
}
public Integer getDuration() {
return duration;
}
public void setDuration(Integer duration) {
this.duration = duration;
}
public RayBenchmarkTest getRayBenchmarkTest() {
return rayBenchmarkTest;
}
public void setRayBenchmarkTest(RayBenchmarkTest rayBenchmarkTest) {
this.rayBenchmarkTest = rayBenchmarkTest;
}
public RayActor getRayActor() {
return rayActor;
}
public void setRayActor(RayActor rayActor) {
this.rayActor = rayActor;
}
}
@@ -0,0 +1,8 @@
package org.ray.api.benchmark;
public enum PressureTestType {
SINGLE_LATENCY,
RATE_LIMITER,
MAX
}
@@ -0,0 +1,50 @@
package org.ray.api.benchmark;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.test.MyRunner;
@RunWith(MyRunner.class)
public class RateLimiterPressureTest extends RayBenchmarkTest {
private static final long serialVersionUID = 6616958120966144235L;
public static final int clientNum = 2;
public static final int totalQps = 2;
public static final int duration = 10;
@Test
public void Test() {
PressureTestParameter pressureTestParameter = new PressureTestParameter();
pressureTestParameter.setClientNum(clientNum);
pressureTestParameter.setTotalQps(totalQps);
pressureTestParameter.setDuration(duration);
pressureTestParameter.setRayBenchmarkTest(this);
super.rateLimiterPressureTest(pressureTestParameter);
}
@RayRemote
public static RemoteResult<Integer> currentTime() {
RemoteResult<Integer> remoteResult = new RemoteResult<>();
remoteResult.setFinishTime(System.nanoTime());
remoteResult.setResult(0);
return remoteResult;
}
@Override
public boolean checkResult(Object o) {
return (int) o == 0;
}
@Override
public RayObject<RemoteResult<Integer>> rayCall(RayActor rayActor) {
return Ray.call(RateLimiterPressureTest::currentTime);
}
}
@@ -0,0 +1,152 @@
package org.ray.api.benchmark;
import com.google.common.util.concurrent.RateLimiter;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.junit.Assert;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.util.logger.RayLog;
public abstract class RayBenchmarkTest<T> implements Serializable {
private static final long serialVersionUID = 416045641835782523L;
//not thread safe ,but we only have one thread here
public static final DecimalFormat df = new DecimalFormat("00.00");
private static void printList(List<Long> list) {
int len = list.size();
int middle = len / 2;
int almostHundred = (int) (len * 0.9999);
int ninetyNine = (int) (len * 0.99);
int ninetyFive = (int) (len * 0.95);
int ninety = (int) (len * 0.9);
int fifty = (int) (len * 0.5);
RayLog.core.error("Final result of rt as below:");
RayLog.core.error("max: " + list.get(len - 1) + "μs");
RayLog.core.error("min: " + list.get(0) + "μs");
RayLog.core.error("median: " + list.get(middle) + "μs");
RayLog.core.error("99.99% data smaller than: " + list.get(almostHundred) + "μs");
RayLog.core.error("99% data smaller than: " + list.get(ninetyNine) + "μs");
RayLog.core.error("95% data smaller than: " + list.get(ninetyFive) + "μs");
RayLog.core.error("90% data smaller than: " + list.get(ninety) + "μs");
RayLog.core.error("50% data smaller than: " + list.get(fifty) + "μs");
}
public void singleLatencyTest(int times, RayActor rayActor) {
List<Long> counterList = new ArrayList<>();
for (int i = 0; i < times; i++) {
long startTime = System.nanoTime();
RayObject<RemoteResult<T>> rayObject = rayCall(rayActor);
RemoteResult<T> remoteResult = rayObject.get();
T t = remoteResult.getResult();
long endTime = System.nanoTime();
long costTime = endTime - startTime;
counterList.add(costTime / 1000);
RayLog.core.warn("SINGLE_LATENCY_cost_time: " + costTime + " us");
Assert.assertTrue(checkResult(t));
}
Collections.sort(counterList);
printList(counterList);
}
public void rateLimiterPressureTest(PressureTestParameter pressureTestParameter) {
pressureTestParameter.setPressureTestType(PressureTestType.RATE_LIMITER);
notSinglePressTest(pressureTestParameter);
}
public void maxPressureTest(PressureTestParameter pressureTestParameter) {
pressureTestParameter.setPressureTestType(PressureTestType.MAX);
notSinglePressTest(pressureTestParameter);
}
private void notSinglePressTest(PressureTestParameter pressureTestParameter) {
List<Long> counterList = new ArrayList<>();
int clientNum = pressureTestParameter.getClientNum();
RayObject<List<Long>>[] rayObjects = new RayObject[clientNum];
for (int i = 0; i < clientNum; i++) {
rayObjects[i] = Ray.call(RayBenchmarkTest::singleClient, pressureTestParameter);
}
for (int i = 0; i < clientNum; i++) {
List<Long> subCounterList = rayObjects[i].get();
Assert.assertNotNull(subCounterList);
counterList.addAll(subCounterList);
}
Collections.sort(counterList);
printList(counterList);
}
@RayRemote
private static List<Long> singleClient(PressureTestParameter pressureTestParameter) {
try {
List<Long> counterList = new ArrayList<>();
PressureTestType pressureTestType = pressureTestParameter.getPressureTestType();
RayBenchmarkTest rayBenchmarkTest = pressureTestParameter.getRayBenchmarkTest();
int clientNum = pressureTestParameter.getClientNum();
//SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
int len;
String logPrefix;
RateLimiter rateLimiter = null;
if (pressureTestType.equals(PressureTestType.MAX)) {
len = pressureTestParameter.getTotalNum() / clientNum;
logPrefix = "MAX";
} else {
int totalQps = pressureTestParameter.getTotalQps();
int duration = pressureTestParameter.getDuration();
int qps = totalQps / clientNum;
rateLimiter = RateLimiter.create(qps);
len = qps * duration;
logPrefix = "RATE_LIMITER";
}
RemoteResultWrapper[] remoteResultWrappers = new RemoteResultWrapper[len];
int i = 0;
while (i < len) {
if (rateLimiter != null) {
rateLimiter.acquire();
}
// Date currentTime = new Date();
// String dateString = formatter.format(currentTime);
// RayLog.core.info(logPrefix + "_startTime: " + dateString);
RemoteResultWrapper temp = new RemoteResultWrapper();
temp.setStartTime(System.nanoTime());
temp.setRayObject(rayBenchmarkTest.rayCall(pressureTestParameter.getRayActor()));
remoteResultWrappers[i++] = temp;
}
int j = 0;
while (j < len) {
RemoteResultWrapper temp = remoteResultWrappers[j++];
RemoteResult remoteResult = (RemoteResult) temp.getRayObject().get();
long endTime = remoteResult.getFinishTime();
long costTime = endTime - temp.getStartTime();
counterList.add(costTime / 1000);
RayLog.core.warn(logPrefix + "_cost_time:" + costTime + "ns");
Assert.assertTrue(rayBenchmarkTest.checkResult(remoteResult.getResult()));
}
return counterList;
} catch (Exception e) {
RayLog.core.error("singleClient", e);
return null;
}
}
abstract public RayObject<RemoteResult<T>> rayCall(RayActor rayActor);
abstract public boolean checkResult(T t);
}
@@ -0,0 +1,28 @@
package org.ray.api.benchmark;
import java.io.Serializable;
public class RemoteResult<T> implements Serializable {
private static final long serialVersionUID = -3825949468039358540L;
private long finishTime;
private T result;
public long getFinishTime() {
return finishTime;
}
public void setFinishTime(long finishTime) {
this.finishTime = finishTime;
}
public T getResult() {
return result;
}
public void setResult(T result) {
this.result = result;
}
}
@@ -0,0 +1,26 @@
package org.ray.api.benchmark;
import org.ray.api.RayObject;
public class RemoteResultWrapper<T> {
private long startTime;
private RayObject<RemoteResult<T>> rayObject;
public long getStartTime() {
return startTime;
}
public void setStartTime(long startTime) {
this.startTime = startTime;
}
public RayObject<RemoteResult<T>> getRayObject() {
return rayObject;
}
public void setRayObject(RayObject<RemoteResult<T>> rayObject) {
this.rayObject = rayObject;
}
}
@@ -0,0 +1,39 @@
package org.ray.api.benchmark;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.test.MyRunner;
@RunWith(MyRunner.class)
public class SingleLatencyTest extends RayBenchmarkTest {
private static final long serialVersionUID = 3559601273941694468L;
public static final int totalNum = 10;
@Test
public void Test() {
super.singleLatencyTest(totalNum, null);
}
@RayRemote
public static RemoteResult<Integer> doFunc() {
RemoteResult<Integer> remoteResult = new RemoteResult<>();
remoteResult.setResult(1);
return remoteResult;
}
@Override
public RayObject<RemoteResult<Integer>> rayCall(RayActor rayActor) {
return Ray.call(SingleLatencyTest::doFunc);
}
@Override
public boolean checkResult(Object o) {
return (int) o == 1;
}
}
@@ -0,0 +1,179 @@
package org.ray.api.experiment.mr;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
/**
* mimic the MapReduce interface atop of Ray API (in memory version)
*/
public class MemoryMapReduce<TInput, TMapKey, TMapValue, TReduceValue> {
public List<Pair<TMapKey, TMapValue>> Map(TInput input) throws Exception {
throw new Exception("not implemented");
}
public TReduceValue Reduce(TMapKey k, List<TMapValue> values) throws Exception {
throw new Exception("not implemented");
}
//
// main logic to execute this map-reduce with remote mappers and reducers
//
// @param inputs - given input file segments each containing List<TInput>
// @return output file segments each containing SortedMap<TMapKey, TReduceValue>
//
public SortedMap<TMapKey, TReduceValue> Run(List<List<TInput>> inputs, int mapperCount,
Integer reducerCount) {
// start all mappers
ArrayList<RayList<SortedMap<TMapKey, List<TMapValue>>>> mappers = new ArrayList<>();
int inputCountPerMap = inputs.size() / mapperCount;
int index = 0;
for (int i = 0; i < mapperCount; i++) {
if (index >= inputs.size()) {
break;
}
List<List<TInput>> perMapInputs = new ArrayList<>();
for (int j = 0; j < inputCountPerMap && index < inputs.size(); j++) {
perMapInputs.add(inputs.get(index++));
}
mappers.add(Ray.call_n(MemoryMapReduce::InternalMap, reducerCount, reducerCount, perMapInputs,
this.getClass().getName()));
}
// BSP barrier for all mappers to be finished
// this is unnecessary as later on we call map.get() to wait for mappers to be completed
// Ray.wait(mappers.toArray(new RayObject<?>[0]), mappers.size(), 0);
// start all reducers
ArrayList<RayObject<SortedMap<TMapKey, TReduceValue>>> reducers = new ArrayList<>();
for (int i = 0; i < reducerCount; i++) {
// collect states from mappers for this reducer
RayList<SortedMap<TMapKey, List<TMapValue>>> fromMappers = new RayList();
for (int j = 0; j < mapperCount; j++) {
assert (mappers.get(j).size() == reducerCount);
fromMappers.add(mappers.get(j).Get(i));
}
// start this reducer with given input
reducers.add(Ray.call(MemoryMapReduce::InternalReduce,
(List<SortedMap<TMapKey, List<TMapValue>>>) fromMappers, this.getClass().getName()));
}
// BSP barrier for all reducers to be finished
// this is unnecessary coz we will call reducer.get() to wait for their completion
// Ray.wait(reducers.toArray(new RayObject<?>[0]), reducers.size(), 0);
// collect outputs
TreeMap<TMapKey, TReduceValue> outputs = new TreeMap<>();
for (RayObject<SortedMap<TMapKey, TReduceValue>> r : reducers) {
r.get().forEach(outputs::put);
}
return outputs;
}
//
// given a set of input files, output another set of files for reducers
//
// @param inputs - for each input file, it contains List<TInput>
// @return %recuderCount% number of files for the reducers numbering from 0 to reducerCount - 1.
// for each output file, it contains SortedMap<TMapKey, List<TMapValue>>
//
@RayRemote
public static <TInput, TMapKey, TMapValue, TReduceValue> List<SortedMap<TMapKey, List<TMapValue>>> InternalMap(
Integer reducerCount,
List<List<TInput>> inputs,
String mrClassName) {
MemoryMapReduce<TInput, TMapKey, TMapValue, TReduceValue> mr;
try {
mr = (MemoryMapReduce<TInput, TMapKey, TMapValue, TReduceValue>) Class
.forName(mrClassName).getConstructors()[0].newInstance();
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
return null;
}
ArrayList<SortedMap<TMapKey, List<TMapValue>>> out = new ArrayList<>();
for (int i = 0; i < reducerCount; i++) {
out.add(new TreeMap<>());
}
for (List<TInput> inputSeg : inputs) {
for (TInput input : inputSeg) {
try {
List<Pair<TMapKey, TMapValue>> result = mr.Map(input);
for (Pair<TMapKey, TMapValue> pr : result) {
int reducerIndex = Math.abs(pr.getKey().hashCode()) % reducerCount;
out.get(reducerIndex).computeIfAbsent(pr.getKey(), k -> new ArrayList<>());
out.get(reducerIndex).get(pr.getKey()).add(pr.getValue());
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
return out;
}
//public static class Helper {
//
// given a set of input sets from all mappers, performance merge sort and apply reducer function
//
// @param inputs each file contains SortedMap<TMapKey, List<TMapValue>>
// @return an output file contains SortedMap<TMapKey, TReduceValue>
//
@RayRemote
public static <TInput, TMapKey, TMapValue, TReduceValue> SortedMap<TMapKey, TReduceValue> InternalReduce(
List<SortedMap<TMapKey, List<TMapValue>>> inputs,
String mrClassName) {
MemoryMapReduce<TInput, TMapKey, TMapValue, TReduceValue> mr;
try {
mr = (MemoryMapReduce<TInput, TMapKey, TMapValue, TReduceValue>) Class
.forName(mrClassName).getConstructors()[0].newInstance();
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
return null;
}
// merge inputs from many mappers
TreeMap<TMapKey, List<TMapValue>> minputs = new TreeMap<>();
for (SortedMap<TMapKey, List<TMapValue>> input : inputs) {
for (Map.Entry<TMapKey, List<TMapValue>> entry : input.entrySet()) {
if (!minputs.containsKey(entry.getKey())) {
minputs.put(entry.getKey(), new ArrayList<>());
}
minputs.get(entry.getKey()).addAll(entry.getValue());
}
}
// reduce
TreeMap<TMapKey, TReduceValue> out = new TreeMap<>();
for (Map.Entry<TMapKey, List<TMapValue>> entry : minputs.entrySet()) {
try {
out.put(entry.getKey(), mr.Reduce(entry.getKey(), entry.getValue()));
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return out;
}
// }
}
@@ -0,0 +1,154 @@
package org.ray.api.test;
import java.util.ArrayList;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
@RunWith(MyRunner.class)
public class ActorTest {
@Test
public void Test() {
RayActor<ActorTest.Adder> adder = Ray.create(ActorTest.Adder.class);
Ray.call(Adder::set, adder, 10);
RayObject<Integer> result = Ray.call(Adder::add, adder, 1);
Assert.assertEquals(11, (int) result.get());
RayActor<Adder> secondAdder = Ray.create(Adder.class);
RayObject<Integer> result2 = Ray.call(Adder::add, secondAdder, 1);
Assert.assertEquals(1, (int) result2.get());
RayObject<Integer> result3 = Ray.call(Adder::add2, 1);
Assert.assertEquals(2, (int) result3.get());
RayObject<Integer> result4 = Ray.call(ActorTest::sayWorld, 2, adder);
Assert.assertEquals(14, (int) result4.get());
RayActor<Adder2> adder2 = Ray.create(Adder2.class);
Ray.call(Adder2::setAdder, adder2, adder);
RayObject<Integer> result5 = Ray.call(Adder2::increase, adder2);
Assert.assertEquals(1, (int) result5.get());
List list = new ArrayList<>();
list.add(adder);
Ray.call(Adder2::setAdderList, adder2, list);
RayObject<Integer> result7 = Ray.call(Adder2::testActorList, adder2);
Assert.assertEquals(14, (int) result7.get());
List tempList = new ArrayList<>();
tempList.add(result);
Ray.call(Adder::setObjectList, adder, tempList);
RayObject<Integer> result8 = Ray.call(Adder::testObjectList, adder);
Assert.assertEquals(11, (int) result8.get());
}
@RayRemote
public static class Adder {
private List<RayObject<Integer>> objectList;
public Integer set(Integer n) {
sum = n;
return sum;
}
public Integer increase() {
return (++sum);
}
public Integer add(Integer n) {
return (sum += n);
}
public static Integer add2(Integer n) {
return n + 1;
}
public Integer setObjectList(List<RayObject<Integer>> objectList) {
this.objectList = objectList;
return 1;
}
public Integer testObjectList() {
return ((RayObject<Integer>) objectList.get(0)).get();
}
private Integer sum = 0;
}
@RayRemote
public static Integer sayWorld(Integer n, RayActor<ActorTest.Adder> adder) {
RayObject<Integer> result = Ray.call(ActorTest.Adder::add, adder, 1);
return result.get() + n;
}
@RayRemote
public static class Adder2 {
private RayActor<Adder> adder;
private List<RayActor<Adder>> adderList;
private UniqueID id;
public Integer set(Integer n) {
sum = n;
return sum;
}
public Integer increase() {
RayObject<Integer> result = Ray.call(Adder::increase, adder);
Assert.assertEquals(13, (int) result.get());
return (++sum);
}
public Integer testActorList() {
RayActor<Adder> temp = adderList.get(0);
RayObject<Integer> result = Ray.call(Adder::increase, temp);
return result.get();
}
public Integer add(Integer n) {
return (sum += n);
}
public static Integer add2(Adder a, Integer n) {
return n + 1;
}
public RayActor<Adder> getAdder() {
return adder;
}
public Integer setAdder(RayActor<Adder> adder) {
this.adder = adder;
return 0;
}
public UniqueID getId() {
return id;
}
public Integer setId(UniqueID id) {
this.id = id;
adder = new RayActor<>(id);
return 0;
}
public Integer setAdderList(List<RayActor<Adder>> adderList) {
this.adderList = adderList;
return 0;
}
private Integer sum = 0;
}
}
@@ -0,0 +1,46 @@
package org.ray.api.test;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayRemote;
@RunWith(MyRunner.class)
public class EchoTest {
@Test
public void test() {
long startTime, endTime;
for (int i = 0; i < 100; i++) {
startTime = System.nanoTime();
String ret = echo("Ray++" + i);
endTime = System.nanoTime();
System.out.println("echo: " + ret + " , total time is " + (endTime - startTime));
}
}
public String echo(String who) {
return Ray.call(
EchoTest::recho,
Ray.call(EchoTest::Hi),
Ray.call(EchoTest::Who, who)
).get();
}
@RayRemote
public static String Hi() {
return "Hi";
}
@RayRemote
public static String Who(String who) {
return who;
}
@RayRemote
public static String recho(String pre, String who) {
return pre + ", " + who + "!";
}
}
@@ -0,0 +1,51 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.util.logger.RayLog;
/**
* Hello world
*/
@RunWith(MyRunner.class)
public class HelloWorldTest {
@Test
public void test() {
String helloWorld = sayHelloWorld();
RayLog.rapp.info(helloWorld);
Assert.assertEquals("hello,world!", helloWorld);
Assert.assertTrue(Ray.call(TypesTest::sayBool).get());
}
public String sayHelloWorld() {
RayObject<String> hello = Ray.call(HelloWorldTest::sayHello);
RayObject<String> world = Ray.call(HelloWorldTest::sayWorld);
return Ray.call(HelloWorldTest::merge, hello, world).get();
}
@RayRemote
public static String sayHello() {
String ret = "he";
ret += "llo";
RayLog.rapp.info("real say hello");
//throw new RuntimeException("+++++++++++++++++++++hello exception");
return ret;
}
@RayRemote
public static String sayWorld() {
String ret = "world";
ret += "!";
return ret;
}
@RayRemote
public static String merge(String hello, String world) {
return hello + "," + world;
}
}
@@ -0,0 +1,71 @@
package org.ray.api.test;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.experiment.mr.MemoryMapReduce;
/**
* test the MapReduce interface
*/
@RunWith(MyRunner.class)
public class MemoryWordCountTest {
public static class MemoryWordCount extends MemoryMapReduce<String, String, Integer, Integer> {
public List<Pair<String, Integer>> Map(String input) {
ArrayList<Pair<String, Integer>> counts = new ArrayList<>();
for (String s : input.split(" ")) {
counts.add(Pair.of(s, 1));
}
return counts;
}
public Integer Reduce(String k, List<Integer> values) {
return values.size();
}
}
@Test
public void test() {
List<List<String>> iinputs = new ArrayList<>();
List<String> inputs = new ArrayList<>();
inputs.add("1 3 5 7 9");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("1 3 5 7 9");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("1 3 5 7 9");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
MemoryWordCount wc = new MemoryWordCount();
SortedMap<String, Integer> result = wc.Run(iinputs, 2, 2);
Assert.assertEquals(6, (int) result.get("0"));
Assert.assertEquals(6, (int) result.get("1"));
Assert.assertEquals(6, (int) result.get("2"));
Assert.assertEquals(6, (int) result.get("3"));
Assert.assertEquals(6, (int) result.get("4"));
Assert.assertEquals(6, (int) result.get("5"));
Assert.assertEquals(6, (int) result.get("6"));
Assert.assertEquals(6, (int) result.get("7"));
Assert.assertEquals(6, (int) result.get("8"));
Assert.assertEquals(6, (int) result.get("9"));
}
}
@@ -0,0 +1,19 @@
package org.ray.api.test;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.model.InitializationError;
public class MyRunner extends BlockJUnit4ClassRunner {
public MyRunner(Class<?> klass) throws InitializationError {
super(klass);
}
@Override
public void run(RunNotifier notifier) {
notifier.addListener(new TestListener());
notifier.fireTestRunStarted(getDescription());
super.run(notifier);
}
}
@@ -0,0 +1,54 @@
package org.ray.api.test;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.util.logger.RayLog;
/**
* Integration test for Ray.*
*/
@RunWith(MyRunner.class)
public class RayMethodsTest {
@Test
public void test() {
RayObject<Integer> i1Id = Ray.put(1);
RayObject<Double> f1Id = Ray.put(3.14);
RayObject<String> s1Id = Ray.put(String.valueOf("Hello "));
RayObject<String> s2Id = Ray.put(String.valueOf("World!"));
RayObject<Object> n1Id = Ray.put(null);
RayList<String> wIds = new RayList<>();
wIds.add(s1Id);
wIds.add(s2Id);
WaitResult<String> readys = Ray.wait(wIds, 2);
List<String> ss = readys.getReadyOnes().get();
int i1 = i1Id.get();
double f1 = f1Id.get();
Object n1 = n1Id.get();
RayLog.rapp.info("Strings: " + ss.get(0) + ss.get(1) +
" int: " + i1 +
" double: " + f1 +
" null: " + n1);
Assert.assertEquals("Hello World!", ss.get(0) + ss.get(1));
Assert.assertEquals(1, i1);
Assert.assertEquals(3.14, f1, Double.MIN_NORMAL);
Assert.assertNull(n1);
// metadata test
RayObject<Integer> vId = Ray.put(643, "test metadata");
Integer v = vId.get();
String m = vId.getMeta();
Assert.assertEquals(643L, v.longValue());
Assert.assertEquals("test metadata", m);
}
}
@@ -0,0 +1,43 @@
package org.ray.api.test;
import org.apache.commons.lang3.SerializationUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.util.RemoteFunction;
@RunWith(MyRunner.class)
public class RemoteLambdaTest {
public static <T> String RemoteToString(T o) {
return o.toString();
}
@Test
public void test() {
RemoteFunction<String, String> f0 = RemoteLambdaTest::RemoteToString;
byte[] bytes = SerializationUtils.serialize(f0);
//System.out.println(new String(bytes));
//Object m = SerializationUtils.deserialize(bytes);
RemoteFunction<Integer, Integer> f = x ->
{
System.out.println("remote function " + x);
return x + 1;
};
RemoteFunction<Integer, Integer> f2 = SerializationUtils.clone(f);
Assert.assertEquals(101, (int) f2.apply(100));
Integer y = 100;
RemoteFunction<Integer, Integer> f3 = x ->
{
System.out.println("remote function " + x);
return x + y + 1;
};
RemoteFunction<Integer, Integer> f4 = SerializationUtils.clone(f3);
Assert.assertEquals(201, (int) f4.apply(100));
Assert.assertEquals(201, (int) f4.apply(100));
}
}
@@ -0,0 +1,13 @@
package org.ray.api.test;
import java.io.IOException;
import java.util.zip.DataFormatException;
import org.ray.hook.JarRewriter;
public class RewriteTest {
public static void main(String[] args) throws IOException, DataFormatException {
System.out.println(System.getProperty("user.dir"));
JarRewriter.rewrite("target", "target2");
}
}
@@ -0,0 +1,39 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayRemote;
@RunWith(MyRunner.class)
public class RpcTest {
@Test
public void test() {
Assert.assertEquals(0, (int) Ray.call(RpcTest::with0Params).get());
Assert.assertEquals(1, (int) Ray.call(RpcTest::with1Params, 1).get());
Assert.assertEquals(3, (int) Ray.call(RpcTest::with2Params, 1, 2).get());
Assert.assertEquals(6, (int) Ray.call(RpcTest::with3Params, 1, 2, 3).get());
}
@RayRemote
public static Integer with0Params() {
return 0;
}
@RayRemote
public static Integer with1Params(Integer x) {
return x;
}
@RayRemote
public static Integer with2Params(Integer x, Integer y) {
return x + y;
}
@RayRemote
public static Integer with3Params(Integer x, Integer y, Integer z) {
return x + y + z;
}
}
@@ -0,0 +1,20 @@
package org.ray.api.test;
import org.junit.runner.Description;
import org.junit.runner.Result;
import org.junit.runner.notification.RunListener;
import org.ray.api.Ray;
import org.ray.core.RayRuntime;
public class TestListener extends RunListener {
@Override
public void testRunStarted(Description description) {
Ray.init();
}
@Override
public void testRunFinished(Result result) {
RayRuntime.getInstance().cleanUp();
}
}
@@ -0,0 +1,20 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
/**
* local test in IDE, for class lazy load
*/
@RunWith(MyRunner.class)
public class TwoClassTest {
@Test
public void testLocal() {
Assert.assertTrue(
Ray.call(TypesTest::sayBool).get());//call function in other class which may not be loaded
}
}
@@ -0,0 +1,202 @@
package org.ray.api.test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.returns.MultipleReturns2;
import org.ray.api.returns.RayObjects2;
/**
* types test
*/
@RunWith(MyRunner.class)
public class TypesTest {
@Test
public void test() {
sayTypes();
}
public void sayTypes() {
Assert.assertEquals(1, (int) Ray.call(TypesTest::sayInt).get());
Assert.assertEquals(1, (byte) Ray.call(TypesTest::sayByte).get());
Assert.assertEquals(1, (short) Ray.call(TypesTest::sayShort).get());
Assert.assertEquals(1, (long) Ray.call(TypesTest::sayLong).get());
Assert.assertEquals(1.0, Ray.call(TypesTest::sayDouble).get(), 0.0);
Assert.assertEquals(1.0f, Ray.call(TypesTest::sayFloat).get(), 0.0);
Assert.assertEquals(true, Ray.call(TypesTest::sayBool).get());
Assert.assertEquals("object", Ray.call(TypesTest::sayReference).get());
RayObjects2<Integer, String> refs = Ray.call_2(TypesTest::sayReferences);
Assert.assertEquals(123, (int) refs.r0().get());
Assert.assertEquals("123", refs.r1().get());
RayMap<Integer, String> futureRs = Ray.call_n(TypesTest::sayReferencesN,
Arrays.asList(1, 2, 4, 3), "n_refs_", "_suffix");
for (Entry<Integer, RayObject<String>> fne : futureRs.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_refs_" + fne.getKey() + "_suffix");
}
RayMap<Integer, String> futureRs2 = Ray.call_n(TypesTest::sayReferencesN,
Arrays.asList(1), "n_refs_", "_suffix");
for (Entry<Integer, RayObject<String>> fne : futureRs2.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_refs_" + fne.getKey() + "_suffix");
}
RayObject<Integer> future = Ray.call(TypesTest::sayRayFuture);
Assert.assertEquals(123, (int) future.get());
RayObjects2<Integer, String> futures = Ray.call_2(TypesTest::sayRayFutures);
Assert.assertEquals(123, (int) futures.r0().get());
Assert.assertEquals("123", futures.r1().get());
RayMap<Integer, String> futureNs = Ray.call_n(TypesTest::sayRayFuturesN,
Arrays.asList(1, 2, 4, 3), "n_futures_");
for (Entry<Integer, RayObject<String>> fne : futureNs.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_futures_" + fne.getKey());
}
RayList<Integer> ns = Ray.call_n(TypesTest::sayArray, 10, 10);
for (int i = 0; i < 10; i++) {
Assert.assertEquals(i, (int) ns.Get(i).get());
}
RayList<Integer> ns2 = Ray.call_n(TypesTest::sayArray, 1, 1);
Assert.assertEquals(0, (int) ns2.Get(0).get());
RayObject<List<Integer>> ns3 = Ray.call(TypesTest::sayArray, 1);
Assert.assertEquals(0, (int) ns3.get().get(0));
RayList<Integer> ints = new RayList<>();
ints.add(Ray.call(TypesTest::sayInt));
ints.add(Ray.call(TypesTest::sayInt));
ints.add(Ray.call(TypesTest::sayInt));
// TODO: when RayParameters.use_remote_lambda is on, we have to explicitly
// cast RayList and RayMap to List and Map explicitly, so that the parameter
// types of the lambdas can be correctly deducted.
RayObject<Integer> collection = Ray.call(TypesTest::sayReadRayList, (List<Integer>) ints);
Assert.assertEquals(3, (int) collection.get());
RayMap<String, Integer> namedInts = new RayMap();
namedInts.put("a", Ray.call(TypesTest::sayInt));
namedInts.put("b", Ray.call(TypesTest::sayInt));
namedInts.put("c", Ray.call(TypesTest::sayInt));
RayObject<Integer> collection2 = Ray
.call(TypesTest::sayReadRayMap, (Map<String, Integer>) namedInts);
Assert.assertEquals(3, (int) collection2.get());
}
@RayRemote
public static int sayInt() {
return 1;
}
@RayRemote
public static byte sayByte() {
return 1;
}
@RayRemote
public static short sayShort() {
return 1;
}
@RayRemote
public static long sayLong() {
return 1;
}
@RayRemote
public static double sayDouble() {
return 1;
}
@RayRemote
public static float sayFloat() {
return 1;
}
@RayRemote
public static boolean sayBool() {
return true;
}
@RayRemote
public static Object sayReference() {
return "object";
}
@RayRemote
public static MultipleReturns2<Integer, String> sayReferences() {
return new MultipleReturns2<>(123, "123");
}
@RayRemote
public static Map<Integer, String> sayReferencesN(Collection<Integer> userReturnIds,
String prefix, String suffix) {
Map<Integer, String> ret = new HashMap<>();
for (Integer returnid : userReturnIds) {
ret.put(returnid, prefix + returnid + suffix);
}
return ret;
}
@RayRemote
public static List<Integer> sayArray(Integer returnCount) {
ArrayList<Integer> rets = new ArrayList<>();
for (int i = 0; i < returnCount; i++) {
rets.add(i);
}
return rets;
}
@RayRemote(externalIO = true)
public static Integer sayRayFuture() {
return 123;
}
@RayRemote(externalIO = true)
public static MultipleReturns2<Integer, String> sayRayFutures() {
return new MultipleReturns2<>(123, "123");
}
@RayRemote(externalIO = true)
public static Map<Integer, String> sayRayFuturesN(
Collection<Integer/*user's custom return_id*/> userReturnIds,
String prefix) {
Map<Integer, String> ret = new HashMap<>();
for (int id : userReturnIds) {
ret.put(id, prefix + id);
}
return ret;
}
@RayRemote
public static int sayReadRayList(List<Integer> ints) {
int sum = 0;
for (Integer i : ints) {
sum += i;
}
return sum;
}
@RayRemote
public static int sayReadRayMap(Map<String, Integer> ints) {
int sum = 0;
for (Integer i : ints.values()) {
sum += i;
}
return sum;
}
}
@@ -0,0 +1,38 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.core.RayRuntime;
import org.ray.core.UniqueIdHelper;
@RunWith(MyRunner.class)
public class UniqueIDTest {
@Test
public void test() {
UniqueID tid = UniqueIdHelper.nextTaskId(0xdeadbeefL);
UniqueIdHelper.setTest(tid, true);
System.out.println("Tested task id = " + tid);
RayFunc_1_1<Integer, String> f = UniqueIDTest::hi;
RayObject<String> result = new RayObject<>(
RayRuntime.getInstance().call(
tid,
RayFunc_1_1.class,
f,
1, 1
).getObjs()[0].getId()
);
System.out.println("Tested task return object id = " + result.getId());
Assert.assertEquals("hi1", result.get());
}
@RayRemote
public static String hi(Integer i) {
return "hi" + i;
}
}
@@ -0,0 +1,74 @@
package org.ray.api.test;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.util.FileUtil;
/**
* given a directory of document files on each "machine", we would like to count the appearance of
* some word
*/
public class WordCountTest {
//@Test
public void test() {
int sum = mapReduce();
Assert.assertEquals(sum, 143);
}
public int mapReduce() {
RayObject<List<String>> machines = Ray.call(WordCountTest::getMachineList);
RayObject<Integer> total = null;
for (String machine : machines.get()) {
RayObject<Integer> wordcount = Ray.call(WordCountTest::countWord, machine, "ray");
if (total == null) {
total = wordcount;
} else {
total = Ray.call(WordCountTest::sum, total, wordcount);
}
}
return total.get();
}
@RayRemote
public static List<String> getMachineList() {
return Arrays.asList("A", "B", "C");
}
@RayRemote
public static Integer countWord(String machine, String word) {
String log;
try {
log = FileUtil.readResourceFile("mapreduce/" + machine + ".log");
} catch (FileNotFoundException e) {
e.printStackTrace();
log = "";
}
log = log.toLowerCase();
int start = 0;
int count = 0;
while (true) {
if (start >= log.length()) {
break;
}
int index = log.indexOf(word, start);
if (index == -1) {
break;
}
start = index + word.length();
count++;
}
return count;
}
@RayRemote
public static Integer sum(Integer a, Integer/*TODO modify int to Integer in ASM hook*/ b) {
return a + b;
}
}