Push an error to the driver when the workload hangs on ray.put reconstruction (#382)

* Fix worker blocked bug

* tmp

* Push an error to the driver on ray.put for non-driver tasks

* Fix result table tests

* Fix test, logging

* Address comments

* Fix suppression bug

* Fix redis module test

* Edit error message

* Get values in chunks during reconstruction

* Test case for driver ray.put errors

* Error for evicting ray.put objects from the driver

* Fix tests

* Reduce verbosity

* Documentation
This commit is contained in:
Stephanie Wang
2017-03-21 00:16:48 -07:00
committed by Robert Nishihara
parent 4618fd45b1
commit 083e7a28ad
21 changed files with 528 additions and 132 deletions
+30 -20
View File
@@ -14,6 +14,7 @@ import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
@@ -197,6 +198,11 @@ class TestGlobalStateStore(unittest.TestCase):
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, 0)
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
# Try looking up something in the result table before anything is added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertIsNone(response)
@@ -206,17 +212,17 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id)
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, 0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertEqual(response, task_id)
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1")
self.assertEqual(response, task_id)
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id)
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, 1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2")
self.assertEqual(response, task_id)
check_result_table_entry(response, task_id, True)
def testInvalidTaskTableAdd(self):
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with
@@ -241,12 +247,13 @@ class TestGlobalStateStore(unittest.TestCase):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
def check_task_reply(message, task_args):
def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
@@ -266,7 +273,7 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(get_response, task_args[1:])
@@ -277,43 +284,46 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is no longer the same as the test value, the
# response is nil.
task_args[1] = TASK_STATUS_WAITING
# response is the same task as before the test-and-set.
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, None)
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response2, get_response)
self.assertNotEqual(get_response2, task_args[1:])
# If the test value is a bitmask that matches the current value, the update
# happens.
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that does not match the current value, the
# update does not happen.
task_args[1] = TASK_STATUS_SCHEDULED
# update does not happen, and the response is the same task as before the
# test-and-set.
new_task_args = task_args[:]
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, None)
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response, old_response)
self.assertNotEqual(get_response, task_args[1:])
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
def testTaskTableSubscribe(self):
scheduling_state = 1
+49 -5
View File
@@ -43,7 +43,10 @@ DRIVER_ID_LENGTH = 20
ERROR_ID_LENGTH = 20
# This must match the definition of NIL_ACTOR_ID in task.h.
NIL_ACTOR_ID = 20 * b"\xff"
NIL_ID = 20 * b"\xff"
NIL_LOCAL_SCHEDULER_ID = NIL_ID
NIL_FUNCTION_ID = NIL_ID
NIL_ACTOR_ID = NIL_ID
# When performing ray.get, wait 1 second before attemping to reconstruct and
# fetch the object again.
@@ -52,6 +55,10 @@ GET_TIMEOUT_MILLISECONDS = 1000
# This must be kept in sync with the `error_types` array in
# common/state/error_table.h.
OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch"
PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction"
# This must be kept in sync with the `scheduling_state` enum in common/task.h.
TASK_STATUS_RUNNING = 8
def random_string():
return np.random.bytes(20)
@@ -696,9 +703,14 @@ def error_info(worker=global_worker):
error_contents = worker.redis_client.hgetall(error_key)
# If the error is an object hash mismatch, look up the function name for
# the nondeterministic task.
if error_contents[b"type"] == OBJECT_HASH_MISMATCH_ERROR_TYPE:
error_type = error_contents[b"type"]
if (error_type == OBJECT_HASH_MISMATCH_ERROR_TYPE or error_type ==
PUT_RECONSTRUCTION_ERROR_TYPE):
function_id = error_contents[b"data"]
function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name")
if function_id == NIL_FUNCTION_ID:
function_name = b"Driver"
else:
function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name")
error_contents[b"data"] = function_name
errors.append(error_contents)
@@ -1238,6 +1250,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a
redis_ip_address, redis_port = info["redis_address"].split(":")
worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
worker.lock = threading.Lock()
# Register the worker with Redis.
if mode in [SCRIPT_MODE, SILENT_MODE]:
# The concept of a driver is the same as the concept of a "job". Register
@@ -1266,7 +1279,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a
# Create an object store client.
worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"])
# Create the local scheduler client.
worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(info["local_scheduler_socket_name"], worker.actor_id, is_worker)
worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(
info["local_scheduler_socket_name"],
worker.actor_id,
is_worker)
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
@@ -1292,12 +1308,39 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a
# Set other fields needed for computing task IDs.
worker.task_index = 0
worker.put_index = 0
# Create an entry for the driver task in the task table. This task is added
# immediately with status RUNNING. This allows us to push errors related to
# this driver task back to the driver. For example, if the driver creates
# an object that is later evicted, we should notify the user that we're
# unable to reconstruct the object, since we cannot rerun the driver.
driver_task = ray.local_scheduler.Task(
worker.task_driver_id,
ray.local_scheduler.ObjectID(NIL_FUNCTION_ID),
[],
0,
worker.current_task_id,
worker.task_index,
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
worker.actor_counters[actor_id],
[0, 0])
worker.redis_client.execute_command(
"RAY.TASK_TABLE_ADD",
driver_task.task_id().id(),
TASK_STATUS_RUNNING,
NIL_LOCAL_SCHEDULER_ID,
ray.local_scheduler.task_to_string(driver_task))
# Set the driver's current task ID to the task ID assigned to the driver
# task.
worker.current_task_id = driver_task.task_id()
# If this is a worker, then start a thread to import exports from the driver.
if mode == WORKER_MODE:
t = threading.Thread(target=import_thread, args=(worker,))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would push
# messages to the driver's worker service, but we ran into bugs when trying to
@@ -1503,7 +1546,8 @@ def put(value, worker=global_worker):
if worker.mode == PYTHON_MODE:
# In PYTHON_MODE, ray.put is the identity operation
return value
object_id = ray.local_scheduler.compute_put_id(worker.current_task_id, worker.put_index)
object_id = worker.local_scheduler_client.compute_put_id(
worker.current_task_id, worker.put_index)
worker.put_object(object_id, value)
worker.put_index += 1
return object_id