[Java] Single-process mode (#4245)

This commit is contained in:
bibabolynn
2019-03-05 13:50:20 +08:00
committed by Hao Chen
parent fa8c07dd19
commit c73d5086f3
17 changed files with 280 additions and 59 deletions
@@ -66,7 +66,8 @@ public abstract class AbstractRayRuntime implements RayRuntime {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.driverResourcePath);
worker = new Worker(this);
workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.driverId);
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.driverId, rayConfig.runMode);
}
/**
@@ -17,12 +17,12 @@ public class RayDevRuntime extends AbstractRayRuntime {
public void start() {
store = new MockObjectStore(this);
objectStoreProxy = new ObjectStoreProxy(this, null);
rayletClient = new MockRayletClient(this, store);
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
}
@Override
public void shutdown() {
// nothing to do
rayletClient.destroy();
}
public MockObjectStore getObjectStore() {
@@ -2,6 +2,7 @@ package org.ray.runtime;
import com.google.common.base.Preconditions;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
@@ -34,12 +35,17 @@ public class WorkerContext {
*/
private long mainThreadId;
/**
* The run-mode of this worker.
*/
private RunMode runMode;
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) {
mainThreadId = Thread.currentThread().getId();
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(UniqueId::randomId);
this.runMode = runMode;
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
workerId = driverId;
@@ -65,10 +71,12 @@ public class WorkerContext {
* be called from the main thread.
*/
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
Preconditions.checkState(
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
);
if (runMode == RunMode.CLUSTER) {
Preconditions.checkState(
Thread.currentThread().getId() == mainThreadId,
"This method should only be called from the main thread."
);
}
Preconditions.checkNotNull(task);
this.currentTaskId.set(task.taskId);
@@ -63,6 +63,11 @@ public class RayConfig {
public final String driverResourcePath;
public final String pythonWorkerCommand;
/**
* Number of threads that execute tasks.
*/
public final int numberExecThreadsForDevRuntime;
private void validate() {
if (workerMode == WorkerMode.WORKER) {
Preconditions.checkArgument(redisAddress != null,
@@ -196,6 +201,9 @@ public class RayConfig {
driverResourcePath = null;
}
// Number of threads that execute tasks.
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
// validate config
validate();
LOGGER.debug("Created config: {}", this);
@@ -6,11 +6,12 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.ObjectStoreLink.ObjectStoreData;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.raylet.MockRayletClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -20,13 +21,21 @@ import org.slf4j.LoggerFactory;
public class MockObjectStore implements ObjectStoreLink {
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectStore.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private final RayDevRuntime runtime;
private final Map<UniqueId, byte[]> data = new ConcurrentHashMap<>();
private final Map<UniqueId, byte[]> metadata = new ConcurrentHashMap<>();
private MockRayletClient scheduler = null;
private final List<Consumer<UniqueId>> objectPutCallbacks;
public MockObjectStore(RayDevRuntime runtime) {
this.runtime = runtime;
this.objectPutCallbacks = new ArrayList<>();
}
public void addObjectPutCallback(Consumer<UniqueId> callback) {
this.objectPutCallbacks.add(callback);
}
@Override
@@ -41,34 +50,56 @@ public class MockObjectStore implements ObjectStoreLink {
if (metadataValue != null) {
metadata.put(uniqueId, metadataValue);
}
if (scheduler != null) {
scheduler.onObjectPut(uniqueId);
UniqueId id = new UniqueId(objectId);
for (Consumer<UniqueId> callback : objectPutCallbacks) {
callback.accept(id);
}
}
@Override
public byte[] get(byte[] objectId, int timeoutMs, boolean isMetadata) {
return get(new byte[][] {objectId}, timeoutMs, isMetadata).get(0);
}
@Override
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
final Map<UniqueId, byte[]> dataMap = isMetadata ? metadata : data;
ArrayList<byte[]> rets = new ArrayList<>(objectIds.length);
for (byte[] objId : objectIds) {
UniqueId uniqueId = new UniqueId(objId);
LOGGER.info("{} is notified for objectid {}",logPrefix(), uniqueId);
rets.add(dataMap.get(uniqueId));
}
return rets;
return get(objectIds, timeoutMs)
.stream()
.map(data -> isMetadata ? data.data : data.metadata)
.collect(Collectors.toList());
}
@Override
public List<ObjectStoreData> get(byte[][] objectIds, int timeoutMs) {
int ready = 0;
int remainingTime = timeoutMs;
boolean firstCheck = true;
while (ready < objectIds.length && remainingTime > 0) {
if (!firstCheck) {
int sleepTime = Math.min(remainingTime, GET_CHECK_INTERVAL_MS);
try {
Thread.sleep(sleepTime);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while sleeping.");
}
remainingTime -= sleepTime;
}
ready = 0;
for (byte[] id : objectIds) {
if (data.containsKey(new UniqueId(id))) {
ready += 1;
}
}
firstCheck = false;
}
ArrayList<ObjectStoreData> rets = new ArrayList<>();
// TODO(yuhguo): make ObjectStoreData's constructor public.
for (byte[] objId : objectIds) {
UniqueId uniqueId = new UniqueId(objId);
for (byte[] id : objectIds) {
try {
Constructor<ObjectStoreData> constructor = ObjectStoreData.class.getConstructor(
byte[].class, byte[].class);
byte[].class, byte[].class);
constructor.setAccessible(true);
rets.add(constructor.newInstance(metadata.get(uniqueId), data.get(uniqueId)));
rets.add(constructor.newInstance(metadata.get(new UniqueId(id)),
data.get(new UniqueId(id))));
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -119,7 +150,8 @@ public class MockObjectStore implements ObjectStoreLink {
return data.containsKey(id);
}
public void registerScheduler(MockRayletClient s) {
scheduler = s;
public void free(UniqueId id) {
data.remove(id);
metadata.remove(id);
}
}
@@ -1,66 +1,143 @@
package org.ray.runtime.raylet;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.Worker;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A mock implementation of RayletClient, used in single process mode.
*/
public class MockRayletClient implements RayletClient {
private final Map<UniqueId, Map<UniqueId, TaskSpec>> waitTasks = new ConcurrentHashMap<>();
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
private final Map<UniqueId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private final RayDevRuntime runtime;
private final ExecutorService exec;
private final Deque<Worker> idleWorkers;
private final Map<UniqueId, Worker> actorWorkers;
public MockRayletClient(RayDevRuntime runtime, MockObjectStore store) {
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
this.runtime = runtime;
this.store = store;
store.registerScheduler(this);
this.store = runtime.getObjectStore();
store.addObjectPutCallback(this::onObjectPut);
// The thread pool that executes tasks in parallel.
exec = Executors.newFixedThreadPool(numberThreads);
idleWorkers = new LinkedList<>();
actorWorkers = new HashMap<>();
}
public void onObjectPut(UniqueId id) {
Map<UniqueId, TaskSpec> bucket = waitTasks.get(id);
if (bucket != null) {
waitTasks.remove(id);
for (TaskSpec ts : bucket.values()) {
submitTask(ts);
public synchronized void onObjectPut(UniqueId id) {
Set<TaskSpec> tasks = waitingTasks.get(id);
if (tasks != null) {
waitingTasks.remove(id);
for (TaskSpec taskSpec : tasks) {
submitTask(taskSpec);
}
}
}
/**
* Get a worker from the worker pool to run the given task.
*/
private Worker getWorker(TaskSpec task) {
if (task.isActorTask()) {
return actorWorkers.get(task.actorId);
}
Worker worker;
if (idleWorkers.size() > 0) {
worker = idleWorkers.pop();
} else {
worker = new Worker(runtime);
}
if (task.isActorCreationTask()) {
actorWorkers.put(task.actorCreationId, worker);
}
return worker;
}
/**
* Return the worker to the worker pool.
*/
private void returnWorker(Worker worker) {
idleWorkers.push(worker);
}
@Override
public void submitTask(TaskSpec task) {
UniqueId id = isTaskReady(task);
if (id == null) {
runtime.getWorker().execute(task);
public synchronized void submitTask(TaskSpec task) {
LOGGER.debug("Submitting task: {}.", task);
Set<UniqueId> unreadyObjects = getUnreadyObjects(task);
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
exec.submit(() -> {
Worker worker = getWorker(task);
try {
worker.execute(task);
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
if (task.isActorCreationTask() || task.isActorTask()) {
UniqueId[] returnIds = task.returnIds;
store.put(returnIds[returnIds.length - 1].getBytes(),
new byte[]{}, new byte[]{});
}
} finally {
if (!task.isActorCreationTask() && !task.isActorTask()) {
returnWorker(worker);
}
}
});
} else {
Map<UniqueId, TaskSpec> bucket = waitTasks
.computeIfAbsent(id, id_ -> new ConcurrentHashMap<>());
bucket.put(id, task);
// If some dependencies aren't ready yet, put this task in waiting list.
for (UniqueId id : unreadyObjects) {
waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task);
}
}
}
private UniqueId isTaskReady(TaskSpec spec) {
private Set<UniqueId> getUnreadyObjects(TaskSpec spec) {
Set<UniqueId> unreadyObjects = new HashSet<>();
// Check whether task arguments are ready.
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
if (!store.isObjectReady(arg.id)) {
return arg.id;
unreadyObjects.add(arg.id);
}
}
}
return null;
// Check whether task dependencies are ready.
for (UniqueId id : spec.getExecutionDependencies()) {
if (!store.isObjectReady(id)) {
unreadyObjects.add(id);
}
}
return unreadyObjects;
}
@Override
public TaskSpec getTask() {
throw new RuntimeException("invalid execution flow here");
@@ -84,18 +161,36 @@ public class MockRayletClient implements RayletClient {
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId) {
return new WaitResult<T>(
waitFor,
ImmutableList.of()
);
timeoutMs, UniqueId currentTaskId) {
if (waitFor == null || waitFor.isEmpty()) {
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
}
byte[][] ids = new byte[waitFor.size()][];
for (int i = 0; i < waitFor.size(); i++) {
ids[i] = waitFor.get(i).getId().getBytes();
}
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
List<byte[]> result = store.get(ids, timeoutMs, false);
for (int i = 0; i < waitFor.size(); i++) {
if (result.get(i) != null) {
readyList.add(waitFor.get(i));
} else {
unreadyList.add(waitFor.get(i));
}
}
return new WaitResult<>(readyList, unreadyList);
}
@Override
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
return;
for (UniqueId id : objectIds) {
store.free(id);
}
}
@Override
public UniqueId prepareCheckpoint(UniqueId actorId) {
throw new NotImplementedException("Not implemented.");
@@ -105,4 +200,9 @@ public class MockRayletClient implements RayletClient {
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
throw new NotImplementedException("Not implemented.");
}
@Override
public void destroy() {
exec.shutdown();
}
}
@@ -29,4 +29,6 @@ public interface RayletClient {
UniqueId prepareCheckpoint(UniqueId actorId);
void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId);
void destroy();
}
@@ -95,4 +95,12 @@ ray {
}
}
// ----------------------------
// configurations under SINGLE_PROCESS mode
// ----------------------------
dev-runtime {
// Number of threads that you process tasks
execution-parallelism: 5
}
}
+17 -3
View File
@@ -16,8 +16,22 @@ echo "${check_style}"
[[ ${check_style} =~ "BUILD FAILURE" ]] && exit 1
# test raylet
mvn_test=$(mvn test)
echo "${mvn_test}"
[[ ${mvn_test} =~ "BUILD SUCCESS" ]] || exit 1
mvn test | tee mvn_test
if [ `grep -c "BUILD FAILURE" mvn_test` -eq '0' ]; then
rm mvn_test
echo "Tests passed under CLUSTER mode!"
else
rm mvn_test
exit 1
fi
# test raylet under SINGLE_PROCESS mode
mvn test -Dray.run-mode=SINGLE_PROCESS | tee dev_mvn_test
if [ `grep -c "BUILD FAILURE" dev_mvn_test` -eq '0' ]; then
rm dev_mvn_test
echo "Tests passed under SINGLE_PROCESS mode!"
else
rm dev_mvn_test
exit 1
fi
popd
@@ -0,0 +1,15 @@
package org.ray.api;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.RunMode;
import org.testng.SkipException;
public class TestUtils {
public static void skipTestUnderSingleProcess() {
AbstractRayRuntime runtime = (AbstractRayRuntime)Ray.internal();
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
throw new SkipException("Skip case.");
}
}
}
@@ -9,6 +9,7 @@ import java.util.concurrent.TimeUnit;
import org.ray.api.Checkpointable;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.TestUtils;
import org.ray.api.annotation.RayRemote;
import org.ray.api.exception.RayActorException;
import org.ray.api.id.UniqueId;
@@ -33,6 +34,11 @@ public class ActorReconstructionTest extends BaseTest {
}
}
@Override
public void beforeEachCase() {
TestUtils.skipTestUnderSingleProcess();
}
@Test
public void testActorReconstruction() throws InterruptedException, IOException {
ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1);
@@ -12,6 +12,7 @@ public class BaseTest {
System.setProperty("ray.resources", "CPU:4,RES-A:4");
beforeInitRay();
Ray.init();
beforeEachCase();
}
@AfterMethod
@@ -37,4 +38,8 @@ public class BaseTest {
protected void afterShutdownRay() {
}
protected void beforeEachCase() {
}
}
@@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList;
import java.util.concurrent.TimeUnit;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayObjectImpl;
@@ -16,6 +17,11 @@ public class ClientExceptionTest extends BaseTest {
private static final Logger LOGGER = LoggerFactory.getLogger(ClientExceptionTest.class);
@Override
public void beforeEachCase() {
TestUtils.skipTestUnderSingleProcess();
}
@Test
public void testWaitAndCrash() {
UniqueId randomId = UniqueId.randomId();
@@ -3,6 +3,7 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.exception.RayActorException;
import org.ray.api.exception.RayTaskException;
import org.ray.api.exception.RayWorkerException;
@@ -54,6 +55,11 @@ public class FailureTest extends BaseTest {
}
}
@Override
public void beforeEachCase() {
TestUtils.skipTestUnderSingleProcess();
}
@Test
public void testNormalTaskFailure() {
assertTaskFailedWithRayTaskException(Ray.call(FailureTest::badFunc));
@@ -1,7 +1,6 @@
package org.ray.api.test;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -17,7 +16,6 @@ public class RayConfigTest {
Assert.assertEquals("/path/to/ray", rayConfig.rayHome);
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
Assert.assertEquals(RunMode.CLUSTER, rayConfig.runMode);
System.setProperty("ray.home", "");
rayConfig = RayConfig.create();
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
import org.ray.api.options.ActorCreationOptions;
@@ -29,6 +30,11 @@ public class ResourcesManagementTest extends BaseTest {
}
}
@Override
public void beforeEachCase() {
TestUtils.skipTestUnderSingleProcess();
}
@Test
public void testMethods() {
CallOptions callOptions1 = new CallOptions(ImmutableMap.of("CPU", 4.0, "GPU", 0.0));
@@ -6,6 +6,7 @@ import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.id.UniqueId;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -16,6 +17,11 @@ public class StressTest extends BaseTest {
return x;
}
@Override
public void beforeEachCase() {
TestUtils.skipTestUnderSingleProcess();
}
@Test
public void testSubmittingTasks() {
for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) {