[Java] Replace binary rewrite with Remote Lambda Cache (SerdeLambda) (#2245)

* <feature> : serde lambda

* <feature>:fixed CR

with issue #2245

* <feature>: fixed CR
This commit is contained in:
mylinyuzhi
2018-06-14 03:58:07 +08:00
committed by Philipp Moritz
parent 62de86ff7a
commit fa0ade2bc5
89 changed files with 2633 additions and 7668 deletions
@@ -103,13 +103,6 @@ public final class Ray extends Rpc {
return impl;
}
/**
* whether to use remote lambda.
*/
public static boolean isRemoteLambda() {
return impl.isRemoteLambda();
}
/**
* for ray's app's log.
*/
@@ -10,7 +10,7 @@ import org.ray.util.Sha1Digestor;
* Ray actor abstraction.
*/
public class RayActor<T> extends RayObject<T> implements Externalizable {
public static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
private static final long serialVersionUID = 1877485807405645036L;
private int taskCounter = 0;
@@ -84,4 +84,4 @@ public class RayActor<T> extends RayObject<T> implements Externalizable {
this.actorHandleId = (UniqueID) in.readObject();
this.taskCursor = (UniqueID) in.readObject();
}
}
}
+11 -19
View File
@@ -3,7 +3,7 @@ package org.ray.api;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import org.ray.api.internal.Callable;
import org.ray.api.internal.RayFunc;
import org.ray.util.exception.TaskExecutionException;
/**
@@ -54,14 +54,13 @@ public interface RayApi {
* submit a new task by invoking a remote function.
*
* @param taskId nil
* @param funcRun the target running function with @RayRemote
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnCount the number of to-be-returned objects from funcRun
* @param args arguments to this funcRun, can be its original form or RayObject
* @return a set of ray objects with their return ids
*/
RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args);
RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args);
/**
@@ -71,34 +70,27 @@ public interface RayApi {
* outputs with a set of labels (usually with Integer or String).
*
* @param taskId nil
* @param funcRun the target running function with @RayRemote
* @param returnIds a set of labels to be used by the returned objects
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnids a set of labels to be used by the returned objects
* @param args arguments to this funcRun, can be its original form or
* RayObject<original-type>
* @return a set of ray objects with their labels and return ids
*/
<R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RIDT> returnIds, Object... args);
<R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Collection<RIDT> returnids,
Object... args);
RayFunc lambda, Collection<RIDT> returnids, Object... args);
/**
* a special case for the above RID-based labeling as <0...returnCount - 1>.
*
* @param taskId nil
* @param funcRun the target running function with @RayRemote
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnCount the number of to-be-returned objects from funcRun
* @param args arguments to this funcRun, can be its original form or
* RayObject<original-type>
* @return an array of returned objects with their Unique ids
*/
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun, Integer returnCount,
Object... args);
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, Serializable lambda,
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
Integer returnCount, Object... args);
boolean isRemoteLambda();
}
File diff suppressed because it is too large Load Diff
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_1<R0> extends RayFunc {
static <R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_1<R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply();
}
R0 apply() throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_0_2<R0, R1> extends RayFunc {
static <R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_2<R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply();
}
MultipleReturns2<R0, R1> apply() throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_0_3<R0, R1, R2> extends RayFunc {
static <R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_3<R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply();
}
MultipleReturns3<R0, R1, R2> apply() throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_0_4<R0, R1, R2, R3> extends RayFunc {
static <R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_4<R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply();
}
MultipleReturns4<R0, R1, R2, R3> apply() throws Throwable;
}
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_n<R, RIDT> extends RayFunc {
static <R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_n<R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f.apply((Collection<RIDT>) args[0]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_n_list<R> extends RayFunc {
static <R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_0_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_0_n_list<R> f = SerializationUtils.deserialize(funcBytes);
return f.apply();
}
List<R> apply() throws Throwable;
}
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_1<T0, R0> extends RayFunc {
static <T0, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_1<T0, R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0]);
}
R0 apply(T0 t0) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_1_2<T0, R0, R1> extends RayFunc {
static <T0, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_2<T0, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0]);
}
MultipleReturns2<R0, R1> apply(T0 t0) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_1_3<T0, R0, R1, R2> extends RayFunc {
static <T0, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_3<T0, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_1_4<T0, R0, R1, R2, R3> extends RayFunc {
static <T0, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_4<T0, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0) throws Throwable;
}
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_n<T0, R, RIDT> extends RayFunc {
static <T0, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_n<T0, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f.apply((Collection<RIDT>) args[0], (T0) args[1]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_n_list<T0, R> extends RayFunc {
static <T0, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_1_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_1_n_list<T0, R> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0]);
}
List<R> apply(T0 t0) throws Throwable;
}
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_1<T0, T1, R0> extends RayFunc {
static <T0, T1, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_1<T0, T1, R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1]);
}
R0 apply(T0 t0, T1 t1) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_2_2<T0, T1, R0, R1> extends RayFunc {
static <T0, T1, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_2<T0, T1, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1]);
}
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_2_3<T0, T1, R0, R1, R2> extends RayFunc {
static <T0, T1, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_3<T0, T1, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_2_4<T0, T1, R0, R1, R2, R3> extends RayFunc {
static <T0, T1, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_4<T0, T1, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_n<T0, T1, R, RIDT> extends RayFunc {
static <T0, T1, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_n<T0, T1, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_n_list<T0, T1, R> extends RayFunc {
static <T0, T1, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_2_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_2_n_list<T0, T1, R> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1]);
}
List<R> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_1<T0, T1, T2, R0> extends RayFunc {
static <T0, T1, T2, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_1<T0, T1, T2, R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
}
R0 apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_3_2<T0, T1, T2, R0, R1> extends RayFunc {
static <T0, T1, T2, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_2<T0, T1, T2, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
}
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_3_3<T0, T1, T2, R0, R1, R2> extends RayFunc {
static <T0, T1, T2, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_3<T0, T1, T2, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_3_4<T0, T1, T2, R0, R1, R2, R3> extends RayFunc {
static <T0, T1, T2, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_4<T0, T1, T2, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_n<T0, T1, T2, R, RIDT> extends RayFunc {
static <T0, T1, T2, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_n<T0, T1, T2, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_n_list<T0, T1, T2, R> extends RayFunc {
static <T0, T1, T2, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_3_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_3_n_list<T0, T1, T2, R> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
}
List<R> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_1<T0, T1, T2, T3, R0> extends RayFunc {
static <T0, T1, T2, T3, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_1<T0, T1, T2, T3, R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
}
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_4_2<T0, T1, T2, T3, R0, R1> extends RayFunc {
static <T0, T1, T2, T3, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_2<T0, T1, T2, T3, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
}
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_4_3<T0, T1, T2, T3, R0, R1, R2> extends RayFunc {
static <T0, T1, T2, T3, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_3<T0, T1, T2, T3, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_4_4<T0, T1, T2, T3, R0, R1, R2, R3> extends RayFunc {
static <T0, T1, T2, T3, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_4<T0, T1, T2, T3, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -8,15 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_n<T0, T1, T2, T3, R, RIDT> extends RayFunc {
static <T0, T1, T2, T3, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_n<T0, T1, T2, T3, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_n_list<T0, T1, T2, T3, R> extends RayFunc {
static <T0, T1, T2, T3, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_4_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_4_n_list<T0, T1, T2, T3, R> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
}
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_1<T0, T1, T2, T3, T4, R0> extends RayFunc {
static <T0, T1, T2, T3, T4, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_1<T0, T1, T2, T3, T4, R0> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
}
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_5_2<T0, T1, T2, T3, T4, R0, R1> extends RayFunc {
static <T0, T1, T2, T3, T4, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_2<T0, T1, T2, T3, T4, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
}
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_5_3<T0, T1, T2, T3, T4, R0, R1, R2> extends RayFunc {
static <T0, T1, T2, T3, T4, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_3<T0, T1, T2, T3, T4, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_5_4<T0, T1, T2, T3, T4, R0, R1, R2, R3> extends RayFunc {
static <T0, T1, T2, T3, T4, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(
Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_4<T0, T1, T2, T3, T4, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -8,16 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_n<T0, T1, T2, T3, T4, R, RIDT> extends RayFunc {
static <T0, T1, T2, T3, T4, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_n<T0, T1, T2, T3, T4, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4],
(T4) args[5]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4)
throws Throwable;
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_n_list<T0, T1, T2, T3, T4, R> extends RayFunc {
static <T0, T1, T2, T3, T4, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_5_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_5_n_list<T0, T1, T2, T3, T4, R> f = SerializationUtils.deserialize(funcBytes);
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
}
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -6,15 +6,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_1<T0, T1, T2, T3, T4, T5, R0> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R0> R0 execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_1.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_1<T0, T1, T2, T3, T4, T5, R0> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
}
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -7,16 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_6_2<T0, T1, T2, T3, T4, T5, R0, R1> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_2.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_2<T0, T1, T2, T3, T4, T5, R0, R1> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
}
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -7,16 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_6_3<T0, T1, T2, T3, T4, T5, R0, R1, R2> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_3.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_3<T0, T1, T2, T3, T4, T5, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
}
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -7,17 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_6_4<T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(
Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_4.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_4<T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> f = SerializationUtils
.deserialize(funcBytes);
return f
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
}
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -8,16 +8,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_n<T0, T1, T2, T3, T4, T5, R, RIDT> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_n.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_n<T0, T1, T2, T3, T4, T5, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4],
(T4) args[5], (T5) args[6]);
}
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5)
throws Throwable;
@@ -7,15 +7,6 @@ import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_n_list<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
static <T0, T1, T2, T3, T4, T5, R> List<R> execute(Object[] args) throws Throwable {
String name = (String) args[args.length - 2];
assert (name.equals(RayFunc_6_n_list.class.getName()));
byte[] funcBytes = (byte[]) args[args.length - 1];
RayFunc_6_n_list<T0, T1, T2, T3, T4, T5, R> f = SerializationUtils.deserialize(funcBytes);
return f
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
}
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,9 +0,0 @@
package org.ray.api.internal;
/**
* hold the remote call.
*/
public interface Callable {
void run() throws Throwable;
}
+10 -4
View File
@@ -22,19 +22,25 @@
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>quartz</groupId>
<artifactId>quartz</artifactId>
<version>1.5.2</version>
</dependency>
<dependency>
<groupId>org.ini4j</groupId>
<artifactId>ini4j</artifactId>
<version>0.5.2</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>
</project>
@@ -0,0 +1,31 @@
package org.ray.util;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
/**
* see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html.
*/
public final class LambdaUtils {
private LambdaUtils() {
}
public static SerializedLambda getSerializedLambda(Serializable lambda) {
// Note.
// the class of lambda which isAssignableFrom Serializable
// has an privte method:writeReplace
// This mechanism may be changed in the future
try {
Method m = lambda.getClass().getDeclaredMethod("writeReplace");
m.setAccessible(true);
return (SerializedLambda) m.invoke(lambda);
} catch (Exception e) {
throw new RuntimeException("failed to getSerializedLambda:" + lambda.getClass().getName(), e);
}
}
}
@@ -0,0 +1,212 @@
package org.ray.util;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.WeakHashMap;
import org.objectweb.asm.Type;
import org.ray.util.logger.RayLog;
/**
* An instance of RayFunc is a lambda.
* MethodId describe the information of the called function in lambda.<br/>
* e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:<br/>
* MethodId.className = Foo <br/>
* MethodId.methodName = foo <br/>
* MethodId.methodDesc = describe the types of args and return.
* see org.objectweb.asm.Type.getDescriptor.
*/
public final class MethodId {
/**
* use ThreadLocal to avoid lock.
* A cache from the lambda instances to MethodId.
* Note: the lambda instances are dynamically created per call site,
* we use WeakHashMap to avoid OOM.
*/
private static final ThreadLocal<WeakHashMap<Class<Serializable>, MethodId>>
CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>());
public final String className;
public final String methodName;
public final String methodDesc;
public final boolean isStatic;
/**
* encode the className,methodName,methodDesc,isStatic as an uniquel id.
*/
private final String encoding;
/**
* sha1 from the encoding, used as functionId.
*/
private final byte[] digest;
public MethodId(String className, String methodName, String methodDesc, boolean isStatic) {
this.className = className;
this.methodName = methodName;
this.methodDesc = methodDesc;
this.isStatic = isStatic;
this.encoding = encode(className, methodName, methodDesc, isStatic);
this.digest = getSha1Hash0();
}
private static String encode(String className, String methodName, String methodDesc,
boolean isStatic) {
StringBuilder sb = new StringBuilder(512);
sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&")
.append(isStatic);
return sb.toString();
}
public static MethodId fromMethod(Method method) {
final boolean isstatic = Modifier.isStatic(method.getModifiers());
final String className = method.getDeclaringClass().getName();
final String methodName = method.getName();
final Type type = Type.getType(method);
final String methodDesc = type.getDescriptor();
return new MethodId(className, methodName, methodDesc, isstatic);
}
public static MethodId fromSerializedLambda(Serializable serial) {
return fromSerializedLambda(serial, false);
}
public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) {
Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be "
+ "SerializedLambda");
Class<Serializable> clazz = (Class<Serializable>) serial.getClass();
WeakHashMap<Class<Serializable>, MethodId> map = CACHE.get();
MethodId id = map.get(clazz);
if (id == null || forceNew) {
final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial);
Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda "
+ "which is closure");
final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic;
final String className = lambda.getImplClass().replace('/', '.');
id = new MethodId(className, lambda.getImplMethodName(),
lambda.getImplMethodSignature(), isStatic);
if (!forceNew) {
map.put(clazz, id);
}
}
return id;
}
public Method load() {
return load(null);
}
public Method load(ClassLoader loader) {
Class<?> cls = null;
try {
RayLog.core.debug(
"load class " + className + " from class loader " + (loader == null ? this.getClass()
.getClassLoader() : loader)
+ " for method " + toString() + " with ID = " + toHexHashString()
);
cls = Class
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
} catch (Throwable e) {
RayLog.core.error("Cannot load class " + className, e);
return null;
}
Method[] ms = cls.getDeclaredMethods();
ArrayList<Method> methods = new ArrayList<>();
Type t = Type.getMethodType(this.methodDesc);
Type[] params = t.getArgumentTypes();
String rt = t.getReturnType().getDescriptor();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
continue;
}
String mrt = Type.getDescriptor(m.getReturnType());
if (!rt.equals(mrt)) {
continue;
}
if (isStatic != Modifier.isStatic(m.getModifiers())) {
continue;
}
methods.add(m);
}
}
if (methods.size() != 1) {
RayLog.core.error(
"Load method " + toString() + " failed as there are " + methods.size() + " definitions");
return null;
}
return methods.get(0);
}
private byte[] getSha1Hash0() {
byte[] digests = Sha1Digestor.digest(encoding);
ByteBuffer bb = ByteBuffer.wrap(digests);
bb.order(ByteOrder.LITTLE_ENDIAN);
if (methodName.contains("createActorStage1")) {
bb.putLong(Long.BYTES, 1);
} else {
bb.putLong(Long.BYTES, 0);
}
return digests;
}
public byte[] getSha1Hash() {
return digest;
}
private String toHexHashString() {
byte[] id = this.getSha1Hash();
return StringUtil.toHexHashString(id);
}
public String toEncodingString() {
return encoding;
}
@Override
public int hashCode() {
return encoding.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MethodId other = (MethodId) obj;
return className.equals(other.className)
&& methodName.equals(other.methodName)
&& methodDesc.equals(other.methodDesc)
&& isStatic == other.isStatic;
}
@Override
public String toString() {
return encoding;
}
}
@@ -15,6 +15,7 @@ public class Sha1Digestor {
}
});
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
@@ -27,4 +28,14 @@ public class Sha1Digestor {
dg.update(longBuffer.get().putLong(addIndex).array());
return dg.digest();
}
}
public static byte[] digest(String str) {
return digest(str.getBytes(StringUtil.UTF8));
}
public static byte[] digest(byte[] src) {
MessageDigest dg = md.get();
dg.reset();
return dg.digest(src);
}
}
@@ -1,11 +1,16 @@
package org.ray.util;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Vector;
public class StringUtil {
public static final Charset UTF8 = Charset.forName("UTF-8");
private static final char[] HEX_CHARS = "0123456789abcdef".toCharArray();
/**
* split.
* @param s input string
@@ -117,6 +122,17 @@ public class StringUtil {
return objs.length == 0 ? "" : sb.substring(0, sb.length() - concatenator.length());
}
public static String toHexHashString(byte[] id) {
StringBuilder sb = new StringBuilder(20);
assert (id.length == 20);
for (int i = 0; i < 20; i++) {
int val = id[i] & 0xff;
sb.append(HEX_CHARS[val >> 4]);
sb.append(HEX_CHARS[val & 0xf]);
}
return sb.toString();
}
// Holds the start of an element and which brace started it.
private static class Start {
@@ -132,4 +148,3 @@ public class StringUtil {
}
}
}
+1 -6
View File
@@ -50,11 +50,6 @@
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<version>1.0</version>
</dependency>
</dependencies>
<build>
<plugins>
@@ -99,4 +94,4 @@
</plugins>
</build>
</project>
</project>
-88
View File
@@ -1,88 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<groupId>org.ray.parent</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<name>java api hook for ray</name>
<description>java api hook for ray</description>
<url></url>
<packaging>jar</packaging>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.ow2.asm/asm -->
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>6.0</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-common</artifactId>
<version>1.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>2.5</version>
<configuration>
<archive>
<manifestEntries>
<Premain-Class>org.ray.hook.Agent</Premain-Class>
<Can-Retransform-Classes>true</Can-Retransform-Classes>
</manifestEntries>
</archive>
</configuration>
</plugin>
<plugin>
<artifactId>maven-shade-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<relocations>
<relocation>
<pattern>org.objectweb.asm</pattern>
<shadedPattern>agent.org.objectweb.asm</shadedPattern>
</relocation>
</relocations>
<createDependencyReducedPom>false</createDependencyReducedPom>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
@@ -1,44 +0,0 @@
package org.ray.hook;
import java.util.Set;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
public class ClassAdapter {
public static Result hookClass(ClassLoader loader, String className, byte[] classfileBuffer) {
// we have to comment out this quick filter as this is not accurate
// e.g., org/ray/api/test/ActorTest$Adder.class is skipped!!!
// even worse, this is non-deterministic...
ClassReader reader = new ClassReader(classfileBuffer);
ClassWriter writer = new ClassWriter(reader, 0);
ClassDetectVisitor pre = new ClassDetectVisitor(loader, writer, className);
byte[] result;
reader.accept(pre, ClassReader.SKIP_FRAMES);
if (pre.detectedMethods().isEmpty() && pre.actorCalls() == 0) {
result = classfileBuffer;
} else {
if (pre.actorCalls() > 0) {
reader = new ClassReader(writer.toByteArray());
}
writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
reader.accept(new ClassOverrideVisitor(writer, className, pre.detectedMethods()),
ClassReader.SKIP_FRAMES);
result = writer.toByteArray();
}
Result rr = new Result();
rr.changedMethods = pre.detectedMethods();
rr.classBuffer = result;
return rr;
}
public static class Result {
public byte[] classBuffer;
public Set<MethodId> changedMethods;
}
}
@@ -1,165 +0,0 @@
package org.ray.hook;
import java.security.InvalidParameterException;
import java.util.HashSet;
import java.util.Set;
import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
/**
* rewrite phase 1.
*/
public class ClassDetectVisitor extends ClassVisitor {
static int count = 0;
final String className;
final Set<MethodId> rayMethods = new HashSet<>();
final ClassLoader loader;
boolean isActor = false;
int actorCalls = 0;
public ClassDetectVisitor(ClassLoader loader, ClassVisitor origin, String className) {
super(Opcodes.ASM6, origin);
this.className = className;
this.loader = loader;
}
public int actorCalls() {
return actorCalls;
}
public Set<MethodId> detectedMethods() {
return rayMethods;
}
@Override
public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
if (desc.contains("Lorg/ray/api/RayRemote;")) {
isActor = true;
}
return super.visitAnnotation(desc, visible);
}
@Override
public void visitInnerClass(String name, String outerName,
String innerName, int access) {
// System.err.println("visist inner class " + outerName + "$" + innerName);
super.visitInnerClass(name, outerName, innerName, access);
}
@Override
public MethodVisitor visitMethod(int access, String name, String mdesc, String signature,
String[] exceptions) {
//System.out.println("Visit " + className + "." + name);
if (isActor && (access & Opcodes.ACC_PUBLIC) != 0) {
visitRayMethod(access, name, mdesc);
}
MethodVisitor origin = super.visitMethod(access, name, mdesc, signature, exceptions);
return new MethodVisitor(this.api, origin) {
@Override
public AnnotationVisitor visitAnnotation(String adesc, boolean visible) {
//handle rayRemote annotation
if (adesc.contains("Lorg/ray/api/RayRemote;")) {
visitRayMethod(access, name, mdesc);
}
return super.visitAnnotation(adesc, visible);
}
@Override
public void visitInvokeDynamicInsn(String name, String desc, Handle bsm,
Object... bsmArgs) {
// fix all actor calls from InvokeVirtual to InvokeStatic
if (desc.contains("org/ray/api/funcs/RayFunc_")) {
int count = bsmArgs.length;
for (int i = 0; i < count; ++i) {
Object arg = bsmArgs[i];
// System.err.println(arg.getClass().getName() + " " + arg.toString());
if (arg.getClass().equals(Handle.class)) {
Handle h = (Handle) arg;
if (h.getTag() == Opcodes.H_INVOKEVIRTUAL) {
String dsptr = h.getDesc();
Type[] argTypes = Type.getArgumentTypes(dsptr);
for (Type argt : argTypes) {
if (!isValidCallParameterOrReturnType(argt)) {
throw new InvalidParameterException(
"cannot use primitive parameter type '" + argt.getClassName()
+ "' in method " + h.getOwner() + "." + h.getName());
}
}
Type retType = Type.getReturnType(dsptr);
if (!isValidCallParameterOrReturnType(retType)) {
throw new InvalidParameterException(
"cannot use primitive return type '" + retType.getClassName() + "' in method "
+ h.getOwner() + "." + h.getName());
}
dsptr = "(L" + h.getOwner() + ";" + dsptr.substring(1);
Handle newh = new Handle(
Opcodes.H_INVOKESTATIC,
h.getOwner(),
h.getName() + MethodId.getFunctionIdPostfix,
dsptr,
h.isInterface());
bsmArgs[i] = newh;
//System.err.println("Change ray.call from " + h + " -> " + newh + ", isInterface
// = " + h.isInterface());
++actorCalls;
}
}
}
}
super.visitInvokeDynamicInsn(name, desc, bsm, bsmArgs);
}
private boolean isValidCallParameterOrReturnType(Type t) {
if (t.equals(Type.VOID_TYPE)) {
return false;
}
if (t.equals(Type.BOOLEAN_TYPE)) {
return false;
}
if (t.equals(Type.CHAR_TYPE)) {
return false;
}
if (t.equals(Type.BYTE_TYPE)) {
return false;
}
if (t.equals(Type.SHORT_TYPE)) {
return false;
}
if (t.equals(Type.INT_TYPE)) {
return false;
}
if (t.equals(Type.FLOAT_TYPE)) {
return false;
}
if (t.equals(Type.LONG_TYPE)) {
return false;
}
if (t.equals(Type.DOUBLE_TYPE)) {
return false;
}
return true;
}
};
}
private void visitRayMethod(int access, String name, String mdesc) {
if (name.equals("<init>")) {
return;
}
MethodId m = new MethodId(className, name, mdesc, (access & Opcodes.ACC_STATIC) != 0, loader);
rayMethods.add(m);
//System.err.println("Visit " + m.toString());
count++;
}
}
@@ -1,214 +0,0 @@
package org.ray.hook;
import java.util.Set;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
/**
* rewrite phase 2.
*/
public class ClassOverrideVisitor extends ClassVisitor {
final String className;
final Set<MethodId> rayRemoteMethods;
MethodVisitor clinitVisitor;
public ClassOverrideVisitor(ClassVisitor origin, String className,
Set<MethodId> rayRemoteMethods) {
super(Opcodes.ASM6, origin);
this.className = className;
this.rayRemoteMethods = rayRemoteMethods;
this.clinitVisitor = null;
}
@Override
public MethodVisitor visitMethod(int access, String name, String desc, String signature,
String[] exceptions) {
if ("<clinit>".equals(name) && clinitVisitor == null) {
MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
clinitVisitor = new StaticBlockVisitor(mv);
return clinitVisitor;// dispatch the ASM modifications (assign values to the preComputedxxx
// static field) to the clinitVisitor
}
ClassVisitor current = this;
MethodId m = new MethodId(className, name, desc, (access & Opcodes.ACC_STATIC) != 0, null);
if (rayRemoteMethods.contains(m)) {
if (m.isStaticMethod()) {
return new MethodVisitor(api,
super.visitMethod(access, name, desc, signature, exceptions)) {
@Override
public void visitCode() {
// step 1: add a field for the function id of this method
System.out.println("add field: " + m.getStaticHashValueFieldName());
String fieldName = m.getStaticHashValueFieldName();
current.visitField(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, fieldName, "[B",
null, null);
// step 2: rewrite current method so if MethodSwitcher returns true, returns the
// added function id directly
// else call the original method
mv.visitFieldInsn(Opcodes.GETSTATIC, className, fieldName, "[B");
mv.visitMethodInsn(Opcodes.INVOKESTATIC, "org/ray/hook/runtime/MethodSwitcher",
"execute",
"([B)Z", false);
Label dorealwork = new Label();
mv.visitJumpInsn(Opcodes.IFEQ, dorealwork);
properReturn(sayReturnType(desc), mv);// proper return on
// different types
mv.visitLabel(dorealwork);
mv.visitCode();// real work
}
};
} else { // non-static
return super.visitMethod(access, name, desc, signature, exceptions);
}
} else {
return super.visitMethod(access, name, desc, signature, exceptions);
}
}
@Override
public void visitEnd() {
if (clinitVisitor == null) { // works fine
// Create an empty static block and let our method
// visitor modify it the same way it modifies an
// existing static block
{
MethodVisitor mv = super.visitMethod(Opcodes.ACC_STATIC, "<clinit>", "()V", null, null);
mv = new StaticBlockVisitor(mv);
mv.visitCode();
mv.visitInsn(Opcodes.RETURN);
mv.visitMaxs(0, 0);
mv.visitEnd();
}
}
// for each non-static method, create a method for returning hash
for (MethodId mid : this.rayRemoteMethods) {
if (!mid.isStaticMethod()) {
// step 1: create a new method called method_name_function_id()
System.out.println("add method: " + mid.getIdMethodName());
MethodVisitor mv = super.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC,
mid.getIdMethodName(), mid.getIdMethodDesc(), null, null);
mv.visitCode();
Label l0 = new Label();
mv.visitLabel(l0);
// step 2: add a new static field as the function id of this method
System.out.println("add field: " + mid.getStaticHashValueFieldName());
String fieldName = mid.getStaticHashValueFieldName();
this.visitField(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, fieldName, "[B",
null, null);
mv.visitFieldInsn(Opcodes.GETSTATIC, className, fieldName, "[B");
// step 3: call method switcher, and returns the function id when the mode is rewritten
mv.visitMethodInsn(Opcodes.INVOKESTATIC, "org/ray/hook/runtime/MethodSwitcher", "execute",
"([B)Z", false);
mv.visitInsn(Opcodes.POP);
Label l1 = new Label();
mv.visitLabel(l1);
properReturn(sayReturnType(mid.getMethodDesc()), mv);
Label l2 = new Label();
mv.visitLabel(l2);
org.objectweb.asm.Type[] args = org.objectweb.asm.Type
.getArgumentTypes(mid.getIdMethodDesc());
int argCount = args.length;
mv.visitMaxs(2, argCount);
mv.visitEnd();
}
}
}
private void properReturn(String returnType, MethodVisitor mv) {
int returnCode;
Object returnValue = null;
switch (returnType) {
case "V":
mv.visitInsn(Opcodes.RETURN);
return;
case "I":
case "B":
case "S":
case "Z": // int byte short boolean
returnCode = Opcodes.IRETURN;
returnValue = 0;
break;
case "J": // long
returnCode = Opcodes.LRETURN;
returnValue = 0L;
break;
case "D": // double
returnCode = Opcodes.DRETURN;
returnValue = 0D;
break;
case "F": // float
returnCode = Opcodes.FRETURN;
returnValue = 0F;
break;
default: // reference
returnCode = Opcodes.ARETURN;
break;
}
if (returnValue != null) {
mv.visitLdcInsn(returnValue);
} else {
mv.visitInsn(Opcodes.ACONST_NULL);
}
mv.visitInsn(returnCode);
}
private String sayReturnType(String str) {
int left = str.lastIndexOf(")") + 1;
return str.substring(left);
}
// init the static added field in <clinit>
// static {
// assign value to _hashOf_XXX
// }
class StaticBlockVisitor extends MethodVisitor {
StaticBlockVisitor(MethodVisitor mv) {
super(Opcodes.ASM6, mv);
}
@Override
public void visitCode() {
super.visitCode();
// assign value for added hash fields within <clinit>
for (MethodId m : rayRemoteMethods) {
byte[] hash = m.getSha1Hash();
insertByteArray(hash);
mv.visitFieldInsn(Opcodes.PUTSTATIC, className, m.getStaticHashValueFieldName(), "[B");
System.out.println("assign field: " + m.getStaticHashValueFieldName() + " = " + MethodId
.toHexHashString(hash));
}
}
private void insertByteArray(byte[] bytes) {
int length = bytes.length;
assert (length < Short.MAX_VALUE);
mv.visitIntInsn(Opcodes.SIPUSH, length);
mv.visitIntInsn(Opcodes.NEWARRAY, Opcodes.T_BYTE);
mv.visitInsn(Opcodes.DUP);
for (int i = 0; i < length; ++i) {
mv.visitIntInsn(Opcodes.BIPUSH, i);
mv.visitIntInsn(Opcodes.BIPUSH, bytes[i]);
mv.visitInsn(Opcodes.BASTORE);
if (i < (length - 1)) {
mv.visitInsn(Opcodes.DUP);
}
}
}
}
}
@@ -1,201 +0,0 @@
package org.ray.hook;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.Scanner;
import java.util.function.BiConsumer;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.jar.JarOutputStream;
import java.util.zip.DataFormatException;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.ray.hook.runtime.JarLoader;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.util.logger.RayLog;
/**
* rewrite jars to new jars with methods marked using Ray annotations.
*/
public class JarRewriter {
private static final String FUNCTIONS_FILE = "ray.functions.txt";
public static void main(String[] args)
throws IOException, SecurityException, DataFormatException {
if (args.length == 1) {
LoadedFunctions funcs = load(args[0], null);
for (MethodId mi : funcs.functions) {
System.err.println(mi.getIdMethodDesc());
Method m = mi.load();
String logInfo = "load: " + m.getDeclaringClass().getName() + "." + m.getName();
RayLog.core.info(logInfo);
}
return;
} else if (args.length < 2) {
System.err.println("org.ray.hook.JarRewriter source-jar-dir dest-jar-dir");
System.exit(1);
}
rewrite(args[0], args[1]);
}
public static LoadedFunctions load(String dir, String baseDir)
throws FileNotFoundException, SecurityException {
List<String> functions = JarRewriter.getRewrittenFunctions(dir);
LoadedFunctions efuncs = new LoadedFunctions();
efuncs.loader = JarLoader.loadJars(dir, false);
for (String func : functions) {
MethodId mid = new MethodId(func, efuncs.loader);
efuncs.functions.add(mid);
}
if (baseDir != null && !baseDir.equals("")) {
List<String> baseFunctions = JarRewriter.getRewrittenFunctions(baseDir);
for (String func : baseFunctions) {
MethodId mid = new MethodId(func, efuncs.loader);
efuncs.functions.add(mid);
}
}
return efuncs;
}
public static void rewrite(String fromDir, String toDir) throws IOException, DataFormatException {
File fromDirFile = new File(fromDir);
File toDirFileTmp = new File(toDir + ".tmp");
final File toDirFile = new File(toDir);
File[] topFiles = fromDirFile.listFiles();
if (topFiles.length != 1 || !topFiles[0].isDirectory()) {
throw new DataFormatException("There should be a top dir in the Ray app zip file.");
}
String topDir = topFiles[0].getName();
if (toDirFileTmp.exists()) {
FileUtils.deleteDirectory(toDirFileTmp);
}
//toDirFileTmp.mkdir();
FileUtils.copyDirectory(fromDirFile, toDirFileTmp);
PrintWriter functionCollector = new PrintWriter(toDir + ".tmp/" + FUNCTIONS_FILE, "UTF-8");
// get all jars
Collection<File> files = FileUtils.listFiles(
fromDirFile,
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY
);
// load and rewrite
int prefixLength = fromDirFile.getAbsolutePath().length() + topDir.length() + 2;
for (File appJar : files) {
String fromPath = appJar.getAbsolutePath();
if (fromPath.substring(prefixLength).contains("/")) {
functionCollector.close();
throw new DataFormatException("There should not be any subdir"
+ " containing jar file in the top dir of the Ray app zip file.");
}
JarFile jar = new JarFile(appJar.getAbsolutePath());
String to = fromPath
.replaceFirst(fromDirFile.getAbsolutePath(), toDirFileTmp.getAbsolutePath());
rewrite(jar, to, (l, m) -> functionCollector.println(m.toEncodingString()));
jar.close();
}
// rename the whole dir
functionCollector.close();
if (toDirFile.exists()) {
FileUtils.deleteDirectory(toDirFile);
}
FileUtils.moveDirectory(toDirFileTmp, toDirFile);
}
public static void rewrite(JarFile from, String to, BiConsumer<ClassLoader, MethodId> consumer)
throws IOException {
FileOutputStream ofStream = new FileOutputStream(to);
JarOutputStream ojStream = new JarOutputStream(ofStream);
Enumeration<JarEntry> e = from.entries();
String className;
while (e.hasMoreElements()) {
JarEntry je = e.nextElement();
byte[] jeBytes = IOUtils.toByteArray(from.getInputStream(je));
//System.err.println("XXXXXX " + from.getName() + " :: " + je.getName());
if (!je.isDirectory() && je.getName().endsWith(".class")) {
className = je.getName().substring(0, je.getName().length() - ".class".length());
//System.err.println("XXXXXX " + from.getName() + " :: " + je.getName() + " - " +
// className);
ClassAdapter.Result result = ClassAdapter.hookClass(null, className, jeBytes);
if (result.classBuffer != jeBytes) {
String logInfo = "Rewrite class " + className + " from " + jeBytes.length + " bytes to "
+ result.classBuffer.length + " bytes ";
RayLog.core.info(logInfo);
}
if (result.changedMethods != null) {
for (MethodId m : result.changedMethods) {
consumer.accept(null, m);
}
}
je = new JarEntry(je.getName());
je.setTime(System.currentTimeMillis());
je.setSize(result.classBuffer.length);
jeBytes = result.classBuffer;
}
ojStream.putNextEntry(je);
ojStream.write(jeBytes);
//ojStream.closeEntry();
}
ojStream.close();
ofStream.close();
}
public static List<String> getRewrittenFunctions(String rewrittenDir)
throws FileNotFoundException {
ArrayList<String> functions = new ArrayList<>();
Scanner s = new Scanner(new File(rewrittenDir + "/" + FUNCTIONS_FILE));
while (s.hasNext()) {
String f = s.next();
if (!f.startsWith("(")) {
functions.add(f);
}
}
s.close();
return functions;
}
public static LoadedFunctions loadBase(String baseDir)
throws FileNotFoundException, SecurityException {
List<String> functions = JarRewriter.getRewrittenFunctions(baseDir);
LoadedFunctions efuncs = new LoadedFunctions();
efuncs.loader = null;
for (String func : functions) {
MethodId mid = new MethodId(func, efuncs.loader);
efuncs.functions.add(mid);
}
return efuncs;
}
}
@@ -1,226 +0,0 @@
package org.ray.hook;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.codec.digest.DigestUtils;
import org.objectweb.asm.Type;
import org.ray.util.logger.RayLog;
/**
* Represent a Method in a Class.
*/
public class MethodId {
static final String getFunctionIdPostfix = "_function_id";
String className;
String methodName;
String methodDesc;
boolean isStatic;
ClassLoader loader;
public MethodId(String cls, String method, String mdesc, boolean isstatic, ClassLoader loader) {
className = cls;
methodName = method;
methodDesc = mdesc;
isStatic = isstatic;
this.loader = loader;
}
public MethodId(String encodedString, ClassLoader loader) {
// className + "." + methodName + "::" + methodDesc + "&&" + isStatic;
int lastPos3 = encodedString.lastIndexOf("&&");
int lastPos2 = encodedString.lastIndexOf("::");
int lastPos1 = encodedString.lastIndexOf(".");
if (lastPos1 == -1 || lastPos2 == -1 || lastPos3 == -1) {
throw new RuntimeException("invalid given method id " + encodedString
+ " - it must be className.methodName::methodDesc&&isStatic");
}
className = encodedString.substring(0, lastPos1);
methodName = encodedString.substring(lastPos1 + ".".length(), lastPos2);
methodDesc = encodedString.substring(lastPos2 + "::".length(), lastPos3);
isStatic = Boolean.parseBoolean(encodedString.substring(lastPos3 + "&&".length()));
this.loader = loader;
}
public static String toHexHashString(byte[] id) {
String s = "";
String hex = "0123456789abcdef";
assert (id.length == 20);
for (int i = 0; i < 20; i++) {
int val = id[i] & 0xff;
s += hex.charAt(val >> 4);
s += hex.charAt(val & 0xf);
}
return s;
}
private String toHexHashString() {
byte[] id = this.getSha1Hash();
return toHexHashString(id);
}
public String getClassName() {
return className;
}
public String getMethodName() {
return methodName;
}
public String getMethodDesc() {
return methodDesc;
}
public ClassLoader getLoader() {
return loader;
}
public Boolean isStaticMethod() {
return isStatic;
}
public String getIdMethodName() {
return this.methodName + getFunctionIdPostfix;
}
public String getIdMethodDesc() {
return "(L" + this.className + ";" + this.methodDesc.substring(1);
}
public Method load() {
String loadClsName = className.replace('/', '.');
Class<?> cls;
try {
RayLog.core.debug(
"load class " + loadClsName + " from class loader " + (loader == null ? this.getClass()
.getClassLoader() : loader)
+ " for method " + toString() + " with ID = " + toHexHashString()
);
cls = Class
.forName(loadClsName, true, loader == null ? this.getClass().getClassLoader() : loader);
} catch (Throwable e) {
RayLog.core.error("Cannot load class " + loadClsName, e);
return null;
}
Method[] ms = cls.getDeclaredMethods();
ArrayList<Method> methods = new ArrayList<>();
Type t = Type.getMethodType(this.methodDesc);
Type[] params = t.getArgumentTypes();
String rt = t.getReturnType().getDescriptor();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
continue;
}
String mrt = Type.getDescriptor(m.getReturnType());
if (!rt.equals(mrt)) {
continue;
}
methods.add(m);
}
}
if (methods.size() != 1) {
RayLog.core.error(
"Load method " + toString() + " failed as there are " + methods.size() + " definitions");
return null;
}
Method m = methods.get(0);
try {
Field fld = cls.getField(getStaticHashValueFieldName());
Object hashValue = fld.get(null);
if (hashValue instanceof byte[] && Arrays.equals((byte[]) hashValue, this.getSha1Hash())) {
RayLog.core.debug("Method " + toString() + " hash: " + toHexHashString((byte[]) hashValue));
} else {
if (hashValue instanceof byte[]) {
RayLog.core.error(
"Method " + toString() + " hash-field: " + toHexHashString((byte[]) hashValue)
+ " vs id-hash: " + toHexHashString());
} else {
RayLog.core.error(
"Method " + toString() + " hash-field: " + (hashValue != null ? hashValue.toString()
: "<nil>") + " vs id-hash: " + toHexHashString());
}
}
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException
| IllegalAccessException e) {
RayLog.core.error("load method hash field failed for " + toString(), e);
}
return m;
}
public String toEncodingString() {
return className + "." + methodName + "::" + methodDesc + "&&" + isStatic;
}
public byte[] getSha1Hash() {
byte[] digests = DigestUtils.sha(toEncodingString());
ByteBuffer bb = ByteBuffer.wrap(digests);
bb.order(ByteOrder.LITTLE_ENDIAN);
if (methodName.contains("createActorStage1")) {
bb.putLong(Long.BYTES, 1);
} else {
bb.putLong(Long.BYTES, 0);
}
return digests;
}
public String getStaticHashValueFieldName() {
// _hashOf<init>_([Ljava/lang/String;)V
String r = "_hashOf" + methodName + "_" + methodDesc;
r = r.replace("<", "_")
.replace(">", "_")
.replace("(", "_")
.replace("[", "_")
.replace("/", "_")
.replace(";", "_")
.replace(")", "_")
;
// System.err.println(r);
return r;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + className.hashCode();
result = prime * result + methodName.hashCode();
result = prime * result + methodDesc.hashCode();
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MethodId other = (MethodId) obj;
return className.equals(other.className)
&& methodName.equals(other.methodName)
&& methodDesc.equals(other.methodDesc)
&& isStatic == other.isStatic
;
}
@Override
public String toString() {
return toEncodingString();
}
}
@@ -1,131 +0,0 @@
package org.ray.hook.runtime;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.ray.util.logger.RayLog;
/**
* load and unload jars from a dir.
*/
public class JarLoader {
private static Method AddUrl = initAddUrl();
private static Method initAddUrl() {
try {
Method m = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
m.setAccessible(true);
return m;
} catch (NoSuchMethodException | SecurityException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
}
public static URLClassLoader loadJars(String dir, boolean explicitLoadForHook) {
// get all jars
Collection<File> jars = FileUtils.listFiles(
new File(dir),
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY
);
return loadJar(jars, explicitLoadForHook);
}
public static ClassLoader loadJars(String[] appJars, boolean explicitLoadForHook) {
List<File> jars = new ArrayList<>();
for (String jar : appJars) {
if (jar.endsWith(".jar")) {
jars.add(new File(jar));
} else {
loadJarDir(jar, jars);
}
}
return loadJar(jars, explicitLoadForHook);
}
private static URLClassLoader loadJar(Collection<File> appJars, boolean explicitLoadForHook) {
List<JarFile> jars = new ArrayList<>();
List<URL> urls = new ArrayList<>();
for (File appJar : appJars) {
try {
RayLog.core.info("load jar " + appJar.getAbsolutePath() + " in ray hook");
JarFile jar = new JarFile(appJar.getAbsolutePath());
jars.add(jar);
urls.add(appJar.toURI().toURL());
} catch (IOException e) {
e.printStackTrace();
System.err.println(
"invalid app jar path: " + appJar.getAbsolutePath() + ", load failed with exception "
+ e.getMessage());
System.exit(1);
return null;
}
}
URLClassLoader cl = URLClassLoader.newInstance(urls.toArray(new URL[0]));
if (!explicitLoadForHook) {
return cl;
}
for (JarFile jar : jars) {
Enumeration<JarEntry> e = jar.entries();
while (e.hasMoreElements()) {
JarEntry je = e.nextElement();
//System.err.println("check " + je.getName());
if (je.isDirectory() || !je.getName().endsWith(".class")) {
continue;
}
// -6 because of .class
String className = je.getName().substring(0, je.getName().length() - 6);
className = className.replace('/', '.');
try {
Class.forName(className, true, cl);
//System.err.println("load class " + className + " OK");
} catch (ClassNotFoundException e1) {
e1.printStackTrace();
}
}
try {
jar.close();
} catch (IOException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
}
return cl;
}
private static void loadJarDir(String jarDir, List<File> jars) {
Collection<File> files = FileUtils.listFiles(
new File(jarDir),
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY
);
jars.addAll(files);
}
public static void unloadJars(ClassLoader loader) {
// TODO:
}
}
@@ -1,12 +0,0 @@
package org.ray.hook.runtime;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.ray.hook.MethodId;
public class LoadedFunctions {
public final Set<MethodId> functions = Collections.synchronizedSet(new HashSet<>());
public ClassLoader loader = null;
}
@@ -1,36 +0,0 @@
package org.ray.hook.runtime;
import java.util.Arrays;
public class MethodHash {
private final byte[] hash;
public MethodHash(byte[] hash) {
this.hash = hash;
}
@Override
public int hashCode() {
return Arrays.hashCode(getHash());
}
public byte[] getHash() {
return hash;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MethodHash other = (MethodHash) obj;
return Arrays.equals(this.getHash(), other.getHash());
}
}
@@ -1,21 +0,0 @@
package org.ray.hook.runtime;
/**
* method mode switch at runtime.
*/
public class MethodSwitcher {
public static final ThreadLocal<Boolean> IsRemoteCall = new ThreadLocal<>();
public static final ThreadLocal<byte[]> MethodId = new ThreadLocal<>();
public static boolean execute(byte[] id) {
Boolean hooking = IsRemoteCall.get();
if (Boolean.TRUE.equals(hooking)) {
MethodId.set(id);
return true;
} else {
return false;
}
}
}
+96 -3
View File
@@ -5,13 +5,13 @@
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<groupId>org.ray.parent</groupId>
<artifactId>ray-superpom</artifactId>
<version>1.0</version>
<modules>
<module>api</module>
<module>common</module>
<module>hook</module>
<module>runtime-common</module>
<module>runtime-native</module>
<module>runtime-dev</module>
@@ -24,6 +24,7 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
@@ -31,7 +32,98 @@
<artifactId>arrow-plasma</artifactId>
<version>0.10.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>de.ruedigermoeller</groupId>
<artifactId>fst</artifactId>
<version>2.47</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>quartz</groupId>
<artifactId>quartz</artifactId>
<version>1.5.2</version>
</dependency>
<dependency>
<groupId>org.ini4j</groupId>
<artifactId>ini4j</artifactId>
<version>0.5.2</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>6.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.7.0.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.8.0</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.2</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>net.lingala.zip4j</groupId>
<artifactId>zip4j</artifactId>
<version>1.3.2</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>19.0</version>
</dependency>
<dependency>
<groupId>commons-collections</groupId>
<artifactId>commons-collections</artifactId>
<version>3.2.2</version>
</dependency>
</dependencies>
</dependencyManagement>
<build>
@@ -145,7 +237,8 @@
<violationSeverity>warning</violationSeverity>
<format>xml</format>
<format>html</format>
<outputFile>${project.build.directory}/test/checkstyle-errors.xml</outputFile>
<outputFile>${project.build.directory}/test/checkstyle-errors.xml
</outputFile>
<linkXRef>false</linkXRef>
</configuration>
</plugin>
@@ -153,4 +246,4 @@
</pluginManagement>
</build>
</project>
</project>
+9 -10
View File
@@ -1,5 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
@@ -24,35 +23,35 @@
<artifactId>ray-api</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>de.ruedigermoeller</groupId>
<artifactId>fst</artifactId>
<version>2.47</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
<dependency>
<groupId>com.github.davidmoten</groupId>
<artifactId>flatbuffers-java</artifactId>
<version>1.7.0.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.8.0</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-plasma</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>
</project>
@@ -4,6 +4,7 @@ import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -83,7 +84,9 @@ public class ArgumentsBuilder {
@SuppressWarnings({"rawtypes", "unchecked"})
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
throws TaskExecutionException {
FunctionArg[] fargs = task.args;
// the last arg is className
FunctionArg[] fargs = Arrays.copyOf(task.args, task.args.length - 1);
Object current = null;
Object[] realArgs;
@@ -68,7 +68,7 @@ public class InvocationExecutor {
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
String taskdesc)
String taskdesc)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Map<?, UniqueID> userRayReturnIdMap = null;
@@ -86,11 +86,12 @@ public class InvocationExecutor {
}
// execute
Object result;
if (!UniqueIdHelper.isLambdaFunction(task.functionId)) {
Object result = null;
try {
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
} else {
result = m.invoke(realArgs.getLeft(), new Object[] {realArgs.getRight()});
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
RayLog.core.error("invoke failed:" + m);
throw e;
}
if (task.returnIds == null || task.returnIds.length == 0) {
@@ -147,4 +148,4 @@ public class InvocationExecutor {
private static void safePut(UniqueID objectId, Object obj) {
RayRuntime.getInstance().putRaw(objectId, obj);
}
}
}
@@ -1,15 +1,14 @@
package org.ray.core;
import java.lang.reflect.Method;
import com.google.common.base.Preconditions;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.JarLoader;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.model.FunctionArg;
import org.ray.spi.model.RayActorMethods;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.RayTaskMethods;
import org.ray.util.logger.RayLog;
/**
@@ -18,6 +17,7 @@ import org.ray.util.logger.RayLog;
public class LocalFunctionManager {
private final RemoteFunctionManager remoteLoader;
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables
= new ConcurrentHashMap<>();
@@ -29,74 +29,44 @@ public class LocalFunctionManager {
this.remoteLoader = remoteLoader;
}
private FunctionTable loadDriverFunctions(UniqueID driverId) {
FunctionTable functionTable = functionTables.get(driverId);
if (functionTable == null) {
RayLog.core.info("DriverId " + driverId + " Try to load functions");
ClassLoader classLoader = remoteLoader.loadResource(driverId);
if (classLoader == null) {
throw new RuntimeException(
"Cannot find resource' classLoader for app " + driverId.toString());
}
functionTable = new FunctionTable(classLoader);
functionTables.put(driverId, functionTable);
}
return functionTable;
}
Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId, String className) {
// assert the driver's resource is load.
FunctionTable functionTable = loadDriverFunctions(driverId);
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
: functionTable.getActorMethod(methodId, className);
Preconditions
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
methodId, driverId);
return Pair.of(functionTable.classLoader, method);
}
/**
* get local method for executing, which pulls information from remote repo on-demand, therefore
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID methodId,
FunctionArg[] args) throws NoSuchMethodException,
SecurityException, ClassNotFoundException {
FunctionTable funcs = loadDriverFunctions(driverId);
RayMethod m;
// hooked methods
if (!UniqueIdHelper.isLambdaFunction(methodId)) {
m = funcs.functions.get(methodId);
if (null == m) {
throw new RuntimeException(
"DriverId " + driverId + " load remote function methodId:" + methodId + " failed");
}
} else { // remote lambda
assert args.length >= 2;
String fname = Serializer.decode(args[args.length - 2].data);
Method fm = Class.forName(fname).getMethod("execute", Object[].class);
m = new RayMethod(fm);
}
return Pair.of(funcs.linkedFunctions.loader, m);
}
private synchronized FunctionTable loadDriverFunctions(UniqueID driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (null == funcs) {
RayLog.core.debug("DriverId " + driverId + " Try to load functions");
LoadedFunctions funcs2 = remoteLoader.loadFunctions(driverId);
if (funcs2 == null) {
throw new RuntimeException("Cannot find resource for app " + driverId.toString());
}
funcs = new FunctionTable();
funcs.linkedFunctions = funcs2;
for (MethodId mid : funcs.linkedFunctions.functions) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
UniqueID k = new UniqueID(mid.getSha1Hash());
String logInfo =
"DriverId" + driverId + " load remote function " + m.getName() + ", hash = " + k
.toString();
RayLog.core.debug(logInfo);
System.err.println(logInfo);
funcs.functions.put(k, v);
}
functionTables.put(driverId, funcs);
} else { // reSync automatically
// more functions are loaded
if (funcs.linkedFunctions.functions.size() > funcs.functions.size()) {
for (MethodId mid : funcs.linkedFunctions.functions) {
UniqueID k = new UniqueID(mid.getSha1Hash());
if (!funcs.functions.containsKey(k)) {
Method m = mid.load();
assert (m != null);
RayMethod v = new RayMethod(m);
v.check();
funcs.functions.put(k, v);
}
}
}
}
return funcs;
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
UniqueID methodId,
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
String className = (String) Serializer.decode(args[args.length - 1].data);
return getMethod(driverId, actorId, methodId, className);
}
/**
@@ -106,13 +76,47 @@ public class LocalFunctionManager {
FunctionTable funcs = functionTables.get(driverId);
if (funcs != null) {
functionTables.remove(driverId);
JarLoader.unloadJars(funcs.linkedFunctions.loader);
remoteLoader.unloadFunctions(driverId);
}
}
static class FunctionTable {
private static class FunctionTable {
public final ConcurrentHashMap<UniqueID, RayMethod> functions = new ConcurrentHashMap<>();
public LoadedFunctions linkedFunctions;
final ClassLoader classLoader;
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, RayActorMethods> actors = new ConcurrentHashMap<>();
FunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
}
RayMethod getTaskMethod(UniqueID methodId, String className) {
RayTaskMethods tasks = taskMethods.get(className);
if (tasks == null) {
tasks = RayTaskMethods.fromClass(className, classLoader);
RayLog.core.info("create RayTaskMethods:" + tasks);
taskMethods.put(className, tasks);
}
RayMethod m = tasks.functions.get(methodId);
if (m != null) {
return m;
}
// it is a actor static func.
return getActorMethod(methodId, className, true);
}
RayMethod getActorMethod(UniqueID methodId, String className) {
return getActorMethod(methodId, className, false);
}
private RayMethod getActorMethod(UniqueID methodId, String className, boolean isStatic) {
RayActorMethods actor = actors.get(className);
if (actor == null) {
actor = RayActorMethods.fromClass(className, classLoader);
RayLog.core.info("create RayActorMethods:" + actor);
actors.put(className, actor);
}
return isStatic ? actor.staticFunctions.get(methodId) : actor.functions.get(methodId);
}
}
}
}
@@ -19,7 +19,7 @@ import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.internal.Callable;
import org.ray.api.internal.RayFunc;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
@@ -261,49 +261,23 @@ public abstract class RayRuntime implements RayApi {
}
@Override
public RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args) {
return worker.rpc(taskId, funcRun, returnCount, args);
}
@Override
public RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, UniqueID.nil, funcCls, lambda, returnCount, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RIDT> returnIds,
Object... args) {
return worker.rpcWithReturnLabels(taskId, funcRun, returnIds, args);
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Collection<RIDT>
returnids,
Object... args) {
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcRun, returnCount, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount, Object...
args) {
RayFunc lambda, Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
@Override
public boolean isRemoteLambda() {
return params.run_mode.isRemoteLambda();
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
@@ -465,4 +439,4 @@ public abstract class RayRuntime implements RayApi {
public RemoteFunctionManager getRemoteFunctionManager() {
return remoteFunctionManager;
}
}
}
@@ -249,12 +249,6 @@ public class UniqueIdHelper {
return taskId;
}
public static boolean isLambdaFunction(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return wbb.getLong() == 0xffffffffffffffffL;
}
public static void markCreateActorStage1Function(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
@@ -279,4 +273,4 @@ public class UniqueIdHelper {
TASK,
ACTOR,
}
}
}
@@ -1,10 +1,11 @@
package org.ray.core;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
@@ -12,12 +13,13 @@ import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.internal.Callable;
import org.ray.hook.runtime.MethodSwitcher;
import org.ray.api.internal.RayFunc;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.LambdaUtils;
import org.ray.util.MethodId;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -47,13 +49,13 @@ public class Worker {
RayLog.core.info("Task " + task.taskId + " start execute");
Throwable ex = null;
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
}
try {
Pair<ClassLoader, RayMethod> pr = funcs.getMethod(task.driverId, task.functionId, task.args);
Pair<ClassLoader, RayMethod> pr = funcs
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
WorkerContext.prepare(task, pr.getLeft());
InvocationExecutor.execute(task, pr);
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
@@ -72,161 +74,75 @@ public class Worker {
}
public RayObjects rpc(UniqueID taskId, Callable funcRun, int returnCount, Object[] args) {
byte[] fid = fidFromHook(funcRun);
return submit(taskId, fid, returnCount, false, args);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, Class<?> funcCls, Serializable lambda,
int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return submit(taskId, fid, returnCount, false, ls);
}
public RayObjects rpc(UniqueID taskId, RayActor<?> actor, Callable funcRun, int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
return actorTaskSubmit(taskId, fid, returnCount, false, args, actor);
}
public RayObjects rpc(UniqueID taskId, UniqueID functionId, RayActor<?> actor, Class<?> funcCls,
Serializable lambda, int returnCount, Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return actorTaskSubmit(taskId, fid, returnCount, false, ls, actor);
}
private byte[] fidFromHook(Callable funcRun) {
MethodSwitcher.IsRemoteCall.set(true);
try {
funcRun.run();
} catch (Throwable e) {
RayLog.core.error(
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for"
+ " options", e);
throw new RuntimeException("make sure you are using code rewritten using the rewrite tool,"
+ "see JarRewriter for options");
}
byte[] fid = MethodSwitcher.MethodId.get();//get the identity of function from hook
MethodSwitcher.IsRemoteCall.set(false);
return fid;
}
private RayObjects submit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, fid, returnCount, multiReturn, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, fid, returnCount, multiReturn, args);
}
private RayObjects taskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
}
private RayObjects actorTaskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args,
RayActor<?> actor) {
RayInvocation ri = new RayInvocation(fid, args, actor);
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args,
RayActor<?> actor) {
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
actor.setTaskCursor(returnObjs.pop().getId());
return returnObjs;
}
private RayObjects taskSubmit(UniqueID taskId,
byte[] fid,
int returnCount,
boolean multiReturn,
Object[] args) {
RayInvocation ri = new RayInvocation(fid, args);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, Callable funcRun,
int returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayInvocation ri = new RayInvocation(fid, new Object[] {});
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, UniqueID functionId,
Class<?> funcCls, Serializable lambda, int returnCount,
Object[] args) {
byte[] fid = functionId.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayInvocation ri = new RayInvocation(fid, ls);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Callable funcRun,
Collection<RIDT> returnids,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
private RayObjects submit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
return scheduler.submit(taskId, new RayInvocation(fid, args), returnids);
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
(RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
}
}
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
int returnCount, Object[] args) {
MethodId mid = methodIdOf(lambda);
return submit(taskId, mid, returnCount, false, args);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
Preconditions.checkNotNull(taskId);
MethodId mid = methodIdOf(lambda);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
Serializable lambda,
Collection<RIDT> returnids,
Object[] args) {
final byte[] fid = UniqueID.nil.getBytes();
RayFunc lambda, Collection<RIDT> returnids,
Object[] args) {
MethodId mid = methodIdOf(lambda);
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
return scheduler.submit(taskId, new RayInvocation(fid, ls), returnids);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnids);
}
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Callable funcRun,
Integer returnCount,
Object[] args) {
byte[] fid = fidFromHook(funcRun);
RayObjects objs = submit(taskId, fid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
}
@SuppressWarnings("unchecked")
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
Serializable lambda, Integer returnCount,
Object[] args) {
byte[] fid = UniqueID.nil.getBytes();
Object[] ls = Arrays.copyOf(args, args.length + 2);
ls[args.length] = funcCls.getName();
ls[args.length + 1] = SerializationUtils.serialize(lambda);
RayObjects objs = submit(taskId, fid, returnCount, true, ls);
RayFunc lambda, Integer returnCount,
Object[] args) {
MethodId mid = methodIdOf(lambda);
RayObjects objs = submit(taskId, mid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
@@ -234,6 +150,27 @@ public class Worker {
return rets;
}
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
RayActor<?> actor) {
UniqueID driverId = WorkerContext.currentTask().driverId;
Object[] ls = Arrays.copyOf(args, args.length + 1);
ls[args.length] = methodId.className;
RayMethod method = functions
.getMethod(driverId, actor.getId(), new UniqueID(methodId.getSha1Hash()),
methodId.className).getRight();
RayInvocation ri = new RayInvocation(methodId.className, method.getFuncId(),
ls, method.remoteAnnotation, actor);
return ri;
}
private MethodId methodIdOf(RayFunc serialLambda) {
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
return mid;
}
public UniqueID getCurrentTaskId() {
return WorkerContext.currentTask().taskId;
}
@@ -246,4 +183,4 @@ public class Worker {
public UniqueID[] getCurrentTaskReturnIDs() {
return WorkerContext.currentTask().returnIds;
}
}
}
@@ -1,40 +1,23 @@
package org.ray.core.model;
public enum RunMode {
SINGLE_PROCESS(true, false, true, false), // remote lambda, dev path, dev runtime
SINGLE_BOX(true, false, true, true), // remote lambda, dev path, native runtime
CLUSTER(false, true, false, true); // static rewrite, deploy path, naive runtime
SINGLE_PROCESS(true, false), // dev path, dev runtime
SINGLE_BOX(true, true), // dev path, native runtime
CLUSTER(false, true); // deploy path, naive runtime
private final boolean remoteLambda;
private final boolean staticRewrite;
private final boolean devPathManager;
private final boolean nativeRuntime;
RunMode(boolean remoteLambda, boolean staticRewrite, boolean devPathManager,
boolean nativeRuntime) {
this.remoteLambda = remoteLambda;
this.staticRewrite = staticRewrite;
RunMode(boolean devPathManager,
boolean nativeRuntime) {
this.devPathManager = devPathManager;
this.nativeRuntime = nativeRuntime;
}
/**
* Getter method for property <tt>remoteLambda</tt>.
*
* @return property value of remoteLambda
* the jar has add to java -cp, no need to load jar after started.
*/
public boolean isRemoteLambda() {
return remoteLambda;
}
private final boolean devPathManager;
/**
* Getter method for property <tt>staticRewrite</tt>.
*
* @return property value of staticRewrite
*/
public boolean isStaticRewrite() {
return staticRewrite;
}
private final boolean nativeRuntime;
/**
* Getter method for property <tt>devPathManager</tt>.
@@ -53,4 +36,4 @@ public enum RunMode {
public boolean isNativeRuntime() {
return nativeRuntime;
}
}
}
@@ -97,7 +97,7 @@ public class LocalSchedulerProxy {
task.args = ArgumentsBuilder.wrap(invocation);
task.driverId = current.driverId;
task.functionId = new UniqueID(invocation.getId());
task.functionId = invocation.getId();
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
task.parentTaskId = current.taskId;
task.actorHandleId = invocation.getActor().getActorHandleId();
@@ -1,18 +1,12 @@
package org.ray.spi;
import java.util.Set;
import org.ray.api.UniqueID;
import org.ray.hook.MethodId;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.util.logger.RayLog;
/**
* mock version of remote function manager using local loaded jars + runtime hook.
* mock version of remote function manager using local loaded jars.
*/
public class NopRemoteFunctionManager implements RemoteFunctionManager {
private final LoadedFunctions loadedFunctions = new LoadedFunctions();
public NopRemoteFunctionManager(UniqueID driverId) {
//onLoad(driverId, Agent.hookedMethods);
//Agent.consumers.add(m -> { this.onLoad(m); });
@@ -51,14 +45,9 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
}
@Override
public LoadedFunctions loadFunctions(UniqueID driverId) {
public ClassLoader loadResource(UniqueID driverId) {
//assert (startupDriverId().equals(driverId));
if (loadedFunctions == null) {
RayLog.rapp.error("cannot find functions for " + driverId);
return null;
} else {
return loadedFunctions;
}
return this.getClass().getClassLoader();
}
@Override
@@ -66,15 +55,4 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
// never
//assert (startupDriverId().equals(driverId));
}
private void onLoad(UniqueID driverId, Set<MethodId> methods) {
//assert (startupDriverId().equals(driverId));
for (MethodId mid : methods) {
onLoad(mid);
}
}
private void onLoad(MethodId mid) {
loadedFunctions.functions.add(mid);
}
}
}
@@ -1,7 +1,6 @@
package org.ray.spi;
import org.ray.api.UniqueID;
import org.ray.hook.runtime.LoadedFunctions;
/**
* register and load functions from function table.
@@ -53,14 +52,13 @@ public interface RemoteFunctionManager {
void unregisterApp(UniqueID driverId);
/**
* load resource and functions for this driver this function is used by the workers on demand when
* a required function is not found in {@code LocalFunctionManager}.
* load resource.
*/
LoadedFunctions loadFunctions(UniqueID driverId);
ClassLoader loadResource(UniqueID driverId);
/**
* unload functions for this driver
* this function is used by the workers on demand when a driver is dead.
*/
void unloadFunctions(UniqueID driverId);
}
}
@@ -0,0 +1,66 @@
package org.ray.spi.model;
import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
public final class RayActorMethods {
public final Class clazz;
public final RayRemote remoteAnnotation;
public final Map<UniqueID, RayMethod> functions;
/**
* the static function in Actor, call as task.
*/
public final Map<UniqueID, RayMethod> staticFunctions;
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
Map<UniqueID, RayMethod> functions, Map<UniqueID, RayMethod> staticFunctions) {
this.clazz = clazz;
this.remoteAnnotation = remoteAnnotation;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
}
public static RayActorMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
Preconditions
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
Map<UniqueID, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isPublic(m.getModifiers())) {
continue;
}
RayMethod rayMethod = RayMethod.from(m, remoteAnnotation);
if (Modifier.isStatic(m.getModifiers())) {
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
} else {
functions.put(rayMethod.getFuncId(), rayMethod);
}
}
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayActorMethods from " + clazzName, e);
}
}
@Override
public String toString() {
return String
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
functions.values(),
staticFunctions.size(), staticFunctions.values());
}
}
@@ -1,6 +1,7 @@
package org.ray.spi.model;
import org.ray.api.RayActor;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
/**
@@ -9,30 +10,30 @@ import org.ray.api.UniqueID;
public class RayInvocation {
private static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
public final String className;
/**
* unique id for a method.
*
* @see UniqueID
*/
private final byte[] id;
private final RayActor<?> actor;
private final UniqueID id;
private final RayRemote remoteAnnotation;
/**
* function arguments.
*/
private Object[] args;
public RayInvocation(byte[] id, Object[] args) {
this(id, args, nil);
}
public RayInvocation(byte[] id, Object[] args, RayActor<?> actor) {
super();
private RayActor<?> actor;
public RayInvocation(String className, UniqueID id, Object[] args, RayRemote remoteAnnotation,
RayActor<?> actor) {
this.className = className;
this.id = id;
this.args = args;
this.actor = actor;
this.remoteAnnotation = remoteAnnotation;
}
public byte[] getId() {
public UniqueID getId() {
return id;
}
@@ -48,4 +49,8 @@ public class RayInvocation {
return actor;
}
}
public RayRemote getRemoteAnnotation() {
return remoteAnnotation;
}
}
@@ -1,6 +1,9 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.util.MethodId;
/**
* method info.
@@ -9,12 +12,8 @@ public class RayMethod {
public final Method invokable;
public final String fullName;
// TODO: other annotated information
public RayMethod(Method m) {
invokable = m;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
}
public final RayRemote remoteAnnotation;
private final UniqueID funcId;
public void check() {
for (Class<?> paramCls : invokable.getParameterTypes()) {
@@ -24,4 +23,32 @@ public class RayMethod {
}
}
}
}
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueID funcId) {
this.invokable = m;
this.remoteAnnotation = remoteAnnotation;
this.funcId = funcId;
fullName = m.getDeclaringClass().getName() + "." + m.getName();
}
public static RayMethod from(Method m, RayRemote parentRemoteAnnotation) {
Class<?> clazz = m.getDeclaringClass();
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromMethod(m);
UniqueID funcId = new UniqueID(mid.getSha1Hash());
RayMethod method = new RayMethod(m,
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
funcId);
method.check();
return method;
}
@Override
public String toString() {
return fullName;
}
public UniqueID getFuncId() {
return funcId;
}
}
@@ -0,0 +1,54 @@
package org.ray.spi.model;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
public final class RayTaskMethods {
public final Class clazz;
public final Map<UniqueID, RayMethod> functions;
public RayTaskMethods(Class clazz,
Map<UniqueID, RayMethod> functions) {
this.clazz = clazz;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
}
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
Method[] methods = clazz.getDeclaredMethods();
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
for (Method m : methods) {
if (!Modifier.isStatic(m.getModifiers())) {
continue;
}
//task method only for static.
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
if (remoteAnnotation == null) {
continue;
}
m.setAccessible(true);
RayMethod rayMethod = RayMethod.from(m, null);
functions.put(rayMethod.getFuncId(), rayMethod);
}
return new RayTaskMethods(clazz, functions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
}
}
@Override
public String toString() {
return String
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
}
}
@@ -74,9 +74,8 @@ public class RayNativeRuntime extends RayRuntime {
}
// initialize remote function manager
RemoteFunctionManager funcMgr = params.run_mode.isStaticRewrite()
? new NativeRemoteFunctionManager(kvStore) :
new NopRemoteFunctionManager(params.driver_id);
RemoteFunctionManager funcMgr = params.run_mode.isDevPathManager()
? new NopRemoteFunctionManager(params.driver_id) : new NativeRemoteFunctionManager(kvStore);
// initialize worker context
if (params.worker_mode == WorkerMode.DRIVER) {
@@ -101,6 +100,7 @@ public class RayNativeRuntime extends RayRuntime {
int releaseDelay = RayRuntime.configReader
.getIntegerValue("ray", "plasma_default_release_delay", 0,
"how many release requests should be delayed in plasma client");
ObjectStoreLink plink = new PlasmaClient(params.object_store_name, params
.object_store_manager_name, releaseDelay);
@@ -163,7 +163,7 @@ public class RayNativeRuntime extends RayRuntime {
}
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
String managerName, String schedulerName) {
String managerName, String schedulerName) {
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
if (!isWorker) {
@@ -193,26 +193,16 @@ public class RayNativeRuntime extends RayRuntime {
UniqueID actorId = UniqueIdHelper.taskComputeReturnId(createTaskId, 0, false);
RayActor<T> actor = new RayActor<>(actorId);
UniqueID cursorId;
if (params.run_mode.isRemoteLambda()) {
RayFunc_2_1<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
cursorId = worker.rpcCreateActor(
createTaskId,
actorId,
UniqueID.nil,
RayFunc_2_1.class,
createActorLambda,
1,
new Object[] {actorId.getBytes(), cls.getName()}
).getObjs()[0].getId();
} else {
cursorId = worker.rpcCreateActor(
createTaskId,
actorId,
() -> RayNativeRuntime.createActorInActor(null, null),
1,
new Object[] {actorId.getBytes(), cls.getName()}
).getObjs()[0].getId();
}
RayFunc_2_1<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
cursorId = worker.rpcCreateActor(
createTaskId,
actorId,
RayFunc_2_1.class,
createActorLambda,
1,
new Object[]{actorId.getBytes(), cls.getName()}
).getObjs()[0].getId();
actor.setTaskCursor(cursorId);
return actor;
}
@@ -247,4 +237,4 @@ public class RayNativeRuntime extends RayRuntime {
throw new TaskExecutionException(log, e);
}
}
}
}
@@ -0,0 +1,89 @@
package org.ray.spi.impl;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.ray.util.logger.RayLog;
/**
* load and unload jars from a dir.
*/
public class JarLoader {
public static URLClassLoader loadJars(String dir, boolean explicitLoad) {
// get all jars
Collection<File> jars = FileUtils.listFiles(
new File(dir),
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY
);
return loadJar(jars, explicitLoad);
}
public static void unloadJars(ClassLoader loader) {
// now do nothing, if no ref to the loader and loader's class.
// they would be gc.
}
private static URLClassLoader loadJar(Collection<File> appJars, boolean explicitLoad) {
List<JarFile> jars = new ArrayList<>();
List<URL> urls = new ArrayList<>();
for (File appJar : appJars) {
try {
RayLog.core.info("load jar " + appJar.getAbsolutePath());
JarFile jar = new JarFile(appJar.getAbsolutePath());
jars.add(jar);
urls.add(appJar.toURI().toURL());
} catch (IOException e) {
throw new RuntimeException(
"invalid app jar path: " + appJar.getAbsolutePath() + ", load failed with exception",
e);
}
}
URLClassLoader cl = URLClassLoader.newInstance(urls.toArray(new URL[urls.size()]));
if (!explicitLoad) {
return cl;
}
for (JarFile jar : jars) {
try {
Enumeration<JarEntry> e = jar.entries();
while (e.hasMoreElements()) {
JarEntry je = e.nextElement();
if (je.isDirectory() || !je.getName().endsWith(".class")) {
continue;
}
String className = classNameOfJarEntry(je);
className = className.replace('/', '.');
try {
Class.forName(className, true, cl);
} catch (ClassNotFoundException e1) {
e1.printStackTrace();
}
}
} finally {
IOUtils.closeQuietly(jar);
}
}
return cl;
}
private static String classNameOfJarEntry(JarEntry je) {
return je.getName().substring(0, je.getName().length() - ".class".length());
}
}
@@ -6,13 +6,10 @@ import java.security.NoSuchAlgorithmException;
import java.util.concurrent.ConcurrentHashMap;
import net.lingala.zip4j.core.ZipFile;
import org.ray.api.UniqueID;
import org.ray.core.RayRuntime;
import org.ray.hook.JarRewriter;
import org.ray.hook.runtime.JarLoader;
import org.ray.hook.runtime.LoadedFunctions;
import org.ray.spi.KeyValueStoreLink;
import org.ray.spi.RemoteFunctionManager;
import org.ray.util.FileUtil;
import org.ray.util.Sha1Digestor;
import org.ray.util.SystemUtil;
import org.ray.util.logger.RayLog;
@@ -21,10 +18,11 @@ import org.ray.util.logger.RayLog;
*/
public class NativeRemoteFunctionManager implements RemoteFunctionManager {
private ConcurrentHashMap<UniqueID, LoadedFunctions> loadedApps = new ConcurrentHashMap<>();
private final ConcurrentHashMap<UniqueID, ClassLoader> loadedApps = new ConcurrentHashMap<>();
private MessageDigest md;
private String appDir = System.getProperty("user.dir") + "/apps";
private KeyValueStoreLink kvStore;
private final String appDir = System.getProperty("user.dir") + "/apps";
private final KeyValueStoreLink kvStore;
public NativeRemoteFunctionManager(KeyValueStoreLink kvStore) throws NoSuchAlgorithmException {
this.kvStore = kvStore;
@@ -38,24 +36,20 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
@Override
public UniqueID registerResource(byte[] resourceZip) {
byte[] digest = md.digest(resourceZip);
byte[] digest = Sha1Digestor.digest(resourceZip);
assert (digest.length == UniqueID.LENGTH);
UniqueID resourceId = new UniqueID(digest);
// TODO: resources must be saved in persistent store
// instead of cache
//if (!Ray.exist(resourceId)) {
//Ray.put(resourceId, resourceZip);
kvStore.set(resourceId.getBytes(), resourceZip, null);
//}
return resourceId;
}
@Override
public byte[] getResource(UniqueID resourceId) {
return kvStore.get(resourceId.getBytes(), null);
//return (byte[])Ray.get(resourceId);
}
@Override
@@ -65,7 +59,6 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
@Override
public void registerApp(UniqueID driverId, UniqueID resourceId) {
//Ray.put(driverId, resourceId);
kvStore.set("App2ResMap", resourceId.toString(), driverId.toString());
}
@@ -80,27 +73,31 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
}
@Override
public LoadedFunctions loadFunctions(UniqueID driverId) {
LoadedFunctions rf = loadedApps.get(driverId);
if (rf == null) {
rf = initLoadedApps(driverId);
public ClassLoader loadResource(UniqueID driverId) {
ClassLoader classLoader = loadedApps.get(driverId);
if (classLoader == null) {
synchronized (this) {
classLoader = loadedApps.get(driverId);
if (classLoader == null) {
classLoader = initLoadedApps(driverId);
}
}
}
return rf;
return classLoader;
}
private synchronized LoadedFunctions initLoadedApps(UniqueID driverId) {
private ClassLoader initLoadedApps(UniqueID driverId) {
try {
RayLog.core.info("initLoadedApps" + driverId.toString());
LoadedFunctions rf = loadedApps.get(driverId);
if (rf == null) {
UniqueID resId = new UniqueID(kvStore.get("App2ResMap", driverId.toString()));
//UniqueID resId = Ray.get(driverId);
ClassLoader cl = loadedApps.get(driverId);
if (cl == null) {
UniqueID resId = new UniqueID(kvStore.get("App2ResMap", driverId.toString()));
byte[] res = getResource(resId);
if (res == null) {
throw new RuntimeException("get resource null, the resId " + resId.toString());
}
RayLog.core.info("ger resource of " + resId.toString() + ", result len " + res.length);
RayLog.core.info("get resource of " + resId.toString() + ", result len " + res.length);
String resPath =
appDir + "/" + driverId.toString() + "/" + String.valueOf(SystemUtil.pid());
File dir = new File(resPath);
@@ -112,11 +109,10 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
FileUtil.bytesToFile(res, zipPath);
ZipFile zipFile = new ZipFile(zipPath);
zipFile.extractAll(resPath);
rf = JarRewriter
.load(resPath, RayRuntime.getInstance().getPaths().java_runtime_rewritten_jars_dir);
loadedApps.put(driverId, rf);
cl = JarLoader.loadJars(resPath, false);
loadedApps.put(driverId, cl);
}
return rf;
return cl;
} catch (Exception e) {
RayLog.rapp.error("load function for " + driverId + " failed, ex = " + e.getMessage(), e);
return null;
@@ -125,11 +121,11 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
@Override
public synchronized void unloadFunctions(UniqueID driverId) {
LoadedFunctions rf = loadedApps.get(driverId);
ClassLoader cl = loadedApps.get(driverId);
try {
JarLoader.unloadJars(rf.loader);
JarLoader.unloadJars(cl);
} catch (Exception e) {
RayLog.rapp.error("unload function for " + driverId + " failed, ex = " + e.getMessage(), e);
}
}
}
}
-6
View File
@@ -50,12 +50,6 @@
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.ray</groupId>
<artifactId>ray-hook</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
@@ -0,0 +1,212 @@
package org.ray.api.test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.funcs.RayFunc_0_1;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.util.MethodId;
import org.ray.util.logger.RayLog;
public class LambdaUtilsTest {
static final String CLASS_NAME = LambdaUtilsTest.class.getName();
static final Method CALL0;
static final Method CALL1;
static final Method CALL2;
static final Method CALL3;
static {
try {
CALL0 = LambdaUtilsTest.class.getDeclaredMethod("call0", new Class[0]);
CALL1 = LambdaUtilsTest.class.getDeclaredMethod("call1", new Class[]{Long.class});
CALL2 = LambdaUtilsTest.class.getDeclaredMethod("call2", new Class[0]);
CALL3 = LambdaUtilsTest.class
.getDeclaredMethod("call3", new Class[]{Long.class, String.class});
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc_3_1<T0, T1, T2, R0> f, int n,
boolean forceNew, boolean debug)
throws Exception {
if (debug) {
RayLog.core.info("parse#" + f.getClass().getName());
}
long start = System.nanoTime();
for (int i = 0; i < n; i++) {
MethodId mid = MethodId.fromSerializedLambda(f, forceNew);
}
long end = System.nanoTime();
RayLog.core.info(String.format("remoteLambdaParse(new=%s):total=%sms, one=%s", forceNew,
TimeUnit.NANOSECONDS.toMillis(end - start),
(end - start) / n));
}
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc_3_1<T0, T1, T2, R0> f, int n,
boolean de, boolean debug)
throws Exception {
if (debug) {
RayLog.core.info("se#" + f.getClass().getName());
}
long start = System.nanoTime();
for (int i = 0; i < n; i++) {
ByteArrayOutputStream bytes = new ByteArrayOutputStream(1024);
ObjectOutputStream out = new ObjectOutputStream(bytes);
out.writeObject(f);
out.close();
if (de) {
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()));
RayFunc_3_1 def = (RayFunc_3_1) in.readObject();
in.close();
if (debug) {
RayLog.core.info("de#" + def.getClass().getName());
}
}
}
long end = System.nanoTime();
RayLog.core.info(
String.format("remoteLambdaSer(de=%s):total=%sms,one=%s", de,
TimeUnit.NANOSECONDS.toMillis(end - start),
(end - start) / n));
}
public static void testCall0(RayFunc_0_1 f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL0);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall1(RayFunc_1_1<T, R> f, T t) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL1);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall2(RayFunc_1_1<T, R> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL2);
Assert.assertTrue(!mid.isStatic);
}
public static <T0, T1, T2, R0> void testCall3(RayFunc_3_1<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL3);
Assert.assertTrue(!mid.isStatic);
}
public static String call0() {
long t = System.currentTimeMillis();
RayLog.core.info("call0:" + t);
return String.valueOf(t);
}
public static String call1(Long v) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call1:" + v);
return String.valueOf(v);
}
@Test
public void testLambdaSer() throws Exception {
testCall0(LambdaUtilsTest::call0);
testCall1(LambdaUtilsTest::call1, Long.valueOf(System.currentTimeMillis()));
testCall2(LambdaUtilsTest::call2);
testCall3(LambdaUtilsTest::call3);
}
/**
* to test the serdeLambda's perf.
*/
public void testBenchmark() throws Exception {
//test serde
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
//warmup
RayLog.core.info("warmup:serde################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
RayLog.core.info("benchmark:serde################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
RayLog.core.info("benchmark:ser################");
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
//test serde one new call's time, no class cache
long start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
long end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
end = System.nanoTime();
RayLog.core.info("one sertime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test serde one new call's time, no class cache
start = System.nanoTime();
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
end = System.nanoTime();
RayLog.core.info("one serdetime:" + (end - start));
//test lambda
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, true, true);
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, false, true);
//warmup
RayLog.core.info("warmup:parse################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
RayLog.core.info("benchmark:parseNew################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
RayLog.core.info("benchmark:parseCache################");
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
}
public String call2() {
long t = System.currentTimeMillis();
RayLog.core.info("call2:" + t);
return "call2:" + t;
}
public String call3(Long v, String s) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call3:" + v);
return String.valueOf(v);
}
}
@@ -0,0 +1,40 @@
package org.ray.api.test;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.util.LambdaUtils;
import org.ray.util.MethodId;
import org.ray.util.logger.RayLog;
public class MethodIdTest {
public static <T0, T1, T2, R0> MethodId fromLambda(RayFunc_3_1<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f, true);
return mid;
}
public static MethodId fromClass(Method method) {
return MethodId.fromMethod(method);
}
@Test
public void testMethodId2From() throws Exception {
MethodId m1 = fromLambda(MethodIdTest::call);
Method m = MethodIdTest.class.getDeclaredMethod("call", new Class[]{long.class, String.class});
MethodId m2 = fromClass(m);
RayLog.core.info(m1.toString());
Assert.assertEquals(m1, m2);
}
public String call(long v, String s) {
for (int i = 0; i < 100; i++) {
v += i;
}
RayLog.core.info("call:" + v);
return String.valueOf(v);
}
}
@@ -0,0 +1,25 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.ray.spi.model.RayActorMethods;
import org.ray.util.logger.RayLog;
public class RayActorMethodsTest {
@Test
public void testActor() throws Exception {
RayActorMethods methods = RayActorMethods
.fromClass(ActorTest.Adder.class.getName(), RayActorMethodsTest.class.getClassLoader());
RayLog.core.info(methods.toString());
Assert.assertEquals(methods.functions.size(), 5);
Assert.assertEquals(methods.staticFunctions.size(), 1);
RayActorMethods methods2 = RayActorMethods
.fromClass(ActorTest.Adder2.class.getName(), RayActorMethodsTest.class.getClassLoader());
RayLog.core.info(methods2.toString());
Assert.assertEquals(methods2.functions.size(), 9);
Assert.assertEquals(methods2.staticFunctions.size(), 1);
}
}
@@ -0,0 +1,18 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.ray.spi.model.RayTaskMethods;
import org.ray.util.logger.RayLog;
public class RayTaskMethodsTest {
@Test
public void testTask() throws Exception {
RayTaskMethods methods = RayTaskMethods
.fromClass(EchoTest.class.getName(), RayTaskMethodsTest.class.getClassLoader());
RayLog.core.info(methods.toString());
Assert.assertEquals(methods.functions.size(), 3);
}
}
@@ -1,13 +0,0 @@
package org.ray.api.test;
import java.io.IOException;
import java.util.zip.DataFormatException;
import org.ray.hook.JarRewriter;
public class RewriteTest {
public static void main(String[] args) throws IOException, DataFormatException {
System.out.println(System.getProperty("user.dir"));
JarRewriter.rewrite("target", "target2");
}
}