[java] Remove multi-return API (#2724)

This commit is contained in:
Hao Chen
2018-08-26 15:04:54 +08:00
committed by Robert Nishihara
parent dbba7f2a53
commit 4f4bea086a
98 changed files with 615 additions and 7637 deletions
+4 -9
View File
@@ -1,5 +1,6 @@
package org.ray.api;
import com.google.common.collect.ImmutableList;
import java.util.List;
import org.ray.api.internal.RayConnector;
import org.ray.util.exception.TaskExecutionException;
@@ -58,25 +59,19 @@ public final class Ray extends Rpc {
* @param numReturns how many of ready is enough
* @param timeoutMilliseconds in millisecond
*/
public static <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns,
public static <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns,
int timeoutMilliseconds) {
return impl.wait(waitfor, numReturns, timeoutMilliseconds);
}
public static <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns) {
public static <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns) {
return impl.wait(waitfor, numReturns, Integer.MAX_VALUE);
}
public static <T> WaitResult<T> wait(RayList<T> waitfor) {
public static <T> WaitResult<T> wait(List<RayObject<T>> waitfor) {
return impl.wait(waitfor, waitfor.size(), Integer.MAX_VALUE);
}
public static <T> WaitResult<T> wait(RayObject<T> waitfor, int timeoutMilliseconds) {
RayList<T> waits = new RayList<>();
waits.add(waitfor);
return impl.wait(waits, 1, timeoutMilliseconds);
}
/**
* create actor object.
*/
+8 -44
View File
@@ -1,9 +1,7 @@
package org.ray.api;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import org.ray.api.internal.RayFunc;
import org.ray.api.funcs.RayFunc;
import org.ray.util.exception.TaskExecutionException;
/**
@@ -39,11 +37,11 @@ public interface RayApi {
/**
* wait until timeout or enough RayObjects are ready.
*
* @param waitfor wait for who
* @param waitfor wait for who
* @param numReturns how many of ready is enough
* @param timeout in millisecond
* @param timeout in millisecond
*/
<T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout);
<T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout);
/**
* create remote actor.
@@ -51,46 +49,12 @@ public interface RayApi {
<T> RayActor<T> create(Class<T> cls);
/**
* submit a new task by invoking a remote function.
* Invoke a remote function.
*
* @param taskId nil
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnCount the number of to-be-returned objects from funcRun
* @param args arguments to this funcRun, can be its original form or RayObject
* @param func the target running function
* @param args arguments to this funcRun, can be its original form or RayObject
* @return a set of ray objects with their return ids
*/
RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args);
RayObject call(RayFunc func, Object... args);
/**
* In some cases, we would like the return value of a remote function to be splitted into multiple
* parts so that they are consumed by multiple further functions separately (potentially on
* different machines). We therefore introduce this API so that developers can annotate the
* outputs with a set of labels (usually with Integer or String).
*
* @param taskId nil
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnids a set of labels to be used by the returned objects
* @param args arguments to this funcRun, can be its original form or
* RayObject<original-type>
* @return a set of ray objects with their labels and return ids
*/
<R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Collection<RIDT> returnids, Object... args);
/**
* a special case for the above RID-based labeling as <0...returnCount - 1>.
*
* @param taskId nil
* @param funcCls the target running function's class
* @param lambda the target running function
* @param returnCount the number of to-be-returned objects from funcRun
* @param args arguments to this funcRun, can be its original form or
* RayObject<original-type>
* @return an array of returned objects with their Unique ids
*/
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
Integer returnCount, Object... args);
}
@@ -1,213 +0,0 @@
package org.ray.api;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
/**
* A RayList&lt;E&gt; holds a list of RayObject&lt;E&gt;,
* and can serves as parameters and/or return values of Ray calls.
*/
public class RayList<E> extends ArrayList<E> {
private static final long serialVersionUID = 2129403593610953658L;
private final ArrayList<RayObject<E>> ids = new ArrayList<>();
public List<RayObject<E>> Objects() {
return ids;
}
@Override
public int size() {
// throw new UnsupportedOperationException();
return ids.size();
}
@Override
public boolean isEmpty() {
// throw new UnsupportedOperationException();
return ids.isEmpty();
}
@Override
public boolean contains(Object o) {
// throw new UnsupportedOperationException();
return ids.contains(o);
}
@Override
public int indexOf(Object o) {
//throw new UnsupportedOperationException();
return ids.indexOf(o);
}
@Override
public int lastIndexOf(Object o) {
//throw new UnsupportedOperationException();
return ids.lastIndexOf(o);
}
@Override
public Object[] toArray() {
//throw new UnsupportedOperationException();
return ids.toArray();
}
@Override
public <T> T[] toArray(T[] a) {
//throw new UnsupportedOperationException();
return ids.toArray(a);
}
@Override
public E get(int index) {
return ids.get(index).get();
}
public List<E> get() {
List<UniqueID> objectIds = new ArrayList<>();
for (RayObject<E> id : ids) {
objectIds.add(id.getId());
}
return Ray.get(objectIds);
}
@RayDisabled
@Deprecated
@Override
public E set(int index, E element) {
throw new UnsupportedOperationException();
}
public RayObject<E> set(int index, RayObject<E> element) {
return ids.set(index, element);
}
@RayDisabled
@Deprecated
@Override
public boolean add(E e) {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public void add(int index, E element) {
throw new UnsupportedOperationException();
}
public boolean add(RayObject<E> e) {
return ids.add(e);
}
public void add(int index, RayObject<E> element) {
ids.add(index, element);
}
@RayDisabled
@Deprecated
@Override
public E remove(int index) {
throw new UnsupportedOperationException();
}
@Override
public boolean remove(Object o) {
//throw new UnsupportedOperationException();
return ids.remove(o);
}
@Override
public void clear() {
//throw new UnsupportedOperationException();
ids.clear();
}
@RayDisabled
@Deprecated
@Override
public boolean addAll(Collection<? extends E> c) {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public boolean addAll(int index, Collection<? extends E> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean removeAll(Collection<?> c) {
//throw new UnsupportedOperationException();
return ids.removeAll(c);
}
@Override
public boolean retainAll(Collection<?> c) {
//throw new UnsupportedOperationException();
return ids.retainAll(c);
}
@RayDisabled
@Deprecated
@Override
public ListIterator<E> listIterator(int index) {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public ListIterator<E> listIterator() {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public Iterator<E> iterator() {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public List<E> subList(int fromIndex, int toIndex) {
throw new UnsupportedOperationException();
}
public Iterator<RayObject<E>> Iterator() {
return ids.iterator();
}
@Override
public boolean containsAll(Collection<?> c) {
//throw new UnsupportedOperationException();
return ids.containsAll(c);
}
public <T> List<T> getMeta() {
List<UniqueID> objectIds = new ArrayList<>();
for (RayObject<E> id : ids) {
objectIds.add(id.getId());
}
return Ray.getMeta(objectIds);
}
public <TMT> TMT getMeta(int index) {
return ids.get(index).getMeta();
}
public RayObject<E> Get(int index) {
return ids.get(index);
}
public RayObject<E> Remove(int index) {
return ids.remove(index);
}
}
@@ -1,138 +0,0 @@
package org.ray.api;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/**
* A RayMap&lt;K&gt; maintains a map from K to RayObject&lt;V&gt;,
* and serves as parameters and/or return values of Ray calls.
*/
public class RayMap<K, V> extends HashMap<K, V> {
private static final long serialVersionUID = 7296072498584721265L;
private final HashMap<K, RayObject<V>> ids = new HashMap<>();
public HashMap<K, RayObject<V>> Objects() {
return ids;
}
@Override
public int size() {
// throw new UnsupportedOperationException();
return ids.size();
}
@Override
public boolean isEmpty() {
//throw new UnsupportedOperationException();
return ids.isEmpty();
}
@Override
public V get(Object key) {
return ids.get(key).get();
}
// TODO: try to use multiple get
public Map<K, V> get() {
Map<K, V> objs = new HashMap<>();
for (Map.Entry<K, RayObject<V>> id : ids.entrySet()) {
objs.put(id.getKey(), id.getValue().get());
}
return objs;
}
@Override
public boolean containsKey(Object key) {
//throw new UnsupportedOperationException();
return ids.containsKey(key);
}
@RayDisabled
@Deprecated
@Override
public V put(K key, V value) {
throw new UnsupportedOperationException();
}
public RayObject<V> put(K key, RayObject<V> value) {
return ids.put(key, value);
}
@RayDisabled
@Deprecated
@Override
public void putAll(Map<? extends K, ? extends V> m) {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public V remove(Object key) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
//throw new UnsupportedOperationException();
ids.clear();
}
@Override
public boolean containsValue(Object value) {
//throw new UnsupportedOperationException();
return ids.containsValue(value);
}
@Override
public Set<K> keySet() {
return ids.keySet();
}
@RayDisabled
@Deprecated
@Override
public Collection<V> values() {
throw new UnsupportedOperationException();
}
@RayDisabled
@Deprecated
@Override
public Set<java.util.Map.Entry<K, V>> entrySet() {
throw new UnsupportedOperationException();
}
public <TMT> Map<K, TMT> getMeta() {
Map<K, TMT> metas = new HashMap<>();
for (Map.Entry<K, RayObject<V>> id : ids.entrySet()) {
TMT meta = id.getValue().getMeta();
metas.put(id.getKey(), meta);
}
return metas;
}
public <TMT> TMT getMeta(K key) {
return ids.get(key).getMeta();
}
public RayObject<V> Get(K key) {
return ids.get(key);
}
public RayObject<V> Remove(K key) {
return ids.remove(key);
}
public Collection<RayObject<V>> Values() {
return ids.values();
}
public Set<java.util.Map.Entry<K, RayObject<V>>> EntrySet() {
return ids.entrySet();
}
}
@@ -1,32 +0,0 @@
package org.ray.api;
import org.apache.commons.lang3.ArrayUtils;
/**
* Real object or ray future proxy for multiple returns.
*/
public class RayObjects {
protected RayObject[] objs;
public RayObjects(UniqueID[] ids) {
this.objs = new RayObject[ids.length];
for (int k = 0; k < ids.length; k++) {
this.objs[k] = new RayObject<>(ids[k]);
}
}
public RayObjects(RayObject[] objs) {
this.objs = objs;
}
public RayObject pop() {
RayObject lastObj = objs[objs.length - 1];
objs = ArrayUtils.subarray(objs, 0, objs.length - 1);
return lastObj;
}
public RayObject[] getObjs() {
return objs;
}
}
File diff suppressed because it is too large Load Diff
@@ -1,24 +1,26 @@
package org.ray.api;
import java.util.List;
/**
* The result of Ray.wait() distinguish the ready ones and the remain ones
*/
public class WaitResult<T> {
private final RayList<T> readyOnes;
private final RayList<T> remainOnes;
private final List<RayObject<T>> readyOnes;
private final List<RayObject<T>> remainOnes;
public WaitResult(RayList<T> readyOnes, RayList<T> remainOnes) {
public WaitResult(List<RayObject<T>> readyOnes, List<RayObject<T>> remainOnes) {
this.readyOnes = readyOnes;
this.remainOnes = remainOnes;
}
public RayList<T> getReadyOnes() {
public List<RayObject<T>> getReadyOnes() {
return readyOnes;
}
public RayList<T> getRemainOnes() {
public List<RayObject<T>> getRemainOnes() {
return remainOnes;
}
@@ -1,9 +1,9 @@
package org.ray.api.internal;
package org.ray.api.funcs;
import java.io.Serializable;
/**
* Base of the ray remote function.
* Interface of all Ray remote functions.
*/
public interface RayFunc extends Serializable {
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc0<R> extends RayFunc {
R apply();
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc1<T0, R> extends RayFunc {
R apply(T0 t0);
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc2<T0, T1, R> extends RayFunc {
R apply(T0 t0, T1 t1);
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc3<T0, T1, T2, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2);
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc4<T0, T1, T2, T3, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3);
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc5<T0, T1, T2, T3, T4, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4);
}
@@ -0,0 +1,8 @@
// generated automatically, do not modify.
package org.ray.api.funcs;
@FunctionalInterface
public interface RayFunc6<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
R apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5);
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_1<R0> extends RayFunc {
R0 apply() throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_0_2<R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply() throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_0_3<R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply() throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_0_4<R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply() throws Throwable;
}
@@ -1,13 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_n<R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_0_n_list<R> extends RayFunc {
List<R> apply() throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_1<T0, R0> extends RayFunc {
R0 apply(T0 t0) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_1_2<T0, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_1_3<T0, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_1_4<T0, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0) throws Throwable;
}
@@ -1,13 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_n<T0, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_1_n_list<T0, R> extends RayFunc {
List<R> apply(T0 t0) throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_1<T0, T1, R0> extends RayFunc {
R0 apply(T0 t0, T1 t1) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_2_2<T0, T1, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_2_3<T0, T1, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_2_4<T0, T1, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -1,13 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_n<T0, T1, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_2_n_list<T0, T1, R> extends RayFunc {
List<R> apply(T0 t0, T1 t1) throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_1<T0, T1, T2, R0> extends RayFunc {
R0 apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_3_2<T0, T1, T2, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_3_3<T0, T1, T2, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_3_4<T0, T1, T2, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,13 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_n<T0, T1, T2, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_3_n_list<T0, T1, T2, R> extends RayFunc {
List<R> apply(T0 t0, T1 t1, T2 t2) throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_1<T0, T1, T2, T3, R0> extends RayFunc {
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_4_2<T0, T1, T2, T3, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_4_3<T0, T1, T2, T3, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_4_4<T0, T1, T2, T3, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,13 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_n<T0, T1, T2, T3, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_4_n_list<T0, T1, T2, T3, R> extends RayFunc {
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3) throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_1<T0, T1, T2, T3, T4, R0> extends RayFunc {
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_5_2<T0, T1, T2, T3, T4, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_5_3<T0, T1, T2, T3, T4, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_5_4<T0, T1, T2, T3, T4, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -1,14 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_n<T0, T1, T2, T3, T4, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4)
throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_5_n_list<T0, T1, T2, T3, T4, R> extends RayFunc {
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) throws Throwable;
}
@@ -1,11 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_1<T0, T1, T2, T3, T4, T5, R0> extends RayFunc {
R0 apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns2;
@FunctionalInterface
public interface RayFunc_6_2<T0, T1, T2, T3, T4, T5, R0, R1> extends RayFunc {
MultipleReturns2<R0, R1> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns3;
@FunctionalInterface
public interface RayFunc_6_3<T0, T1, T2, T3, T4, T5, R0, R1, R2> extends RayFunc {
MultipleReturns3<R0, R1, R2> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
import org.ray.api.returns.MultipleReturns4;
@FunctionalInterface
public interface RayFunc_6_4<T0, T1, T2, T3, T4, T5, R0, R1, R2, R3> extends RayFunc {
MultipleReturns4<R0, R1, R2, R3> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,14 +0,0 @@
package org.ray.api.funcs;
import java.util.Collection;
import java.util.Map;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_n<T0, T1, T2, T3, T4, T5, R, RIDT> extends RayFunc {
Map<RIDT, R> apply(Collection<RIDT> returnids, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5)
throws Throwable;
}
@@ -1,12 +0,0 @@
package org.ray.api.funcs;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.ray.api.internal.RayFunc;
@FunctionalInterface
public interface RayFunc_6_n_list<T0, T1, T2, T3, T4, T5, R> extends RayFunc {
List<R> apply(T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Throwable;
}
@@ -1,18 +0,0 @@
package org.ray.api.returns;
/**
* Multiple return objects for user's method.
*/
public class MultipleReturns {
protected final Object[] values;
public MultipleReturns(Object[] values) {
this.values = values;
}
public Object[] getValues() {
return values;
}
}
@@ -1,17 +0,0 @@
package org.ray.api.returns;
@SuppressWarnings("unchecked")
public class MultipleReturns2<R0, R1> extends MultipleReturns {
public MultipleReturns2(R0 r0, R1 r1) {
super(new Object[] {r0, r1});
}
public R0 get0() {
return (R0) this.values[0];
}
public R1 get1() {
return (R1) this.values[1];
}
}
@@ -1,21 +0,0 @@
package org.ray.api.returns;
@SuppressWarnings("unchecked")
public class MultipleReturns3<R0, R1, R2> extends MultipleReturns {
public MultipleReturns3(R0 r0, R1 r1, R2 r2) {
super(new Object[] {r0, r1, r2});
}
public R0 get0() {
return (R0) this.values[0];
}
public R1 get1() {
return (R1) this.values[1];
}
public R2 get2() {
return (R2) this.values[2];
}
}
@@ -1,25 +0,0 @@
package org.ray.api.returns;
@SuppressWarnings("unchecked")
public class MultipleReturns4<R0, R1, R2, R3> extends MultipleReturns {
public MultipleReturns4(R0 r0, R1 r1, R2 r2, R3 r3) {
super(new Object[] {r0, r1, r2, r3});
}
public R0 get0() {
return (R0) this.values[0];
}
public R1 get1() {
return (R1) this.values[1];
}
public R2 get2() {
return (R2) this.values[2];
}
public R3 get3() {
return (R3) this.values[3];
}
}
@@ -1,25 +0,0 @@
package org.ray.api.returns;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
@SuppressWarnings({"rawtypes", "unchecked"})
public class RayObjects2<R0, R1> extends RayObjects {
public RayObjects2(UniqueID[] ids) {
super(ids);
}
public RayObjects2(RayObject[] objs) {
super(objs);
}
public RayObject<R0> r0() {
return objs[0];
}
public RayObject<R1> r1() {
return objs[1];
}
}
@@ -1,29 +0,0 @@
package org.ray.api.returns;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
@SuppressWarnings({"rawtypes", "unchecked"})
public class RayObjects3<R0, R1, R2> extends RayObjects {
public RayObjects3(UniqueID[] ids) {
super(ids);
}
public RayObjects3(RayObject[] objs) {
super(objs);
}
public RayObject<R0> r0() {
return objs[0];
}
public RayObject<R1> r1() {
return objs[1];
}
public RayObject<R2> r2() {
return objs[2];
}
}
@@ -1,33 +0,0 @@
package org.ray.api.returns;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
@SuppressWarnings({"rawtypes", "unchecked"})
public class RayObjects4<R0, R1, R2, R3> extends RayObjects {
public RayObjects4(UniqueID[] ids) {
super(ids);
}
public RayObjects4(RayObject[] objs) {
super(objs);
}
public RayObject<R0> r0() {
return objs[0];
}
public RayObject<R1> r1() {
return objs[1];
}
public RayObject<R2> r2() {
return objs[2];
}
public RayObject<R3> r3() {
return objs[3];
}
}
+1 -4
View File
@@ -9,12 +9,9 @@
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>
<suppress checks="MethodNameCheck" files="ResourcePair.java"/>
<suppress checks="ParameterNameCheck" files="ResourcePair.java"/>
<suppress checks="MethodNameCheck" files="RayMap.java"/>
<suppress checks="MethodNameCheck" files="RayList.java"/>
<suppress checks="OverloadMethodsDeclarationOrderCheck" files="Rpc.java"/>
<suppress checks="MethodTypeParameterNameCheck" files="Rpc.java"/>
<suppress checks="AbbreviationAsWordInNameCheck" files="UniqueID.java"/>
<suppress checks="TypeNameCheck" files="RayFunc_[0-9]_.+\.java"/>
<suppress checks=".*" files="Rpc.java"/>
<!-- suppress check for flatbuffer-generated files. -->
<!-- TODO(raulchen): move these files to a directory, so this rule can be simplier. -->
<suppress checks=".*" files="(Arg|ResourcePair|TaskLanguage|TaskInfo|ClientTableData).java" />
@@ -0,0 +1,19 @@
package org.ray.util.generator;
public abstract class BaseGenerator {
protected static final int MAX_PARAMETERS = 6;
protected StringBuilder sb;
protected void newLine(String line) {
sb.append(line).append("\n");
}
protected void indents(int numIndents) {
for (int i = 0; i < numIndents; i++) {
sb.append(" ");
}
}
}
@@ -1,38 +0,0 @@
package org.ray.util.generator;
import java.util.ArrayList;
import java.util.List;
/**
* Calculate all compositions for Parameter's count + Return's count, this class is used by code
* generators.
*/
public class Composition {
public static List<Tr> calculate(int maxT, int maxR) {
List<Tr> ret = new ArrayList<>();
for (int t = 0; t <= maxT; t++) {
// <= 0 for dynamic return count
// 0 for call_n returns RayMap<RID, x>
// -1 for call_n returns RayObject<>[N]
for (int r = -1; r <= maxR; r++) {
ret.add(new Tr(t, r));
}
}
return ret;
}
public static class Tr {
public final int tcount;
public final int rcount;
public Tr(int tcount, int rcount) {
super();
this.tcount = tcount;
this.rcount = rcount;
}
}
}
@@ -2,87 +2,47 @@ package org.ray.util.generator;
import java.io.IOException;
import org.ray.util.FileUtil;
import org.ray.util.generator.Composition.Tr;
/**
* Generate all classes in org.ray.api.funcs
* A util class that generates all the classes under org.ray.api.funcs package.
*/
public class FuncsGenerator {
public class FuncsGenerator extends BaseGenerator {
public static void main(String[] args) throws IOException {
String rootdir = System.getProperty("user.dir") + "/../api/src/main/java";
rootdir += "/org/ray/api/funcs";
generate(rootdir);
}
private String generate(int numParameters) {
sb = new StringBuilder();
private static void generate(String rootdir) throws IOException {
for (Tr tr : Composition.calculate(Share.MAX_T, Share.MAX_R)) {
String str = build(tr.tcount, tr.rcount);
String file = rootdir + "/RayFunc_" + tr.tcount + "_"
+ (tr.rcount <= 0 ? (tr.rcount == 0 ? "n" : "n_list") : tr.rcount) + ".java";
FileUtil.overrideFile(file, str);
System.err.println("override " + file);
String genericTypes = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
genericTypes += "T" + i + ", ";
if (i > 0) {
paramList += ", ";
}
paramList += String.format("T%d t%d", i, i);
}
}
/*
* package org.ray.api.funcs;
*
* @FunctionalInterface
* public interface RayFunc_4_1<T0, T1, T2, T3, R0> extends RayFunc { R0
* apply();
*
* public static <R0> R0 execute(Object[] args) throws Throwable { String name =
* (String)args[args.length - 2]; assert (name.equals(RayFunc_0_1.class.getName())); byte[]
* funcBytes = (byte[])args[args.length - 1]; RayFunc_0_1<R0> f = (RayFunc_0_1<R0>)
* SerializationUtils.deserialize(funcBytes);
* return f.apply(); } }
*/
private static String build(int tcount, int rcount) {
StringBuilder sb = new StringBuilder();
String tname =
"Ray" + "Func_" + tcount + "_" + (rcount <= 0 ? (rcount == 0 ? "n" : "n_list") : rcount);
final String gname = tname + "<" + Share.buildClassDeclare(tcount, rcount) + ">";
newLine("// generated automatically, do not modify.");
newLine("");
newLine("package org.ray.api.funcs;");
newLine("");
newLine("@FunctionalInterface");
newLine(String.format("public interface RayFunc%d<%sR> extends RayFunc {",
numParameters, genericTypes));
indents(1);
newLine(String.format("R apply(%s);", paramList));
newLine("}");
sb.append("package org.ray.api.funcs;").append("\n");
if (rcount > 1) {
sb.append("import org.ray.api.returns.*;").append("\n");
}
if (rcount <= 0) {
sb.append("import java.util.Collection;").append("\n");
sb.append("import java.util.List;").append("\n");
sb.append("import java.util.Map;").append("\n");
}
sb.append("import org.ray.api.*;").append("\n");
sb.append("import org.ray.api.internal.*;").append("\n");
sb.append("import org.apache.commons.lang3.SerializationUtils;").append("\n");
sb.append("\n");
sb.append("@FunctionalInterface").append("\n");
sb.append("public interface ").append(gname).append(" extends RayFunc {")
.append("\n");
sb.append("\t").append(Share.buildFuncReturn(rcount)).append(" apply(")
.append(rcount == 0 ? ("Collection<RID> returnids" + (tcount > 0 ? ", " : "")) : "")
.append(Share.buildParameter(tcount, "T", null)).append(") throws Throwable;")
.append("\n");
sb.append("\t\n");
sb.append("\tpublic static " + "<").append(Share.buildClassDeclare(tcount, rcount))
.append(">").append(" ").append(Share.buildFuncReturn(rcount))
.append(" execute(Object[] args) throws Throwable {").append("\n");
sb.append("\t\tString name = (String)args[args.length - 2];").append("\n");
sb.append("\t\tassert (name.equals(").append(tname).append(".class.getName()));").append("\n");
sb.append("\t\tbyte[] funcBytes = (byte[])args[args.length - 1];").append("\n");
sb.append("\t\t").append(gname).append(" f = SerializationUtils.deserialize(funcBytes);")
.append("\n");
sb.append("\t\treturn f.apply(")
.append(rcount == 0 ? ("(Collection<RID>)args[0]" + (tcount > 0 ? ", " : "")) : "")
.append(Share.buildParameterUse2(tcount, rcount == 0 ? 1 : 0, "T", "args[", "]"))
.append(");").append("\n");
sb.append("\t}").append("\n");
sb.append("\t\n");
sb.append("}").append("\n");
return sb.toString();
}
public static void main(String[] args) throws IOException {
String root = System.getProperty("user.dir")
+ "/api/src/main/java/org/ray/api/funcs/";
FuncsGenerator generator = new FuncsGenerator();
for (int i = 0; i <= MAX_PARAMETERS; i++) {
String content = generator.generate(i);
FileUtil.overrideFile(root + "RayFunc" + i + ".java", content);
}
}
}
@@ -1,58 +0,0 @@
package org.ray.util.generator;
import java.io.IOException;
import org.ray.util.FileUtil;
/**
* Generate all classes in org.ray.api.returns.MultipleReturnsX
*/
public class MultipleReturnGenerator {
public static void main(String[] args) throws IOException {
String rootdir = System.getProperty("user.dir") + "/../api/src/main/java";
rootdir += "/org/ray/api/returns";
for (int r = 2; r <= Share.MAX_R; r++) {
String str = build(r);
String file = rootdir + "/MultipleReturns" + r + ".java";
FileUtil.overrideFile(file, str);
System.err.println("override " + file);
}
}
/**
* package org.ray.api.returns.
*/
private static String build(int rcount) {
StringBuilder sb = new StringBuilder();
sb.append("package org.ray.api.returns;").append("\n");
sb.append("import org.ray.api.*;").append("\n");
sb.append("@SuppressWarnings(\"unchecked\")");
sb.append("public class MultipleReturns").append(rcount).append("<")
.append(Share.buildClassDeclare(0, rcount)).append("> extends MultipleReturns {")
.append("\n");
sb.append("\tpublic MultipleReturns").append(rcount).append("(")
.append(Share.buildParameter(rcount, "R", null)).append(") {")
.append("\n");
sb.append("\t\tsuper(new Object[] { ").append(Share.buildParameterUse(rcount, "R"))
.append(" });")
.append("\n");
sb.append("\t}").append("\n");
for (int k = 0; k < rcount; k++) {
sb.append(buildGetter(k));
}
sb.append("}").append("\n");
return sb.toString();
}
/*
* @SuppressWarnings("unchecked") public R1 get1() { return (R1) this.values[1]; }
*/
private static String buildGetter(int index) {
return ("\tpublic R" + index + " get" + index + "() {\n")
+ "\t\treturn (R" + index + ") this.values[" + index + "];\n"
+ "\t}\n";
}
}
@@ -1,68 +0,0 @@
package org.ray.util.generator;
import java.io.IOException;
import org.ray.util.FileUtil;
/**
* Generate all classes in org.ray.api.returns.RayObjectsX
*/
public class RayObjectsGenerator {
public static void main(String[] args) throws IOException {
String rootdir = System.getProperty("user.dir") + "/../api/src/main/java";
rootdir += "/org/ray/api/returns";
for (int r = 2; r <= Share.MAX_R; r++) {
String str = build(r);
String file = rootdir + "/RayObjects" + r + ".java";
FileUtil.overrideFile(file, str);
System.err.println("override " + file);
}
}
/*
* package org.ray.api.returns;
*
* import org.ray.api.RayObject; import org.ray.spi.model.UniqueID;
*
* @SuppressWarnings({"rawtypes", "unchecked"}) public class RayObjects2<R0, R1> extends
* RayObjects {
*
* public RayObjects2(UniqueID[] ids) { super(ids); }
*
* public RayObjects2(RayObject objs[]) { super(objs); }
*
* public RayObject<R0> r0() { return objs[0]; }
*
* public RayObject<R1> r1() { return objs[1]; } }
*/
private static String build(int rcount) {
StringBuilder sb = new StringBuilder();
sb.append("package org.ray.api.returns;\n");
sb.append("import org.ray.api.*;\n");
sb.append("import org.ray.spi.model.UniqueID;\n");
sb.append("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
sb.append("public class RayObjects").append(rcount).append("<")
.append(Share.buildClassDeclare(0, rcount)).append("> extends RayObjects {")
.append("\n");
sb.append("\tpublic RayObjects").append(rcount).append("(UniqueID[] ids) {").append("\n");
sb.append("\t\tsuper(ids);").append("\n");
sb.append("\t}").append("\n");
sb.append("\tpublic RayObjects").append(rcount).append("(RayObject objs[]) {").append("\n");
sb.append("\t\tsuper(objs);").append("\n");
sb.append("\t}").append("\n");
for (int k = 0; k < rcount; k++) {
sb.append(buildGetter(k));
}
sb.append("}").append("\n");
return sb.toString();
}
private static String buildGetter(int index) {
return "\tpublic RayObject<R" + index + "> r" + index + "() {\n"
+ "\t\treturn objs[" + index + "];\n"
+ "\t}\n";
}
}
@@ -1,160 +1,77 @@
package org.ray.util.generator;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.ArrayList;
import java.util.List;
import org.ray.util.FileUtil;
import org.ray.util.generator.Composition.Tr;
/**
* Generate Rpc.java
* A util class that generates Rpc.java
*/
public class RpcGenerator {
public class RpcGenerator extends BaseGenerator {
private String build() {
sb = new StringBuilder();
newLine("// generated automatically, do not modify.");
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.funcs.*;");
newLine("");
newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
newLine("class Rpc {");
for (int i = 0; i <= 6; i++) {
buildCalls(i);
}
newLine("}");
return sb.toString();
}
private void buildCalls(int numParameters) {
String funcClass = "RayFunc" + numParameters;
String genericTypes = "";
String callList = "";
for (int i = 1; i <= numParameters; i++) {
genericTypes += "T" + i + ", ";
callList += ", t" + i;
}
String body = String.format("return Ray.internal().call(f%s);", callList);
for (String param : generateParameters(numParameters)) {
indents(1);
newLine(String.format(
"public static <%sR> RayObject<R> call(%s<%sR> f%s) {",
genericTypes, funcClass, genericTypes, param
));
indents(2);
newLine(body);
indents(1);
newLine("}");
}
}
private List<String> generateParameters(int numParameters) {
List<String> res = new ArrayList<>();
dfs(1, numParameters, "", res);
return res;
}
private void dfs(int pos, int max, String cur, List<String> res) {
if (pos > max) {
res.add(cur);
return;
}
cur += ", ";
String nextParameter = String.format("T%d t%d", pos, pos);
dfs(pos + 1, max, cur + nextParameter, res);
nextParameter = String.format("RayObject<T%d> t%d", pos, pos);
dfs(pos + 1, max, cur + nextParameter, res);
}
public static void main(String[] args) throws IOException {
String rootdir = System.getProperty("user.dir") + "/../api/src/main/java";
rootdir += "/org/ray/api/Rpc.java";
FileUtil.overrideFile(rootdir, build());
String path = System.getProperty("user.dir") + "/api/src/main/java";
path += "/org/ray/api/Rpc.java";
FileUtil.overrideFile(path, new RpcGenerator().build());
}
private static String build() {
StringBuilder sb = new StringBuilder();
sb.append("package org.ray.api.internal;\n");
sb.append("import org.ray.api.funcs.*;\n");
sb.append("import org.ray.api.returns.*;\n");
sb.append("import org.ray.api.*;\n");
sb.append("import java.util.Collection;\n");
sb.append("import java.util.Map;\n");
sb.append("@SuppressWarnings({\"rawtypes\", \"unchecked\"})\n");
sb.append("class Rpc {\n");
for (Tr tr : Composition.calculate(Share.MAX_T, Share.MAX_R)) {
buildCall(sb, tr.tcount, tr.rcount);
}
sb.append("}\n");
return sb.toString();
}
private static void buildCall(StringBuilder sb, int tcount, int rcount) {
for (Set<Integer> whichTisFuture : whichTisFutureComposition(tcount)) {
sb.append(buildCall(tcount, rcount, whichTisFuture));
}
}
private static String buildCall(int tcount, int rcount, Set<Integer> whichTisFuture) {
StringBuilder sb = new StringBuilder();
String parameter = (tcount == 0 ? ""
: ", " + Share.buildParameter(tcount, "T", whichTisFuture));
sb.append("\tpublic static <").append(Share.buildClassDeclare(tcount, rcount)).append("> ")
.append(Share.buildRpcReturn(rcount)).append(" call")
.append(rcount == 1 ? "" : (rcount <= 0 ? "_n" : ("_" + rcount))).append("(RayFunc_")
.append(tcount).append("_")
.append(rcount <= 0 ? (rcount == 0 ? "n" : "n_list") : rcount).append("<")
.append(Share.buildClassDeclare(tcount, rcount)).append("> f").append(
rcount <= 0 ? (rcount == 0 ? ", Collection<RID> returnids" : ", Integer returnCount")
: "").append(parameter).append(") {\n");
String nulls = Share.buildRepeat("null",
tcount + (rcount == 0 ? 1/*for first arg map*/ : 0));
String parameterUse = (tcount == 0 ? "" : (", " + Share.buildParameterUse(tcount, "T")));
String labmdaUse = "RayFunc_"
+ tcount + "_" + (rcount <= 0 ? (rcount == 0 ? "n" : "n_list") : rcount)
+ ".class, f";
sb.append("\t\tif (Ray.Parameters().remoteLambda()) {\n");
if (rcount == 1) {
sb.append("\t\t\treturn Ray.internal().call(null, ").append(labmdaUse).append(", 1")
.append(parameterUse).append(").objs[0];")
.append("\n");
} else if (rcount == 0) {
sb.append("\t\t\treturn Ray.internal().callWithReturnLabels(null, ")
.append(labmdaUse).append(", returnids").append(parameterUse).append(");")
.append("\n");
} else if (rcount < 0) {
sb.append("\t\t\treturn Ray.internal().callWithReturnIndices(null, ")
.append(labmdaUse).append(", returnCount").append(parameterUse).append(");")
.append("\n");
} else {
sb.append("\t\t\treturn new RayObjects").append(rcount)
.append("(Ray.internal().call(null, ").append(labmdaUse).append(", ").append(rcount)
.append(parameterUse).append(").objs);")
.append("\n");
}
sb.append("\t\t} else {\n");
if (rcount == 1) {
sb.append("\t\t\treturn Ray.internal().call(null, () -> f.apply(").append(nulls)
.append("), 1").append(parameterUse).append(").objs[0];")
.append("\n");
} else if (rcount == 0) {
sb.append("\t\t\treturn Ray.internal().callWithReturnLabels(null, () -> f.apply(")
.append(nulls).append("), returnids").append(parameterUse).append(");")
.append("\n");
} else if (rcount < 0) {
sb.append("\t\t\treturn Ray.internal().callWithReturnIndices(null, () -> f.apply(")
.append(nulls).append("), returnCount").append(parameterUse).append(");")
.append("\n");
} else {
sb.append("\t\t\treturn new RayObjects").append(rcount)
.append("(Ray.internal().call(null, () -> f.apply(").append(nulls).append("), ")
.append(rcount).append(parameterUse).append(").objs);")
.append("\n");
}
sb.append("\t\t}\n");
sb.append("\t}\n");
return sb.toString();
}
private static Set<Set<Integer>> whichTisFutureComposition(int tcount) {
Set<Set<Integer>> ret = new HashSet<>();
Set<Integer> n = new HashSet<>();
for (int k = 0; k < tcount; k++) {
n.add(k);
}
for (int k = 0; k <= tcount; k++) {
ret.addAll(cnn(n, k));
}
return ret;
}
//pick n numbers in N
private static Set<Set<Integer>> cnn(Set<Integer> bigN, int n) {
C c = new C();
for (int k = 0; k < n; k++) {
c.mul(bigN);
}
return c.vc;
}
static class C {
Set<Set<Integer>> vc;
public C() {
vc = new HashSet<>();
vc.add(new HashSet<>());
}
void mul(Set<Integer> ns) {
Set<Set<Integer>> ret = new HashSet<>();
for (int n : ns) {
ret.addAll(mul(n));
}
this.vc = ret;
}
Set<Set<Integer>> mul(int n) {
Set<Set<Integer>> ret = new HashSet<>();
for (Set<Integer> s : vc) {
if (s.contains(n)) {
continue;
}
Set<Integer> ns = new HashSet<>(s);
ns.add(n);
ret.add(ns);
}
return ret;
}
}
}
@@ -1,134 +0,0 @@
package org.ray.util.generator;
import java.util.Set;
/**
* Share util for generators.
*/
public class Share {
public static final int MAX_T = 6;
public static final int MAX_R = 4;
/**
* T0, T1, T2, T3, R.
*/
public static String buildClassDeclare(int tcount, int rcount) {
StringBuilder sb = new StringBuilder();
for (int k = 0; k < tcount; k++) {
sb.append("T").append(k).append(", ");
}
if (rcount == 0) {
sb.append("R, RID");
} else if (rcount < 0) {
assert (rcount == -1);
sb.append("R");
} else {
for (int k = 0; k < rcount; k++) {
sb.append("R").append(k).append(", ");
}
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
}
return sb.toString();
}
/**
* T0 t0, T1 t1, T2 t2, T3 t3.
*/
public static String buildParameter(int tcount, String tr, Set<Integer> whichTisFuture) {
StringBuilder sb = new StringBuilder();
for (int k = 0; k < tcount; k++) {
sb.append(whichTisFuture != null && whichTisFuture.contains(k)
? "RayObject<" + (tr + k) + ">" : (tr + k)).append(" ").append(tr.toLowerCase())
.append(k).append(", ");
}
if (tcount > 0) {
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
}
return sb.toString();
}
/**
* t0, t1, t2.
*/
public static String buildParameterUse(int tcount, String tr) {
StringBuilder sb = new StringBuilder();
for (int k = 0; k < tcount; k++) {
sb.append(tr.toLowerCase()).append(k).append(", ");
}
if (tcount > 0) {
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
}
return sb.toString();
}
public static String buildParameterUse2(int tcount, int startIndex, String tr, String pre,
String post) {
StringBuilder sb = new StringBuilder();
for (int k = 0; k < tcount; k++) {
sb.append("(").append(tr).append(k).append(")").append(pre).append(k + startIndex)
.append(post).append(", ");
}
if (tcount > 0) {
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
}
return sb.toString();
}
public static String buildFuncReturn(int rcount) {
if (rcount == 0) {
return "Map<RID, R>";
} else if (rcount < 0) {
assert (-1 == rcount);
return "List<R>";
}
if (rcount == 1) {
return "R0";
}
StringBuilder sb = new StringBuilder();
for (int k = 0; k < rcount; k++) {
sb.append("R").append(k).append(", ");
}
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
return "MultipleReturns" + rcount + "<" + sb.toString() + ">";
}
public static String buildRpcReturn(int rcount) {
if (rcount == 0) {
return "RayMap<RID, R>";
} else if (rcount < 0) {
assert (rcount == -1);
return "RayList<R>";
}
if (rcount == 1) {
return "RayObject<R0>";
}
StringBuilder sb = new StringBuilder();
for (int k = 0; k < rcount; k++) {
sb.append("R").append(k).append(", ");
}
sb.deleteCharAt(sb.length() - 1);
sb.deleteCharAt(sb.length() - 1);
return "RayObjects" + rcount + "<" + sb.toString() + ">";
}
public static String buildRepeat(String toRepeat, int count) {
StringBuilder ret = new StringBuilder();
for (int k = 0; k < count; k++) {
ret.append(toRepeat).append(", ");
}
if (count > 0) {
ret.deleteCharAt(ret.length() - 1);
ret.deleteCharAt(ret.length() - 1);
}
return ret.toString();
}
}
@@ -1,5 +1,5 @@
# define default properties here
logging.level=INFO
logging.level=WARN
logging.path=./run/logs
logging.file.name=core
logging.max.log.file.num=10
+4 -90
View File
@@ -79,99 +79,13 @@ locally available.
``Ray.wait`` is used to wait for a list of ``RayObject``\s to be locally available.
It will block the current thread until ``numReturns`` objects are ready or
``timeoutMilliseconds`` has passed. See multi-value support for ``RayList``.
``timeoutMilliseconds`` has passed.
.. code:: java
public static WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeoutMilliseconds);
public static WaitResult<T> wait(RayList<T> waitfor, int numReturns);
public static WaitResult<T> wait(RayList<T> waitfor);
Multi-value API
---------------
Multi-value Types
~~~~~~~~~~~~~~~~~
Multiple ``RayObject``\s can be placed in a single data
structure as a return value or as a ``Ray.call`` parameter through the
following container types.
``MultipleReturnsX<R0, R1, ...>``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This consists of multiple heterogeneous values, with types ``R0``,
``R1``,... respectively. Currently this container type is only
supported as the return type of ``Ray.call``. Therefore you cannot use it
as the type of an input parameter.
``RayList<T>``
^^^^^^^^^^^^^^
This is a list of ``RayObject<T>``\s, which inherits from ``List<T>`` in Java. It
can be used as the type for both a return value and a parameter value.
``RayMap<L, T>``
^^^^^^^^^^^^^^^^
A map of ``RayObject<T>``\s with each indexed using a label with type
``L``, inherited from ``Map<L, T>``. It can be used as the type for both
a return value and a parameter value.
Multiple heterogeneous return values
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To return multiple heterogeneous values in a remote functions, you can
define your original method's return type as ``MultipleReturnsX`` and
then invoke it with ``Ray.call_X``. Note: ``X`` is the number of return
values, at most 4 values are supported.
Here's an `example <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise05.java>`_.
Return with ``RayList``
~~~~~~~~~~~~~~~~~~~~~~~
To return a list of ``RayObject``\s, you can invoke your method with ``Ray.call_n``.
``Ray.call_n`` is similar to ``Ray.call`` except that it has an additional parameter
``returnCount``, which specifies the number of return values in the list.
Here's an `example <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise06.java>`_.
Return with ``RayMap``
~~~~~~~~~~~~~~~~~~~~~~
This is similar to ``RayList`` case, except that now each returned
``RayObject<R>`` in ``RayMap<L,R>`` has a given label when
``Ray.call_n`` is called.
Here's an `example <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise07.java>`_.
Use ``RayList`` and ``RayMap`` as parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code:: java
public class ListTExample {
public static void main(String[] args) {
Ray.init();
RayList<Integer> ints = new RayList<>();
ints.add(Ray.put(new Integer(1)));
ints.add(Ray.put(new Integer(1)));
ints.add(Ray.put(new Integer(1)));
RayObject<Integer> obj = Ray.call(ListTExample::sum(List<Integer>)ints);
Assert.assertTrue(obj.get().equals(3));
}
@RayRemote
public static int sum(List<Integer> ints) {
int sum = 0;
for (Integer i : ints) {
sum += i;
}
return sum;
}
}
public static WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeoutMilliseconds);
public static WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns);
public static WaitResult<T> wait(List<RayObject<T>> waitfor);
Actor Support
-------------
@@ -11,8 +11,6 @@ import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.spi.model.FunctionArg;
@@ -43,28 +41,9 @@ public class ArgumentsBuilder {
} else { // serialize actor handle
fargs[k].data = Serializer.encode(oarg);
}
} else if (oarg.getClass().equals(RayObject.class)) {
fargs[k].ids = new ArrayList<>();
fargs[k].ids.add(((RayObject) oarg).getId());
} else if (oarg instanceof RayMap) {
fargs[k].ids = new ArrayList<>();
RayMap<?, ?> rm = (RayMap<?, ?>) oarg;
RayMapArg narg = new RayMapArg();
for (Entry e : rm.EntrySet()) {
narg.put(e.getKey(), ((RayObject) e.getValue()).getId());
fargs[k].ids.add(((RayObject) e.getValue()).getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (oarg instanceof RayList) {
fargs[k].ids = new ArrayList<>();
RayList<?> rl = (RayList<?>) oarg;
RayListArg narg = new RayListArg();
for (RayObject e : rl.Objects()) {
// narg.add(e.getId()); // we don't really need to use the ids
fargs[k].ids.add(e.getId());
}
fargs[k].data = Serializer.encode(narg);
} else if (checkSimpleValue(oarg)) {
fargs[k].data = Serializer.encode(oarg);
} else {
@@ -120,42 +99,11 @@ public class ArgumentsBuilder {
} else if (farg.data == null) { // only ids, big data or single object id
assert (farg.ids.size() == 1);
realArgs[raIndex] = RayRuntime.getInstance().get(farg.ids.get(0));
} else { // both id and data, could be RayList or RayMap only
Object idBag = Serializer.decode(farg.data, classLoader);
if (idBag instanceof RayMapArg) {
Map newMap = new HashMap<>();
RayMapArg<?> oldmap = (RayMapArg<?>) idBag;
assert (farg.ids.size() == oldmap.size());
for (Entry<?, UniqueID> e : oldmap.entrySet()) {
newMap.put(e.getKey(), RayRuntime.getInstance().get(e.getValue()));
}
realArgs[raIndex] = newMap;
} else {
List newlist = new ArrayList<>();
for (UniqueID old : farg.ids) {
newlist.add(RayRuntime.getInstance().get(old));
}
realArgs[raIndex] = newlist;
}
}
}
return Pair.of(current, realArgs);
}
//for recognition
public static class RayMapArg<K> extends HashMap<K, UniqueID> {
private static final long serialVersionUID = 8529310038241410256L;
}
//for recognition
public static class RayListArg<K> extends ArrayList<K> {
private static final long serialVersionUID = 8529310038241410256L;
}
public static class RayActorId implements Serializable {
private static final long serialVersionUID = 3993646395842605166L;
@@ -2,12 +2,9 @@ package org.ray.core;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.UniqueID;
import org.ray.api.returns.MultipleReturns;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.exception.TaskExecutionException;
@@ -34,8 +31,7 @@ public class InvocationExecutor {
// execute
try {
//RayLog.core.debug(task.toString());
executeInternal(task, pr, taskdesc);
executeInternal(task, pr);
} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
if (!task.actorId.isNil() && RayRuntime.getInstance().getLocalActor(task.actorId) == null) {
ex = new TaskExecutionException("Task " + taskdesc + " execution on actor " + task.actorId
@@ -67,23 +63,10 @@ public class InvocationExecutor {
}
}
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr,
String taskdesc)
private static void executeInternal(TaskSpec task, Pair<ClassLoader, RayMethod> pr)
throws IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Method m = pr.getRight().invokable;
Map<?, UniqueID> userRayReturnIdMap = null;
Class<?> returnType = m.getReturnType(); // TODO: not ready for multiple return etc.
boolean hasMultiReturn = false;
if (task.returnIds != null && task.returnIds.length > 0) {
hasMultiReturn = UniqueIdHelper.hasMultipleReturnOrNotFromReturnObjectId(task.returnIds[0]);
}
Pair<Object, Object[]> realArgs = ArgumentsBuilder.unwrap(task, m, pr.getLeft());
if (hasMultiReturn && returnType.equals(Map.class)) {
//first arg is Map<user_return_id,ray_return_id>
userRayReturnIdMap = (Map<?, UniqueID>) realArgs.getRight()[0];
realArgs.getRight()[0] = userRayReturnIdMap.keySet();
}
// execute
Object result = null;
@@ -97,47 +80,7 @@ public class InvocationExecutor {
if (task.returnIds == null || task.returnIds.length == 0) {
return;
}
// set result into storage
if (MultipleReturns.class.isAssignableFrom(returnType)) {
MultipleReturns returns = (MultipleReturns) result;
if (task.returnIds.length != returns.getValues().length) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.getValues().length);
}
for (int k = 0; k < returns.getValues().length; k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.getValues()[k]);
}
} else if (hasMultiReturn && returnType.equals(Map.class)) {
Map<?, ?> returns = (Map<?, ?>) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (Entry<?, ?> e : returns.entrySet()) {
Object userReturnId = e.getKey();
Object value = e.getValue();
UniqueID returnId = userRayReturnIdMap.get(userReturnId);
RayRuntime.getInstance().putRaw(returnId, value);
}
} else if (hasMultiReturn && returnType.equals(List.class)) {
List returns = (List) result;
if (task.returnIds.length != returns.size()) {
throw new RuntimeException("Mismatched return object count for task " + taskdesc
+ " " + task.returnIds.length + " vs "
+ returns.size());
}
for (int k = 0; k < returns.size(); k++) {
RayRuntime.getInstance().putRaw(task.returnIds[k], returns.get(k));
}
} else {
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
}
RayRuntime.getInstance().putRaw(task.returnIds[0], result);
}
private static String formatTaskExecutionExceptionMsg(TaskSpec task, String funcName) {
@@ -5,7 +5,6 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -13,13 +12,10 @@ import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayApi;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.internal.RayFunc;
import org.ray.api.funcs.RayFunc;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
import org.ray.spi.LocalSchedulerProxy;
@@ -261,29 +257,15 @@ public abstract class RayRuntime implements RayApi {
}
@Override
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
return objectStoreProxy.wait(waitfor, numReturns, timeout);
}
@Override
public RayObjects call(UniqueID taskId, Class<?> funcCls, RayFunc lambda, int returnCount,
Object... args) {
return worker.rpc(taskId, funcCls, lambda, returnCount, args);
public RayObject call(RayFunc func, Object... args) {
return worker.submit(func, args);
}
@Override
public <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Collection<RIDT> returnids, Object... args) {
return worker.rpcWithReturnLabels(taskId, funcCls, lambda, returnids, args);
}
@Override
public <R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Integer returnCount, Object... args) {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
@@ -1,24 +1,17 @@
package org.ray.core;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayActor;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.internal.RayFunc;
import org.ray.api.funcs.RayFunc;
import org.ray.spi.LocalSchedulerProxy;
import org.ray.spi.model.RayInvocation;
import org.ray.spi.model.RayMethod;
import org.ray.spi.model.TaskSpec;
import org.ray.util.LambdaUtils;
import org.ray.util.MethodId;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -75,79 +68,35 @@ public class Worker {
}
private RayObjects taskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
private RayObject taskSubmit(UniqueID taskId, MethodId methodId, Object[] args) {
RayInvocation ri = createRemoteInvocation(methodId, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnCount, multiReturn);
return scheduler.submit(taskId, ri);
}
private RayObjects actorTaskSubmit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args,
private RayObject actorTaskSubmit(UniqueID taskId, MethodId methodId, Object[] args,
RayActor<?> actor) {
RayInvocation ri = createRemoteInvocation(methodId, args, actor);
RayObjects returnObjs = scheduler.submit(taskId, ri, returnCount + 1, multiReturn);
actor.setTaskCursor(returnObjs.pop().getId());
return returnObjs;
RayObject ret = scheduler.submitActorTask(taskId, ri);
actor.setTaskCursor(ret.getId());
return ret;
}
private RayObjects submit(UniqueID taskId,
MethodId methodId,
int returnCount,
boolean multiReturn,
Object[] args) {
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
public RayObject submit(RayFunc func, Object[] args) {
MethodId methodId = methodIdOf(func);
UniqueID taskId = UniqueIdHelper.nextTaskId(-1);
if (args.length > 0 && args[0].getClass().equals(RayActor.class)) {
return actorTaskSubmit(taskId, methodId, returnCount, multiReturn, args,
(RayActor<?>) args[0]);
return actorTaskSubmit(taskId, methodId, args, (RayActor<?>) args[0]);
} else {
return taskSubmit(taskId, methodId, returnCount, multiReturn, args);
return taskSubmit(taskId, methodId, args);
}
}
public RayObjects rpc(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
int returnCount, Object[] args) {
MethodId mid = methodIdOf(lambda);
return submit(taskId, mid, returnCount, false, args);
}
public RayObjects rpcCreateActor(UniqueID taskId, UniqueID createActorId,
Class<?> funcCls, RayFunc lambda, int returnCount, Object[] args) {
public RayObject createActor(UniqueID taskId, UniqueID createActorId,
RayFunc func, Object[] args) {
Preconditions.checkNotNull(taskId);
MethodId mid = methodIdOf(lambda);
MethodId mid = methodIdOf(func);
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, createActorId, ri, returnCount, false);
}
public <R, RIDT> RayMap<RIDT, R> rpcWithReturnLabels(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Collection<RIDT> returnids,
Object[] args) {
MethodId mid = methodIdOf(lambda);
if (taskId == null) {
taskId = UniqueIdHelper.nextTaskId(-1);
}
RayInvocation ri = createRemoteInvocation(mid, args, RayActor.nil);
return scheduler.submit(taskId, ri, returnids);
}
public <R> RayList<R> rpcWithReturnIndices(UniqueID taskId, Class<?> funcCls,
RayFunc lambda, Integer returnCount,
Object[] args) {
MethodId mid = methodIdOf(lambda);
RayObjects objs = submit(taskId, mid, returnCount, true, args);
RayList<R> rets = new RayList<>();
for (RayObject obj : objs.getObjs()) {
rets.add(obj);
}
return rets;
return scheduler.submitActorCreationTask(taskId, createActorId, ri);
}
private RayInvocation createRemoteInvocation(MethodId methodId, Object[] args,
@@ -167,8 +116,7 @@ public class Worker {
}
private MethodId methodIdOf(RayFunc serialLambda) {
MethodId mid = MethodId.fromSerializedLambda(serialLambda);
return mid;
return MethodId.fromSerializedLambda(serialLambda);
}
public UniqueID getCurrentTaskId() {
@@ -1,16 +1,8 @@
package org.ray.spi;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.ArgumentsBuilder;
@@ -34,66 +26,37 @@ public class LocalSchedulerProxy {
this.scheduler = scheduler;
}
public RayObjects submit(UniqueID taskId, RayInvocation invocation, int returnCount,
boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
public RayObject submit(UniqueID taskId, RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return new RayObjects(returnIds);
return new RayObject(returnIds[0]);
}
public RayObjects submit(UniqueID taskId, UniqueID createActorId, RayInvocation invocation,
int returnCount, boolean multiReturn) {
UniqueID[] returnIds = buildReturnIds(taskId, returnCount, multiReturn);
public RayObject submitActorTask(UniqueID taskId, RayInvocation invocation) {
// add one for the dummy return ID
UniqueID[] returnIds = genReturnIds(taskId, 2);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return new RayObject(returnIds[0]);
}
public RayObject submitActorCreationTask(UniqueID taskId, UniqueID createActorId,
RayInvocation invocation) {
UniqueID[] returnIds = genReturnIds(taskId, 1);
this.doSubmit(invocation, taskId, returnIds, createActorId);
return new RayObjects(returnIds);
return new RayObject(returnIds[0]);
}
public <R, RIDT> RayMap<RIDT, R> submit(UniqueID taskId, RayInvocation invocation,
Collection<RIDT> userReturnIds) {
UniqueID[] returnIds = buildReturnIds(taskId, userReturnIds.size(), true);
RayMap<RIDT, R> ret = new RayMap<>();
Map<RIDT, UniqueID> returnidmapArg = new HashMap<>();
int index = 0;
for (RIDT userReturnId : userReturnIds) {
if (returnidmapArg.containsKey(userReturnId)) {
RayLog.core.error("TaskId " + taskId + " userReturnId is duplicate " + userReturnId);
continue;
}
returnidmapArg.put(userReturnId, returnIds[index]);
ret.put(userReturnId, new RayObject<>(returnIds[index]));
index++;
// generate the return ids of a task.
private UniqueID[] genReturnIds(UniqueID taskId, int numReturns) {
UniqueID[] ret = new UniqueID[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdHelper.taskComputeReturnId(taskId, i, false);
}
if (index < returnIds.length) {
UniqueID[] newReturnIds = new UniqueID[index];
System.arraycopy(returnIds, 0, newReturnIds, 0, index);
returnIds = newReturnIds;
}
Object[] args = invocation.getArgs();
Object[] newargs;
if (args == null) {
newargs = new Object[] {returnidmapArg};
} else {
newargs = new Object[args.length + 1];
newargs[0] = returnidmapArg;
System.arraycopy(args, 0, newargs, 1, args.length);
}
invocation.setArgs(newargs);
this.doSubmit(invocation, taskId, returnIds, UniqueID.nil);
return ret;
}
// build Object IDs of return values.
private UniqueID[] buildReturnIds(UniqueID taskId, int returnCount, boolean multiReturn) {
UniqueID[] returnIds = new UniqueID[returnCount];
for (int k = 0; k < returnCount; k++) {
returnIds[k] = UniqueIdHelper.taskComputeReturnId(taskId, k, multiReturn);
}
return returnIds;
}
private void doSubmit(RayInvocation invocation, UniqueID taskId,
UniqueID[] returnIds, UniqueID createActorId) {
private void doSubmit(RayInvocation invocation, UniqueID taskId, UniqueID[] returnIds,
UniqueID createActorId) {
final TaskSpec current = WorkerContext.currentTask();
TaskSpec task = new TaskSpec();
@@ -110,15 +73,9 @@ public class LocalSchedulerProxy {
task.taskId = taskId;
task.returnIds = returnIds;
task.cursorId = invocation.getActor() != null ? invocation.getActor().getTaskCursor() : null;
task.resources = ResourceUtil
.getResourcesMapFromArray(invocation.getRemoteAnnotation().resources());
task.resources = ResourceUtil.getResourcesMapFromArray(
invocation.getRemoteAnnotation().resources());
//WorkerContext.onSubmitTask();
RayLog.core.info(
"Task " + taskId + " submitted, functionId = " + task.functionId + " actorId = "
+ task.actorId + ", driverId = " + task.driverId + ", return_ids = " + Arrays
.toString(returnIds) + ", currentTask " + WorkerContext.currentTask().taskId
+ " cursorId = " + task.cursorId);
scheduler.submitTask(task);
}
@@ -153,16 +110,16 @@ public class LocalSchedulerProxy {
return ids;
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
List<byte[]> readys = scheduler.wait(getIdBytes(ids), timeout, numReturns);
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
@@ -4,7 +4,6 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
@@ -91,9 +90,9 @@ public class ObjectStoreProxy {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
public <T> WaitResult<T> wait(List<RayObject<T>> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
for (RayObject<T> obj : waitfor) {
ids.add(obj.getId());
}
List<byte[]> readys;
@@ -103,9 +102,9 @@ public class ObjectStoreProxy {
readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns);
}
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
List<RayObject<T>> readyObjs = new ArrayList<>();
List<RayObject<T>> remainObjs = new ArrayList<>();
for (RayObject<T> obj : waitfor) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
@@ -10,7 +10,7 @@ import org.apache.arrow.plasma.PlasmaClient;
import org.ray.api.RayActor;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.funcs.RayFunc_2_1;
import org.ray.api.funcs.RayFunc2;
import org.ray.core.RayRuntime;
import org.ray.core.UniqueIdHelper;
import org.ray.core.WorkerContext;
@@ -248,15 +248,13 @@ public class RayNativeRuntime extends RayRuntime {
RayActor<T> actor = new RayActor<>(actorId);
UniqueID cursorId;
RayFunc_2_1<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
cursorId = worker.rpcCreateActor(
RayFunc2<byte[], String, byte[]> createActorLambda = RayNativeRuntime::createActorInActor;
cursorId = worker.createActor(
createTaskId,
actorId,
RayFunc_2_1.class,
createActorLambda,
1,
new Object[]{actorId.getBytes(), cls.getName()}
).getObjs()[0].getId();
).getId();
actor.setTaskCursor(cursorId);
return actor;
}
@@ -1,181 +0,0 @@
package org.ray.api.experiment.mr;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
/**
* mimic the MapReduce interface atop of Ray API (in memory version).
*/
public class MemoryMapReduce<TInputT, TMapKeyT, TMapValueT, TReduceValueT> {
public List<Pair<TMapKeyT, TMapValueT>> map(TInputT input) throws Exception {
throw new Exception("not implemented");
}
public TReduceValueT reduce(TMapKeyT k, List<TMapValueT> values) throws Exception {
throw new Exception("not implemented");
}
//
// main logic to execute this map-reduce with remote mappers and reducers
//
// @param inputs - given input file segments each containing List<TInput>
// @return output file segments each containing SortedMap<TMapKey, TReduceValue>
//
public SortedMap<TMapKeyT, TReduceValueT> run(List<List<TInputT>> inputs, int mapperCount,
Integer reducerCount) {
// start all mappers
ArrayList<RayList<SortedMap<TMapKeyT, List<TMapValueT>>>> mappers = new ArrayList<>();
int inputCountPerMap = inputs.size() / mapperCount;
int index = 0;
for (int i = 0; i < mapperCount; i++) {
if (index >= inputs.size()) {
break;
}
List<List<TInputT>> perMapInputs = new ArrayList<>();
for (int j = 0; j < inputCountPerMap && index < inputs.size(); j++) {
perMapInputs.add(inputs.get(index++));
}
mappers.add(Ray.call_n(MemoryMapReduce::internalMap, reducerCount, reducerCount, perMapInputs,
this.getClass().getName()));
}
// BSP barrier for all mappers to be finished
// this is unnecessary as later on we call map.get() to wait for mappers to be completed
// Ray.wait(mappers.toArray(new RayObject<?>[0]), mappers.size(), 0);
// start all reducers
ArrayList<RayObject<SortedMap<TMapKeyT, TReduceValueT>>> reducers = new ArrayList<>();
for (int i = 0; i < reducerCount; i++) {
// collect states from mappers for this reducer
RayList<SortedMap<TMapKeyT, List<TMapValueT>>> fromMappers = new RayList();
for (int j = 0; j < mapperCount; j++) {
assert (mappers.get(j).size() == reducerCount);
fromMappers.add(mappers.get(j).Get(i));
}
// start this reducer with given input
reducers.add(Ray.call(MemoryMapReduce::internalReduce,
(List<SortedMap<TMapKeyT, List<TMapValueT>>>) fromMappers, this.getClass().getName()));
}
// BSP barrier for all reducers to be finished
// this is unnecessary coz we will call reducer.get() to wait for their completion
// Ray.wait(reducers.toArray(new RayObject<?>[0]), reducers.size(), 0);
// collect outputs
TreeMap<TMapKeyT, TReduceValueT> outputs = new TreeMap<>();
for (RayObject<SortedMap<TMapKeyT, TReduceValueT>> r : reducers) {
r.get().forEach(outputs::put);
}
return outputs;
}
//
// given a set of input files, output another set of files for reducers
//
// @param inputs - for each input file, it contains List<TInput>
// @return %recuderCount% number of files for the reducers numbering from 0 to reducerCount - 1.
// for each output file, it contains SortedMap<TMapKey, List<TMapValue>>
//
@RayRemote
public static <TInputT, TMapKeyT, TMapValueT, TReduceValueT> List<SortedMap<TMapKeyT,
List<TMapValueT>>> internalMap(
Integer reducerCount,
List<List<TInputT>> inputs,
String mrClassName) {
MemoryMapReduce<TInputT, TMapKeyT, TMapValueT, TReduceValueT> mr;
try {
mr = (MemoryMapReduce<TInputT, TMapKeyT, TMapValueT, TReduceValueT>) Class
.forName(mrClassName).getConstructors()[0].newInstance();
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
return null;
}
ArrayList<SortedMap<TMapKeyT, List<TMapValueT>>> out = new ArrayList<>();
for (int i = 0; i < reducerCount; i++) {
out.add(new TreeMap<>());
}
for (List<TInputT> inputSeg : inputs) {
for (TInputT input : inputSeg) {
try {
List<Pair<TMapKeyT, TMapValueT>> result = mr.map(input);
for (Pair<TMapKeyT, TMapValueT> pr : result) {
int reducerIndex = Math.abs(pr.getKey().hashCode()) % reducerCount;
out.get(reducerIndex).computeIfAbsent(pr.getKey(), k -> new ArrayList<>());
out.get(reducerIndex).get(pr.getKey()).add(pr.getValue());
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
return out;
}
//public static class Helper {
//
// given a set of input sets from all mappers, performance merge sort and apply reducer function
//
// @param inputs each file contains SortedMap<TMapKey, List<TMapValue>>
// @return an output file contains SortedMap<TMapKey, TReduceValue>
//
@RayRemote
public static <TInputT, TMapKeyT, TMapValueT, TReduceValueT> SortedMap<TMapKeyT, TReduceValueT>
internalReduce(
List<SortedMap<TMapKeyT, List<TMapValueT>>> inputs,
String mrClassName) {
MemoryMapReduce<TInputT, TMapKeyT, TMapValueT, TReduceValueT> mr;
try {
mr = (MemoryMapReduce<TInputT, TMapKeyT, TMapValueT, TReduceValueT>) Class
.forName(mrClassName).getConstructors()[0].newInstance();
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException | SecurityException | ClassNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
return null;
}
// merge inputs from many mappers
TreeMap<TMapKeyT, List<TMapValueT>> minputs = new TreeMap<>();
for (SortedMap<TMapKeyT, List<TMapValueT>> input : inputs) {
for (Map.Entry<TMapKeyT, List<TMapValueT>> entry : input.entrySet()) {
if (!minputs.containsKey(entry.getKey())) {
minputs.put(entry.getKey(), new ArrayList<>());
}
minputs.get(entry.getKey()).addAll(entry.getValue());
}
}
// reduce
TreeMap<TMapKeyT, TReduceValueT> out = new TreeMap<>();
for (Map.Entry<TMapKeyT, List<TMapValueT>> entry : minputs.entrySet()) {
try {
out.put(entry.getKey(), mr.reduce(entry.getKey(), entry.getValue()));
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return out;
}
// }
}
@@ -6,7 +6,6 @@ import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.util.logger.RayLog;
/**
* Hello world.
@@ -15,37 +14,25 @@ import org.ray.util.logger.RayLog;
public class HelloWorldTest {
@RayRemote
public static String sayHello() {
String ret = "he";
ret += "llo";
RayLog.rapp.info("real say hello");
//throw new RuntimeException("+++++++++++++++++++++hello exception");
return ret;
private static String hello() {
return "hello";
}
@RayRemote
public static String sayWorld() {
String ret = "world";
ret += "!";
return ret;
private static String world() {
return "world!";
}
@RayRemote
public static String merge(String hello, String world) {
private static String merge(String hello, String world) {
return hello + "," + world;
}
@Test
public void test() {
String helloWorld = sayHelloWorld();
RayLog.rapp.info(helloWorld);
public void testHelloWorld() {
RayObject<String> hello = Ray.call(HelloWorldTest::hello);
RayObject<String> world = Ray.call(HelloWorldTest::world);
String helloWorld = Ray.call(HelloWorldTest::merge, hello, world).get();
Assert.assertEquals("hello,world!", helloWorld);
Assert.assertTrue(Ray.call(TypesTest::sayBool).get());
}
public String sayHelloWorld() {
RayObject<String> hello = Ray.call(HelloWorldTest::sayHello);
RayObject<String> world = Ray.call(HelloWorldTest::sayWorld);
return Ray.call(HelloWorldTest::merge, hello, world).get();
}
}
@@ -4,14 +4,13 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.funcs.RayFunc_0_1;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.api.funcs.RayFunc0;
import org.ray.api.funcs.RayFunc1;
import org.ray.api.funcs.RayFunc3;
import org.ray.util.MethodId;
import org.ray.util.logger.RayLog;
@@ -36,7 +35,7 @@ public class LambdaUtilsTest {
}
}
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc_3_1<T0, T1, T2, R0> f, int n,
public static <T0, T1, T2, R0> void testRemoteLambdaParse(RayFunc3<T0, T1, T2, R0> f, int n,
boolean forceNew, boolean debug)
throws Exception {
if (debug) {
@@ -52,7 +51,7 @@ public class LambdaUtilsTest {
(end - start) / n));
}
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc_3_1<T0, T1, T2, R0> f, int n,
public static <T0, T1, T2, R0> void testRemoteLambdaSerde(RayFunc3<T0, T1, T2, R0> f, int n,
boolean de, boolean debug)
throws Exception {
if (debug) {
@@ -66,7 +65,7 @@ public class LambdaUtilsTest {
out.close();
if (de) {
ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray()));
RayFunc_3_1 def = (RayFunc_3_1) in.readObject();
RayFunc3 def = (RayFunc3) in.readObject();
in.close();
if (debug) {
RayLog.core.info("de#" + def.getClass().getName());
@@ -80,28 +79,28 @@ public class LambdaUtilsTest {
(end - start) / n));
}
public static void testCall0(RayFunc_0_1 f) {
public static void testCall0(RayFunc0 f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL0);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall1(RayFunc_1_1<T, R> f, T t) {
public static <T, R> void testCall1(RayFunc1<T, R> f, T t) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL1);
Assert.assertTrue(mid.isStatic);
}
public static <T, R> void testCall2(RayFunc_1_1<T, R> f) {
public static <T, R> void testCall2(RayFunc1<T, R> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL2);
Assert.assertTrue(!mid.isStatic);
}
public static <T0, T1, T2, R0> void testCall3(RayFunc_3_1<T0, T1, T2, R0> f) {
public static <T0, T1, T2, R0> void testCall3(RayFunc3<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f);
RayLog.core.info(mid.toString());
Assert.assertEquals(mid.load(), CALL3);
@@ -1,71 +0,0 @@
package org.ray.api.test;
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.experiment.mr.MemoryMapReduce;
/**
* test the MapReduce interface.
*/
@RunWith(MyRunner.class)
public class MemoryWordCountTest {
@Test
public void test() {
List<List<String>> iinputs = new ArrayList<>();
List<String> inputs = new ArrayList<>();
inputs.add("1 3 5 7 9");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("1 3 5 7 9");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
inputs = new ArrayList<>();
inputs.add("1 2 3 4 5 6 7 8 9 0");
inputs.add("1 3 5 7 9");
inputs.add("0 2 4 6 8");
iinputs.add(inputs);
MemoryWordCount wc = new MemoryWordCount();
SortedMap<String, Integer> result = wc.run(iinputs, 2, 2);
Assert.assertEquals(6, (int) result.get("0"));
Assert.assertEquals(6, (int) result.get("1"));
Assert.assertEquals(6, (int) result.get("2"));
Assert.assertEquals(6, (int) result.get("3"));
Assert.assertEquals(6, (int) result.get("4"));
Assert.assertEquals(6, (int) result.get("5"));
Assert.assertEquals(6, (int) result.get("6"));
Assert.assertEquals(6, (int) result.get("7"));
Assert.assertEquals(6, (int) result.get("8"));
Assert.assertEquals(6, (int) result.get("9"));
}
public static class MemoryWordCount extends MemoryMapReduce<String, String, Integer, Integer> {
public List<Pair<String, Integer>> map(String input) {
ArrayList<Pair<String, Integer>> counts = new ArrayList<>();
for (String s : input.split(" ")) {
counts.add(Pair.of(s, 1));
}
return counts;
}
public Integer reduce(String k, List<Integer> values) {
return values.size();
}
}
}
@@ -1,18 +1,16 @@
package org.ray.api.test;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.util.LambdaUtils;
import org.ray.api.funcs.RayFunc3;
import org.ray.util.MethodId;
import org.ray.util.logger.RayLog;
public class MethodIdTest {
public static <T0, T1, T2, R0> MethodId fromLambda(RayFunc_3_1<T0, T1, T2, R0> f) {
public static <T0, T1, T2, R0> MethodId fromLambda(RayFunc3<T0, T1, T2, R0> f) {
MethodId mid = MethodId.fromSerializedLambda(f, true);
return mid;
}
@@ -1,11 +1,13 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.util.logger.RayLog;
@@ -24,12 +26,9 @@ public class RayMethodsTest {
RayObject<String> s2Id = Ray.put(String.valueOf("World!"));
RayObject<Object> n1Id = Ray.put(null);
RayList<String> waitIds = new RayList<>();
waitIds.add(s1Id);
waitIds.add(s2Id);
WaitResult<String> readys = Ray.wait(waitIds, 2);
WaitResult<String> res = Ray.wait(ImmutableList.of(s1Id, s2Id), 2);
List<String> ss = readys.getReadyOnes().get();
List<String> ss = res.getReadyOnes().stream().map(RayObject::get).collect(Collectors.toList());
int i1 = i1Id.get();
double f1 = f1Id.get();
Object n1 = n1Id.get();
@@ -1,5 +1,6 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
@@ -55,7 +56,7 @@ public class ResourcesManagementTest {
// This is a case that can't satisfy required resources.
final RayObject<Integer> result2 = Ray.call(ResourcesManagementTest::echo2, 200);
WaitResult<Integer> waitResult = Ray.wait(result2, 1000);
WaitResult<Integer> waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000);
Assert.assertEquals(0, waitResult.getReadyOnes().size());
Assert.assertEquals(1, waitResult.getRemainOnes().size());
@@ -72,7 +73,7 @@ public class ResourcesManagementTest {
// This is a case that can't satisfy required resources.
RayActor<ResourcesManagementTest.Echo2> echo2 = Ray.create(Echo2.class);
final RayObject<Integer> result2 = Ray.call(Echo2::echo, echo2, 100);
WaitResult<Integer> waitResult = Ray.wait(result2, 1000);
WaitResult<Integer> waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000);
Assert.assertEquals(0, waitResult.getReadyOnes().size());
Assert.assertEquals(1, waitResult.getRemainOnes().size());
@@ -1,20 +0,0 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
/**
* local test in IDE, for class lazy load.
*/
@RunWith(MyRunner.class)
public class TwoClassTest {
@Test
public void testLocal() {
Assert.assertTrue(
Ray.call(TypesTest::sayBool).get());//call function in other class which may not be loaded
}
}
@@ -1,202 +1,82 @@
package org.ray.api.test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.returns.MultipleReturns2;
import org.ray.api.returns.RayObjects2;
/**
* types test.
* Test returning different types.
*/
@RunWith(MyRunner.class)
public class TypesTest {
@RayRemote
public static int sayInt() {
private static int testInt() {
return 1;
}
@RayRemote
public static byte sayByte() {
private static byte testByte() {
return 1;
}
@RayRemote
public static short sayShort() {
private static short testShort() {
return 1;
}
@RayRemote
public static long sayLong() {
private static long testLong() {
return 1;
}
@RayRemote
public static double sayDouble() {
private static double testDouble() {
return 1;
}
@RayRemote
public static float sayFloat() {
private static float testFloat() {
return 1;
}
@RayRemote
public static boolean sayBool() {
private static boolean testBool() {
return true;
}
@RayRemote
public static Object sayReference() {
return "object";
private static String testString() {
return "foo";
}
@RayRemote
public static MultipleReturns2<Integer, String> sayReferences() {
return new MultipleReturns2<>(123, "123");
private static List<Integer> testList() {
return ImmutableList.of(1, 2, 3);
}
@RayRemote
public static Map<Integer, String> sayReferencesN(Collection<Integer> userReturnIds,
String prefix, String suffix) {
Map<Integer, String> ret = new HashMap<>();
for (Integer returnid : userReturnIds) {
ret.put(returnid, prefix + returnid + suffix);
}
return ret;
}
@RayRemote
public static List<Integer> sayArray(Integer returnCount) {
ArrayList<Integer> rets = new ArrayList<>();
for (int i = 0; i < returnCount; i++) {
rets.add(i);
}
return rets;
}
@RayRemote
public static Integer sayRayFuture() {
return 123;
}
@RayRemote
public static MultipleReturns2<Integer, String> sayRayFutures() {
return new MultipleReturns2<>(123, "123");
}
@RayRemote
public static Map<Integer, String> sayRayFuturesN(
Collection<Integer/*user's custom return_id*/> userReturnIds,
String prefix) {
Map<Integer, String> ret = new HashMap<>();
for (int id : userReturnIds) {
ret.put(id, prefix + id);
}
return ret;
}
@RayRemote
public static int sayReadRayList(List<Integer> ints) {
int sum = 0;
for (Integer i : ints) {
sum += i;
}
return sum;
}
@RayRemote
public static int sayReadRayMap(Map<String, Integer> ints) {
int sum = 0;
for (Integer i : ints.values()) {
sum += i;
}
return sum;
private static Map<String, Integer> testMap() {
return ImmutableMap.of("1", 1, "2", 2);
}
@Test
public void test() {
sayTypes();
}
public void sayTypes() {
Assert.assertEquals(1, (int) Ray.call(TypesTest::sayInt).get());
Assert.assertEquals(1, (byte) Ray.call(TypesTest::sayByte).get());
Assert.assertEquals(1, (short) Ray.call(TypesTest::sayShort).get());
Assert.assertEquals(1, (long) Ray.call(TypesTest::sayLong).get());
Assert.assertEquals(1.0, Ray.call(TypesTest::sayDouble).get(), 0.0);
Assert.assertEquals(1.0f, Ray.call(TypesTest::sayFloat).get(), 0.0);
Assert.assertEquals(true, Ray.call(TypesTest::sayBool).get());
Assert.assertEquals("object", Ray.call(TypesTest::sayReference).get());
RayObjects2<Integer, String> refs = Ray.call_2(TypesTest::sayReferences);
Assert.assertEquals(123, (int) refs.r0().get());
Assert.assertEquals("123", refs.r1().get());
RayMap<Integer, String> futureRs = Ray.call_n(TypesTest::sayReferencesN,
Arrays.asList(1, 2, 4, 3), "n_refs_", "_suffix");
for (Entry<Integer, RayObject<String>> fne : futureRs.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_refs_" + fne.getKey() + "_suffix");
}
RayMap<Integer, String> futureRs2 = Ray.call_n(TypesTest::sayReferencesN,
Arrays.asList(1), "n_refs_", "_suffix");
for (Entry<Integer, RayObject<String>> fne : futureRs2.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_refs_" + fne.getKey() + "_suffix");
}
RayObject<Integer> future = Ray.call(TypesTest::sayRayFuture);
Assert.assertEquals(123, (int) future.get());
RayObjects2<Integer, String> futures = Ray.call_2(TypesTest::sayRayFutures);
Assert.assertEquals(123, (int) futures.r0().get());
Assert.assertEquals("123", futures.r1().get());
RayMap<Integer, String> futureNs = Ray.call_n(TypesTest::sayRayFuturesN,
Arrays.asList(1, 2, 4, 3), "n_futures_");
for (Entry<Integer, RayObject<String>> fne : futureNs.EntrySet()) {
Assert.assertEquals(fne.getValue().get(), "n_futures_" + fne.getKey());
}
RayList<Integer> ns = Ray.call_n(TypesTest::sayArray, 10, 10);
for (int i = 0; i < 10; i++) {
Assert.assertEquals(i, (int) ns.Get(i).get());
}
RayList<Integer> ns2 = Ray.call_n(TypesTest::sayArray, 1, 1);
Assert.assertEquals(0, (int) ns2.Get(0).get());
RayObject<List<Integer>> ns3 = Ray.call(TypesTest::sayArray, 1);
Assert.assertEquals(0, (int) ns3.get().get(0));
RayList<Integer> ints = new RayList<>();
ints.add(Ray.call(TypesTest::sayInt));
ints.add(Ray.call(TypesTest::sayInt));
ints.add(Ray.call(TypesTest::sayInt));
// TODO: when RayParameters.use_remote_lambda is on, we have to explicitly
// cast RayList and RayMap to List and map explicitly, so that the parameter
// types of the lambdas can be correctly deducted.
RayObject<Integer> collection = Ray.call(TypesTest::sayReadRayList, (List<Integer>) ints);
Assert.assertEquals(3, (int) collection.get());
RayMap<String, Integer> namedInts = new RayMap();
namedInts.put("a", Ray.call(TypesTest::sayInt));
namedInts.put("b", Ray.call(TypesTest::sayInt));
namedInts.put("c", Ray.call(TypesTest::sayInt));
RayObject<Integer> collection2 = Ray
.call(TypesTest::sayReadRayMap, (Map<String, Integer>) namedInts);
Assert.assertEquals(3, (int) collection2.get());
Assert.assertEquals(1, (int) Ray.call(TypesTest::testInt).get());
Assert.assertEquals(1, (byte) Ray.call(TypesTest::testByte).get());
Assert.assertEquals(1, (short) Ray.call(TypesTest::testShort).get());
Assert.assertEquals(1, (long) Ray.call(TypesTest::testLong).get());
Assert.assertEquals(1.0, Ray.call(TypesTest::testDouble).get(), 0.0);
Assert.assertEquals(1.0f, Ray.call(TypesTest::testFloat).get(), 0.0);
Assert.assertEquals(true, Ray.call(TypesTest::testBool).get());
Assert.assertEquals("foo", Ray.call(TypesTest::testString).get());
Assert.assertEquals(ImmutableList.of(1, 2, 3), Ray.call(TypesTest::testList).get());
Assert.assertEquals(ImmutableMap.of("1", 1, "2", 2), Ray.call(TypesTest::testMap).get());
}
}
@@ -1,39 +0,0 @@
package org.ray.api.test;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.UniqueID;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.core.RayRuntime;
import org.ray.core.UniqueIdHelper;
@RunWith(MyRunner.class)
public class UIdTest {
@RayRemote
public static String hi(Integer i) {
return "hi" + i;
}
@Test
public void test() {
UniqueID tid = UniqueIdHelper.nextTaskId(0xdeadbeefL);
UniqueIdHelper.setTest(tid, true);
System.out.println("Tested task id = " + tid);
RayFunc_1_1<Integer, String> f = UIdTest::hi;
RayObject<String> result = new RayObject<>(
RayRuntime.getInstance().call(
tid,
RayFunc_1_1.class,
f,
1, 1
).getObjs()[0].getId()
);
System.out.println("Tested task return object id = " + result.getId());
Assert.assertEquals("hi1", result.get());
}
}
@@ -1,10 +1,11 @@
package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.WaitResult;
@@ -33,16 +34,14 @@ public class WaitTest {
RayObject<String> obj1 = Ray.call(WaitTest::hi);
RayObject<String> obj2 = Ray.call(WaitTest::delayHi);
RayList<String> waitfor = new RayList<>();
waitfor.add(obj1);
waitfor.add(obj2);
List<RayObject<String>> waitfor = ImmutableList.of(obj1, obj2);
WaitResult<String> waitResult = Ray.wait(waitfor, 2, 2 * 1000);
RayList<String> readys = waitResult.getReadyOnes();
List<RayObject<String>> readys = waitResult.getReadyOnes();
if (!readys.isEmpty()) {
Assert.assertEquals(1, waitResult.getReadyOnes().size());
Assert.assertEquals(1, waitResult.getRemainOnes().size());
Assert.assertEquals("hi", readys.get(0));
Assert.assertEquals("hi", readys.get(0).get());
} else {
Assert.assertEquals(0, waitResult.getReadyOnes().size());
Assert.assertEquals(2, waitResult.getRemainOnes().size());
+1 -7
View File
@@ -22,10 +22,4 @@ To run a exercise case, set the ``RAY_CONFIG`` env variable and run the followin
`Exercise 4 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise04.java>`_: Use ``Ray.wait`` to ignore stragglers.
`Exercise 5 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise05.java>`_: Use multiple heterogeneous return values.
`Exercise 6 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise06.java>`_: Usage of ``RayList<T>``.
`Exercise 7 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise07.java>`_: Usage of ``RayMap<L, T>``.
`Exercise 8 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise08.java>`_: Actor Support of create Actor and call Actor method.
`Exercise 5 <https://github.com/ray-project/ray/tree/master/java/tutorial/src/main/java/org/ray/exercise/Exercise08.java>`_: Actor Support of create Actor and call Actor method.
@@ -1,7 +1,8 @@
package org.ray.exercise;
import com.google.common.collect.ImmutableList;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.WaitResult;
@@ -44,26 +45,24 @@ public class Exercise04 {
public static void main(String[] args) throws Exception {
try {
Ray.init();
RayObject<String> o1 = Ray.call(Exercise04::f1);
RayObject<String> o2 = Ray.call(Exercise04::f2);
RayObject<String> o3 = Ray.call(Exercise04::f3);
RayList<String> rayList = new RayList<>();
rayList.add(o1);
rayList.add(o2);
rayList.add(o3);
List<RayObject<String>> waitList = ImmutableList.of(
Ray.call(Exercise04::f1),
Ray.call(Exercise04::f2),
Ray.call(Exercise04::f3)
);
// Ray.wait will block until specified number of results are ready
// or specified timeout have passed.
// In this case, the result of f3 will be ignored.
WaitResult<String> waitResult = Ray.wait(rayList, 2, 3000);
RayList<String> readyOnes = waitResult.getReadyOnes();
RayList<String> remainOnes = waitResult.getRemainOnes();
WaitResult<String> waitResult = Ray.wait(waitList, 2, 3000);
List<RayObject<String>> readyOnes = waitResult.getReadyOnes();
List<RayObject<String>> remainOnes = waitResult.getRemainOnes();
System.out.println("Number of readyOnes: " + readyOnes.size());
for (int i = 0; i < readyOnes.size(); i++) {
System.out.println("The value of readyOnes " + i + " is " + readyOnes.get(i));
System.out.println("The value of readyOnes " + i + " is " + readyOnes.get(i).get());
}
System.out.println("Number of remainOnes: " + remainOnes.size());
for (int i = 0; i < remainOnes.size(); i++) {
System.out.println("The value of remainOnes " + i + " is " + remainOnes.get(i));
System.out.println("The value of remainOnes " + i + " is " + remainOnes.get(i).get());
}
} catch (Throwable t) {
t.printStackTrace();
@@ -1,26 +1,26 @@
package org.ray.exercise;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.api.returns.MultipleReturns2;
import org.ray.api.returns.RayObjects2;
import org.ray.core.RayRuntime;
/**
* Use multiple heterogeneous return values
* Java worker support at most four heterogeneous return values,
* To call such remote functions, use {@code Ray.call_X} as follows.
* Show usage of actors.
*/
public class Exercise05 {
public static void main(String[] args) {
try {
Ray.init();
RayObjects2<Integer, String> refs = Ray.call_2(Exercise05::sayMultiRet);
Integer obj1 = refs.r0().get();
String obj2 = refs.r1().get();
System.out.println(obj1);
System.out.println(obj2);
// `Ray.create` creates an actor instance.
RayActor<Adder> adder = Ray.create(Adder.class);
// Use `Ray.call(actor, parameters)` to call an actor method.
RayObject<Integer> result1 = Ray.call(Adder::add, adder, 1);
System.out.println(result1.get());
RayObject<Integer> result2 = Ray.call(Adder::add, adder, 10);
System.out.println(result2.get());
} catch (Throwable t) {
t.printStackTrace();
} finally {
@@ -29,10 +29,20 @@ public class Exercise05 {
}
/**
* A remote function that returns multiple heterogeneous values.
* An example actor.
*/
// `@RayRemote` annotation also converts a normal class to an actor.
@RayRemote
public static MultipleReturns2<Integer, String> sayMultiRet() {
return new MultipleReturns2<Integer, String>(123, "123");
public static class Adder {
public Adder() {
sum = 0;
}
public int add(int n) {
return sum += n;
}
private int sum;
}
}
@@ -1,46 +0,0 @@
package org.ray.exercise;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayList;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.core.RayRuntime;
/**
* Show usage of RayList.
* RayList is a list of {@code RayObject}s, inherited from {@code List}.
* It can be used as the type for both return values and parameters.
*
*/
public class Exercise06 {
public static void main(String[] args) {
try {
Ray.init();
// The result is a `RayList`.
RayList<Integer> ns = Ray.call_n(Exercise06::sayList, 10, 10);
for (int i = 0; i < 10; i++) {
RayObject<Integer> obj = ns.Get(i);
System.out.println(obj.get());
}
} catch (Throwable t) {
t.printStackTrace();
} finally {
RayRuntime.getInstance().cleanUp();
}
}
/**
* A remote function that returns a list.
*/
@RayRemote
public static List<Integer> sayList(Integer count) {
ArrayList<Integer> rets = new ArrayList<>();
for (int i = 0; i < count; i++) {
rets.add(i);
}
return rets;
}
}
@@ -1,48 +0,0 @@
package org.ray.exercise;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.core.RayRuntime;
/**
* Show usage of RayMap.
* {@code RayMap} is a map of {@code RayObject}s, inherited from {@code Map}.
* It can be used as the type for both return values and parameters.
*/
public class Exercise07 {
public static void main(String[] args) {
try {
Ray.init();
RayMap<Integer, String> ns = Ray.call_n(Exercise07::sayMap,
Arrays.asList(1, 2, 4, 3), "n_futures_");
for (Map.Entry<Integer, RayObject<String>> ne : ns.EntrySet()) {
Integer key = ne.getKey();
RayObject<String> obj = ne.getValue();
System.out.println(obj.get());
}
} catch (Throwable t) {
t.printStackTrace();
} finally {
RayRuntime.getInstance().cleanUp();
}
}
/**
* A remote function that returns a map.
*/
@RayRemote()
public static Map<Integer, String> sayMap(Collection<Integer> ids, String prefix) {
Map<Integer, String> ret = new HashMap<>();
for (int id : ids) {
ret.put(id, prefix + id);
}
return ret;
}
}
@@ -1,48 +0,0 @@
package org.ray.exercise;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayRemote;
import org.ray.core.RayRuntime;
/**
* Show usage of actors.
*/
public class Exercise08 {
public static void main(String[] args) {
try {
Ray.init();
// `Ray.create` creates an actor instance.
RayActor<Adder> adder = Ray.create(Adder.class);
// Use `Ray.call(actor, parameters)` to call an actor method.
RayObject<Integer> result1 = Ray.call(Adder::add, adder, 1);
System.out.println(result1.get());
RayObject<Integer> result2 = Ray.call(Adder::add, adder, 10);
System.out.println(result2.get());
} catch (Throwable t) {
t.printStackTrace();
} finally {
RayRuntime.getInstance().cleanUp();
}
}
/**
* An example actor.
*/
// `@RayRemote` annotation also converts a normal class to an actor.
@RayRemote
public static class Adder {
public Adder() {
sum = 0;
}
public int add(int n) {
return sum += n;
}
private int sum;
}
}