[Java] Replace binary rewrite with Remote Lambda Cache (SerdeLambda) (#2245)

* <feature> : serde lambda

* <feature>:fixed CR

with issue #2245

* <feature>: fixed CR
This commit is contained in:
mylinyuzhi
2018-06-14 03:58:07 +08:00
committed by Philipp Moritz
parent 62de86ff7a
commit fa0ade2bc5
89 changed files with 2633 additions and 7668 deletions
+10 -4
View File
@@ -22,19 +22,25 @@
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>1.2.17</version>
</dependency>
<dependency>
<groupId>quartz</groupId>
<artifactId>quartz</artifactId>
<version>1.5.2</version>
</dependency>
<dependency>
<groupId>org.ini4j</groupId>
<artifactId>ini4j</artifactId>
<version>0.5.2</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>
</project>
@@ -0,0 +1,31 @@
package org.ray.util;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
/**
* see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html.
*/
public final class LambdaUtils {
private LambdaUtils() {
}
public static SerializedLambda getSerializedLambda(Serializable lambda) {
// Note.
// the class of lambda which isAssignableFrom Serializable
// has an privte method:writeReplace
// This mechanism may be changed in the future
try {
Method m = lambda.getClass().getDeclaredMethod("writeReplace");
m.setAccessible(true);
return (SerializedLambda) m.invoke(lambda);
} catch (Exception e) {
throw new RuntimeException("failed to getSerializedLambda:" + lambda.getClass().getName(), e);
}
}
}
@@ -0,0 +1,212 @@
package org.ray.util;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.WeakHashMap;
import org.objectweb.asm.Type;
import org.ray.util.logger.RayLog;
/**
* An instance of RayFunc is a lambda.
* MethodId describe the information of the called function in lambda.<br/>
* e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:<br/>
* MethodId.className = Foo <br/>
* MethodId.methodName = foo <br/>
* MethodId.methodDesc = describe the types of args and return.
* see org.objectweb.asm.Type.getDescriptor.
*/
public final class MethodId {
/**
* use ThreadLocal to avoid lock.
* A cache from the lambda instances to MethodId.
* Note: the lambda instances are dynamically created per call site,
* we use WeakHashMap to avoid OOM.
*/
private static final ThreadLocal<WeakHashMap<Class<Serializable>, MethodId>>
CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>());
public final String className;
public final String methodName;
public final String methodDesc;
public final boolean isStatic;
/**
* encode the className,methodName,methodDesc,isStatic as an uniquel id.
*/
private final String encoding;
/**
* sha1 from the encoding, used as functionId.
*/
private final byte[] digest;
public MethodId(String className, String methodName, String methodDesc, boolean isStatic) {
this.className = className;
this.methodName = methodName;
this.methodDesc = methodDesc;
this.isStatic = isStatic;
this.encoding = encode(className, methodName, methodDesc, isStatic);
this.digest = getSha1Hash0();
}
private static String encode(String className, String methodName, String methodDesc,
boolean isStatic) {
StringBuilder sb = new StringBuilder(512);
sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&")
.append(isStatic);
return sb.toString();
}
public static MethodId fromMethod(Method method) {
final boolean isstatic = Modifier.isStatic(method.getModifiers());
final String className = method.getDeclaringClass().getName();
final String methodName = method.getName();
final Type type = Type.getType(method);
final String methodDesc = type.getDescriptor();
return new MethodId(className, methodName, methodDesc, isstatic);
}
public static MethodId fromSerializedLambda(Serializable serial) {
return fromSerializedLambda(serial, false);
}
public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) {
Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be "
+ "SerializedLambda");
Class<Serializable> clazz = (Class<Serializable>) serial.getClass();
WeakHashMap<Class<Serializable>, MethodId> map = CACHE.get();
MethodId id = map.get(clazz);
if (id == null || forceNew) {
final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial);
Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda "
+ "which is closure");
final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic;
final String className = lambda.getImplClass().replace('/', '.');
id = new MethodId(className, lambda.getImplMethodName(),
lambda.getImplMethodSignature(), isStatic);
if (!forceNew) {
map.put(clazz, id);
}
}
return id;
}
public Method load() {
return load(null);
}
public Method load(ClassLoader loader) {
Class<?> cls = null;
try {
RayLog.core.debug(
"load class " + className + " from class loader " + (loader == null ? this.getClass()
.getClassLoader() : loader)
+ " for method " + toString() + " with ID = " + toHexHashString()
);
cls = Class
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
} catch (Throwable e) {
RayLog.core.error("Cannot load class " + className, e);
return null;
}
Method[] ms = cls.getDeclaredMethods();
ArrayList<Method> methods = new ArrayList<>();
Type t = Type.getMethodType(this.methodDesc);
Type[] params = t.getArgumentTypes();
String rt = t.getReturnType().getDescriptor();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
continue;
}
String mrt = Type.getDescriptor(m.getReturnType());
if (!rt.equals(mrt)) {
continue;
}
if (isStatic != Modifier.isStatic(m.getModifiers())) {
continue;
}
methods.add(m);
}
}
if (methods.size() != 1) {
RayLog.core.error(
"Load method " + toString() + " failed as there are " + methods.size() + " definitions");
return null;
}
return methods.get(0);
}
private byte[] getSha1Hash0() {
byte[] digests = Sha1Digestor.digest(encoding);
ByteBuffer bb = ByteBuffer.wrap(digests);
bb.order(ByteOrder.LITTLE_ENDIAN);
if (methodName.contains("createActorStage1")) {
bb.putLong(Long.BYTES, 1);
} else {
bb.putLong(Long.BYTES, 0);
}
return digests;
}
public byte[] getSha1Hash() {
return digest;
}
private String toHexHashString() {
byte[] id = this.getSha1Hash();
return StringUtil.toHexHashString(id);
}
public String toEncodingString() {
return encoding;
}
@Override
public int hashCode() {
return encoding.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MethodId other = (MethodId) obj;
return className.equals(other.className)
&& methodName.equals(other.methodName)
&& methodDesc.equals(other.methodDesc)
&& isStatic == other.isStatic;
}
@Override
public String toString() {
return encoding;
}
}
@@ -15,6 +15,7 @@ public class Sha1Digestor {
}
});
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
@@ -27,4 +28,14 @@ public class Sha1Digestor {
dg.update(longBuffer.get().putLong(addIndex).array());
return dg.digest();
}
}
public static byte[] digest(String str) {
return digest(str.getBytes(StringUtil.UTF8));
}
public static byte[] digest(byte[] src) {
MessageDigest dg = md.get();
dg.reset();
return dg.digest(src);
}
}
@@ -1,11 +1,16 @@
package org.ray.util;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Vector;
public class StringUtil {
public static final Charset UTF8 = Charset.forName("UTF-8");
private static final char[] HEX_CHARS = "0123456789abcdef".toCharArray();
/**
* split.
* @param s input string
@@ -117,6 +122,17 @@ public class StringUtil {
return objs.length == 0 ? "" : sb.substring(0, sb.length() - concatenator.length());
}
public static String toHexHashString(byte[] id) {
StringBuilder sb = new StringBuilder(20);
assert (id.length == 20);
for (int i = 0; i < 20; i++) {
int val = id[i] & 0xff;
sb.append(HEX_CHARS[val >> 4]);
sb.append(HEX_CHARS[val & 0xf]);
}
return sb.toString();
}
// Holds the start of an element and which brace started it.
private static class Start {
@@ -132,4 +148,3 @@ public class StringUtil {
}
}
}