diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index e667a3229..bfbf64d81 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -71,14 +71,14 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public RayObject put(T obj) { UniqueId objectId = UniqueIdUtil.computePutId( - workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); + workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); put(objectId, obj); return new RayObjectImpl<>(objectId); } public void put(UniqueId objectId, T obj) { - UniqueId taskId = workerContext.getCurrentTask().taskId; + UniqueId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj, null); } @@ -92,8 +92,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { */ public RayObject putSerialized(byte[] obj) { UniqueId objectId = UniqueIdUtil.computePutId( - workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); - UniqueId taskId = workerContext.getCurrentTask().taskId; + workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); + UniqueId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); objectStoreProxy.putSerialized(objectId, obj, null); return new RayObjectImpl<>(objectId); @@ -108,7 +108,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public List get(List objectIds) { boolean wasBlocked = false; - UniqueId taskId = workerContext.getCurrentThreadTaskId(); try { int numObjectIds = objectIds.size(); @@ -117,7 +116,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { List> fetchBatches = splitIntoBatches(objectIds, FETCH_BATCH_SIZE); for (List batch : fetchBatches) { - rayletClient.fetchOrReconstruct(batch, true, taskId); + rayletClient.fetchOrReconstruct(batch, true, workerContext.getCurrentTaskId()); } // Get the objects. We initially try to get the objects immediately. @@ -144,7 +143,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { splitIntoBatches(unreadyList, FETCH_BATCH_SIZE); for (List batch : reconstructBatches) { - rayletClient.fetchOrReconstruct(batch, false, taskId); + rayletClient.fetchOrReconstruct(batch, false, workerContext.getCurrentTaskId()); } List> results = objectStoreProxy @@ -171,7 +170,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { } if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), taskId); + LOGGER.debug("Got objects {} for task {}.", Arrays.toString(objectIds.toArray()), + workerContext.getCurrentTaskId()); } List finalRet = new ArrayList<>(); @@ -182,13 +182,13 @@ public abstract class AbstractRayRuntime implements RayRuntime { return finalRet; } catch (RayException e) { - LOGGER.error("Failed to get objects for task {}.", taskId, e); + LOGGER.error("Failed to get objects for task {}.", workerContext.getCurrentTaskId(), e); throw e; } finally { // If there were objects that we weren't able to get locally, let the local // scheduler know that we're now unblocked. if (wasBlocked) { - rayletClient.notifyUnblocked(taskId); + rayletClient.notifyUnblocked(workerContext.getCurrentTaskId()); } } } @@ -217,7 +217,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { @Override public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { return rayletClient.wait(waitList, numReturns, - timeoutMs, workerContext.getCurrentThreadTaskId()); + timeoutMs, workerContext.getCurrentTaskId()); } @Override @@ -277,9 +277,8 @@ public abstract class AbstractRayRuntime implements RayRuntime { */ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, boolean isActorCreationTask, BaseTaskOptions taskOptions) { - final TaskSpec current = workerContext.getCurrentTask(); - UniqueId taskId = rayletClient.generateTaskId(current.driverId, - current.taskId, workerContext.nextCallIndex()); + UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), + workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; UniqueId[] returnIds = genReturnIds(taskId, numReturns); @@ -304,11 +303,11 @@ public abstract class AbstractRayRuntime implements RayRuntime { if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; } - RayFunction rayFunction = functionManager.getFunction(current.driverId, func); + RayFunction rayFunction = functionManager.getFunction(workerContext.getCurrentDriverId(), func); return new TaskSpec( - current.driverId, + workerContext.getCurrentDriverId(), taskId, - current.taskId, + workerContext.getCurrentTaskId(), -1, actorCreationId, maxActorReconstruction, diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index f47993246..139abdf63 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -39,7 +39,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { path += ":"; } - path += rayConfig.libraryPath.stream().collect(Collectors.joining(":")); + path += String.join(":", rayConfig.libraryPath); // This is a hack to reset library path at runtime, // see https://stackoverflow.com/questions/15409223/. @@ -80,7 +80,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.rayletSocketName, workerContext.getCurrentWorkerId(), rayConfig.workerMode == WorkerMode.WORKER, - workerContext.getCurrentTask().taskId + workerContext.getCurrentDriverId() ); // register diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 3531b7ed8..929f343be 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -43,8 +43,7 @@ public class Worker { RayFunction rayFunction = runtime.getFunctionManager() .getFunction(spec.driverId, spec.functionDescriptor); // Set context - runtime.getWorkerContext().setCurrentTask(spec); - runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); + runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); // Get local actor object and arguments. Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null; @@ -67,6 +66,7 @@ public class Worker { LOGGER.error("Error executing task " + spec, e); runtime.put(returnId, new RayException("Error executing task " + spec, e)); } finally { + runtime.getWorkerContext().setCurrentTask(null, null); Thread.currentThread().setContextClassLoader(oldLoader); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 785581086..b97a08b52 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,9 +1,6 @@ package org.ray.runtime; import com.google.common.base.Preconditions; -import java.util.HashMap; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.id.UniqueId; import org.ray.runtime.config.WorkerMode; import org.ray.runtime.task.TaskSpec; @@ -11,123 +8,114 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class WorkerContext { + private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class); - /** - * Worker id. - */ private UniqueId workerId; - /** - * Current task. - */ - private TaskSpec currentTask; + private ThreadLocal currentTaskId; /** - * Current class loader. + * Number of objects that have been put from current task. */ + private ThreadLocal putIndex; + + /** + * Number of tasks that have been submitted from current task. + */ + private ThreadLocal taskIndex; + + private UniqueId currentDriverId; + private ClassLoader currentClassLoader; - /** - * How many puts have been done by current task. - */ - private AtomicInteger currentTaskPutCount; - - /** - * How many calls have been done by current task. - */ - private AtomicInteger currentTaskCallCount; - /** * The ID of main thread which created the worker context. */ private long mainThreadId; - /** - * If the multi-threading warning message has been logged. - */ - private AtomicBoolean multiThreadingWarned; + public WorkerContext(WorkerMode workerMode, UniqueId driverId) { - workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId(); - currentTaskPutCount = new AtomicInteger(0); - currentTaskCallCount = new AtomicInteger(0); - currentClassLoader = null; - currentTask = createDummyTask(workerMode, driverId); mainThreadId = Thread.currentThread().getId(); - multiThreadingWarned = new AtomicBoolean(false); + taskIndex = ThreadLocal.withInitial(() -> 0); + putIndex = ThreadLocal.withInitial(() -> 0); + currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); + if (workerMode == WorkerMode.DRIVER) { + workerId = driverId; + currentTaskId.set(UniqueId.randomId()); + currentDriverId = driverId; + currentClassLoader = null; + } else { + workerId = UniqueId.randomId(); + setCurrentTask(null, null); + } } /** - * Get the current thread's task ID. - * This returns the assigned task ID if called on the main thread, else a - * random task ID. + * @return For the main thread, this method returns the ID of this worker's current running task; + * for other threads, this method returns a random ID. */ - public UniqueId getCurrentThreadTaskId() { - UniqueId taskId; - if (Thread.currentThread().getId() == mainThreadId) { - taskId = currentTask.taskId; + public UniqueId getCurrentTaskId() { + return currentTaskId.get(); + } + + /** + * Set the current task which is being executed by the current worker. Note, this method can only + * be called from the main thread. + */ + public void setCurrentTask(TaskSpec task, ClassLoader classLoader) { + Preconditions.checkState( + Thread.currentThread().getId() == mainThreadId, + "This method should only be called from the main thread." + ); + if (task != null) { + currentTaskId.set(task.taskId); + currentDriverId = task.driverId; } else { - taskId = UniqueId.randomId(); - if (multiThreadingWarned.compareAndSet(false, true)) { - LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " + - "may lead to deadlock if the main thread blocks on this " + - "thread and there are not enough resources to execute " + - "more tasks"); - } + currentTaskId.set(UniqueId.NIL); + currentDriverId = UniqueId.NIL; } - - Preconditions.checkState(!taskId.isNil()); - return taskId; - } - - public void setWorkerId(UniqueId workerId) { - this.workerId = workerId; - } - - public TaskSpec getCurrentTask() { - return currentTask; + taskIndex.set(0); + putIndex.set(0); + currentClassLoader = classLoader; } + /** + * Increment the put index and return the new value. + */ public int nextPutIndex() { - return currentTaskPutCount.incrementAndGet(); + putIndex.set(putIndex.get() + 1); + return putIndex.get(); } - public int nextCallIndex() { - return currentTaskCallCount.incrementAndGet(); + /** + * Increment the task index and return the new value. + */ + public int nextTaskIndex() { + taskIndex.set(taskIndex.get() + 1); + return taskIndex.get(); } + /** + * @return The ID of the current worker. + */ public UniqueId getCurrentWorkerId() { return workerId; } + /** + * @return If this worker is a driver, this method returns the driver ID; Otherwise, it returns + * the driver ID of the current running task. + */ + public UniqueId getCurrentDriverId() { + return currentDriverId; + } + + /** + * @return The class loader which is associated with the current driver. + */ public ClassLoader getCurrentClassLoader() { return currentClassLoader; } - public void setCurrentTask(TaskSpec currentTask) { - this.currentTask = currentTask; - currentTaskCallCount.set(0); - currentTaskPutCount.set(0); - } - - public void setCurrentClassLoader(ClassLoader currentClassLoader) { - this.currentClassLoader = currentClassLoader; - } - - private TaskSpec createDummyTask(WorkerMode workerMode, UniqueId driverId) { - return new TaskSpec( - driverId, - workerMode == WorkerMode.DRIVER ? UniqueId.randomId() : UniqueId.NIL, - UniqueId.NIL, - 0, - UniqueId.NIL, - 0, - UniqueId.NIL, - UniqueId.NIL, - 0, - null, - null, - new HashMap<>(), - null); - } } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 0e3c70ed9..466233e9b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -95,7 +95,7 @@ public class MockObjectStore implements ObjectStoreLink { } private String logPrefix() { - return runtime.getWorkerContext().getCurrentTask().taskId + "-" + getUserTrace() + " -> "; + return runtime.getWorkerContext().getCurrentTaskId() + "-" + getUserTrace() + " -> "; } private String getUserTrace() { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 511b24b22..b3df8d1cb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -79,6 +79,9 @@ public class RayletClientImpl implements RayletClient { @Override public void submitTask(TaskSpec spec) { LOGGER.debug("Submitting task: {}", spec); + Preconditions.checkState(!spec.parentTaskId.isNil()); + Preconditions.checkState(!spec.driverId.isNil()); + ByteBuffer info = convertTaskSpecToFlatbuffer(spec); byte[] cursorId = null; if (!spec.getExecutionDependencies().isEmpty()) { diff --git a/python/ray/actor.py b/python/ray/actor.py index 54baeddb0..5d3d67baa 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -7,6 +7,7 @@ import hashlib import inspect import logging import sys +import threading import traceback import ray.cloudpickle as pickle @@ -225,8 +226,7 @@ class ActorMethod(object): self._method_name, args=args, kwargs=kwargs, - num_return_vals=num_return_vals, - dependency=self._actor._ray_actor_cursor) + num_return_vals=num_return_vals) class ActorClass(object): @@ -525,13 +525,13 @@ class ActorHandle(object): self._ray_actor_method_cpus = actor_method_cpus self._ray_actor_driver_id = actor_driver_id self._ray_new_actor_handles = [] + self._ray_actor_lock = threading.Lock() def _actor_method_call(self, method_name, args=None, kwargs=None, - num_return_vals=None, - dependency=None): + num_return_vals=None): """Method execution stub for an actor handle. This is the function that executes when @@ -570,41 +570,34 @@ class ActorHandle(object): return getattr(worker.actors[self._ray_actor_id], method_name)(*copy.deepcopy(args)) - # Add the execution dependency. - if dependency is None: - execution_dependencies = [] - else: - execution_dependencies = [dependency] - is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") function_descriptor = FunctionDescriptor( self._ray_module_name, method_name, self._ray_class_name) - object_ids = worker.submit_task( - function_descriptor, - args, - actor_id=self._ray_actor_id, - actor_handle_id=self._ray_actor_handle_id, - actor_counter=self._ray_actor_counter, - is_actor_checkpoint_method=is_actor_checkpoint_method, - actor_creation_dummy_object_id=( - self._ray_actor_creation_dummy_object_id), - execution_dependencies=execution_dependencies, - new_actor_handles=self._ray_new_actor_handles, - # We add one for the dummy return ID. - num_return_vals=num_return_vals + 1, - resources={"CPU": self._ray_actor_method_cpus}, - placement_resources={}, - driver_id=self._ray_actor_driver_id) - # Update the actor counter and cursor to reflect the most recent - # invocation. - self._ray_actor_counter += 1 - # The last object returned is the dummy object that should be - # passed in to the next actor method. Do not return it to the user. - self._ray_actor_cursor = object_ids.pop() - # We have notified the backend of the new actor handles to expect since - # the last task was submitted, so clear the list. - self._ray_new_actor_handles = [] + with self._ray_actor_lock: + object_ids = worker.submit_task( + function_descriptor, + args, + actor_id=self._ray_actor_id, + actor_handle_id=self._ray_actor_handle_id, + actor_counter=self._ray_actor_counter, + is_actor_checkpoint_method=is_actor_checkpoint_method, + actor_creation_dummy_object_id=( + self._ray_actor_creation_dummy_object_id), + execution_dependencies=[self._ray_actor_cursor], + new_actor_handles=self._ray_new_actor_handles, + # We add one for the dummy return ID. + num_return_vals=num_return_vals + 1, + resources={"CPU": self._ray_actor_method_cpus}, + placement_resources={}, + driver_id=self._ray_actor_driver_id, + ) + # Update the actor counter and cursor to reflect the most recent + # invocation. + self._ray_actor_counter += 1 + # The last object returned is the dummy object that should be + # passed in to the next actor method. Do not return it to the user. + self._ray_actor_cursor = object_ids.pop() if len(object_ids) == 1: object_ids = object_ids[0] diff --git a/python/ray/worker.py b/python/ray/worker.py index f43f12206..1576265eb 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -143,13 +143,6 @@ class Worker(object): cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. profiler: the profiler used to aggregate profiling information. - state_lock (Lock): - Used to lock worker's non-thread-safe internal states: - 1) task_index increment: make sure we generate unique task ids; - 2) Object reconstruction: because the node manager will - recycle/return the worker's resources before/after reconstruction, - it's unsafe for multiple threads to call object - reconstruction simultaneously. """ def __init__(self): @@ -169,42 +162,56 @@ class Worker(object): self.original_gpu_ids = ray.utils.get_cuda_visible_devices() self.profiler = None self.memory_monitor = memory_monitor.MemoryMonitor() - self.state_lock = threading.Lock() # A dictionary that maps from driver id to SerializationContext # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} self.function_actor_manager = FunctionActorManager(self) - # Reads/writes to the following fields must be protected by - # self.state_lock. # Identity of the driver that this worker is processing. self.task_driver_id = ray.ObjectID(NIL_ID) - self.current_task_id = ray.ObjectID(NIL_ID) - self.task_index = 0 - self.put_index = 1 + self._task_context = threading.local() - def get_current_thread_task_id(self): - """Get the current thread's task ID. + @property + def task_context(self): + """A thread-local that contains the following attributes. - This returns the assigned task ID if called on the main thread, else a - random task ID. This method is not thread-safe and must be called with - self.state_lock acquired. + current_task_id: For the main thread, this field is the ID of this + worker's current running task; for other threads, this field is a + fake random ID. + task_index: The number of tasks that have been submitted from the + current task. + put_index: The number of objects that have been put from the current + task. """ - current_task_id = self.current_task_id - if not ray.utils.is_main_thread(): - # If this is running on a separate thread, then the mapping - # to the current task ID may not be correct. Generate a - # random task ID so that the backend can differentiate - # between different threads. - current_task_id = ray.ObjectID(random_string()) - if not self.multithreading_warned: - logger.warning( - "Calling ray.get or ray.wait in a separate thread " - "may lead to deadlock if the main thread blocks on this " - "thread and there are not enough resources to execute " - "more tasks") - self.multithreading_warned = True - assert not current_task_id.is_nil() - return current_task_id + if not hasattr(self._task_context, 'initialized'): + # Initialize task_context for the current thread. + if ray.utils.is_main_thread(): + # If this is running on the main thread, initialize it to + # NIL. The actual value will set when the worker receives + # a task from raylet backend. + self._task_context.current_task_id = ray.ObjectID(NIL_ID) + else: + # If this is running on a separate thread, then the mapping + # to the current task ID may not be correct. Generate a + # random task ID so that the backend can differentiate + # between different threads. + self._task_context.current_task_id = ray.ObjectID( + random_string()) + if getattr(self, '_multithreading_warned', False) is not True: + logger.warning( + "Calling ray.get or ray.wait in a separate thread " + "may lead to deadlock if the main thread blocks on " + "this thread and there are not enough resources to " + "execute more tasks") + self._multithreading_warned = True + + self._task_context.task_index = 0 + self._task_context.put_index = 1 + self._task_context.initialized = True + return self._task_context + + @property + def current_task_id(self): + return self.task_context.current_task_id def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -467,48 +474,45 @@ class Worker(object): } if len(unready_ids) > 0: - with self.state_lock: - # Get the task ID, to notify the backend which task is blocked. - current_task_id = self.get_current_thread_task_id() + # Try reconstructing any objects we haven't gotten yet. Try to + # get them until at least get_timeout_milliseconds + # milliseconds passes, then repeat. + while len(unready_ids) > 0: + object_ids_to_fetch = [ + plasma.ObjectID(unready_id) + for unready_id in unready_ids.keys() + ] + ray_object_ids_to_fetch = [ + ray.ObjectID(unready_id) + for unready_id in unready_ids.keys() + ] + fetch_request_size = ray._config.worker_fetch_request_size() + for i in range(0, len(object_ids_to_fetch), + fetch_request_size): + self.raylet_client.fetch_or_reconstruct( + ray_object_ids_to_fetch[i:(i + fetch_request_size)], + False, + self.current_task_id, + ) + results = self.retrieve_and_deserialize( + object_ids_to_fetch, + max([ + ray._config.get_timeout_milliseconds(), + int(0.01 * len(unready_ids)), + ]), + ) + # Remove any entries for objects we received during this + # iteration so we don't retrieve the same object twice. + for i, val in enumerate(results): + if val is not plasma.ObjectNotAvailable: + object_id = object_ids_to_fetch[i].binary() + index = unready_ids[object_id] + final_results[index] = val + unready_ids.pop(object_id) - # Try reconstructing any objects we haven't gotten yet. Try to - # get them until at least get_timeout_milliseconds - # milliseconds passes, then repeat. - while len(unready_ids) > 0: - object_ids_to_fetch = [ - plasma.ObjectID(unready_id) - for unready_id in unready_ids.keys() - ] - ray_object_ids_to_fetch = [ - ray.ObjectID(unready_id) - for unready_id in unready_ids.keys() - ] - fetch_request_size = ( - ray._config.worker_fetch_request_size()) - for i in range(0, len(object_ids_to_fetch), - fetch_request_size): - self.raylet_client.fetch_or_reconstruct( - ray_object_ids_to_fetch[i:( - i + fetch_request_size)], False, - current_task_id) - results = self.retrieve_and_deserialize( - object_ids_to_fetch, - max([ - ray._config.get_timeout_milliseconds(), - int(0.01 * len(unready_ids)) - ])) - # Remove any entries for objects we received during this - # iteration so we don't retrieve the same object twice. - for i, val in enumerate(results): - if val is not plasma.ObjectNotAvailable: - object_id = object_ids_to_fetch[i].binary() - index = unready_ids[object_id] - final_results[index] = val - unready_ids.pop(object_id) - - # If there were objects that we weren't able to get locally, - # let the local scheduler know that we're now unblocked. - self.raylet_client.notify_unblocked(current_task_id) + # If there were objects that we weren't able to get locally, + # let the local scheduler know that we're now unblocked. + self.raylet_client.notify_unblocked(self.current_task_id) assert len(final_results) == len(object_ids) return final_results @@ -616,24 +620,32 @@ class Worker(object): if placement_resources is None: placement_resources = {} - with self.state_lock: - # Increment the worker's task index to track how many tasks - # have been submitted by the current task so far. - task_index = self.task_index - self.task_index += 1 - # The parent task must be set for the submitted task. - if self.actor_id == NIL_ACTOR_ID: - assert not self.current_task_id.is_nil() + # Increment the worker's task index to track how many tasks + # have been submitted by the current task so far. + self.task_context.task_index += 1 + # The parent task must be set for the submitted task. + assert not self.current_task_id.is_nil() # Submit the task to local scheduler. function_descriptor_list = ( function_descriptor.get_function_descriptor_list()) task = ray.raylet.Task( - driver_id, function_descriptor_list, args_for_local_scheduler, - num_return_vals, self.current_task_id, task_index, - actor_creation_id, actor_creation_dummy_object_id, - max_actor_reconstructions, actor_id, actor_handle_id, - actor_counter, new_actor_handles, execution_dependencies, - resources, placement_resources) + driver_id, + function_descriptor_list, + args_for_local_scheduler, + num_return_vals, + self.current_task_id, + self.task_context.task_index, + actor_creation_id, + actor_creation_dummy_object_id, + max_actor_reconstructions, + actor_id, + actor_handle_id, + actor_counter, + new_actor_handles, + execution_dependencies, + resources, + placement_resources, + ) self.raylet_client.submit_task(task) return task.returns() @@ -770,24 +782,23 @@ class Worker(object): (these will be retrieved by calls to get or by subsequent tasks that use the outputs of this task). """ - with self.state_lock: - assert self.current_task_id.is_nil() - assert self.task_index == 0 - assert self.put_index == 1 - if task.actor_id().is_nil(): - # If this worker is not an actor, check that `task_driver_id` - # was reset when the worker finished the previous task. - assert self.task_driver_id.is_nil() - # Set the driver ID of the current running task. This is - # needed so that if the task throws an exception, we propagate - # the error message to the correct driver. - self.task_driver_id = task.driver_id() - else: - # If this worker is an actor, task_driver_id wasn't reset. - # Check that current task's driver ID equals the previous one. - assert self.task_driver_id == task.driver_id() + assert self.current_task_id.is_nil() + assert self.task_context.task_index == 0 + assert self.task_context.put_index == 1 + if task.actor_id().is_nil(): + # If this worker is not an actor, check that `task_driver_id` + # was reset when the worker finished the previous task. + assert self.task_driver_id.is_nil() + # Set the driver ID of the current running task. This is + # needed so that if the task throws an exception, we propagate + # the error message to the correct driver. + self.task_driver_id = task.driver_id() + else: + # If this worker is an actor, task_driver_id wasn't reset. + # Check that current task's driver ID equals the previous one. + assert self.task_driver_id == task.driver_id() - self.current_task_id = task.task_id() + self.task_context.current_task_id = task.task_id() function_descriptor = FunctionDescriptor.from_bytes_list( task.function_descriptor_list()) @@ -931,13 +942,14 @@ class Worker(object): with _changeproctitle(title, next_title): self._process_task(task, execution_info) # Reset the state fields so the next task can run. - with self.state_lock: - if self.actor_id == NIL_ACTOR_ID: - # We will keep task_driver_id unchanged for actor. - self.task_driver_id = ray.ObjectID(NIL_ID) - self.current_task_id = ray.ObjectID(NIL_ID) - self.task_index = 0 - self.put_index = 1 + self.task_context.current_task_id = ray.ObjectID(NIL_ID) + self.task_context.task_index = 0 + self.task_context.put_index = 1 + if self.actor_id == NIL_ACTOR_ID: + # Don't need to reset task_driver_id if the worker is an + # actor. Because the following tasks should all have the + # same driver id. + self.task_driver_id = ray.ObjectID(NIL_ID) # Increase the task execution counter. self.function_actor_manager.increase_task_counter( @@ -1925,13 +1937,8 @@ def connect(ray_params, else: # Try to use true randomness. np.random.seed(None) - worker.current_task_id = ray.ObjectID( - np.random.bytes(ray_constants.ID_SIZE)) # Reset the state of the numpy random number generator. np.random.set_state(numpy_state) - # Set other fields needed for computing task IDs. - worker.task_index = 0 - worker.put_index = 1 # Create an entry for the driver task in the task table. This task is # added immediately with status RUNNING. This allows us to push errors @@ -1944,11 +1951,22 @@ def connect(ray_params, function_descriptor = FunctionDescriptor.for_driver_task() driver_task = ray.raylet.Task( worker.task_driver_id, - function_descriptor.get_function_descriptor_list(), [], 0, - worker.current_task_id, worker.task_index, - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), 0, - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, [], [], {"CPU": 0}, {}) + function_descriptor.get_function_descriptor_list(), + [], # arguments. + 0, # num_returns. + ray.ObjectID(random_string()), # parent_task_id. + 0, # parent_counter. + ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id. + ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id. + 0, # max_actor_reconstructions. + ray.ObjectID(NIL_ACTOR_ID), # actor_id. + ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id. + nil_actor_counter, # actor_counter. + [], # new_actor_handles. + [], # execution_dependencies. + {"CPU": 0}, # resource_map. + {}, # placement_resource_map. + ) # Add the driver task to the task table. global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", @@ -1959,16 +1977,14 @@ def connect(ray_params, # Set the driver's current task ID to the task ID assigned to the # driver task. - worker.current_task_id = driver_task.task_id() - else: - # A non-driver worker begins without an assigned task. - worker.current_task_id = ray.ObjectID(NIL_ID) - # A flag for making sure that we only print one warning message about - # multithreading per worker. - worker.multithreading_warned = False + worker.task_context.current_task_id = driver_task.task_id() worker.raylet_client = ray.raylet.RayletClient( - raylet_socket, worker.worker_id, is_worker, worker.current_task_id) + raylet_socket, + worker.worker_id, + is_worker, + worker.current_task_id, + ) # Start the import thread import_thread.ImportThread(worker, mode).start() @@ -2254,9 +2270,11 @@ def put(value, worker=global_worker): # In LOCAL_MODE, ray.put is the identity operation. return value object_id = worker.raylet_client.compute_put_id( - worker.current_task_id, worker.put_index) + worker.current_task_id, + worker.task_context.put_index, + ) worker.put_object(object_id, value) - worker.put_index += 1 + worker.task_context.put_index += 1 return object_id @@ -2342,15 +2360,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): raise Exception("num_returns cannot be greater than the number " "of objects provided to ray.wait.") - # Get the task ID, to notify the backend which task is blocked. - with worker.state_lock: - current_task_id = worker.get_current_thread_task_id() - timeout = timeout if timeout is not None else 10**6 timeout_milliseconds = int(timeout * 1000) ready_ids, remaining_ids = worker.raylet_client.wait( - object_ids, num_returns, timeout_milliseconds, False, - current_task_id) + object_ids, + num_returns, + timeout_milliseconds, + False, + worker.current_task_id, + ) return ready_ids, remaining_ids diff --git a/test/runtest.py b/test/runtest.py index 9ec7c1854..080fef4b1 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -5,6 +5,7 @@ from __future__ import print_function import json import logging import os +import random import re import setproctitle import string @@ -13,6 +14,7 @@ import sys import threading import time from collections import defaultdict, namedtuple, OrderedDict +from concurrent.futures import ThreadPoolExecutor import numpy as np import pytest @@ -1176,59 +1178,137 @@ def test_multithreading(shutdown_only): # relase resources when joining the threads. ray.init(num_cpus=2) + def run_test_in_multi_threads(test_case, num_threads=20, num_repeats=50): + """A helper function that runs test cases in multiple threads.""" + + def wrapper(): + for _ in range(num_repeats): + test_case() + time.sleep(random.randint(0, 10) / 1000.0) + return "ok" + + executor = ThreadPoolExecutor(max_workers=num_threads) + futures = [executor.submit(wrapper) for _ in range(num_threads)] + for future in futures: + assert future.result() == "ok" + @ray.remote - def f(): - pass + def echo(value, delay_ms=0): + if delay_ms > 0: + time.sleep(delay_ms / 1000.0) + return value - def g(n): - for _ in range(1000 // n): - ray.get([f.remote() for _ in range(n)]) - res = [ray.put(i) for i in range(1000 // n)] - ray.wait(res, len(res)) + @ray.remote + class Echo(object): + def echo(self, value): + return value - def test_multi_threading(): - threads = [ - threading.Thread(target=g, args=(n, )) - for n in [1, 5, 10, 100, 1000] + def test_api_in_multi_threads(): + """Test using Ray api in multiple threads.""" + + # Test calling remote functions in multiple threads. + def test_remote_call(): + value = random.randint(0, 1000000) + result = ray.get(echo.remote(value)) + assert value == result + + run_test_in_multi_threads(test_remote_call) + + # Test multiple threads calling one actor. + actor = Echo.remote() + + def test_call_actor(): + value = random.randint(0, 1000000) + result = ray.get(actor.echo.remote(value)) + assert value == result + + run_test_in_multi_threads(test_call_actor) + + # Test put and get. + def test_put_and_get(): + value = random.randint(0, 1000000) + result = ray.get(ray.put(value)) + assert value == result + + run_test_in_multi_threads(test_put_and_get) + + # Test multiple threads waiting for objects. + num_wait_objects = 10 + objects = [ + echo.remote(i, delay_ms=10) for i in range(num_wait_objects) ] - [thread.start() for thread in threads] - [thread.join() for thread in threads] + def test_wait(): + ready, _ = ray.wait( + objects, + num_returns=len(objects), + timeout=1000, + ) + assert len(ready) == num_wait_objects + assert ray.get(ready) == list(range(num_wait_objects)) + run_test_in_multi_threads(test_wait, num_repeats=1) + + # Run tests in a driver. + test_api_in_multi_threads() + + # Run tests in a worker. @ray.remote - def test_multi_threading_in_worker(): - test_multi_threading() + def run_tests_in_worker(): + test_api_in_multi_threads() + return "ok" - def block(args, n): - ray.wait(args, num_returns=n) - ray.get(args[:n]) + assert ray.get(run_tests_in_worker.remote()) == "ok" + # Test actor that runs background threads. @ray.remote class MultithreadedActor(object): def __init__(self): - pass + self.lock = threading.Lock() + self.thread_results = [] + + def background_thread(self, wait_objects): + try: + # Test wait + ready, _ = ray.wait( + wait_objects, + num_returns=len(wait_objects), + timeout=1000, + ) + assert len(ready) == len(wait_objects) + for _ in range(50): + num = 20 + # Test remote call + results = [echo.remote(i) for i in range(num)] + assert ray.get(results) == list(range(num)) + # Test put and get + objects = [ray.put(i) for i in range(num)] + assert ray.get(objects) == list(range(num)) + time.sleep(random.randint(0, 10) / 1000.0) + except Exception as e: + with self.lock: + self.thread_results.append(e) + else: + with self.lock: + self.thread_results.append("ok") def spawn(self): - objects = [f.remote() for _ in range(1000)] + wait_objects = [echo.remote(i, delay_ms=10) for i in range(20)] self.threads = [ - threading.Thread(target=block, args=(objects, n)) - for n in [1, 5, 10, 100, 1000] + threading.Thread( + target=self.background_thread, args=(wait_objects, )) + for _ in range(20) ] - [thread.start() for thread in self.threads] def join(self): [thread.join() for thread in self.threads] + assert self.thread_results == ["ok"] * len(self.threads) + return "ok" - # test multi-threading in the driver - test_multi_threading() - # test multi-threading in the worker - ray.get(test_multi_threading_in_worker.remote()) - - # test multi-threading in the actor - a = MultithreadedActor.remote() - ray.get(a.spawn.remote()) - ray.get(a.join.remote()) + actor = MultithreadedActor.remote() + actor.spawn.remote() + ray.get(actor.join.remote()) == "ok" def test_free_objects_multi_node(ray_start_cluster):