diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 2a178a3d9..b1c29b513 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -27,13 +27,16 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.UniqueIdUtil; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Core functionality to implement Ray APIs. */ public abstract class AbstractRayRuntime implements RayRuntime { + private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); + private static final int GET_TIMEOUT_MS = 1000; private static final int FETCH_BATCH_SIZE = 1000; @@ -75,10 +78,26 @@ public abstract class AbstractRayRuntime implements RayRuntime { public void put(UniqueId objectId, T obj) { UniqueId taskId = workerContext.getCurrentTask().taskId; - RayLog.core.debug("Putting object {}, for task {} ", objectId, taskId); + LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj, null); } + + /** + * Store a serialized object in the object store. + * + * @param obj The serialized Java object to be stored. + * @return A RayObject instance that represents the in-store object. + */ + public RayObject putSerialized(byte[] obj) { + UniqueId objectId = UniqueIdUtil.computePutId( + workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); + UniqueId taskId = workerContext.getCurrentTask().taskId; + LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); + objectStoreProxy.putSerialized(objectId, obj, null); + return new RayObjectImpl<>(objectId); + } + @Override public T get(UniqueId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); @@ -142,8 +161,9 @@ public abstract class AbstractRayRuntime implements RayRuntime { } } - RayLog.core - .debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get"); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId); + } List finalRet = new ArrayList<>(); for (Pair value : ret) { @@ -152,8 +172,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { return finalRet; } catch (RayException e) { - RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) - + " get with Exception", e); + LOGGER.error("Failed to get objects for task {}.", taskId, e); throw e; } finally { // If there were objects that we weren't able to get locally, let the local diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 3dbe7b614..92d8c4ce7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -34,8 +34,9 @@ public class MockObjectStore implements ObjectStoreLink { } UniqueId uniqueId = new UniqueId(objectId); data.put(uniqueId, value); - metadata.put(uniqueId, metadataValue); - + if (metadataValue != null) { + metadata.put(uniqueId, metadataValue); + } if (scheduler != null) { scheduler.onObjectPut(uniqueId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 5f8221ff6..be33150c7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -75,6 +75,10 @@ public class ObjectStoreProxy { store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); } + public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) { + store.put(id.getBytes(), obj, metadata); + } + public enum GetStatus { SUCCESS, FAILED } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 58dc3d803..83714a6de 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -6,14 +6,17 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.id.UniqueId; +import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; public class ArgumentsBuilder { - private static boolean checkSimpleValue(Object o) { - // TODO(raulchen): implement this. - return true; - } + /** + * If the the size of an argument's serialized data is smaller than this number, + * the argument will be passed by value. Otherwise it'll be passed by reference. + */ + private static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024; + /** * Convert real function arguments to task spec arguments. @@ -30,10 +33,13 @@ public class ArgumentsBuilder { data = Serializer.encode(arg); } else if (arg instanceof RayObject) { id = ((RayObject) arg).getId(); - } else if (checkSimpleValue(arg)) { - data = Serializer.encode(arg); } else { - id = Ray.put(arg).getId(); + byte[] serialized = Serializer.encode(arg); + if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { + id = ((AbstractRayRuntime)Ray.internal()).putSerialized(serialized).getId(); + } else { + data = serialized; + } } if (id != null) { ret[i] = FunctionArg.passByReference(id); diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index ae1fc4483..99be04c1b 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -2,6 +2,8 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; + +import java.io.Serializable; import java.util.List; import java.util.Map; import org.junit.Assert; @@ -66,6 +68,15 @@ public class RayCallTest { return val; } + public static class LargeObject implements Serializable { + private byte[] data = new byte[1024 * 1024]; + } + + @RayRemote + private static LargeObject testLargeObject(LargeObject largeObject) { + return largeObject; + } + /** * Test calling and returning different types. */ @@ -83,6 +94,8 @@ public class RayCallTest { Assert.assertEquals(list, Ray.call(RayCallTest::testList, list).get()); Map map = ImmutableMap.of("1", 1, "2", 2); Assert.assertEquals(map, Ray.call(RayCallTest::testMap, map).get()); + LargeObject largeObject = new LargeObject(); + Assert.assertNotNull(Ray.call(RayCallTest::testLargeObject, largeObject).get()); } @RayRemote @@ -130,4 +143,5 @@ public class RayCallTest { Assert.assertEquals(5, (int) Ray.call(RayCallTest::testFiveParams, 1, 1, 1, 1, 1).get()); Assert.assertEquals(6, (int) Ray.call(RayCallTest::testSixParams, 1, 1, 1, 1, 1, 1).get()); } + }