diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index 629908dbd..50b7ba421 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -15,6 +15,6 @@ if hasattr(ctypes, "windll"): import ray.experimental import ray.serialization -from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote +from ray.worker import register_class, error_info, init, connect, disconnect, get, put, wait, remote, log_event, log_span, flush_log from ray.worker import Reusable, reusables from ray.worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index a677035e0..e04d7aa78 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import json import hashlib import os import sys @@ -30,6 +31,10 @@ WORKER_MODE = 1 PYTHON_MODE = 2 SILENT_MODE = 3 +LOG_POINT = 0 +LOG_SPAN_START = 1 +LOG_SPAN_END = 2 + def random_string(): return np.random.bytes(20) @@ -476,30 +481,31 @@ class Worker(object): be object IDs or they can be values. If they are values, they must be serializable objecs. """ - check_main_thread() - # Put large or complex arguments that are passed by value in the object - # store first. - args_for_photon = [] - for arg in args: - if isinstance(arg, photon.ObjectID): - args_for_photon.append(arg) - elif photon.check_simple_value(arg): - args_for_photon.append(arg) - else: - args_for_photon.append(put(arg)) + with log_span("ray:submit_task", worker=self): + check_main_thread() + # Put large or complex arguments that are passed by value in the object + # store first. + args_for_photon = [] + for arg in args: + if isinstance(arg, photon.ObjectID): + args_for_photon.append(arg) + elif photon.check_simple_value(arg): + args_for_photon.append(arg) + else: + args_for_photon.append(put(arg)) - # Submit the task to Photon. - task = photon.Task(photon.ObjectID(function_id.id()), - args_for_photon, - self.num_return_vals[function_id.id()], - self.current_task_id, - self.task_index) - # Increment the worker's task index to track how many tasks have been - # submitted by the current task so far. - self.task_index += 1 - self.photon_client.submit(task) + # Submit the task to Photon. + task = photon.Task(photon.ObjectID(function_id.id()), + args_for_photon, + self.num_return_vals[function_id.id()], + self.current_task_id, + self.task_index) + # Increment the worker's task index to track how many tasks have been + # submitted by the current task so far. + self.task_index += 1 + self.photon_client.submit(task) - return task.returns() + return task.returns() def run_function_on_all_workers(self, function): """Run arbitrary code on all of the workers. @@ -1014,11 +1020,14 @@ def import_thread(worker): for i in range(worker.worker_import_counter, num_imports): key = worker.redis_client.lindex("Exports", i) if key.startswith(b"RemoteFunction"): - fetch_and_register_remote_function(key, worker=worker) + with log_span("ray:import_remote_function", worker=worker): + fetch_and_register_remote_function(key, worker=worker) elif key.startswith(b"ReusableVariables"): - fetch_and_register_reusable_variable(key, worker=worker) + with log_span("ray:import_reusable_variable", worker=worker): + fetch_and_register_reusable_variable(key, worker=worker) elif key.startswith(b"FunctionsToRun"): - fetch_and_execute_function_to_run(key, worker=worker) + with log_span("ray:import_function_to_run", worker=worker): + fetch_and_execute_function_to_run(key, worker=worker) else: raise Exception("This code should be unreachable.") worker.redis_client.hincrby(worker_info_key, "export_counter", 1) @@ -1044,6 +1053,10 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker): worker.worker_id = random_string() worker.connected = True worker.set_mode(mode) + # The worker.events field is used to aggregate logging information and display + # it in the web UI. Note that Python lists protected by the GIL, which is + # important because we will append to this field from multiple threads. + worker.events = [] # If running Ray in PYTHON_MODE, there is no need to create call create_worker # or to start the worker service. if mode == PYTHON_MODE: @@ -1061,9 +1074,9 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker): worker.photon_client = photon.PhotonClient(info["local_scheduler_socket_name"]) # Register the worker with Redis. if mode in [SCRIPT_MODE, SILENT_MODE]: - worker.redis_client.rpush("Drivers", worker.worker_id) + worker.redis_client.hmset(b"Drivers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address}) elif mode == WORKER_MODE: - worker.redis_client.rpush("Workers", worker.worker_id) + worker.redis_client.hmset(b"Workers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address}) else: raise Exception("This code should be unreachable.") # If this is a driver, set the current task ID and set the task index to 0. @@ -1175,6 +1188,74 @@ def register_class(cls, pickle=False, worker=global_worker): serialization.add_class_to_whitelist(cls, pickle=pickle) worker.run_function_on_all_workers(register_class_for_serialization) +class RayLogSpan(object): + """An object used to enable logging a span of events with a with statement. + + Attributes: + event_type (str): The type of the event being logged. + contents: Additional information to log. + """ + def __init__(self, event_type, contents=None, worker=global_worker): + """Initialize a RayLogSpan object.""" + self.event_type = event_type + self.contents = contents + self.worker = worker + + def __enter__(self): + """Log the beginning of a span event.""" + log(event_type=self.event_type, + contents=self.contents, + kind=LOG_SPAN_START, + worker=self.worker) + + def __exit__(self, type, value, tb): + """Log the end of a span event. Log any exception that occurred.""" + if type is None: + log(event_type=self.event_type, kind=LOG_SPAN_END, worker=self.worker) + else: + log(event_type=self.event_type, + contents={"type": str(type), + "value": value, + "traceback": traceback.format_exc()}, + kind=LOG_SPAN_END, + worker=self.worker) + +def log_span(event_type, contents=None, worker=global_worker): + return RayLogSpan(event_type, contents=contents, worker=worker) + +def log_event(event_type, contents=None, worker=global_worker): + log(event_type, kind=LOG_POINT, contents=contents, worker=worker) + +def log(event_type, kind, contents=None, worker=global_worker): + """Log an event to the global state store. + + This adds the event to a buffer of events locally. The buffer can be flushed + and written to the global state store by calling flush_log(). + + Args: + event_type (str): The type of the event. + contents: More general data to store with the event. + kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is + LOG_POINT if the event being logged happens at a single point in time. It + is LOG_SPAN_START if we are starting to log a span of time, and it is + LOG_SPAN_END if we are finishing logging a span of time. + """ + # TODO(rkn): This code currently takes around half a microsecond. Since we + # call it tens of times per task, this adds up. We will need to redo the + # logging code, perhaps in C. + contents = {} if contents is None else contents + assert isinstance(contents, dict) + # Make sure all of the keys and values in the dictionary are strings. + contents = {str(k): str(v) for k, v in contents.items()} + worker.events.append((time.time(), event_type, kind, contents)) + +def flush_log(worker=global_worker): + """Send the logged worker events to the global state store.""" + event_log_key = b"event_log:" + worker.worker_id + b":" + worker.current_task_id.id() + event_log_value = json.dumps(worker.events) + worker.photon_client.log_event(event_log_key, event_log_value) + worker.events = [] + def get(objectid, worker=global_worker): """Get a remote object or a list of remote objects from the object store. @@ -1190,22 +1271,26 @@ def get(objectid, worker=global_worker): Returns: A Python object or a list of Python objects. """ - check_main_thread() - check_connected(worker) - if worker.mode == PYTHON_MODE: - return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) - if isinstance(objectid, list): - values = [worker.get_object(x) for x in objectid] - for i, value in enumerate(values): + with log_span("ray:get", worker=worker): + check_main_thread() + check_connected(worker) + + if worker.mode == PYTHON_MODE: + # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) + return objectid + if isinstance(objectid, list): + values = [worker.get_object(x) for x in objectid] + for i, value in enumerate(values): + if isinstance(value, RayTaskError): + raise RayGetError(objectid[i], value) + return values + else: + value = worker.get_object(objectid) if isinstance(value, RayTaskError): - raise RayGetError(objectid[i], value) - return values - value = worker.get_object(objectid) - if isinstance(value, RayTaskError): - # If the result is a RayTaskError, then the task that created this object - # failed, and we should propagate the error message here. - raise RayGetError(objectid, value) - return value + # If the result is a RayTaskError, then the task that created this object + # failed, and we should propagate the error message here. + raise RayGetError(objectid, value) + return value def put(value, worker=global_worker): """Store an object in the object store. @@ -1216,14 +1301,17 @@ def put(value, worker=global_worker): Returns: The object ID assigned to this value. """ - check_main_thread() - check_connected(worker) - if worker.mode == PYTHON_MODE: - return value # In PYTHON_MODE, ray.put is the identity operation - object_id = photon.compute_put_id(worker.current_task_id, worker.put_index) - worker.put_object(object_id, value) - worker.put_index += 1 - return object_id + with log_span("ray:put", worker=worker): + check_main_thread() + check_connected(worker) + + if worker.mode == PYTHON_MODE: + # In PYTHON_MODE, ray.put is the identity operation + return value + object_id = photon.compute_put_id(worker.current_task_id, worker.put_index) + worker.put_object(object_id, value) + worker.put_index += 1 + return object_id def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): """Return a list of IDs that are ready and a list of IDs that are not ready. @@ -1247,14 +1335,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): Returns: A list of object IDs that are ready and a list of the remaining object IDs. """ - check_main_thread() - check_connected(worker) - object_id_strs = [object_id.id() for object_id in object_ids] - timeout = timeout if timeout is not None else 2 ** 30 - ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns) - ready_ids = [photon.ObjectID(object_id) for object_id in ready_ids] - remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids] - return ready_ids, remaining_ids + with log_span("ray:wait", worker=worker): + check_main_thread() + check_connected(worker) + object_id_strs = [object_id.id() for object_id in object_ids] + timeout = timeout if timeout is not None else 2 ** 30 + ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns) + ready_ids = [photon.ObjectID(object_id) for object_id in ready_ids] + remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids] + return ready_ids, remaining_ids def wait_for_valid_import_counter(function_id, timeout=5, worker=global_worker): """Wait until this worker has imported enough to execute the function. @@ -1335,14 +1424,20 @@ def main_loop(worker=global_worker): args = task.arguments() return_object_ids = task.returns() function_name = worker.function_names[function_id.id()] + # Get task arguments from the object store. - arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker) + with log_span("ray:task:get_arguments", worker=worker): + arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker) + # Execute the task. - outputs = worker.functions[function_id.id()].executor(arguments) + with log_span("ray:task:execute", worker=worker): + outputs = worker.functions[function_id.id()].executor(arguments) + # Store the outputs in the local object store. - if len(return_object_ids) == 1: - outputs = (outputs,) - store_outputs_in_objstore(return_object_ids, outputs, worker) + with log_span("ray:task:store_outputs", worker=worker): + if len(return_object_ids) == 1: + outputs = (outputs,) + store_outputs_in_objstore(return_object_ids, outputs, worker) except Exception as e: # We determine whether the exception was caused by the call to # get_arguments_for_execution or by the execution of the remote function @@ -1370,7 +1465,8 @@ def main_loop(worker=global_worker): try: # Reinitialize the values of reusable variables that were used in the task # above so that changes made to their state do not affect other tasks. - reusables._reinitialize() + with log_span("ray:task:reinitialize_reusables", worker=worker): + reusables._reinitialize() except Exception as e: # The attempt to reinitialize the reusable variables threw an exception. # We record the traceback and notify the scheduler. @@ -1384,19 +1480,32 @@ def main_loop(worker=global_worker): check_main_thread() while True: - task = worker.photon_client.get_task() + with log_span("ray:get_task", worker=worker): + task = worker.photon_client.get_task() + function_id = task.function_id() # Check that the number of imports we have is at least as great as the # export counter for the task. If not, wait until we have imported enough. # We will push warnings to the user if we spend too long in this loop. - wait_for_valid_import_counter(function_id, worker=worker) + with log_span("ray:wait_for_import_counter", worker=worker): + wait_for_valid_import_counter(function_id, worker=worker) + # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a # warning to the user if we are waiting too long to acquire the lock because # that may indicate that the system is hanging, and it'd be good to know # where the system is hanging. + log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker) with worker.lock: - process_task(task) + log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker) + + contents = {"function_name": worker.function_names[function_id.id()], + "task_id": task.task_id().hex()} + with log_span("ray:task", contents=contents, worker=worker): + process_task(task) + + # Push all of the log events to the global state store. + flush_log() def push_warning_to_user(message, worker=global_worker): error_key = "GenericWarning:{}".format(random_string()) diff --git a/src/common/Makefile b/src/common/Makefile index 0d90afbcd..9087a6faf 100644 --- a/src/common/Makefile +++ b/src/common/Makefile @@ -4,7 +4,7 @@ BUILD = build all: hiredis redis redismodule $(BUILD)/libcommon.a -$(BUILD)/libcommon.a: event_loop.o common.o task.o io.o net.o state/redis.o state/table.o state/object_table.o state/task_table.o state/db_client_table.o state/local_scheduler_table.o thirdparty/ae/ae.o thirdparty/sha256.o +$(BUILD)/libcommon.a: event_loop.o common.o task.o io.o net.o logging.o state/redis.o state/table.o state/object_table.o state/task_table.o state/db_client_table.o state/local_scheduler_table.o thirdparty/ae/ae.o thirdparty/sha256.o ar rcs $@ $^ $(BUILD)/common_tests: test/common_tests.c $(BUILD)/libcommon.a diff --git a/src/common/lib/python/common_extension.c b/src/common/lib/python/common_extension.c index 0a0ddbc0c..7ea1941c8 100644 --- a/src/common/lib/python/common_extension.c +++ b/src/common/lib/python/common_extension.c @@ -74,6 +74,14 @@ static PyObject *PyObjectID_id(PyObject *self) { sizeof(s->object_id.id)); } +static PyObject *PyObjectID_hex(PyObject *self) { + PyObjectID *s = (PyObjectID *) self; + char hex_id[ID_STRING_SIZE]; + object_id_to_string(s->object_id, hex_id, ID_STRING_SIZE); + PyObject *result = PyUnicode_FromString(hex_id); + return result; +} + static PyObject *PyObjectID_richcompare(PyObjectID *self, PyObject *other, int op) { @@ -140,6 +148,8 @@ static PyObject *PyObjectID___reduce__(PyObjectID *self) { static PyMethodDef PyObjectID_methods[] = { {"id", (PyCFunction) PyObjectID_id, METH_NOARGS, "Return the hash associated with this ObjectID"}, + {"hex", (PyCFunction) PyObjectID_hex, METH_NOARGS, + "Return the object ID as a string in hex."}, {"__reduce__", (PyCFunction) PyObjectID___reduce__, METH_NOARGS, "Say how to pickle this ObjectID. This raises an exception to prevent" "object IDs from being serialized."}, diff --git a/src/common/logging.c b/src/common/logging.c index 45a9c8bb2..268517da7 100644 --- a/src/common/logging.c +++ b/src/common/logging.c @@ -82,3 +82,15 @@ void ray_log(ray_logger *logger, utstring_free(formatted_message); utstring_free(timestamp); } + +void ray_log_event(db_handle *db, + uint8_t *key, + int64_t key_length, + uint8_t *value, + int64_t value_length) { + int status = redisAsyncCommand(db->context, NULL, NULL, "RPUSH %b %b", key, + key_length, value, value_length); + if ((status == REDIS_ERR) || db->context->err) { + LOG_REDIS_DEBUG(db->context, "error while logging message to event log"); + } +} diff --git a/src/common/logging.h b/src/common/logging.h index 4ef7c8fcb..326e62b60 100644 --- a/src/common/logging.h +++ b/src/common/logging.h @@ -13,6 +13,8 @@ #define RAY_OBJECT "OBJECT" #define RAY_TASK "TASK" +#include "state/db.h" + typedef struct ray_logger_impl ray_logger; /* Initialize a Ray logger for the given client type and logging level. If the @@ -36,4 +38,20 @@ void ray_log(ray_logger *logger, const char *event_type, const char *message); -#endif +/** + * Log an event to the event log. + * + * @param db The database handle. + * @param key The key in Redis to store the event in. + * @param key_length The length of the key. + * @param value The value to log. + * @param value_length The length of the value. + * @return Void. + */ +void ray_log_event(db_handle *db, + uint8_t *key, + int64_t key_length, + uint8_t *value, + int64_t value_length); + +#endif /* LOGGING_H */ diff --git a/src/photon/photon.h b/src/photon/photon.h index 42b2b1b3e..9208a1cd8 100644 --- a/src/photon/photon.h +++ b/src/photon/photon.h @@ -25,6 +25,8 @@ enum photon_message_type { EXECUTE_TASK, /** Reconstruct a possibly lost object. */ RECONSTRUCT_OBJECT, + /** Log a message to the event table. */ + EVENT_LOG_MESSAGE, }; // clang-format off diff --git a/src/photon/photon_client.c b/src/photon/photon_client.c index 18b030d76..5452a299a 100644 --- a/src/photon/photon_client.c +++ b/src/photon/photon_client.c @@ -15,6 +15,28 @@ void photon_disconnect(photon_conn *conn) { free(conn); } +void photon_log_event(photon_conn *conn, + uint8_t *key, + int64_t key_length, + uint8_t *value, + int64_t value_length) { + int64_t message_length = + sizeof(key_length) + sizeof(value_length) + key_length + value_length; + uint8_t *message = malloc(message_length); + int64_t offset = 0; + memcpy(&message[offset], &key_length, sizeof(key_length)); + offset += sizeof(key_length); + memcpy(&message[offset], &value_length, sizeof(value_length)); + offset += sizeof(value_length); + memcpy(&message[offset], key, key_length); + offset += key_length; + memcpy(&message[offset], value, value_length); + offset += value_length; + CHECK(offset == message_length); + write_message(conn->conn, EVENT_LOG_MESSAGE, message_length, message); + free(message); +} + void photon_submit(photon_conn *conn, task_spec *task) { write_message(conn->conn, SUBMIT_TASK, task_spec_size(task), (uint8_t *) task); diff --git a/src/photon/photon_client.h b/src/photon/photon_client.h index af8433c06..e4e58397a 100644 --- a/src/photon/photon_client.h +++ b/src/photon/photon_client.h @@ -35,6 +35,25 @@ void photon_disconnect(photon_conn *conn); */ void photon_submit(photon_conn *conn, task_spec *task); +/** + * Log an event to the event log. This will call RPUSH key value. We use RPUSH + * instead of SET so that it is possible to flush the log multiple times with + * the same key (for example the key might be shared across logging calls in the + * same task on a worker). + * + * @param conn The connection information. + * @param key The key to store the event in. + * @param key_length The length of the key. + * @param value The value to store. + * @param value_length The length of the value. + * @return Void. + */ +void photon_log_event(photon_conn *conn, + uint8_t *key, + int64_t key_length, + uint8_t *value, + int64_t value_length); + /** * Get next task for this client. This will block until the scheduler assigns * a task to this worker. This allocates and returns a task, and so the task diff --git a/src/photon/photon_extension.c b/src/photon/photon_extension.c index 23b8b3a49..8ffe3ac6b 100644 --- a/src/photon/photon_extension.c +++ b/src/photon/photon_extension.c @@ -62,6 +62,21 @@ static PyObject *PyPhotonClient_reconstruct_object(PyObject *self, Py_RETURN_NONE; } +static PyObject *PyPhotonClient_log_event(PyObject *self, PyObject *args) { + const char *key; + int key_length; + const char *value; + int value_length; + if (!PyArg_ParseTuple(args, "s#s#", &key, &key_length, &value, + &value_length)) { + return NULL; + } + photon_log_event(((PyPhotonClient *) self)->photon_connection, + (uint8_t *) key, key_length, (uint8_t *) value, + value_length); + Py_RETURN_NONE; +} + static PyMethodDef PyPhotonClient_methods[] = { {"submit", (PyCFunction) PyPhotonClient_submit, METH_VARARGS, "Submit a task to the local scheduler."}, @@ -69,6 +84,8 @@ static PyMethodDef PyPhotonClient_methods[] = { "Get a task from the local scheduler."}, {"reconstruct_object", (PyCFunction) PyPhotonClient_reconstruct_object, METH_VARARGS, "Ask the local scheduler to reconstruct an object."}, + {"log_event", (PyCFunction) PyPhotonClient_log_event, METH_VARARGS, + "Log an event to the event log through the local scheduler."}, {NULL} /* Sentinel */ }; diff --git a/src/photon/photon_scheduler.c b/src/photon/photon_scheduler.c index da0aa71c0..67b9ba1c1 100644 --- a/src/photon/photon_scheduler.c +++ b/src/photon/photon_scheduler.c @@ -9,6 +9,7 @@ #include "common.h" #include "event_loop.h" #include "io.h" +#include "logging.h" #include "object_info.h" #include "photon.h" #include "photon_scheduler.h" @@ -221,7 +222,7 @@ void process_message(event_loop *loop, local_scheduler_state *state = context; int64_t type; - read_buffer(client_sock, &type, state->input_buffer); + int64_t length = read_buffer(client_sock, &type, state->input_buffer); LOG_DEBUG("New event of type %" PRId64, type); @@ -232,6 +233,30 @@ void process_message(event_loop *loop, } break; case TASK_DONE: { } break; + case EVENT_LOG_MESSAGE: { + /* Parse the message. TODO(rkn): Redo this using flatbuffers to serialize + * the message. */ + uint8_t *message = (uint8_t *) utarray_front(state->input_buffer); + int64_t offset = 0; + int64_t key_length; + memcpy(&key_length, &message[offset], sizeof(key_length)); + offset += sizeof(key_length); + int64_t value_length; + memcpy(&value_length, &message[offset], sizeof(value_length)); + offset += sizeof(value_length); + uint8_t *key = malloc(key_length); + memcpy(key, &message[offset], key_length); + offset += key_length; + uint8_t *value = malloc(value_length); + memcpy(value, &message[offset], value_length); + offset += value_length; + CHECK(offset == length); + if (state->db != NULL) { + ray_log_event(state->db, key, key_length, value, value_length); + } + free(key); + free(value); + } break; case GET_TASK: { worker_index *wi; HASH_FIND_INT(state->worker_index, &client_sock, wi); diff --git a/test/runtest.py b/test/runtest.py index bf69c02c5..fd87ccc63 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -497,6 +497,56 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testLoggingAPI(self): + ray.init(start_ray_local=True, num_workers=1) + + def events(): + # This is a hack for getting the event log. It is not part of the API. + keys = ray.worker.global_worker.redis_client.keys("event_log:*") + return [ray.worker.global_worker.redis_client.lrange(key, 0, -1) for key in keys] + + def wait_for_num_events(num_events, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(events()) >= num_events: + return + time.sleep(0.1) + print("Timing out of wait.") + + @ray.remote + def test_log_event(): + ray.log_event("event_type1", contents={"key": "val"}) + + @ray.remote + def test_log_span(): + with ray.log_span("event_type2", contents={"key": "val"}): + pass + + # Make sure that we can call ray.log_event in a remote function. + ray.get(test_log_event.remote()) + # Wait for the event to appear in the event log. + wait_for_num_events(1) + self.assertEqual(len(events()), 1) + + # Make sure that we can call ray.log_span in a remote function. + ray.get(test_log_span.remote()) + # Wait for the events to appear in the event log. + wait_for_num_events(2) + self.assertEqual(len(events()), 2) + + @ray.remote + def test_log_span_exception(): + with ray.log_span("event_type2", contents={"key": "val"}): + raise Exception("This failed.") + + # Make sure that logging a span works if an exception is thrown. + test_log_span_exception.remote() + # Wait for the events to appear in the event log. + wait_for_num_events(3) + self.assertEqual(len(events()), 3) + + ray.worker.cleanup() + class PythonModeTest(unittest.TestCase): def testPythonMode(self):