mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 11:44:39 +08:00
[Java] Replace binary rewrite with Remote Lambda Cache (SerdeLambda) (#2245)
* <feature> : serde lambda * <feature>:fixed CR with issue #2245 * <feature>: fixed CR
This commit is contained in:
committed by
Philipp Moritz
parent
62de86ff7a
commit
fa0ade2bc5
@@ -103,13 +103,6 @@ public final class Ray extends Rpc {
|
||||
return impl;
|
||||
}
|
||||
|
||||
/**
|
||||
* whether to use remote lambda.
|
||||
*/
|
||||
public static boolean isRemoteLambda() {
|
||||
return impl.isRemoteLambda();
|
||||
}
|
||||
|
||||
/**
|
||||
* for ray's app's log.
|
||||
*/
|
||||
|
||||
@@ -10,7 +10,7 @@ import org.ray.util.Sha1Digestor;
|
||||
* Ray actor abstraction.
|
||||
*/
|
||||
public class RayActor<T> extends RayObject<T> implements Externalizable {
|
||||
|
||||
public static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
|
||||
private static final long serialVersionUID = 1877485807405645036L;
|
||||
|
||||
private int taskCounter = 0;
|
||||
@@ -84,4 +84,4 @@ public class RayActor<T> extends RayObject<T> implements Externalizable {
|
||||
this.actorHandleId = (UniqueID) in.readObject();
|
||||
this.taskCursor = (UniqueID) in.readObject();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package org.ray.api;
|
||||
import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import org.ray.api.internal.Callable;
|
||||
import org.ray.api.internal.RayFunc;
|
||||
import org.ray.util.exception.TaskExecutionException;
|
||||
|
||||
/**
|
||||
@@ -54,14 +54,13 @@ public interface RayApi {
|
||||
* submit a new task by invoking a remote function.
|
||||
*
|
||||
* @param taskId nil
|
||||
* @param funcRun the target running function with @RayRemote
|
||||
* @param funcCls the target running function's class
|
||||
* @param lambda the target running function
|
||||
* @param returnCount the number of to-be-returned objects from funcRun
|
||||
* @param args arguments to this funcRun, can be its original form or RayObject
|
||||
* @return a set of ray objects with their return ids
|
||||
*/
|
||||
RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args);
|
||||
|
||||
RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
|
||||
RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
|
||||
Object... args);
|
||||
|
||||
/**
|
||||
@@ -71,34 +70,27 @@ public interface RayApi {
|
||||
* outputs with a set of labels (usually with Integer or String).
|
||||
*
|
||||
* @param taskId nil
|
||||
* @param funcRun the target running function with @RayRemote
|
||||
* @param returnIds a set of labels to be used by the returned objects
|
||||
* @param funcCls the target running function's class
|
||||
* @param lambda the target running function
|
||||
* @param returnids a set of labels to be used by the returned objects
|
||||
* @param args arguments to this funcRun, can be its original form or
|
||||
* RayObject<original-type>
|
||||
* @return a set of ray objects with their labels and return ids
|
||||
*/
|
||||
<R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
|
||||
Collection<RIDT> returnIds, Object... args);
|
||||
|
||||
<R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Collection<RIDT> returnids,
|
||||
Object... args);
|
||||
RayFunc lambda, Collection<RIDT> returnids, Object... args);
|
||||
|
||||
/**
|
||||
* a special case for the above RID-based labeling as <0...returnCount - 1>.
|
||||
*
|
||||
* @param taskId nil
|
||||
* @param funcRun the target running function with @RayRemote
|
||||
* @param funcCls the target running function's class
|
||||
* @param lambda the target running function
|
||||
* @param returnCount the number of to-be-returned objects from funcRun
|
||||
* @param args arguments to this funcRun, can be its original form or
|
||||
* RayObject<original-type>
|
||||
* @return an array of returned objects with their Unique ids
|
||||
*/
|
||||
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun, Integer returnCount,
|
||||
Object... args);
|
||||
|
||||
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, Serializable lambda,
|
||||
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
|
||||
Integer returnCount, Object... args);
|
||||
|
||||
boolean isRemoteLambda();
|
||||
}
|
||||
|
||||
+1441
-5697
File diff suppressed because it is too large
Load Diff
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_1<R0> extends RayFunc {
|
||||
|
||||
static <R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_1<R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply();
|
||||
}
|
||||
|
||||
R0 apply() throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_2<R0, R1> extends RayFunc {
|
||||
|
||||
static <R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_2<R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply();
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply() throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_3<R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_3<R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply();
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply() throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_4<R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_4<R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply();
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply() throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_n<R, RIDT> extends RayFunc {
|
||||
|
||||
static <R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_n<R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((Collection<RIDT>) args[0]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_0_n_list<R> extends RayFunc {
|
||||
|
||||
static <R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_0_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_0_n_list<R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply();
|
||||
}
|
||||
|
||||
List<R> apply() throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_1<T0, R0> extends RayFunc {
|
||||
|
||||
static <T0, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_1<T0, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_2<T0, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_2<T0, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_3<T0, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_3<T0, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_4<T0, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_4<T0, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_n<T0, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_n<T0, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((Collection<RIDT>) args[0], (T0) args[1]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_1_n_list<T0, R> extends RayFunc {
|
||||
|
||||
static <T0, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_1_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_1_n_list<T0, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_1<T0, T1, R0> extends RayFunc {
|
||||
|
||||
static <T0, T1, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_1<T0, T1, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_2<T0, T1, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, T1, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_2<T0, T1, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_3<T0, T1, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, T1, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_3<T0, T1, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_4<T0, T1, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, T1, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_4<T0, T1, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_n<T0, T1, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, T1, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_n<T0, T1, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_2_n_list<T0, T1, R> extends RayFunc {
|
||||
|
||||
static <T0, T1, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_2_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_2_n_list<T0, T1, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0, T1 t1) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_1<T0, T1, T2, R0> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_1<T0, T1, T2, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_2<T0, T1, T2, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_2<T0, T1, T2, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_3<T0, T1, T2, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_3<T0, T1, T2, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_4<T0, T1, T2, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_4<T0, T1, T2, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,14 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_n<T0, T1, T2, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_n<T0, T1, T2, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_3_n_list<T0, T1, T2, R> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_3_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_3_n_list<T0, T1, T2, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_1<T0, T1, T2, T3, R0> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_1<T0, T1, T2, T3, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_2<T0, T1, T2, T3, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_2<T0, T1, T2, T3, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_3<T0, T1, T2, T3, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_3<T0, T1, T2, T3, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_4<T0, T1, T2, T3, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_4<T0, T1, T2, T3, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,15 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_n<T0, T1, T2, T3, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_n<T0, T1, T2, T3, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_4_n_list<T0, T1, T2, T3, R> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_4_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_4_n_list<T0, T1, T2, T3, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,14 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_1<T0, T1, T2, T3, T4, R0> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_1<T0, T1, T2, T3, T4, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_2<T0, T1, T2, T3, T4, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_2<T0, T1, T2, T3, T4, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_3<T0, T1, T2, T3, T4, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_3<T0, T1, T2, T3, T4, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_4<T0, T1, T2, T3, T4, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(
|
||||
Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_4<T0, T1, T2, T3, T4, R0, R1, R2, R3> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,16 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_n<T0, T1, T2, T3, T4, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_n<T0, T1, T2, T3, T4, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4],
|
||||
(T4) args[5]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4)
|
||||
throws Throwable;
|
||||
|
||||
|
||||
@@ -7,14 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_5_n_list<T0, T1, T2, T3, T4, R> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_5_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_5_n_list<T0, T1, T2, T3, T4, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,15 +6,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_1<T0, T1, T2, T3, T4, T5, R0> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R0> R0 execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_1.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_1<T0, T1, T2, T3, T4, T5, R0> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
|
||||
}
|
||||
|
||||
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,16 +7,6 @@ import org.ray.api.returns.MultipleReturns2;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_2<T0, T1, T2, T3, T4, T5, R0, R1> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R0, R1> MultipleReturns2<R0, R1> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_2.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_2<T0, T1, T2, T3, T4, T5, R0, R1> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
|
||||
}
|
||||
|
||||
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,16 +7,6 @@ import org.ray.api.returns.MultipleReturns3;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_3<T0, T1, T2, T3, T4, T5, R0, R1, R2> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R0, R1, R2> MultipleReturns3<R0, R1, R2> execute(Object[] args)
|
||||
throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_3.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_3<T0, T1, T2, T3, T4, T5, R0, R1, R2> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
|
||||
}
|
||||
|
||||
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -7,17 +7,6 @@ import org.ray.api.returns.MultipleReturns4;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_4<T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> MultipleReturns4<R0, R1, R2, R3> execute(
|
||||
Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_4.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_4<T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> f = SerializationUtils
|
||||
.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
|
||||
}
|
||||
|
||||
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -8,16 +8,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_n<T0, T1, T2, T3, T4, T5, R, RIDT> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R, RIDT> Map<RIDT, R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_n.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_n<T0, T1, T2, T3, T4, T5, R, RIDT> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((Collection<RIDT>) args[0], (T0) args[1], (T1) args[2], (T2) args[3], (T3) args[4],
|
||||
(T4) args[5], (T5) args[6]);
|
||||
}
|
||||
|
||||
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5)
|
||||
throws Throwable;
|
||||
|
||||
|
||||
@@ -7,15 +7,6 @@ import org.ray.api.internal.RayFunc;
|
||||
@FunctionalInterface
|
||||
public interface RayFunc_6_n_list<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
|
||||
|
||||
static <T0, T1, T2, T3, T4, T5, R> List<R> execute(Object[] args) throws Throwable {
|
||||
String name = (String) args[args.length - 2];
|
||||
assert (name.equals(RayFunc_6_n_list.class.getName()));
|
||||
byte[] funcBytes = (byte[]) args[args.length - 1];
|
||||
RayFunc_6_n_list<T0, T1, T2, T3, T4, T5, R> f = SerializationUtils.deserialize(funcBytes);
|
||||
return f
|
||||
.apply((T0) args[0], (T1) args[1], (T2) args[2], (T3) args[3], (T4) args[4], (T5) args[5]);
|
||||
}
|
||||
|
||||
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package org.ray.api.internal;
|
||||
|
||||
/**
|
||||
* hold the remote call.
|
||||
*/
|
||||
public interface Callable {
|
||||
|
||||
void run() throws Throwable;
|
||||
}
|
||||
+10
-4
@@ -22,19 +22,25 @@
|
||||
<dependency>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
<version>1.2.17</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>quartz</groupId>
|
||||
<artifactId>quartz</artifactId>
|
||||
<version>1.5.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ini4j</groupId>
|
||||
<artifactId>ini4j</artifactId>
|
||||
<version>0.5.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
</project>
|
||||
</project>
|
||||
@@ -0,0 +1,31 @@
|
||||
package org.ray.util;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Method;
|
||||
|
||||
/**
|
||||
* see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html.
|
||||
*/
|
||||
public final class LambdaUtils {
|
||||
|
||||
private LambdaUtils() {
|
||||
}
|
||||
|
||||
|
||||
public static SerializedLambda getSerializedLambda(Serializable lambda) {
|
||||
// Note.
|
||||
// the class of lambda which isAssignableFrom Serializable
|
||||
// has an privte method:writeReplace
|
||||
// This mechanism may be changed in the future
|
||||
try {
|
||||
Method m = lambda.getClass().getDeclaredMethod("writeReplace");
|
||||
m.setAccessible(true);
|
||||
return (SerializedLambda) m.invoke(lambda);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to getSerializedLambda:" + lambda.getClass().getName(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
package org.ray.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.lang.invoke.MethodHandleInfo;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.WeakHashMap;
|
||||
import org.objectweb.asm.Type;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
|
||||
/**
|
||||
* An instance of RayFunc is a lambda.
|
||||
* MethodId describe the information of the called function in lambda.<br/>
|
||||
* e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:<br/>
|
||||
* MethodId.className = Foo <br/>
|
||||
* MethodId.methodName = foo <br/>
|
||||
* MethodId.methodDesc = describe the types of args and return.
|
||||
* see org.objectweb.asm.Type.getDescriptor.
|
||||
*/
|
||||
public final class MethodId {
|
||||
|
||||
/**
|
||||
* use ThreadLocal to avoid lock.
|
||||
* A cache from the lambda instances to MethodId.
|
||||
* Note: the lambda instances are dynamically created per call site,
|
||||
* we use WeakHashMap to avoid OOM.
|
||||
*/
|
||||
private static final ThreadLocal<WeakHashMap<Class<Serializable>, MethodId>>
|
||||
CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>());
|
||||
|
||||
public final String className;
|
||||
public final String methodName;
|
||||
|
||||
public final String methodDesc;
|
||||
public final boolean isStatic;
|
||||
/**
|
||||
* encode the className,methodName,methodDesc,isStatic as an uniquel id.
|
||||
*/
|
||||
private final String encoding;
|
||||
|
||||
/**
|
||||
* sha1 from the encoding, used as functionId.
|
||||
*/
|
||||
private final byte[] digest;
|
||||
|
||||
public MethodId(String className, String methodName, String methodDesc, boolean isStatic) {
|
||||
this.className = className;
|
||||
this.methodName = methodName;
|
||||
this.methodDesc = methodDesc;
|
||||
this.isStatic = isStatic;
|
||||
this.encoding = encode(className, methodName, methodDesc, isStatic);
|
||||
this.digest = getSha1Hash0();
|
||||
}
|
||||
|
||||
private static String encode(String className, String methodName, String methodDesc,
|
||||
boolean isStatic) {
|
||||
StringBuilder sb = new StringBuilder(512);
|
||||
sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&")
|
||||
.append(isStatic);
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public static MethodId fromMethod(Method method) {
|
||||
final boolean isstatic = Modifier.isStatic(method.getModifiers());
|
||||
final String className = method.getDeclaringClass().getName();
|
||||
final String methodName = method.getName();
|
||||
final Type type = Type.getType(method);
|
||||
final String methodDesc = type.getDescriptor();
|
||||
return new MethodId(className, methodName, methodDesc, isstatic);
|
||||
}
|
||||
|
||||
public static MethodId fromSerializedLambda(Serializable serial) {
|
||||
return fromSerializedLambda(serial, false);
|
||||
}
|
||||
|
||||
public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) {
|
||||
Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be "
|
||||
+ "SerializedLambda");
|
||||
Class<Serializable> clazz = (Class<Serializable>) serial.getClass();
|
||||
WeakHashMap<Class<Serializable>, MethodId> map = CACHE.get();
|
||||
MethodId id = map.get(clazz);
|
||||
if (id == null || forceNew) {
|
||||
final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial);
|
||||
Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda "
|
||||
+ "which is closure");
|
||||
final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic;
|
||||
final String className = lambda.getImplClass().replace('/', '.');
|
||||
id = new MethodId(className, lambda.getImplMethodName(),
|
||||
lambda.getImplMethodSignature(), isStatic);
|
||||
if (!forceNew) {
|
||||
map.put(clazz, id);
|
||||
}
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
public Method load() {
|
||||
return load(null);
|
||||
}
|
||||
|
||||
public Method load(ClassLoader loader) {
|
||||
Class<?> cls = null;
|
||||
try {
|
||||
RayLog.core.debug(
|
||||
"load class " + className + " from class loader " + (loader == null ? this.getClass()
|
||||
.getClassLoader() : loader)
|
||||
+ " for method " + toString() + " with ID = " + toHexHashString()
|
||||
);
|
||||
cls = Class
|
||||
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
|
||||
} catch (Throwable e) {
|
||||
RayLog.core.error("Cannot load class " + className, e);
|
||||
return null;
|
||||
}
|
||||
|
||||
Method[] ms = cls.getDeclaredMethods();
|
||||
ArrayList<Method> methods = new ArrayList<>();
|
||||
Type t = Type.getMethodType(this.methodDesc);
|
||||
Type[] params = t.getArgumentTypes();
|
||||
String rt = t.getReturnType().getDescriptor();
|
||||
|
||||
for (Method m : ms) {
|
||||
if (m.getName().equals(methodName)) {
|
||||
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String mrt = Type.getDescriptor(m.getReturnType());
|
||||
if (!rt.equals(mrt)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isStatic != Modifier.isStatic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
methods.add(m);
|
||||
}
|
||||
}
|
||||
|
||||
if (methods.size() != 1) {
|
||||
RayLog.core.error(
|
||||
"Load method " + toString() + " failed as there are " + methods.size() + " definitions");
|
||||
return null;
|
||||
}
|
||||
|
||||
return methods.get(0);
|
||||
}
|
||||
|
||||
private byte[] getSha1Hash0() {
|
||||
byte[] digests = Sha1Digestor.digest(encoding);
|
||||
ByteBuffer bb = ByteBuffer.wrap(digests);
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
if (methodName.contains("createActorStage1")) {
|
||||
bb.putLong(Long.BYTES, 1);
|
||||
} else {
|
||||
bb.putLong(Long.BYTES, 0);
|
||||
}
|
||||
return digests;
|
||||
}
|
||||
|
||||
public byte[] getSha1Hash() {
|
||||
return digest;
|
||||
}
|
||||
|
||||
private String toHexHashString() {
|
||||
byte[] id = this.getSha1Hash();
|
||||
return StringUtil.toHexHashString(id);
|
||||
}
|
||||
|
||||
public String toEncodingString() {
|
||||
return encoding;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return encoding.hashCode();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MethodId other = (MethodId) obj;
|
||||
return className.equals(other.className)
|
||||
&& methodName.equals(other.methodName)
|
||||
&& methodDesc.equals(other.methodDesc)
|
||||
&& isStatic == other.isStatic;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return encoding;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15,6 +15,7 @@ public class Sha1Digestor {
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
|
||||
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
|
||||
|
||||
@@ -27,4 +28,14 @@ public class Sha1Digestor {
|
||||
dg.update(longBuffer.get().putLong(addIndex).array());
|
||||
return dg.digest();
|
||||
}
|
||||
}
|
||||
|
||||
public static byte[] digest(String str) {
|
||||
return digest(str.getBytes(StringUtil.UTF8));
|
||||
}
|
||||
|
||||
public static byte[] digest(byte[] src) {
|
||||
MessageDigest dg = md.get();
|
||||
dg.reset();
|
||||
return dg.digest(src);
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
package org.ray.util;
|
||||
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Vector;
|
||||
|
||||
public class StringUtil {
|
||||
|
||||
public static final Charset UTF8 = Charset.forName("UTF-8");
|
||||
|
||||
private static final char[] HEX_CHARS = "0123456789abcdef".toCharArray();
|
||||
|
||||
/**
|
||||
* split.
|
||||
* @param s input string
|
||||
@@ -117,6 +122,17 @@ public class StringUtil {
|
||||
return objs.length == 0 ? "" : sb.substring(0, sb.length() - concatenator.length());
|
||||
}
|
||||
|
||||
public static String toHexHashString(byte[] id) {
|
||||
StringBuilder sb = new StringBuilder(20);
|
||||
assert (id.length == 20);
|
||||
for (int i = 0; i < 20; i++) {
|
||||
int val = id[i] & 0xff;
|
||||
sb.append(HEX_CHARS[val >> 4]);
|
||||
sb.append(HEX_CHARS[val & 0xf]);
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
// Holds the start of an element and which brace started it.
|
||||
private static class Start {
|
||||
|
||||
@@ -132,4 +148,3 @@ public class StringUtil {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,11 +50,6 @@
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-hook</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
@@ -99,4 +94,4 @@
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
||||
</project>
|
||||
@@ -1,88 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<parent>
|
||||
<groupId>org.ray.parent</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.0</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-hook</artifactId>
|
||||
<name>java api hook for ray</name>
|
||||
<description>java api hook for ray</description>
|
||||
<url></url>
|
||||
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/org.ow2.asm/asm -->
|
||||
<dependency>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
<version>6.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-codec</groupId>
|
||||
<artifactId>commons-codec</artifactId>
|
||||
<version>1.4</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>2.5</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-common</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>2.5</version>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifestEntries>
|
||||
<Premain-Class>org.ray.hook.Agent</Premain-Class>
|
||||
<Can-Retransform-Classes>true</Can-Retransform-Classes>
|
||||
</manifestEntries>
|
||||
</archive>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>3.1.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<relocations>
|
||||
<relocation>
|
||||
<pattern>org.objectweb.asm</pattern>
|
||||
<shadedPattern>agent.org.objectweb.asm</shadedPattern>
|
||||
</relocation>
|
||||
</relocations>
|
||||
<createDependencyReducedPom>false</createDependencyReducedPom>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
@@ -1,44 +0,0 @@
|
||||
package org.ray.hook;
|
||||
|
||||
import java.util.Set;
|
||||
import org.objectweb.asm.ClassReader;
|
||||
import org.objectweb.asm.ClassWriter;
|
||||
|
||||
public class ClassAdapter {
|
||||
|
||||
public static Result hookClass(ClassLoader loader, String className, byte[] classfileBuffer) {
|
||||
// we have to comment out this quick filter as this is not accurate
|
||||
// e.g., org/ray/api/test/ActorTest$Adder.class is skipped!!!
|
||||
// even worse, this is non-deterministic...
|
||||
|
||||
ClassReader reader = new ClassReader(classfileBuffer);
|
||||
ClassWriter writer = new ClassWriter(reader, 0);
|
||||
ClassDetectVisitor pre = new ClassDetectVisitor(loader, writer, className);
|
||||
byte[] result;
|
||||
|
||||
reader.accept(pre, ClassReader.SKIP_FRAMES);
|
||||
if (pre.detectedMethods().isEmpty() && pre.actorCalls() == 0) {
|
||||
result = classfileBuffer;
|
||||
} else {
|
||||
if (pre.actorCalls() > 0) {
|
||||
reader = new ClassReader(writer.toByteArray());
|
||||
}
|
||||
|
||||
writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
|
||||
reader.accept(new ClassOverrideVisitor(writer, className, pre.detectedMethods()),
|
||||
ClassReader.SKIP_FRAMES);
|
||||
result = writer.toByteArray();
|
||||
}
|
||||
|
||||
Result rr = new Result();
|
||||
rr.changedMethods = pre.detectedMethods();
|
||||
rr.classBuffer = result;
|
||||
return rr;
|
||||
}
|
||||
|
||||
public static class Result {
|
||||
|
||||
public byte[] classBuffer;
|
||||
public Set<MethodId> changedMethods;
|
||||
}
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
package org.ray.hook;
|
||||
|
||||
import java.security.InvalidParameterException;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.objectweb.asm.AnnotationVisitor;
|
||||
import org.objectweb.asm.ClassVisitor;
|
||||
import org.objectweb.asm.Handle;
|
||||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
import org.objectweb.asm.Type;
|
||||
|
||||
/**
|
||||
* rewrite phase 1.
|
||||
*/
|
||||
public class ClassDetectVisitor extends ClassVisitor {
|
||||
|
||||
static int count = 0;
|
||||
final String className;
|
||||
final Set<MethodId> rayMethods = new HashSet<>();
|
||||
final ClassLoader loader;
|
||||
boolean isActor = false;
|
||||
int actorCalls = 0;
|
||||
|
||||
public ClassDetectVisitor(ClassLoader loader, ClassVisitor origin, String className) {
|
||||
super(Opcodes.ASM6, origin);
|
||||
this.className = className;
|
||||
this.loader = loader;
|
||||
}
|
||||
|
||||
public int actorCalls() {
|
||||
return actorCalls;
|
||||
}
|
||||
|
||||
public Set<MethodId> detectedMethods() {
|
||||
return rayMethods;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
|
||||
if (desc.contains("Lorg/ray/api/RayRemote;")) {
|
||||
isActor = true;
|
||||
}
|
||||
return super.visitAnnotation(desc, visible);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitInnerClass(String name, String outerName,
|
||||
String innerName, int access) {
|
||||
// System.err.println("visist inner class " + outerName + "$" + innerName);
|
||||
super.visitInnerClass(name, outerName, innerName, access);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodVisitor visitMethod(int access, String name, String mdesc, String signature,
|
||||
String[] exceptions) {
|
||||
//System.out.println("Visit " + className + "." + name);
|
||||
if (isActor && (access & Opcodes.ACC_PUBLIC) != 0) {
|
||||
visitRayMethod(access, name, mdesc);
|
||||
}
|
||||
|
||||
MethodVisitor origin = super.visitMethod(access, name, mdesc, signature, exceptions);
|
||||
return new MethodVisitor(this.api, origin) {
|
||||
@Override
|
||||
public AnnotationVisitor visitAnnotation(String adesc, boolean visible) {
|
||||
//handle rayRemote annotation
|
||||
if (adesc.contains("Lorg/ray/api/RayRemote;")) {
|
||||
visitRayMethod(access, name, mdesc);
|
||||
}
|
||||
return super.visitAnnotation(adesc, visible);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitInvokeDynamicInsn(String name, String desc, Handle bsm,
|
||||
Object... bsmArgs) {
|
||||
|
||||
// fix all actor calls from InvokeVirtual to InvokeStatic
|
||||
if (desc.contains("org/ray/api/funcs/RayFunc_")) {
|
||||
int count = bsmArgs.length;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
Object arg = bsmArgs[i];
|
||||
// System.err.println(arg.getClass().getName() + " " + arg.toString());
|
||||
if (arg.getClass().equals(Handle.class)) {
|
||||
Handle h = (Handle) arg;
|
||||
if (h.getTag() == Opcodes.H_INVOKEVIRTUAL) {
|
||||
String dsptr = h.getDesc();
|
||||
|
||||
Type[] argTypes = Type.getArgumentTypes(dsptr);
|
||||
for (Type argt : argTypes) {
|
||||
if (!isValidCallParameterOrReturnType(argt)) {
|
||||
throw new InvalidParameterException(
|
||||
"cannot use primitive parameter type '" + argt.getClassName()
|
||||
+ "' in method " + h.getOwner() + "." + h.getName());
|
||||
}
|
||||
}
|
||||
Type retType = Type.getReturnType(dsptr);
|
||||
if (!isValidCallParameterOrReturnType(retType)) {
|
||||
throw new InvalidParameterException(
|
||||
"cannot use primitive return type '" + retType.getClassName() + "' in method "
|
||||
+ h.getOwner() + "." + h.getName());
|
||||
}
|
||||
|
||||
dsptr = "(L" + h.getOwner() + ";" + dsptr.substring(1);
|
||||
Handle newh = new Handle(
|
||||
Opcodes.H_INVOKESTATIC,
|
||||
h.getOwner(),
|
||||
h.getName() + MethodId.getFunctionIdPostfix,
|
||||
dsptr,
|
||||
h.isInterface());
|
||||
bsmArgs[i] = newh;
|
||||
//System.err.println("Change ray.call from " + h + " -> " + newh + ", isInterface
|
||||
// = " + h.isInterface());
|
||||
++actorCalls;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
super.visitInvokeDynamicInsn(name, desc, bsm, bsmArgs);
|
||||
}
|
||||
|
||||
private boolean isValidCallParameterOrReturnType(Type t) {
|
||||
if (t.equals(Type.VOID_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.BOOLEAN_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.CHAR_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.BYTE_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.SHORT_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.INT_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.FLOAT_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.LONG_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
if (t.equals(Type.DOUBLE_TYPE)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void visitRayMethod(int access, String name, String mdesc) {
|
||||
if (name.equals("<init>")) {
|
||||
return;
|
||||
}
|
||||
|
||||
MethodId m = new MethodId(className, name, mdesc, (access & Opcodes.ACC_STATIC) != 0, loader);
|
||||
rayMethods.add(m);
|
||||
//System.err.println("Visit " + m.toString());
|
||||
count++;
|
||||
}
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package org.ray.hook;
|
||||
|
||||
import java.util.Set;
|
||||
import org.objectweb.asm.ClassVisitor;
|
||||
import org.objectweb.asm.Label;
|
||||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
|
||||
/**
|
||||
* rewrite phase 2.
|
||||
*/
|
||||
public class ClassOverrideVisitor extends ClassVisitor {
|
||||
|
||||
final String className;
|
||||
final Set<MethodId> rayRemoteMethods;
|
||||
MethodVisitor clinitVisitor;
|
||||
|
||||
public ClassOverrideVisitor(ClassVisitor origin, String className,
|
||||
Set<MethodId> rayRemoteMethods) {
|
||||
super(Opcodes.ASM6, origin);
|
||||
this.className = className;
|
||||
this.rayRemoteMethods = rayRemoteMethods;
|
||||
this.clinitVisitor = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodVisitor visitMethod(int access, String name, String desc, String signature,
|
||||
String[] exceptions) {
|
||||
if ("<clinit>".equals(name) && clinitVisitor == null) {
|
||||
MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
|
||||
clinitVisitor = new StaticBlockVisitor(mv);
|
||||
return clinitVisitor;// dispatch the ASM modifications (assign values to the preComputedxxx
|
||||
// static field) to the clinitVisitor
|
||||
}
|
||||
|
||||
ClassVisitor current = this;
|
||||
MethodId m = new MethodId(className, name, desc, (access & Opcodes.ACC_STATIC) != 0, null);
|
||||
if (rayRemoteMethods.contains(m)) {
|
||||
if (m.isStaticMethod()) {
|
||||
return new MethodVisitor(api,
|
||||
super.visitMethod(access, name, desc, signature, exceptions)) {
|
||||
@Override
|
||||
public void visitCode() {
|
||||
// step 1: add a field for the function id of this method
|
||||
System.out.println("add field: " + m.getStaticHashValueFieldName());
|
||||
String fieldName = m.getStaticHashValueFieldName();
|
||||
current.visitField(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, fieldName, "[B",
|
||||
null, null);
|
||||
|
||||
// step 2: rewrite current method so if MethodSwitcher returns true, returns the
|
||||
// added function id directly
|
||||
// else call the original method
|
||||
mv.visitFieldInsn(Opcodes.GETSTATIC, className, fieldName, "[B");
|
||||
mv.visitMethodInsn(Opcodes.INVOKESTATIC, "org/ray/hook/runtime/MethodSwitcher",
|
||||
"execute",
|
||||
"([B)Z", false);
|
||||
Label dorealwork = new Label();
|
||||
mv.visitJumpInsn(Opcodes.IFEQ, dorealwork);
|
||||
properReturn(sayReturnType(desc), mv);// proper return on
|
||||
// different types
|
||||
mv.visitLabel(dorealwork);
|
||||
mv.visitCode();// real work
|
||||
}
|
||||
};
|
||||
} else { // non-static
|
||||
return super.visitMethod(access, name, desc, signature, exceptions);
|
||||
}
|
||||
} else {
|
||||
return super.visitMethod(access, name, desc, signature, exceptions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitEnd() {
|
||||
if (clinitVisitor == null) { // works fine
|
||||
// Create an empty static block and let our method
|
||||
// visitor modify it the same way it modifies an
|
||||
// existing static block
|
||||
{
|
||||
MethodVisitor mv = super.visitMethod(Opcodes.ACC_STATIC, "<clinit>", "()V", null, null);
|
||||
mv = new StaticBlockVisitor(mv);
|
||||
mv.visitCode();
|
||||
mv.visitInsn(Opcodes.RETURN);
|
||||
mv.visitMaxs(0, 0);
|
||||
mv.visitEnd();
|
||||
}
|
||||
}
|
||||
|
||||
// for each non-static method, create a method for returning hash
|
||||
for (MethodId mid : this.rayRemoteMethods) {
|
||||
if (!mid.isStaticMethod()) {
|
||||
|
||||
// step 1: create a new method called method_name_function_id()
|
||||
System.out.println("add method: " + mid.getIdMethodName());
|
||||
MethodVisitor mv = super.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC,
|
||||
mid.getIdMethodName(), mid.getIdMethodDesc(), null, null);
|
||||
mv.visitCode();
|
||||
Label l0 = new Label();
|
||||
mv.visitLabel(l0);
|
||||
|
||||
// step 2: add a new static field as the function id of this method
|
||||
System.out.println("add field: " + mid.getStaticHashValueFieldName());
|
||||
String fieldName = mid.getStaticHashValueFieldName();
|
||||
this.visitField(Opcodes.ACC_PUBLIC | Opcodes.ACC_STATIC, fieldName, "[B",
|
||||
null, null);
|
||||
mv.visitFieldInsn(Opcodes.GETSTATIC, className, fieldName, "[B");
|
||||
|
||||
// step 3: call method switcher, and returns the function id when the mode is rewritten
|
||||
mv.visitMethodInsn(Opcodes.INVOKESTATIC, "org/ray/hook/runtime/MethodSwitcher", "execute",
|
||||
"([B)Z", false);
|
||||
mv.visitInsn(Opcodes.POP);
|
||||
Label l1 = new Label();
|
||||
mv.visitLabel(l1);
|
||||
properReturn(sayReturnType(mid.getMethodDesc()), mv);
|
||||
|
||||
Label l2 = new Label();
|
||||
mv.visitLabel(l2);
|
||||
|
||||
org.objectweb.asm.Type[] args = org.objectweb.asm.Type
|
||||
.getArgumentTypes(mid.getIdMethodDesc());
|
||||
int argCount = args.length;
|
||||
mv.visitMaxs(2, argCount);
|
||||
mv.visitEnd();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void properReturn(String returnType, MethodVisitor mv) {
|
||||
int returnCode;
|
||||
Object returnValue = null;
|
||||
switch (returnType) {
|
||||
case "V":
|
||||
mv.visitInsn(Opcodes.RETURN);
|
||||
return;
|
||||
case "I":
|
||||
case "B":
|
||||
case "S":
|
||||
case "Z": // int byte short boolean
|
||||
returnCode = Opcodes.IRETURN;
|
||||
returnValue = 0;
|
||||
break;
|
||||
case "J": // long
|
||||
returnCode = Opcodes.LRETURN;
|
||||
returnValue = 0L;
|
||||
break;
|
||||
case "D": // double
|
||||
returnCode = Opcodes.DRETURN;
|
||||
returnValue = 0D;
|
||||
break;
|
||||
case "F": // float
|
||||
returnCode = Opcodes.FRETURN;
|
||||
returnValue = 0F;
|
||||
break;
|
||||
default: // reference
|
||||
returnCode = Opcodes.ARETURN;
|
||||
break;
|
||||
}
|
||||
if (returnValue != null) {
|
||||
mv.visitLdcInsn(returnValue);
|
||||
} else {
|
||||
mv.visitInsn(Opcodes.ACONST_NULL);
|
||||
}
|
||||
mv.visitInsn(returnCode);
|
||||
}
|
||||
|
||||
private String sayReturnType(String str) {
|
||||
int left = str.lastIndexOf(")") + 1;
|
||||
return str.substring(left);
|
||||
}
|
||||
|
||||
// init the static added field in <clinit>
|
||||
// static {
|
||||
// assign value to _hashOf_XXX
|
||||
// }
|
||||
class StaticBlockVisitor extends MethodVisitor {
|
||||
|
||||
StaticBlockVisitor(MethodVisitor mv) {
|
||||
super(Opcodes.ASM6, mv);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitCode() {
|
||||
super.visitCode();
|
||||
|
||||
// assign value for added hash fields within <clinit>
|
||||
for (MethodId m : rayRemoteMethods) {
|
||||
byte[] hash = m.getSha1Hash();
|
||||
insertByteArray(hash);
|
||||
mv.visitFieldInsn(Opcodes.PUTSTATIC, className, m.getStaticHashValueFieldName(), "[B");
|
||||
|
||||
System.out.println("assign field: " + m.getStaticHashValueFieldName() + " = " + MethodId
|
||||
.toHexHashString(hash));
|
||||
}
|
||||
}
|
||||
|
||||
private void insertByteArray(byte[] bytes) {
|
||||
int length = bytes.length;
|
||||
assert (length < Short.MAX_VALUE);
|
||||
mv.visitIntInsn(Opcodes.SIPUSH, length);
|
||||
mv.visitIntInsn(Opcodes.NEWARRAY, Opcodes.T_BYTE);
|
||||
mv.visitInsn(Opcodes.DUP);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
mv.visitIntInsn(Opcodes.BIPUSH, i);
|
||||
mv.visitIntInsn(Opcodes.BIPUSH, bytes[i]);
|
||||
mv.visitInsn(Opcodes.BASTORE);
|
||||
if (i < (length - 1)) {
|
||||
mv.visitInsn(Opcodes.DUP);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
package org.ray.hook;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.PrintWriter;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Enumeration;
|
||||
import java.util.List;
|
||||
import java.util.Scanner;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.jar.JarEntry;
|
||||
import java.util.jar.JarFile;
|
||||
import java.util.jar.JarOutputStream;
|
||||
import java.util.zip.DataFormatException;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.filefilter.DirectoryFileFilter;
|
||||
import org.apache.commons.io.filefilter.RegexFileFilter;
|
||||
import org.ray.hook.runtime.JarLoader;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* rewrite jars to new jars with methods marked using Ray annotations.
|
||||
*/
|
||||
public class JarRewriter {
|
||||
|
||||
private static final String FUNCTIONS_FILE = "ray.functions.txt";
|
||||
|
||||
public static void main(String[] args)
|
||||
throws IOException, SecurityException, DataFormatException {
|
||||
if (args.length == 1) {
|
||||
LoadedFunctions funcs = load(args[0], null);
|
||||
for (MethodId mi : funcs.functions) {
|
||||
System.err.println(mi.getIdMethodDesc());
|
||||
Method m = mi.load();
|
||||
String logInfo = "load: " + m.getDeclaringClass().getName() + "." + m.getName();
|
||||
RayLog.core.info(logInfo);
|
||||
}
|
||||
return;
|
||||
} else if (args.length < 2) {
|
||||
System.err.println("org.ray.hook.JarRewriter source-jar-dir dest-jar-dir");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
rewrite(args[0], args[1]);
|
||||
}
|
||||
|
||||
public static LoadedFunctions load(String dir, String baseDir)
|
||||
throws FileNotFoundException, SecurityException {
|
||||
List<String> functions = JarRewriter.getRewrittenFunctions(dir);
|
||||
LoadedFunctions efuncs = new LoadedFunctions();
|
||||
efuncs.loader = JarLoader.loadJars(dir, false);
|
||||
|
||||
for (String func : functions) {
|
||||
MethodId mid = new MethodId(func, efuncs.loader);
|
||||
efuncs.functions.add(mid);
|
||||
}
|
||||
|
||||
if (baseDir != null && !baseDir.equals("")) {
|
||||
List<String> baseFunctions = JarRewriter.getRewrittenFunctions(baseDir);
|
||||
for (String func : baseFunctions) {
|
||||
MethodId mid = new MethodId(func, efuncs.loader);
|
||||
efuncs.functions.add(mid);
|
||||
}
|
||||
}
|
||||
|
||||
return efuncs;
|
||||
}
|
||||
|
||||
public static void rewrite(String fromDir, String toDir) throws IOException, DataFormatException {
|
||||
File fromDirFile = new File(fromDir);
|
||||
File toDirFileTmp = new File(toDir + ".tmp");
|
||||
final File toDirFile = new File(toDir);
|
||||
|
||||
File[] topFiles = fromDirFile.listFiles();
|
||||
if (topFiles.length != 1 || !topFiles[0].isDirectory()) {
|
||||
throw new DataFormatException("There should be a top dir in the Ray app zip file.");
|
||||
}
|
||||
String topDir = topFiles[0].getName();
|
||||
|
||||
if (toDirFileTmp.exists()) {
|
||||
FileUtils.deleteDirectory(toDirFileTmp);
|
||||
}
|
||||
//toDirFileTmp.mkdir();
|
||||
FileUtils.copyDirectory(fromDirFile, toDirFileTmp);
|
||||
|
||||
PrintWriter functionCollector = new PrintWriter(toDir + ".tmp/" + FUNCTIONS_FILE, "UTF-8");
|
||||
|
||||
// get all jars
|
||||
Collection<File> files = FileUtils.listFiles(
|
||||
fromDirFile,
|
||||
new RegexFileFilter(".*\\.jar"),
|
||||
DirectoryFileFilter.DIRECTORY
|
||||
);
|
||||
|
||||
// load and rewrite
|
||||
int prefixLength = fromDirFile.getAbsolutePath().length() + topDir.length() + 2;
|
||||
for (File appJar : files) {
|
||||
String fromPath = appJar.getAbsolutePath();
|
||||
if (fromPath.substring(prefixLength).contains("/")) {
|
||||
functionCollector.close();
|
||||
throw new DataFormatException("There should not be any subdir"
|
||||
+ " containing jar file in the top dir of the Ray app zip file.");
|
||||
}
|
||||
JarFile jar = new JarFile(appJar.getAbsolutePath());
|
||||
String to = fromPath
|
||||
.replaceFirst(fromDirFile.getAbsolutePath(), toDirFileTmp.getAbsolutePath());
|
||||
rewrite(jar, to, (l, m) -> functionCollector.println(m.toEncodingString()));
|
||||
jar.close();
|
||||
}
|
||||
|
||||
// rename the whole dir
|
||||
functionCollector.close();
|
||||
|
||||
if (toDirFile.exists()) {
|
||||
FileUtils.deleteDirectory(toDirFile);
|
||||
}
|
||||
|
||||
FileUtils.moveDirectory(toDirFileTmp, toDirFile);
|
||||
}
|
||||
|
||||
public static void rewrite(JarFile from, String to, BiConsumer<ClassLoader, MethodId> consumer)
|
||||
throws IOException {
|
||||
|
||||
FileOutputStream ofStream = new FileOutputStream(to);
|
||||
JarOutputStream ojStream = new JarOutputStream(ofStream);
|
||||
Enumeration<JarEntry> e = from.entries();
|
||||
String className;
|
||||
|
||||
while (e.hasMoreElements()) {
|
||||
JarEntry je = e.nextElement();
|
||||
byte[] jeBytes = IOUtils.toByteArray(from.getInputStream(je));
|
||||
|
||||
//System.err.println("XXXXXX " + from.getName() + " :: " + je.getName());
|
||||
if (!je.isDirectory() && je.getName().endsWith(".class")) {
|
||||
className = je.getName().substring(0, je.getName().length() - ".class".length());
|
||||
|
||||
//System.err.println("XXXXXX " + from.getName() + " :: " + je.getName() + " - " +
|
||||
// className);
|
||||
ClassAdapter.Result result = ClassAdapter.hookClass(null, className, jeBytes);
|
||||
if (result.classBuffer != jeBytes) {
|
||||
String logInfo = "Rewrite class " + className + " from " + jeBytes.length + " bytes to "
|
||||
+ result.classBuffer.length + " bytes ";
|
||||
RayLog.core.info(logInfo);
|
||||
}
|
||||
|
||||
if (result.changedMethods != null) {
|
||||
for (MethodId m : result.changedMethods) {
|
||||
consumer.accept(null, m);
|
||||
}
|
||||
}
|
||||
|
||||
je = new JarEntry(je.getName());
|
||||
je.setTime(System.currentTimeMillis());
|
||||
je.setSize(result.classBuffer.length);
|
||||
jeBytes = result.classBuffer;
|
||||
}
|
||||
|
||||
ojStream.putNextEntry(je);
|
||||
ojStream.write(jeBytes);
|
||||
//ojStream.closeEntry();
|
||||
}
|
||||
|
||||
ojStream.close();
|
||||
ofStream.close();
|
||||
}
|
||||
|
||||
public static List<String> getRewrittenFunctions(String rewrittenDir)
|
||||
throws FileNotFoundException {
|
||||
ArrayList<String> functions = new ArrayList<>();
|
||||
Scanner s = new Scanner(new File(rewrittenDir + "/" + FUNCTIONS_FILE));
|
||||
while (s.hasNext()) {
|
||||
String f = s.next();
|
||||
if (!f.startsWith("(")) {
|
||||
functions.add(f);
|
||||
}
|
||||
}
|
||||
s.close();
|
||||
|
||||
return functions;
|
||||
}
|
||||
|
||||
public static LoadedFunctions loadBase(String baseDir)
|
||||
throws FileNotFoundException, SecurityException {
|
||||
List<String> functions = JarRewriter.getRewrittenFunctions(baseDir);
|
||||
LoadedFunctions efuncs = new LoadedFunctions();
|
||||
efuncs.loader = null;
|
||||
|
||||
for (String func : functions) {
|
||||
MethodId mid = new MethodId(func, efuncs.loader);
|
||||
efuncs.functions.add(mid);
|
||||
}
|
||||
|
||||
return efuncs;
|
||||
}
|
||||
}
|
||||
@@ -1,226 +0,0 @@
|
||||
package org.ray.hook;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import org.apache.commons.codec.digest.DigestUtils;
|
||||
import org.objectweb.asm.Type;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* Represent a Method in a Class.
|
||||
*/
|
||||
public class MethodId {
|
||||
|
||||
static final String getFunctionIdPostfix = "_function_id";
|
||||
String className;
|
||||
String methodName;
|
||||
String methodDesc;
|
||||
boolean isStatic;
|
||||
ClassLoader loader;
|
||||
|
||||
public MethodId(String cls, String method, String mdesc, boolean isstatic, ClassLoader loader) {
|
||||
className = cls;
|
||||
methodName = method;
|
||||
methodDesc = mdesc;
|
||||
isStatic = isstatic;
|
||||
this.loader = loader;
|
||||
}
|
||||
|
||||
public MethodId(String encodedString, ClassLoader loader) {
|
||||
// className + "." + methodName + "::" + methodDesc + "&&" + isStatic;
|
||||
int lastPos3 = encodedString.lastIndexOf("&&");
|
||||
int lastPos2 = encodedString.lastIndexOf("::");
|
||||
int lastPos1 = encodedString.lastIndexOf(".");
|
||||
if (lastPos1 == -1 || lastPos2 == -1 || lastPos3 == -1) {
|
||||
throw new RuntimeException("invalid given method id " + encodedString
|
||||
+ " - it must be className.methodName::methodDesc&&isStatic");
|
||||
}
|
||||
|
||||
className = encodedString.substring(0, lastPos1);
|
||||
methodName = encodedString.substring(lastPos1 + ".".length(), lastPos2);
|
||||
methodDesc = encodedString.substring(lastPos2 + "::".length(), lastPos3);
|
||||
isStatic = Boolean.parseBoolean(encodedString.substring(lastPos3 + "&&".length()));
|
||||
this.loader = loader;
|
||||
}
|
||||
|
||||
public static String toHexHashString(byte[] id) {
|
||||
String s = "";
|
||||
String hex = "0123456789abcdef";
|
||||
assert (id.length == 20);
|
||||
for (int i = 0; i < 20; i++) {
|
||||
int val = id[i] & 0xff;
|
||||
s += hex.charAt(val >> 4);
|
||||
s += hex.charAt(val & 0xf);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
private String toHexHashString() {
|
||||
byte[] id = this.getSha1Hash();
|
||||
return toHexHashString(id);
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public String getMethodName() {
|
||||
return methodName;
|
||||
}
|
||||
|
||||
public String getMethodDesc() {
|
||||
return methodDesc;
|
||||
}
|
||||
|
||||
public ClassLoader getLoader() {
|
||||
return loader;
|
||||
}
|
||||
|
||||
public Boolean isStaticMethod() {
|
||||
return isStatic;
|
||||
}
|
||||
|
||||
public String getIdMethodName() {
|
||||
return this.methodName + getFunctionIdPostfix;
|
||||
}
|
||||
|
||||
public String getIdMethodDesc() {
|
||||
return "(L" + this.className + ";" + this.methodDesc.substring(1);
|
||||
}
|
||||
|
||||
public Method load() {
|
||||
String loadClsName = className.replace('/', '.');
|
||||
Class<?> cls;
|
||||
try {
|
||||
RayLog.core.debug(
|
||||
"load class " + loadClsName + " from class loader " + (loader == null ? this.getClass()
|
||||
.getClassLoader() : loader)
|
||||
+ " for method " + toString() + " with ID = " + toHexHashString()
|
||||
);
|
||||
cls = Class
|
||||
.forName(loadClsName, true, loader == null ? this.getClass().getClassLoader() : loader);
|
||||
} catch (Throwable e) {
|
||||
RayLog.core.error("Cannot load class " + loadClsName, e);
|
||||
return null;
|
||||
}
|
||||
|
||||
Method[] ms = cls.getDeclaredMethods();
|
||||
ArrayList<Method> methods = new ArrayList<>();
|
||||
Type t = Type.getMethodType(this.methodDesc);
|
||||
Type[] params = t.getArgumentTypes();
|
||||
String rt = t.getReturnType().getDescriptor();
|
||||
|
||||
for (Method m : ms) {
|
||||
if (m.getName().equals(methodName)) {
|
||||
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String mrt = Type.getDescriptor(m.getReturnType());
|
||||
if (!rt.equals(mrt)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
methods.add(m);
|
||||
}
|
||||
}
|
||||
|
||||
if (methods.size() != 1) {
|
||||
RayLog.core.error(
|
||||
"Load method " + toString() + " failed as there are " + methods.size() + " definitions");
|
||||
return null;
|
||||
}
|
||||
|
||||
Method m = methods.get(0);
|
||||
try {
|
||||
Field fld = cls.getField(getStaticHashValueFieldName());
|
||||
Object hashValue = fld.get(null);
|
||||
if (hashValue instanceof byte[] && Arrays.equals((byte[]) hashValue, this.getSha1Hash())) {
|
||||
RayLog.core.debug("Method " + toString() + " hash: " + toHexHashString((byte[]) hashValue));
|
||||
} else {
|
||||
if (hashValue instanceof byte[]) {
|
||||
RayLog.core.error(
|
||||
"Method " + toString() + " hash-field: " + toHexHashString((byte[]) hashValue)
|
||||
+ " vs id-hash: " + toHexHashString());
|
||||
} else {
|
||||
RayLog.core.error(
|
||||
"Method " + toString() + " hash-field: " + (hashValue != null ? hashValue.toString()
|
||||
: "<nil>") + " vs id-hash: " + toHexHashString());
|
||||
}
|
||||
}
|
||||
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException
|
||||
| IllegalAccessException e) {
|
||||
RayLog.core.error("load method hash field failed for " + toString(), e);
|
||||
}
|
||||
return m;
|
||||
}
|
||||
|
||||
public String toEncodingString() {
|
||||
return className + "." + methodName + "::" + methodDesc + "&&" + isStatic;
|
||||
}
|
||||
|
||||
public byte[] getSha1Hash() {
|
||||
byte[] digests = DigestUtils.sha(toEncodingString());
|
||||
ByteBuffer bb = ByteBuffer.wrap(digests);
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
if (methodName.contains("createActorStage1")) {
|
||||
bb.putLong(Long.BYTES, 1);
|
||||
} else {
|
||||
bb.putLong(Long.BYTES, 0);
|
||||
}
|
||||
return digests;
|
||||
}
|
||||
|
||||
public String getStaticHashValueFieldName() {
|
||||
// _hashOf<init>_([Ljava/lang/String;)V
|
||||
String r = "_hashOf" + methodName + "_" + methodDesc;
|
||||
r = r.replace("<", "_")
|
||||
.replace(">", "_")
|
||||
.replace("(", "_")
|
||||
.replace("[", "_")
|
||||
.replace("/", "_")
|
||||
.replace(";", "_")
|
||||
.replace(")", "_")
|
||||
;
|
||||
// System.err.println(r);
|
||||
return r;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + className.hashCode();
|
||||
result = prime * result + methodName.hashCode();
|
||||
result = prime * result + methodDesc.hashCode();
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MethodId other = (MethodId) obj;
|
||||
return className.equals(other.className)
|
||||
&& methodName.equals(other.methodName)
|
||||
&& methodDesc.equals(other.methodDesc)
|
||||
&& isStatic == other.isStatic
|
||||
;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return toEncodingString();
|
||||
}
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
package org.ray.hook.runtime;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.URL;
|
||||
import java.net.URLClassLoader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Enumeration;
|
||||
import java.util.List;
|
||||
import java.util.jar.JarEntry;
|
||||
import java.util.jar.JarFile;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.filefilter.DirectoryFileFilter;
|
||||
import org.apache.commons.io.filefilter.RegexFileFilter;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* load and unload jars from a dir.
|
||||
*/
|
||||
public class JarLoader {
|
||||
|
||||
private static Method AddUrl = initAddUrl();
|
||||
|
||||
private static Method initAddUrl() {
|
||||
try {
|
||||
Method m = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
|
||||
m.setAccessible(true);
|
||||
return m;
|
||||
} catch (NoSuchMethodException | SecurityException e) {
|
||||
// TODO Auto-generated catch block
|
||||
e.printStackTrace();
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static URLClassLoader loadJars(String dir, boolean explicitLoadForHook) {
|
||||
// get all jars
|
||||
Collection<File> jars = FileUtils.listFiles(
|
||||
new File(dir),
|
||||
new RegexFileFilter(".*\\.jar"),
|
||||
DirectoryFileFilter.DIRECTORY
|
||||
);
|
||||
return loadJar(jars, explicitLoadForHook);
|
||||
}
|
||||
|
||||
public static ClassLoader loadJars(String[] appJars, boolean explicitLoadForHook) {
|
||||
List<File> jars = new ArrayList<>();
|
||||
|
||||
for (String jar : appJars) {
|
||||
if (jar.endsWith(".jar")) {
|
||||
jars.add(new File(jar));
|
||||
} else {
|
||||
loadJarDir(jar, jars);
|
||||
}
|
||||
}
|
||||
|
||||
return loadJar(jars, explicitLoadForHook);
|
||||
}
|
||||
|
||||
private static URLClassLoader loadJar(Collection<File> appJars, boolean explicitLoadForHook) {
|
||||
List<JarFile> jars = new ArrayList<>();
|
||||
List<URL> urls = new ArrayList<>();
|
||||
|
||||
for (File appJar : appJars) {
|
||||
try {
|
||||
RayLog.core.info("load jar " + appJar.getAbsolutePath() + " in ray hook");
|
||||
JarFile jar = new JarFile(appJar.getAbsolutePath());
|
||||
jars.add(jar);
|
||||
urls.add(appJar.toURI().toURL());
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
System.err.println(
|
||||
"invalid app jar path: " + appJar.getAbsolutePath() + ", load failed with exception "
|
||||
+ e.getMessage());
|
||||
System.exit(1);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
URLClassLoader cl = URLClassLoader.newInstance(urls.toArray(new URL[0]));
|
||||
|
||||
if (!explicitLoadForHook) {
|
||||
return cl;
|
||||
}
|
||||
|
||||
for (JarFile jar : jars) {
|
||||
Enumeration<JarEntry> e = jar.entries();
|
||||
while (e.hasMoreElements()) {
|
||||
JarEntry je = e.nextElement();
|
||||
//System.err.println("check " + je.getName());
|
||||
if (je.isDirectory() || !je.getName().endsWith(".class")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// -6 because of .class
|
||||
String className = je.getName().substring(0, je.getName().length() - 6);
|
||||
className = className.replace('/', '.');
|
||||
try {
|
||||
Class.forName(className, true, cl);
|
||||
//System.err.println("load class " + className + " OK");
|
||||
} catch (ClassNotFoundException e1) {
|
||||
e1.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
jar.close();
|
||||
} catch (IOException e1) {
|
||||
// TODO Auto-generated catch block
|
||||
e1.printStackTrace();
|
||||
}
|
||||
}
|
||||
return cl;
|
||||
}
|
||||
|
||||
private static void loadJarDir(String jarDir, List<File> jars) {
|
||||
Collection<File> files = FileUtils.listFiles(
|
||||
new File(jarDir),
|
||||
new RegexFileFilter(".*\\.jar"),
|
||||
DirectoryFileFilter.DIRECTORY
|
||||
);
|
||||
|
||||
jars.addAll(files);
|
||||
}
|
||||
|
||||
public static void unloadJars(ClassLoader loader) {
|
||||
// TODO:
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package org.ray.hook.runtime;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.ray.hook.MethodId;
|
||||
|
||||
public class LoadedFunctions {
|
||||
|
||||
public final Set<MethodId> functions = Collections.synchronizedSet(new HashSet<>());
|
||||
public ClassLoader loader = null;
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package org.ray.hook.runtime;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class MethodHash {
|
||||
|
||||
private final byte[] hash;
|
||||
|
||||
public MethodHash(byte[] hash) {
|
||||
this.hash = hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Arrays.hashCode(getHash());
|
||||
}
|
||||
|
||||
public byte[] getHash() {
|
||||
return hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MethodHash other = (MethodHash) obj;
|
||||
return Arrays.equals(this.getHash(), other.getHash());
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package org.ray.hook.runtime;
|
||||
|
||||
/**
|
||||
* method mode switch at runtime.
|
||||
*/
|
||||
public class MethodSwitcher {
|
||||
|
||||
public static final ThreadLocal<Boolean> IsRemoteCall = new ThreadLocal<>();
|
||||
|
||||
public static final ThreadLocal<byte[]> MethodId = new ThreadLocal<>();
|
||||
|
||||
public static boolean execute(byte[] id) {
|
||||
Boolean hooking = IsRemoteCall.get();
|
||||
if (Boolean.TRUE.equals(hooking)) {
|
||||
MethodId.set(id);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
+96
-3
@@ -5,13 +5,13 @@
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
|
||||
<groupId>org.ray.parent</groupId>
|
||||
<artifactId>ray-superpom</artifactId>
|
||||
<version>1.0</version>
|
||||
<modules>
|
||||
<module>api</module>
|
||||
<module>common</module>
|
||||
<module>hook</module>
|
||||
<module>runtime-common</module>
|
||||
<module>runtime-native</module>
|
||||
<module>runtime-dev</module>
|
||||
@@ -24,6 +24,7 @@
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
@@ -31,7 +32,98 @@
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
<version>0.10.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>de.ruedigermoeller</groupId>
|
||||
<artifactId>fst</artifactId>
|
||||
<version>2.47</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
<version>1.2.17</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>quartz</groupId>
|
||||
<artifactId>quartz</artifactId>
|
||||
<version>1.5.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ini4j</groupId>
|
||||
<artifactId>ini4j</artifactId>
|
||||
<version>0.5.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
<version>6.0</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
|
||||
<dependency>
|
||||
<groupId>com.github.davidmoten</groupId>
|
||||
<artifactId>flatbuffers-java</artifactId>
|
||||
<version>1.7.0.1</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
|
||||
<dependency>
|
||||
<groupId>redis.clients</groupId>
|
||||
<artifactId>jedis</artifactId>
|
||||
<version>2.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>2.5</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-cli</groupId>
|
||||
<artifactId>commons-cli</artifactId>
|
||||
<version>1.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 -->
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>3.4</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-codec</groupId>
|
||||
<artifactId>commons-codec</artifactId>
|
||||
<version>1.4</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>net.lingala.zip4j</groupId>
|
||||
<artifactId>zip4j</artifactId>
|
||||
<version>1.3.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>19.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-collections</groupId>
|
||||
<artifactId>commons-collections</artifactId>
|
||||
<version>3.2.2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
</dependencyManagement>
|
||||
|
||||
<build>
|
||||
@@ -145,7 +237,8 @@
|
||||
<violationSeverity>warning</violationSeverity>
|
||||
<format>xml</format>
|
||||
<format>html</format>
|
||||
<outputFile>${project.build.directory}/test/checkstyle-errors.xml</outputFile>
|
||||
<outputFile>${project.build.directory}/test/checkstyle-errors.xml
|
||||
</outputFile>
|
||||
<linkXRef>false</linkXRef>
|
||||
</configuration>
|
||||
</plugin>
|
||||
@@ -153,4 +246,4 @@
|
||||
</pluginManagement>
|
||||
</build>
|
||||
|
||||
</project>
|
||||
</project>
|
||||
@@ -1,5 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
@@ -24,35 +23,35 @@
|
||||
<artifactId>ray-api</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-hook</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>de.ruedigermoeller</groupId>
|
||||
<artifactId>fst</artifactId>
|
||||
<version>2.47</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/com.github.davidmoten/flatbuffers-java -->
|
||||
<dependency>
|
||||
<groupId>com.github.davidmoten</groupId>
|
||||
<artifactId>flatbuffers-java</artifactId>
|
||||
<version>1.7.0.1</version>
|
||||
</dependency>
|
||||
|
||||
<!-- https://mvnrepository.com/artifact/redis.clients/jedis -->
|
||||
<dependency>
|
||||
<groupId>redis.clients</groupId>
|
||||
<artifactId>jedis</artifactId>
|
||||
<version>2.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-plasma</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
</project>
|
||||
</project>
|
||||
@@ -4,6 +4,7 @@ import java.io.Serializable;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -83,7 +84,9 @@ public class ArgumentsBuilder {
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
public static Pair<Object, Object[]> unwrap(TaskSpec task, Method m, ClassLoader classLoader)
|
||||
throws TaskExecutionException {
|
||||
FunctionArg[] fargs = task.args;
|
||||
// the last arg is className
|
||||
|
||||
FunctionArg[] fargs = Arrays.copyOf(task.args, task.args.length - 1);
|
||||
Object current = null;
|
||||
Object[] realArgs;
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ public class InvocationExecutor {
|
||||
}
|
||||
|
||||
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
|
||||
String taskdesc)
|
||||
String taskdesc)
|
||||
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
|
||||
Method m = pr.getRight().invokable;
|
||||
Map<?, UniqueID> userRayReturnIdMap = null;
|
||||
@@ -86,11 +86,12 @@ public class InvocationExecutor {
|
||||
}
|
||||
|
||||
// execute
|
||||
Object result;
|
||||
if (!UniqueIdHelper.isLambdaFunction(task.functionId)) {
|
||||
Object result = null;
|
||||
try {
|
||||
result = m.invoke(realArgs.getLeft(), realArgs.getRight());
|
||||
} else {
|
||||
result = m.invoke(realArgs.getLeft(), new Object[] {realArgs.getRight()});
|
||||
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
|
||||
RayLog.core.error("invoke failed:" + m);
|
||||
throw e;
|
||||
}
|
||||
|
||||
if (task.returnIds == null || task.returnIds.length == 0) {
|
||||
@@ -147,4 +148,4 @@ public class InvocationExecutor {
|
||||
private static void safePut(UniqueID objectId, Object obj) {
|
||||
RayRuntime.getInstance().putRaw(objectId, obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,14 @@
|
||||
package org.ray.core;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.MethodId;
|
||||
import org.ray.hook.runtime.JarLoader;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.spi.RemoteFunctionManager;
|
||||
import org.ray.spi.model.FunctionArg;
|
||||
import org.ray.spi.model.RayActorMethods;
|
||||
import org.ray.spi.model.RayMethod;
|
||||
import org.ray.spi.model.RayTaskMethods;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
@@ -18,6 +17,7 @@ import org.ray.util.logger.RayLog;
|
||||
public class LocalFunctionManager {
|
||||
|
||||
private final RemoteFunctionManager remoteLoader;
|
||||
|
||||
private final ConcurrentHashMap<UniqueID, FunctionTable> functionTables
|
||||
= new ConcurrentHashMap<>();
|
||||
|
||||
@@ -29,74 +29,44 @@ public class LocalFunctionManager {
|
||||
this.remoteLoader = remoteLoader;
|
||||
}
|
||||
|
||||
private FunctionTable loadDriverFunctions(UniqueID driverId) {
|
||||
FunctionTable functionTable = functionTables.get(driverId);
|
||||
if (functionTable == null) {
|
||||
RayLog.core.info("DriverId " + driverId + " Try to load functions");
|
||||
ClassLoader classLoader = remoteLoader.loadResource(driverId);
|
||||
if (classLoader == null) {
|
||||
throw new RuntimeException(
|
||||
"Cannot find resource' classLoader for app " + driverId.toString());
|
||||
}
|
||||
functionTable = new FunctionTable(classLoader);
|
||||
functionTables.put(driverId, functionTable);
|
||||
}
|
||||
return functionTable;
|
||||
}
|
||||
|
||||
Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
|
||||
UniqueID methodId, String className) {
|
||||
// assert the driver's resource is load.
|
||||
FunctionTable functionTable = loadDriverFunctions(driverId);
|
||||
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
|
||||
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
|
||||
: functionTable.getActorMethod(methodId, className);
|
||||
Preconditions
|
||||
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
|
||||
methodId, driverId);
|
||||
return Pair.of(functionTable.classLoader, method);
|
||||
}
|
||||
|
||||
/**
|
||||
* get local method for executing, which pulls information from remote repo on-demand, therefore
|
||||
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
|
||||
*/
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID methodId,
|
||||
FunctionArg[] args) throws NoSuchMethodException,
|
||||
SecurityException, ClassNotFoundException {
|
||||
FunctionTable funcs = loadDriverFunctions(driverId);
|
||||
RayMethod m;
|
||||
|
||||
// hooked methods
|
||||
if (!UniqueIdHelper.isLambdaFunction(methodId)) {
|
||||
m = funcs.functions.get(methodId);
|
||||
if (null == m) {
|
||||
throw new RuntimeException(
|
||||
"DriverId " + driverId + " load remote function methodId:" + methodId + " failed");
|
||||
}
|
||||
} else { // remote lambda
|
||||
assert args.length >= 2;
|
||||
String fname = Serializer.decode(args[args.length - 2].data);
|
||||
Method fm = Class.forName(fname).getMethod("execute", Object[].class);
|
||||
m = new RayMethod(fm);
|
||||
}
|
||||
|
||||
return Pair.of(funcs.linkedFunctions.loader, m);
|
||||
}
|
||||
|
||||
private synchronized FunctionTable loadDriverFunctions(UniqueID driverId) {
|
||||
FunctionTable funcs = functionTables.get(driverId);
|
||||
if (null == funcs) {
|
||||
RayLog.core.debug("DriverId " + driverId + " Try to load functions");
|
||||
LoadedFunctions funcs2 = remoteLoader.loadFunctions(driverId);
|
||||
if (funcs2 == null) {
|
||||
throw new RuntimeException("Cannot find resource for app " + driverId.toString());
|
||||
}
|
||||
funcs = new FunctionTable();
|
||||
funcs.linkedFunctions = funcs2;
|
||||
for (MethodId mid : funcs.linkedFunctions.functions) {
|
||||
Method m = mid.load();
|
||||
assert (m != null);
|
||||
RayMethod v = new RayMethod(m);
|
||||
v.check();
|
||||
UniqueID k = new UniqueID(mid.getSha1Hash());
|
||||
String logInfo =
|
||||
"DriverId" + driverId + " load remote function " + m.getName() + ", hash = " + k
|
||||
.toString();
|
||||
RayLog.core.debug(logInfo);
|
||||
System.err.println(logInfo);
|
||||
funcs.functions.put(k, v);
|
||||
}
|
||||
|
||||
functionTables.put(driverId, funcs);
|
||||
} else { // reSync automatically
|
||||
// more functions are loaded
|
||||
if (funcs.linkedFunctions.functions.size() > funcs.functions.size()) {
|
||||
for (MethodId mid : funcs.linkedFunctions.functions) {
|
||||
UniqueID k = new UniqueID(mid.getSha1Hash());
|
||||
if (!funcs.functions.containsKey(k)) {
|
||||
Method m = mid.load();
|
||||
assert (m != null);
|
||||
RayMethod v = new RayMethod(m);
|
||||
v.check();
|
||||
funcs.functions.put(k, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return funcs;
|
||||
public Pair<ClassLoader, RayMethod> getMethod(UniqueID driverId, UniqueID actorId,
|
||||
UniqueID methodId,
|
||||
FunctionArg[] args) throws NoSuchMethodException, SecurityException, ClassNotFoundException {
|
||||
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
|
||||
String className = (String) Serializer.decode(args[args.length - 1].data);
|
||||
return getMethod(driverId, actorId, methodId, className);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -106,13 +76,47 @@ public class LocalFunctionManager {
|
||||
FunctionTable funcs = functionTables.get(driverId);
|
||||
if (funcs != null) {
|
||||
functionTables.remove(driverId);
|
||||
JarLoader.unloadJars(funcs.linkedFunctions.loader);
|
||||
remoteLoader.unloadFunctions(driverId);
|
||||
}
|
||||
}
|
||||
|
||||
static class FunctionTable {
|
||||
private static class FunctionTable {
|
||||
|
||||
public final ConcurrentHashMap<UniqueID, RayMethod> functions = new ConcurrentHashMap<>();
|
||||
public LoadedFunctions linkedFunctions;
|
||||
final ClassLoader classLoader;
|
||||
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
|
||||
final ConcurrentHashMap<String, RayActorMethods> actors = new ConcurrentHashMap<>();
|
||||
|
||||
FunctionTable(ClassLoader classLoader) {
|
||||
this.classLoader = classLoader;
|
||||
}
|
||||
|
||||
RayMethod getTaskMethod(UniqueID methodId, String className) {
|
||||
RayTaskMethods tasks = taskMethods.get(className);
|
||||
if (tasks == null) {
|
||||
tasks = RayTaskMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayTaskMethods:" + tasks);
|
||||
taskMethods.put(className, tasks);
|
||||
}
|
||||
RayMethod m = tasks.functions.get(methodId);
|
||||
if (m != null) {
|
||||
return m;
|
||||
}
|
||||
// it is a actor static func.
|
||||
return getActorMethod(methodId, className, true);
|
||||
}
|
||||
|
||||
RayMethod getActorMethod(UniqueID methodId, String className) {
|
||||
return getActorMethod(methodId, className, false);
|
||||
}
|
||||
|
||||
private RayMethod getActorMethod(UniqueID methodId, String className, boolean isStatic) {
|
||||
RayActorMethods actor = actors.get(className);
|
||||
if (actor == null) {
|
||||
actor = RayActorMethods.fromClass(className, classLoader);
|
||||
RayLog.core.info("create RayActorMethods:" + actor);
|
||||
actors.put(className, actor);
|
||||
}
|
||||
return isStatic ? actor.staticFunctions.get(methodId) : actor.functions.get(methodId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.RayObjects;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.internal.Callable;
|
||||
import org.ray.api.internal.RayFunc;
|
||||
import org.ray.core.model.RayParameters;
|
||||
import org.ray.spi.LocalSchedulerLink;
|
||||
import org.ray.spi.LocalSchedulerProxy;
|
||||
@@ -261,49 +261,23 @@ public abstract class RayRuntime implements RayApi {
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObjects call(UniqueID taskId, Callable funcRun, int returnCount, Object... args) {
|
||||
return worker.rpc(taskId, funcRun, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObjects call(UniqueID taskId, Class<?> funcCls, Serializable lambda, int returnCount,
|
||||
Object... args) {
|
||||
return worker.rpc(taskId, UniqueID.nil, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Callable funcRun,
|
||||
Collection<RIDT> returnIds,
|
||||
Object... args) {
|
||||
return worker.rpcWithReturnLabels(taskId, funcRun, returnIds, args);
|
||||
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
|
||||
Object... args) {
|
||||
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Collection<RIDT>
|
||||
returnids,
|
||||
Object... args) {
|
||||
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
|
||||
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Callable funcRun,
|
||||
Integer returnCount, Object... args) {
|
||||
return worker.rpcWithReturnIndices(taskId, funcRun, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Integer returnCount, Object...
|
||||
args) {
|
||||
RayFunc lambda, Integer returnCount, Object... args) {
|
||||
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRemoteLambda() {
|
||||
return params.run_mode.isRemoteLambda();
|
||||
}
|
||||
|
||||
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
|
||||
throws TaskExecutionException {
|
||||
boolean wasBlocked = false;
|
||||
@@ -465,4 +439,4 @@ public abstract class RayRuntime implements RayApi {
|
||||
public RemoteFunctionManager getRemoteFunctionManager() {
|
||||
return remoteFunctionManager;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,12 +249,6 @@ public class UniqueIdHelper {
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public static boolean isLambdaFunction(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
return wbb.getLong() == 0xffffffffffffffffL;
|
||||
}
|
||||
|
||||
public static void markCreateActorStage1Function(UniqueID functionId) {
|
||||
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
@@ -279,4 +273,4 @@ public class UniqueIdHelper {
|
||||
TASK,
|
||||
ACTOR,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package org.ray.core;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Serializable;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.SerializationUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayList;
|
||||
@@ -12,12 +13,13 @@ import org.ray.api.RayMap;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayObjects;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.api.internal.Callable;
|
||||
import org.ray.hook.runtime.MethodSwitcher;
|
||||
import org.ray.api.internal.RayFunc;
|
||||
import org.ray.spi.LocalSchedulerProxy;
|
||||
import org.ray.spi.model.RayInvocation;
|
||||
import org.ray.spi.model.RayMethod;
|
||||
import org.ray.spi.model.TaskSpec;
|
||||
import org.ray.util.LambdaUtils;
|
||||
import org.ray.util.MethodId;
|
||||
import org.ray.util.exception.TaskExecutionException;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
@@ -47,13 +49,13 @@ public class Worker {
|
||||
RayLog.core.info("Task " + task.taskId + " start execute");
|
||||
Throwable ex = null;
|
||||
|
||||
|
||||
if (!task.actorId.isNil() || (task.createActorId != null && !task.createActorId.isNil())) {
|
||||
task.returnIds = ArrayUtils.subarray(task.returnIds, 0, task.returnIds.length - 1);
|
||||
}
|
||||
|
||||
try {
|
||||
Pair<ClassLoader, RayMethod> pr = funcs.getMethod(task.driverId, task.functionId, task.args);
|
||||
Pair<ClassLoader, RayMethod> pr = funcs
|
||||
.getMethod(task.driverId, task.actorId, task.functionId, task.args);
|
||||
WorkerContext.prepare(task, pr.getLeft());
|
||||
InvocationExecutor.execute(task, pr);
|
||||
} catch (NoSuchMethodException | SecurityException | ClassNotFoundException e) {
|
||||
@@ -72,161 +74,75 @@ public class Worker {
|
||||
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, Callable funcRun, int returnCount, Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
return submit(taskId, fid, returnCount, false, args);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, UniqueID functionId, Class<?> funcCls, Serializable lambda,
|
||||
int returnCount, Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return submit(taskId, fid, returnCount, false, ls);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, RayActor<?> actor, Callable funcRun, int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
return actorTaskSubmit(taskId, fid, returnCount, false, args, actor);
|
||||
}
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, UniqueID functionId, RayActor<?> actor, Class<?> funcCls,
|
||||
Serializable lambda, int returnCount, Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return actorTaskSubmit(taskId, fid, returnCount, false, ls, actor);
|
||||
}
|
||||
|
||||
private byte[] fidFromHook(Callable funcRun) {
|
||||
MethodSwitcher.IsRemoteCall.set(true);
|
||||
try {
|
||||
funcRun.run();
|
||||
} catch (Throwable e) {
|
||||
RayLog.core.error(
|
||||
"make sure you are using code rewritten using the rewrite tool, see JarRewriter for"
|
||||
+ " options", e);
|
||||
throw new RuntimeException("make sure you are using code rewritten using the rewrite tool,"
|
||||
+ "see JarRewriter for options");
|
||||
}
|
||||
byte[] fid = MethodSwitcher.MethodId.get();//get the identity of function from hook
|
||||
MethodSwitcher.IsRemoteCall.set(false);
|
||||
return fid;
|
||||
}
|
||||
|
||||
private RayObjects submit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
|
||||
return actorTaskSubmit(taskId, fid, returnCount, multiReturn, args, (RayActor<?>) args[0]);
|
||||
} else {
|
||||
return taskSubmit(taskId, fid, returnCount, multiReturn, args);
|
||||
}
|
||||
private RayObjects taskSubmit(UniqueID taskId,
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, ri, returnCount, multiReturn);
|
||||
}
|
||||
|
||||
private RayObjects actorTaskSubmit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args,
|
||||
RayActor<?> actor) {
|
||||
RayInvocation ri = new RayInvocation(fid, args, actor);
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args,
|
||||
RayActor<?> actor) {
|
||||
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
|
||||
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
|
||||
actor.setTaskCursor(returnObjs.pop().getId());
|
||||
return returnObjs;
|
||||
}
|
||||
|
||||
private RayObjects taskSubmit(UniqueID taskId,
|
||||
byte[] fid,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
RayInvocation ri = new RayInvocation(fid, args);
|
||||
return scheduler.submit(taskId, ri, returnCount, multiReturn);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, Callable funcRun,
|
||||
int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
RayInvocation ri = new RayInvocation(fid, new Object[] {});
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId, UniqueID functionId,
|
||||
Class<?> funcCls, Serializable lambda, int returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = functionId.getBytes();
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
RayInvocation ri = new RayInvocation(fid, ls);
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Callable funcRun,
|
||||
Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
private RayObjects submit(UniqueID taskId,
|
||||
MethodId methodId,
|
||||
int returnCount,
|
||||
boolean multiReturn,
|
||||
Object[] args) {
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
return scheduler.submit(taskId, new RayInvocation(fid, args), returnids);
|
||||
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
|
||||
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
|
||||
(RayActor<?>) args[0]);
|
||||
} else {
|
||||
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
|
||||
int returnCount, Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
return submit(taskId, mid, returnCount, false, args);
|
||||
}
|
||||
|
||||
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
|
||||
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
|
||||
Preconditions.checkNotNull(taskId);
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
|
||||
}
|
||||
|
||||
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda,
|
||||
Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
final byte[] fid = UniqueID.nil.getBytes();
|
||||
RayFunc lambda, Collection<RIDT> returnids,
|
||||
Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
if (taskId == null) {
|
||||
taskId = UniqueIdHelper.nextTaskId(-1);
|
||||
}
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
return scheduler.submit(taskId, new RayInvocation(fid, ls), returnids);
|
||||
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
|
||||
return scheduler.submit(taskId, ri, returnids);
|
||||
}
|
||||
|
||||
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Callable funcRun,
|
||||
Integer returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = fidFromHook(funcRun);
|
||||
RayObjects objs = submit(taskId, fid, returnCount, true, args);
|
||||
RayList<R> rets = new RayList<>();
|
||||
for (RayObject obj : objs.getObjs()) {
|
||||
rets.add(obj);
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
|
||||
Serializable lambda, Integer returnCount,
|
||||
Object[] args) {
|
||||
byte[] fid = UniqueID.nil.getBytes();
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 2);
|
||||
ls[args.length] = funcCls.getName();
|
||||
ls[args.length + 1] = SerializationUtils.serialize(lambda);
|
||||
|
||||
RayObjects objs = submit(taskId, fid, returnCount, true, ls);
|
||||
|
||||
RayFunc lambda, Integer returnCount,
|
||||
Object[] args) {
|
||||
MethodId mid = methodIdOf(lambda);
|
||||
RayObjects objs = submit(taskId, mid, returnCount, true, args);
|
||||
RayList<R> rets = new RayList<>();
|
||||
for (RayObject obj : objs.getObjs()) {
|
||||
rets.add(obj);
|
||||
@@ -234,6 +150,27 @@ public class Worker {
|
||||
return rets;
|
||||
}
|
||||
|
||||
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
|
||||
RayActor<?> actor) {
|
||||
UniqueID driverId = WorkerContext.currentTask().driverId;
|
||||
|
||||
Object[] ls = Arrays.copyOf(args, args.length + 1);
|
||||
ls[args.length] = methodId.className;
|
||||
|
||||
RayMethod method = functions
|
||||
.getMethod(driverId, actor.getId(), new UniqueID(methodId.getSha1Hash()),
|
||||
methodId.className).getRight();
|
||||
|
||||
RayInvocation ri = new RayInvocation(methodId.className, method.getFuncId(),
|
||||
ls, method.remoteAnnotation, actor);
|
||||
return ri;
|
||||
}
|
||||
|
||||
private MethodId methodIdOf(RayFunc serialLambda) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
|
||||
return mid;
|
||||
}
|
||||
|
||||
public UniqueID getCurrentTaskId() {
|
||||
return WorkerContext.currentTask().taskId;
|
||||
}
|
||||
@@ -246,4 +183,4 @@ public class Worker {
|
||||
public UniqueID[] getCurrentTaskReturnIDs() {
|
||||
return WorkerContext.currentTask().returnIds;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,40 +1,23 @@
|
||||
package org.ray.core.model;
|
||||
|
||||
public enum RunMode {
|
||||
SINGLE_PROCESS(true, false, true, false), // remote lambda, dev path, dev runtime
|
||||
SINGLE_BOX(true, false, true, true), // remote lambda, dev path, native runtime
|
||||
CLUSTER(false, true, false, true); // static rewrite, deploy path, naive runtime
|
||||
SINGLE_PROCESS(true, false), // dev path, dev runtime
|
||||
SINGLE_BOX(true, true), // dev path, native runtime
|
||||
CLUSTER(false, true); // deploy path, naive runtime
|
||||
|
||||
private final boolean remoteLambda;
|
||||
private final boolean staticRewrite;
|
||||
private final boolean devPathManager;
|
||||
private final boolean nativeRuntime;
|
||||
|
||||
RunMode(boolean remoteLambda, boolean staticRewrite, boolean devPathManager,
|
||||
boolean nativeRuntime) {
|
||||
this.remoteLambda = remoteLambda;
|
||||
this.staticRewrite = staticRewrite;
|
||||
RunMode(boolean devPathManager,
|
||||
boolean nativeRuntime) {
|
||||
this.devPathManager = devPathManager;
|
||||
this.nativeRuntime = nativeRuntime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>remoteLambda</tt>.
|
||||
*
|
||||
* @return property value of remoteLambda
|
||||
* the jar has add to java -cp, no need to load jar after started.
|
||||
*/
|
||||
public boolean isRemoteLambda() {
|
||||
return remoteLambda;
|
||||
}
|
||||
private final boolean devPathManager;
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>staticRewrite</tt>.
|
||||
*
|
||||
* @return property value of staticRewrite
|
||||
*/
|
||||
public boolean isStaticRewrite() {
|
||||
return staticRewrite;
|
||||
}
|
||||
private final boolean nativeRuntime;
|
||||
|
||||
/**
|
||||
* Getter method for property <tt>devPathManager</tt>.
|
||||
@@ -53,4 +36,4 @@ public enum RunMode {
|
||||
public boolean isNativeRuntime() {
|
||||
return nativeRuntime;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ public class LocalSchedulerProxy {
|
||||
|
||||
task.args = ArgumentsBuilder.wrap(invocation);
|
||||
task.driverId = current.driverId;
|
||||
task.functionId = new UniqueID(invocation.getId());
|
||||
task.functionId = invocation.getId();
|
||||
task.parentCounter = -1; // TODO: this field is not used in core or python logically yet
|
||||
task.parentTaskId = current.taskId;
|
||||
task.actorHandleId = invocation.getActor().getActorHandleId();
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
package org.ray.spi;
|
||||
|
||||
import java.util.Set;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.MethodId;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* mock version of remote function manager using local loaded jars + runtime hook.
|
||||
* mock version of remote function manager using local loaded jars.
|
||||
*/
|
||||
public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
private final LoadedFunctions loadedFunctions = new LoadedFunctions();
|
||||
|
||||
public NopRemoteFunctionManager(UniqueID driverId) {
|
||||
//onLoad(driverId, Agent.hookedMethods);
|
||||
//Agent.consumers.add(m -> { this.onLoad(m); });
|
||||
@@ -51,14 +45,9 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
}
|
||||
|
||||
@Override
|
||||
public LoadedFunctions loadFunctions(UniqueID driverId) {
|
||||
public ClassLoader loadResource(UniqueID driverId) {
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
if (loadedFunctions == null) {
|
||||
RayLog.rapp.error("cannot find functions for " + driverId);
|
||||
return null;
|
||||
} else {
|
||||
return loadedFunctions;
|
||||
}
|
||||
return this.getClass().getClassLoader();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -66,15 +55,4 @@ public class NopRemoteFunctionManager implements RemoteFunctionManager {
|
||||
// never
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
}
|
||||
|
||||
private void onLoad(UniqueID driverId, Set<MethodId> methods) {
|
||||
//assert (startupDriverId().equals(driverId));
|
||||
for (MethodId mid : methods) {
|
||||
onLoad(mid);
|
||||
}
|
||||
}
|
||||
|
||||
private void onLoad(MethodId mid) {
|
||||
loadedFunctions.functions.add(mid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package org.ray.spi;
|
||||
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
|
||||
/**
|
||||
* register and load functions from function table.
|
||||
@@ -53,14 +52,13 @@ public interface RemoteFunctionManager {
|
||||
void unregisterApp(UniqueID driverId);
|
||||
|
||||
/**
|
||||
* load resource and functions for this driver this function is used by the workers on demand when
|
||||
* a required function is not found in {@code LocalFunctionManager}.
|
||||
* load resource.
|
||||
*/
|
||||
LoadedFunctions loadFunctions(UniqueID driverId);
|
||||
ClassLoader loadResource(UniqueID driverId);
|
||||
|
||||
/**
|
||||
* unload functions for this driver
|
||||
* this function is used by the workers on demand when a driver is dead.
|
||||
*/
|
||||
void unloadFunctions(UniqueID driverId);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
|
||||
public final class RayActorMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final RayRemote remoteAnnotation;
|
||||
public final Map<UniqueID, RayMethod> functions;
|
||||
/**
|
||||
* the static function in Actor, call as task.
|
||||
*/
|
||||
public final Map<UniqueID, RayMethod> staticFunctions;
|
||||
|
||||
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
|
||||
Map<UniqueID, RayMethod> functions, Map<UniqueID, RayMethod> staticFunctions) {
|
||||
this.clazz = clazz;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
|
||||
}
|
||||
|
||||
public static RayActorMethods fromClass(String clazzName, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(clazzName, true, classLoader);
|
||||
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
|
||||
Preconditions
|
||||
.checkNotNull(remoteAnnotation, "%s must declare @RayRemote", clazzName);
|
||||
Method[] methods = clazz.getDeclaredMethods();
|
||||
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
|
||||
Map<UniqueID, RayMethod> staticFunctions = new HashMap<>(methods.length * 2);
|
||||
|
||||
for (Method m : methods) {
|
||||
if (!Modifier.isPublic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
RayMethod rayMethod = RayMethod.from(m, remoteAnnotation);
|
||||
if (Modifier.isStatic(m.getModifiers())) {
|
||||
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
|
||||
} else {
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
}
|
||||
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayActorMethods from " + clazzName, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
|
||||
functions.values(),
|
||||
staticFunctions.size(), staticFunctions.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
/**
|
||||
@@ -9,30 +10,30 @@ import org.ray.api.UniqueID;
|
||||
public class RayInvocation {
|
||||
|
||||
private static final RayActor<?> nil = new RayActor<>(UniqueID.nil, UniqueID.nil);
|
||||
public final String className;
|
||||
/**
|
||||
* unique id for a method.
|
||||
*
|
||||
* @see UniqueID
|
||||
*/
|
||||
private final byte[] id;
|
||||
private final RayActor<?> actor;
|
||||
private final UniqueID id;
|
||||
private final RayRemote remoteAnnotation;
|
||||
/**
|
||||
* function arguments.
|
||||
*/
|
||||
private Object[] args;
|
||||
|
||||
public RayInvocation(byte[] id, Object[] args) {
|
||||
this(id, args, nil);
|
||||
}
|
||||
|
||||
public RayInvocation(byte[] id, Object[] args, RayActor<?> actor) {
|
||||
super();
|
||||
private RayActor<?> actor;
|
||||
|
||||
public RayInvocation(String className, UniqueID id, Object[] args, RayRemote remoteAnnotation,
|
||||
RayActor<?> actor) {
|
||||
this.className = className;
|
||||
this.id = id;
|
||||
this.args = args;
|
||||
this.actor = actor;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
}
|
||||
|
||||
public byte[] getId() {
|
||||
public UniqueID getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -48,4 +49,8 @@ public class RayInvocation {
|
||||
return actor;
|
||||
}
|
||||
|
||||
}
|
||||
public RayRemote getRemoteAnnotation() {
|
||||
return remoteAnnotation;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.util.MethodId;
|
||||
|
||||
/**
|
||||
* method info.
|
||||
@@ -9,12 +12,8 @@ public class RayMethod {
|
||||
|
||||
public final Method invokable;
|
||||
public final String fullName;
|
||||
// TODO: other annotated information
|
||||
|
||||
public RayMethod(Method m) {
|
||||
invokable = m;
|
||||
fullName = m.getDeclaringClass().getName() + "." + m.getName();
|
||||
}
|
||||
public final RayRemote remoteAnnotation;
|
||||
private final UniqueID funcId;
|
||||
|
||||
public void check() {
|
||||
for (Class<?> paramCls : invokable.getParameterTypes()) {
|
||||
@@ -24,4 +23,32 @@ public class RayMethod {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private RayMethod(Method m, RayRemote remoteAnnotation, UniqueID funcId) {
|
||||
this.invokable = m;
|
||||
this.remoteAnnotation = remoteAnnotation;
|
||||
this.funcId = funcId;
|
||||
fullName = m.getDeclaringClass().getName() + "." + m.getName();
|
||||
}
|
||||
|
||||
public static RayMethod from(Method m, RayRemote parentRemoteAnnotation) {
|
||||
Class<?> clazz = m.getDeclaringClass();
|
||||
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
|
||||
MethodId mid = MethodId.fromMethod(m);
|
||||
UniqueID funcId = new UniqueID(mid.getSha1Hash());
|
||||
RayMethod method = new RayMethod(m,
|
||||
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
|
||||
funcId);
|
||||
method.check();
|
||||
return method;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return fullName;
|
||||
}
|
||||
|
||||
public UniqueID getFuncId() {
|
||||
return funcId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ray.spi.model;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayRemote;
|
||||
import org.ray.api.UniqueID;
|
||||
|
||||
|
||||
public final class RayTaskMethods {
|
||||
|
||||
public final Class clazz;
|
||||
public final Map<UniqueID, RayMethod> functions;
|
||||
|
||||
public RayTaskMethods(Class clazz,
|
||||
Map<UniqueID, RayMethod> functions) {
|
||||
this.clazz = clazz;
|
||||
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
|
||||
}
|
||||
|
||||
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
|
||||
try {
|
||||
Class clazz = Class.forName(clazzName, true, classLoader);
|
||||
Method[] methods = clazz.getDeclaredMethods();
|
||||
Map<UniqueID, RayMethod> functions = new HashMap<>(methods.length * 2);
|
||||
|
||||
for (Method m : methods) {
|
||||
if (!Modifier.isStatic(m.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
//task method only for static.
|
||||
RayRemote remoteAnnotation = m.getAnnotation(RayRemote.class);
|
||||
if (remoteAnnotation == null) {
|
||||
continue;
|
||||
}
|
||||
m.setAccessible(true);
|
||||
RayMethod rayMethod = RayMethod.from(m, null);
|
||||
functions.put(rayMethod.getFuncId(), rayMethod);
|
||||
}
|
||||
return new RayTaskMethods(clazz, functions);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String
|
||||
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -74,9 +74,8 @@ public class RayNativeRuntime extends RayRuntime {
|
||||
}
|
||||
|
||||
// initialize remote function manager
|
||||
RemoteFunctionManager funcMgr = params.run_mode.isStaticRewrite()
|
||||
? new NativeRemoteFunctionManager(kvStore) :
|
||||
new NopRemoteFunctionManager(params.driver_id);
|
||||
RemoteFunctionManager funcMgr = params.run_mode.isDevPathManager()
|
||||
? new NopRemoteFunctionManager(params.driver_id) : new NativeRemoteFunctionManager(kvStore);
|
||||
|
||||
// initialize worker context
|
||||
if (params.worker_mode == WorkerMode.DRIVER) {
|
||||
@@ -101,6 +100,7 @@ public class RayNativeRuntime extends RayRuntime {
|
||||
int releaseDelay = RayRuntime.configReader
|
||||
.getIntegerValue("ray", "plasma_default_release_delay", 0,
|
||||
"how many release requests should be delayed in plasma client");
|
||||
|
||||
ObjectStoreLink plink = new PlasmaClient(params.object_store_name, params
|
||||
.object_store_manager_name, releaseDelay);
|
||||
|
||||
@@ -163,7 +163,7 @@ public class RayNativeRuntime extends RayRuntime {
|
||||
}
|
||||
|
||||
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
|
||||
String managerName, String schedulerName) {
|
||||
String managerName, String schedulerName) {
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
|
||||
if (!isWorker) {
|
||||
@@ -193,26 +193,16 @@ public class RayNativeRuntime extends RayRuntime {
|
||||
UniqueID actorId = UniqueIdHelper.taskComputeReturnId(createTaskId, 0, false);
|
||||
RayActor<T> actor = new RayActor<>(actorId);
|
||||
UniqueID cursorId;
|
||||
if (params.run_mode.isRemoteLambda()) {
|
||||
RayFunc_2_1<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
|
||||
cursorId = worker.rpcCreateActor(
|
||||
createTaskId,
|
||||
actorId,
|
||||
UniqueID.nil,
|
||||
RayFunc_2_1.class,
|
||||
createActorLambda,
|
||||
1,
|
||||
new Object[] {actorId.getBytes(), cls.getName()}
|
||||
).getObjs()[0].getId();
|
||||
} else {
|
||||
cursorId = worker.rpcCreateActor(
|
||||
createTaskId,
|
||||
actorId,
|
||||
() -> RayNativeRuntime.createActorInActor(null, null),
|
||||
1,
|
||||
new Object[] {actorId.getBytes(), cls.getName()}
|
||||
).getObjs()[0].getId();
|
||||
}
|
||||
|
||||
RayFunc_2_1<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
|
||||
cursorId = worker.rpcCreateActor(
|
||||
createTaskId,
|
||||
actorId,
|
||||
RayFunc_2_1.class,
|
||||
createActorLambda,
|
||||
1,
|
||||
new Object[]{actorId.getBytes(), cls.getName()}
|
||||
).getObjs()[0].getId();
|
||||
actor.setTaskCursor(cursorId);
|
||||
return actor;
|
||||
}
|
||||
@@ -247,4 +237,4 @@ public class RayNativeRuntime extends RayRuntime {
|
||||
throw new TaskExecutionException(log, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package org.ray.spi.impl;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.URL;
|
||||
import java.net.URLClassLoader;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Enumeration;
|
||||
import java.util.List;
|
||||
import java.util.jar.JarEntry;
|
||||
import java.util.jar.JarFile;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.io.filefilter.DirectoryFileFilter;
|
||||
import org.apache.commons.io.filefilter.RegexFileFilter;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
/**
|
||||
* load and unload jars from a dir.
|
||||
*/
|
||||
public class JarLoader {
|
||||
|
||||
public static URLClassLoader loadJars(String dir, boolean explicitLoad) {
|
||||
// get all jars
|
||||
Collection<File> jars = FileUtils.listFiles(
|
||||
new File(dir),
|
||||
new RegexFileFilter(".*\\.jar"),
|
||||
DirectoryFileFilter.DIRECTORY
|
||||
);
|
||||
return loadJar(jars, explicitLoad);
|
||||
}
|
||||
|
||||
public static void unloadJars(ClassLoader loader) {
|
||||
// now do nothing, if no ref to the loader and loader's class.
|
||||
// they would be gc.
|
||||
}
|
||||
|
||||
private static URLClassLoader loadJar(Collection<File> appJars, boolean explicitLoad) {
|
||||
List<JarFile> jars = new ArrayList<>();
|
||||
List<URL> urls = new ArrayList<>();
|
||||
|
||||
for (File appJar : appJars) {
|
||||
try {
|
||||
RayLog.core.info("load jar " + appJar.getAbsolutePath());
|
||||
JarFile jar = new JarFile(appJar.getAbsolutePath());
|
||||
jars.add(jar);
|
||||
urls.add(appJar.toURI().toURL());
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(
|
||||
"invalid app jar path: " + appJar.getAbsolutePath() + ", load failed with exception",
|
||||
e);
|
||||
}
|
||||
}
|
||||
|
||||
URLClassLoader cl = URLClassLoader.newInstance(urls.toArray(new URL[urls.size()]));
|
||||
|
||||
if (!explicitLoad) {
|
||||
return cl;
|
||||
}
|
||||
for (JarFile jar : jars) {
|
||||
try {
|
||||
Enumeration<JarEntry> e = jar.entries();
|
||||
while (e.hasMoreElements()) {
|
||||
JarEntry je = e.nextElement();
|
||||
if (je.isDirectory() || !je.getName().endsWith(".class")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
String className = classNameOfJarEntry(je);
|
||||
className = className.replace('/', '.');
|
||||
try {
|
||||
Class.forName(className, true, cl);
|
||||
} catch (ClassNotFoundException e1) {
|
||||
e1.printStackTrace();
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
IOUtils.closeQuietly(jar);
|
||||
}
|
||||
}
|
||||
return cl;
|
||||
}
|
||||
|
||||
private static String classNameOfJarEntry(JarEntry je) {
|
||||
return je.getName().substring(0, je.getName().length() - ".class".length());
|
||||
}
|
||||
|
||||
}
|
||||
+28
-32
@@ -6,13 +6,10 @@ import java.security.NoSuchAlgorithmException;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import net.lingala.zip4j.core.ZipFile;
|
||||
import org.ray.api.UniqueID;
|
||||
import org.ray.core.RayRuntime;
|
||||
import org.ray.hook.JarRewriter;
|
||||
import org.ray.hook.runtime.JarLoader;
|
||||
import org.ray.hook.runtime.LoadedFunctions;
|
||||
import org.ray.spi.KeyValueStoreLink;
|
||||
import org.ray.spi.RemoteFunctionManager;
|
||||
import org.ray.util.FileUtil;
|
||||
import org.ray.util.Sha1Digestor;
|
||||
import org.ray.util.SystemUtil;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
@@ -21,10 +18,11 @@ import org.ray.util.logger.RayLog;
|
||||
*/
|
||||
public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
private ConcurrentHashMap<UniqueID, LoadedFunctions> loadedApps = new ConcurrentHashMap<>();
|
||||
private final ConcurrentHashMap<UniqueID, ClassLoader> loadedApps = new ConcurrentHashMap<>();
|
||||
private MessageDigest md;
|
||||
private String appDir = System.getProperty("user.dir") + "/apps";
|
||||
private KeyValueStoreLink kvStore;
|
||||
private final String appDir = System.getProperty("user.dir") + "/apps";
|
||||
private final KeyValueStoreLink kvStore;
|
||||
|
||||
|
||||
public NativeRemoteFunctionManager(KeyValueStoreLink kvStore) throws NoSuchAlgorithmException {
|
||||
this.kvStore = kvStore;
|
||||
@@ -38,24 +36,20 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
@Override
|
||||
public UniqueID registerResource(byte[] resourceZip) {
|
||||
byte[] digest = md.digest(resourceZip);
|
||||
byte[] digest = Sha1Digestor.digest(resourceZip);
|
||||
assert (digest.length == UniqueID.LENGTH);
|
||||
|
||||
UniqueID resourceId = new UniqueID(digest);
|
||||
|
||||
// TODO: resources must be saved in persistent store
|
||||
// instead of cache
|
||||
//if (!Ray.exist(resourceId)) {
|
||||
//Ray.put(resourceId, resourceZip);
|
||||
kvStore.set(resourceId.getBytes(), resourceZip, null);
|
||||
//}
|
||||
|
||||
return resourceId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getResource(UniqueID resourceId) {
|
||||
return kvStore.get(resourceId.getBytes(), null);
|
||||
//return (byte[])Ray.get(resourceId);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -65,7 +59,6 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
@Override
|
||||
public void registerApp(UniqueID driverId, UniqueID resourceId) {
|
||||
//Ray.put(driverId, resourceId);
|
||||
kvStore.set("App2ResMap", resourceId.toString(), driverId.toString());
|
||||
}
|
||||
|
||||
@@ -80,27 +73,31 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
}
|
||||
|
||||
@Override
|
||||
public LoadedFunctions loadFunctions(UniqueID driverId) {
|
||||
LoadedFunctions rf = loadedApps.get(driverId);
|
||||
if (rf == null) {
|
||||
rf = initLoadedApps(driverId);
|
||||
public ClassLoader loadResource(UniqueID driverId) {
|
||||
ClassLoader classLoader = loadedApps.get(driverId);
|
||||
if (classLoader == null) {
|
||||
synchronized (this) {
|
||||
classLoader = loadedApps.get(driverId);
|
||||
if (classLoader == null) {
|
||||
classLoader = initLoadedApps(driverId);
|
||||
}
|
||||
}
|
||||
}
|
||||
return rf;
|
||||
return classLoader;
|
||||
}
|
||||
|
||||
private synchronized LoadedFunctions initLoadedApps(UniqueID driverId) {
|
||||
private ClassLoader initLoadedApps(UniqueID driverId) {
|
||||
try {
|
||||
RayLog.core.info("initLoadedApps" + driverId.toString());
|
||||
LoadedFunctions rf = loadedApps.get(driverId);
|
||||
if (rf == null) {
|
||||
UniqueID resId = new UniqueID(kvStore.get("App2ResMap", driverId.toString()));
|
||||
//UniqueID resId = Ray.get(driverId);
|
||||
|
||||
ClassLoader cl = loadedApps.get(driverId);
|
||||
if (cl == null) {
|
||||
UniqueID resId = new UniqueID(kvStore.get("App2ResMap", driverId.toString()));
|
||||
byte[] res = getResource(resId);
|
||||
if (res == null) {
|
||||
throw new RuntimeException("get resource null, the resId " + resId.toString());
|
||||
}
|
||||
RayLog.core.info("ger resource of " + resId.toString() + ", result len " + res.length);
|
||||
RayLog.core.info("get resource of " + resId.toString() + ", result len " + res.length);
|
||||
String resPath =
|
||||
appDir + "/" + driverId.toString() + "/" + String.valueOf(SystemUtil.pid());
|
||||
File dir = new File(resPath);
|
||||
@@ -112,11 +109,10 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
FileUtil.bytesToFile(res, zipPath);
|
||||
ZipFile zipFile = new ZipFile(zipPath);
|
||||
zipFile.extractAll(resPath);
|
||||
rf = JarRewriter
|
||||
.load(resPath, RayRuntime.getInstance().getPaths().java_runtime_rewritten_jars_dir);
|
||||
loadedApps.put(driverId, rf);
|
||||
cl = JarLoader.loadJars(resPath, false);
|
||||
loadedApps.put(driverId, cl);
|
||||
}
|
||||
return rf;
|
||||
return cl;
|
||||
} catch (Exception e) {
|
||||
RayLog.rapp.error("load function for " + driverId + " failed, ex = " + e.getMessage(), e);
|
||||
return null;
|
||||
@@ -125,11 +121,11 @@ public class NativeRemoteFunctionManager implements RemoteFunctionManager {
|
||||
|
||||
@Override
|
||||
public synchronized void unloadFunctions(UniqueID driverId) {
|
||||
LoadedFunctions rf = loadedApps.get(driverId);
|
||||
ClassLoader cl = loadedApps.get(driverId);
|
||||
try {
|
||||
JarLoader.unloadJars(rf.loader);
|
||||
JarLoader.unloadJars(cl);
|
||||
} catch (Exception e) {
|
||||
RayLog.rapp.error("unload function for " + driverId + " failed, ex = " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -50,12 +50,6 @@
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.ray</groupId>
|
||||
<artifactId>ray-hook</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.funcs.RayFunc_0_1;
|
||||
import org.ray.api.funcs.RayFunc_1_1;
|
||||
import org.ray.api.funcs.RayFunc_3_1;
|
||||
import org.ray.util.MethodId;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
public class LambdaUtilsTest {
|
||||
|
||||
static final String CLASS_NAME = LambdaUtilsTest.class.getName();
|
||||
static final Method CALL0;
|
||||
static final Method CALL1;
|
||||
static final Method CALL2;
|
||||
static final Method CALL3;
|
||||
|
||||
static {
|
||||
try {
|
||||
CALL0 = LambdaUtilsTest.class.getDeclaredMethod("call0", new Class[0]);
|
||||
CALL1 = LambdaUtilsTest.class.getDeclaredMethod("call1", new Class[]{Long.class});
|
||||
CALL2 = LambdaUtilsTest.class.getDeclaredMethod("call2", new Class[0]);
|
||||
CALL3 = LambdaUtilsTest.class
|
||||
.getDeclaredMethod("call3", new Class[]{Long.class, String.class});
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc_3_1<T0, T1, T2, R0> f, int n,
|
||||
boolean forceNew, boolean debug)
|
||||
throws Exception {
|
||||
if (debug) {
|
||||
RayLog.core.info("parse#" + f.getClass().getName());
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
for (int i = 0; i < n; i++) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f, forceNew);
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info(String.format("remoteLambdaParse(new=%s):total=%sms, one=%s", forceNew,
|
||||
TimeUnit.NANOSECONDS.toMillis(end - start),
|
||||
(end - start) / n));
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc_3_1<T0, T1, T2, R0> f, int n,
|
||||
boolean de, boolean debug)
|
||||
throws Exception {
|
||||
if (debug) {
|
||||
RayLog.core.info("se#" + f.getClass().getName());
|
||||
}
|
||||
long start = System.nanoTime();
|
||||
for (int i = 0; i < n; i++) {
|
||||
ByteArrayOutputStream bytes = new ByteArrayOutputStream(1024);
|
||||
ObjectOutputStream out = new ObjectOutputStream(bytes);
|
||||
out.writeObject(f);
|
||||
out.close();
|
||||
if (de) {
|
||||
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()));
|
||||
RayFunc_3_1 def = (RayFunc_3_1) in.readObject();
|
||||
in.close();
|
||||
if (debug) {
|
||||
RayLog.core.info("de#" + def.getClass().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info(
|
||||
String.format("remoteLambdaSer(de=%s):total=%sms,one=%s", de,
|
||||
TimeUnit.NANOSECONDS.toMillis(end - start),
|
||||
(end - start) / n));
|
||||
}
|
||||
|
||||
public static void testCall0(RayFunc_0_1 f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL0);
|
||||
Assert.assertTrue(mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T, R> void testCall1(RayFunc_1_1<T, R> f, T t) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL1);
|
||||
Assert.assertTrue(mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T, R> void testCall2(RayFunc_1_1<T, R> f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL2);
|
||||
Assert.assertTrue(!mid.isStatic);
|
||||
}
|
||||
|
||||
public static <T0, T1, T2, R0> void testCall3(RayFunc_3_1<T0, T1, T2, R0> f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f);
|
||||
RayLog.core.info(mid.toString());
|
||||
Assert.assertEquals(mid.load(), CALL3);
|
||||
Assert.assertTrue(!mid.isStatic);
|
||||
}
|
||||
|
||||
public static String call0() {
|
||||
long t = System.currentTimeMillis();
|
||||
RayLog.core.info("call0:" + t);
|
||||
return String.valueOf(t);
|
||||
}
|
||||
|
||||
public static String call1(Long v) {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
v += i;
|
||||
}
|
||||
RayLog.core.info("call1:" + v);
|
||||
return String.valueOf(v);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLambdaSer() throws Exception {
|
||||
testCall0(LambdaUtilsTest::call0);
|
||||
testCall1(LambdaUtilsTest::call1, Long.valueOf(System.currentTimeMillis()));
|
||||
testCall2(LambdaUtilsTest::call2);
|
||||
testCall3(LambdaUtilsTest::call3);
|
||||
}
|
||||
|
||||
/**
|
||||
* to test the serdeLambda's perf.
|
||||
*/
|
||||
public void testBenchmark() throws Exception {
|
||||
//test serde
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 2, true, true);
|
||||
//warmup
|
||||
|
||||
RayLog.core.info("warmup:serde################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
RayLog.core.info("benchmark:serde################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
RayLog.core.info("benchmark:ser################");
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
long start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
long end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, false, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one sertime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test serde one new call's time, no class cache
|
||||
start = System.nanoTime();
|
||||
testRemoteLambdaSerde(LambdaUtilsTest::call3, 1, true, false);
|
||||
end = System.nanoTime();
|
||||
RayLog.core.info("one serdetime:" + (end - start));
|
||||
|
||||
//test lambda
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, true, true);
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 2, false, true);
|
||||
//warmup
|
||||
RayLog.core.info("warmup:parse################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
RayLog.core.info("benchmark:parseNew################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, true, false);
|
||||
RayLog.core.info("benchmark:parseCache################");
|
||||
testRemoteLambdaParse(LambdaUtilsTest::call3, 1000000, false, false);
|
||||
}
|
||||
|
||||
public String call2() {
|
||||
long t = System.currentTimeMillis();
|
||||
RayLog.core.info("call2:" + t);
|
||||
return "call2:" + t;
|
||||
}
|
||||
|
||||
public String call3(Long v, String s) {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
v += i;
|
||||
}
|
||||
RayLog.core.info("call3:" + v);
|
||||
return String.valueOf(v);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
|
||||
import java.lang.invoke.SerializedLambda;
|
||||
import java.lang.reflect.Method;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.api.funcs.RayFunc_3_1;
|
||||
import org.ray.util.LambdaUtils;
|
||||
import org.ray.util.MethodId;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
public class MethodIdTest {
|
||||
|
||||
public static <T0, T1, T2, R0> MethodId fromLambda(RayFunc_3_1<T0, T1, T2, R0> f) {
|
||||
MethodId mid = MethodId.fromSerializedLambda(f, true);
|
||||
return mid;
|
||||
}
|
||||
|
||||
public static MethodId fromClass(Method method) {
|
||||
return MethodId.fromMethod(method);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMethodId2From() throws Exception {
|
||||
MethodId m1 = fromLambda(MethodIdTest::call);
|
||||
Method m = MethodIdTest.class.getDeclaredMethod("call", new Class[]{long.class, String.class});
|
||||
MethodId m2 = fromClass(m);
|
||||
RayLog.core.info(m1.toString());
|
||||
Assert.assertEquals(m1, m2);
|
||||
}
|
||||
|
||||
public String call(long v, String s) {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
v += i;
|
||||
}
|
||||
RayLog.core.info("call:" + v);
|
||||
return String.valueOf(v);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.spi.model.RayActorMethods;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
public class RayActorMethodsTest {
|
||||
|
||||
@Test
|
||||
public void testActor() throws Exception {
|
||||
RayActorMethods methods = RayActorMethods
|
||||
.fromClass(ActorTest.Adder.class.getName(), RayActorMethodsTest.class.getClassLoader());
|
||||
RayLog.core.info(methods.toString());
|
||||
Assert.assertEquals(methods.functions.size(), 5);
|
||||
Assert.assertEquals(methods.staticFunctions.size(), 1);
|
||||
|
||||
RayActorMethods methods2 = RayActorMethods
|
||||
.fromClass(ActorTest.Adder2.class.getName(), RayActorMethodsTest.class.getClassLoader());
|
||||
RayLog.core.info(methods2.toString());
|
||||
|
||||
Assert.assertEquals(methods2.functions.size(), 9);
|
||||
Assert.assertEquals(methods2.staticFunctions.size(), 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.ray.spi.model.RayTaskMethods;
|
||||
import org.ray.util.logger.RayLog;
|
||||
|
||||
|
||||
public class RayTaskMethodsTest {
|
||||
|
||||
@Test
|
||||
public void testTask() throws Exception {
|
||||
RayTaskMethods methods = RayTaskMethods
|
||||
.fromClass(EchoTest.class.getName(), RayTaskMethodsTest.class.getClassLoader());
|
||||
RayLog.core.info(methods.toString());
|
||||
Assert.assertEquals(methods.functions.size(), 3);
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.zip.DataFormatException;
|
||||
import org.ray.hook.JarRewriter;
|
||||
|
||||
public class RewriteTest {
|
||||
|
||||
public static void main(String[] args) throws IOException, DataFormatException {
|
||||
System.out.println(System.getProperty("user.dir"));
|
||||
JarRewriter.rewrite("target", "target2");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user