diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index d97f4d2a6..6e403c80a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -95,7 +95,6 @@ public class Worker { currentActorId = returnId; } } finally { - runtime.getWorkerContext().setCurrentTask(null, null); Thread.currentThread().setContextClassLoader(oldLoader); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index b97a08b52..8aa2eedea 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -40,14 +40,15 @@ public class WorkerContext { taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); + currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { workerId = driverId; currentTaskId.set(UniqueId.randomId()); currentDriverId = driverId; - currentClassLoader = null; } else { workerId = UniqueId.randomId(); - setCurrentTask(null, null); + this.currentTaskId.set(UniqueId.NIL); + this.currentDriverId = UniqueId.NIL; } } @@ -68,13 +69,10 @@ public class WorkerContext { Thread.currentThread().getId() == mainThreadId, "This method should only be called from the main thread." ); - if (task != null) { - currentTaskId.set(task.taskId); - currentDriverId = task.driverId; - } else { - currentTaskId.set(UniqueId.NIL); - currentDriverId = UniqueId.NIL; - } + + Preconditions.checkNotNull(task); + this.currentTaskId.set(task.taskId); + this.currentDriverId = task.driverId; taskIndex.set(0); putIndex.set(0); currentClassLoader = classLoader; diff --git a/python/ray/worker.py b/python/ray/worker.py index 1153ed448..ec5233618 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -619,6 +619,9 @@ class Worker(object): self.task_context.task_index += 1 # The parent task must be set for the submitted task. assert not self.current_task_id.is_nil() + # Current driver id must not be nil when submitting a task. + # Because every task must belong to a driver. + assert not self.task_driver_id.is_nil() # Submit the task to local scheduler. function_descriptor_list = ( function_descriptor.get_function_descriptor_list())