[java] Pass large args by reference (#3504)

This commit is contained in:
bibabolynn
2018-12-14 23:32:35 +08:00
committed by Hao Chen
parent de3fdeb5b5
commit 7fd24e384b
5 changed files with 59 additions and 15 deletions
@@ -27,13 +27,16 @@ import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.UniqueIdUtil;
import org.ray.runtime.util.logger.RayLog;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class AbstractRayRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
private static final int GET_TIMEOUT_MS = 1000;
private static final int FETCH_BATCH_SIZE = 1000;
@@ -75,10 +78,26 @@ public abstract class AbstractRayRuntime implements RayRuntime {
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = workerContext.getCurrentTask().taskId;
RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId);
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
}
/**
* Store a serialized object in the object store.
*
* @param obj The serialized Java object to be stored.
* @return A RayObject instance that represents the in-store object.
*/
public RayObject<Object> putSerialized(byte[] obj) {
UniqueId objectId = UniqueIdUtil.computePutId(
workerContext.getCurrentTask().taskId, workerContext.nextPutIndex());
UniqueId taskId = workerContext.getCurrentTask().taskId;
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
objectStoreProxy.putSerialized(objectId, obj, null);
return new RayObjectImpl<>(objectId);
}
@Override
public <T> T get(UniqueId objectId) throws RayException {
List<T> ret = get(ImmutableList.of(objectId));
@@ -142,8 +161,9 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
}
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId);
}
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
@@ -152,8 +172,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return finalRet;
} catch (RayException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
+ " get with Exception", e);
LOGGER.error("Failed to get objects for task {}.", taskId, e);
throw e;
} finally {
// If there were objects that we weren't able to get locally, let the local
@@ -34,8 +34,9 @@ public class MockObjectStore implements ObjectStoreLink {
}
UniqueId uniqueId = new UniqueId(objectId);
data.put(uniqueId, value);
metadata.put(uniqueId, metadataValue);
if (metadataValue != null) {
metadata.put(uniqueId, metadataValue);
}
if (scheduler != null) {
scheduler.onObjectPut(uniqueId);
}
@@ -75,6 +75,10 @@ public class ObjectStoreProxy {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
store.put(id.getBytes(), obj, metadata);
}
public enum GetStatus {
SUCCESS, FAILED
}
@@ -6,14 +6,17 @@ import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.util.Serializer;
public class ArgumentsBuilder {
private static boolean checkSimpleValue(Object o) {
// TODO(raulchen): implement this.
return true;
}
/**
* If the the size of an argument's serialized data is smaller than this number,
* the argument will be passed by value. Otherwise it'll be passed by reference.
*/
private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
/**
* Convert real function arguments to task spec arguments.
@@ -30,10 +33,13 @@ public class ArgumentsBuilder {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (checkSimpleValue(arg)) {
data = Serializer.encode(arg);
} else {
id = Ray.put(arg).getId();
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId();
} else {
data = serialized;
}
}
if (id != null) {
ret[i] = FunctionArg.passByReference(id);