diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 390b1dea2..d76d07faf 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -2,9 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import heapq import json import pickle import redis +import sys +import time import ray from ray.utils import (decode, binary_to_object_id, binary_to_hex, @@ -340,19 +343,44 @@ class GlobalState(object): return ip_filename_file - def task_profiles(self): + def task_profiles(self, start=None, end=None, num=None): """Fetch and return a list of task profiles. + Args: + start: The start point of the time window that is queried for tasks. + end: The end point in time of the time window that is queried for tasks. + num: A limit on the number of tasks that task_profiles will return. + Returns: A tuple of two elements. The first element is a dictionary mapping the task ID of a task to a list of the profiling information for all of the executions of that task. The second element is a list of profiling information for tasks where the events have no task ID. """ + if start is None: + start = 0 + if num is None: + num = sys.maxsize + task_info = dict() - event_names = self.redis_client.keys("event_log*") - for i in range(len(event_names)): - event_list = self.redis_client.lrange(event_names[i], 0, -1) + event_log_sets = self.redis_client.keys("event_log*") + + # The heap is used to maintain the set of x tasks that occurred the most + # recently across all of the workers, where x is defined as the function + # parameter num. The key is the start time of the "get_task" component of + # each task. Calling heappop will result in the taks with the earliest + # "get_task_start" to be removed from the heap. + + heap = [] + heapq.heapify(heap) + heap_size = 0 + # Parse through event logs to determine task start and end points. + for i in range(len(event_log_sets)): + event_list = self.redis_client.zrangebyscore(event_log_sets[i], + min=start, + max=end, + start=start, + num=num) for event in event_list: event_dict = json.loads(event) task_id = "" @@ -363,6 +391,10 @@ class GlobalState(object): for event in event_dict: if event[1] == "ray:get_task" and event[2] == 1: task_info[task_id]["get_task_start"] = event[0] + # Add task to min heap by its start point. + heapq.heappush(heap, + (task_info[task_id]["get_task_start"], task_id)) + heap_size += 1 if event[1] == "ray:get_task" and event[2] == 2: task_info[task_id]["get_task_end"] = event[0] if event[1] == "ray:import_remote_function" and event[2] == 1: @@ -389,9 +421,13 @@ class GlobalState(object): task_info[task_id]["worker_id"] = event[3]["worker_id"] if "function_name" in event[3]: task_info[task_id]["function_name"] = event[3]["function_name"] + if heap_size > num: + min_task, task_id_hex = heapq.heappop(heap) + del task_info[task_id_hex] + heap_size -= 1 return task_info - def dump_catapult_trace(self, path): + def dump_catapult_trace(self, path, start=None, end=None, num=None): """Dump task profiling information to a file. This information can be viewed as a timeline of profiling information by @@ -401,9 +437,10 @@ class GlobalState(object): Args: path: The filepath to dump the profiling information to. """ - task_info = self.task_profiles() + if end is None: + end = time.time() + task_info = self.task_profiles(start=start, end=end, num=num) workers = self.workers() - tasks = self.task_table() start_time = None for info in task_info.values(): task_start = min(self._get_times(info)) @@ -414,12 +451,12 @@ class GlobalState(object): return int(1e6 * (ts - start_time)) full_trace = [] - for task_id, info in task_info.items(): - parent_info = task_info.get(tasks[task_id]["TaskSpec"]["ParentTaskID"]) + task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) + task_data = self._task_table(task_id_hex) + parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"]) times = self._get_times(info) worker = workers[info["worker_id"]] - if parent_info: parent_worker = workers[parent_info["worker_id"]] parent_times = self._get_times(parent_info) @@ -473,7 +510,6 @@ class GlobalState(object): with open(path, "w") as outfile: json.dump(full_trace, outfile) - task_info def _get_times(self, data): """Extract the numerical times from a task profile. diff --git a/python/ray/worker.py b/python/ray/worker.py index 761829131..49e84dec8 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1559,10 +1559,11 @@ def log(event_type, kind, contents=None, worker=global_worker): def flush_log(worker=global_worker): """Send the logged worker events to the global state store.""" - event_log_key = (b"event_log:" + worker.worker_id + b":" + - worker.current_task_id.id()) + event_log_key = b"event_log:" + worker.worker_id event_log_value = json.dumps(worker.events) - worker.local_scheduler_client.log_event(event_log_key, event_log_value) + worker.local_scheduler_client.log_event(event_log_key, + event_log_value, + time.time()) worker.events = [] diff --git a/src/common/logging.cc b/src/common/logging.cc index 87ba0f20d..a4f16293c 100644 --- a/src/common/logging.cc +++ b/src/common/logging.cc @@ -9,6 +9,8 @@ #include "state/redis.h" #include "io.h" +#include +#include static const char *log_levels[5] = {"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}; static const char *log_fmt = @@ -90,9 +92,12 @@ void RayLogger_log_event(DBHandle *db, uint8_t *key, int64_t key_length, uint8_t *value, - int64_t value_length) { - int status = redisAsyncCommand(db->context, NULL, NULL, "RPUSH %b %b", key, - key_length, value, value_length); + int64_t value_length, + double timestamp) { + std::string timestamp_string = std::to_string(timestamp); + int status = redisAsyncCommand(db->context, NULL, NULL, "ZADD %b %s %b", key, + key_length, timestamp_string.c_str(), value, + value_length); if ((status == REDIS_ERR) || db->context->err) { LOG_REDIS_DEBUG(db->context, "error while logging message to event log"); } diff --git a/src/common/logging.h b/src/common/logging.h index ae36e3ae3..b122621e6 100644 --- a/src/common/logging.h +++ b/src/common/logging.h @@ -52,6 +52,7 @@ void RayLogger_log_event(DBHandle *db, uint8_t *key, int64_t key_length, uint8_t *value, - int64_t value_length); + int64_t value_length, + double time); #endif /* LOGGING_H */ diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs index c0cba6ceb..656235963 100644 --- a/src/local_scheduler/format/local_scheduler.fbs +++ b/src/local_scheduler/format/local_scheduler.fbs @@ -47,6 +47,7 @@ table GetTaskReply { table EventLogMessage { key: string; value: string; + timestamp: double; } // This struct is used to register a new worker with the local scheduler. diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 617ac3793..cc1187140 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -913,9 +913,10 @@ void process_message(event_loop *loop, /* Parse the message. */ auto message = flatbuffers::GetRoot(input); if (state->db != NULL) { - RayLogger_log_event( - state->db, (uint8_t *) message->key()->data(), message->key()->size(), - (uint8_t *) message->value()->data(), message->value()->size()); + RayLogger_log_event(state->db, (uint8_t *) message->key()->data(), + message->key()->size(), + (uint8_t *) message->value()->data(), + message->value()->size(), message->timestamp()); } } break; case MessageType_RegisterClientRequest: { diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 31fb2eed9..a83a3d6d1 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -73,11 +73,13 @@ void local_scheduler_log_event(LocalSchedulerConnection *conn, uint8_t *key, int64_t key_length, uint8_t *value, - int64_t value_length) { + int64_t value_length, + double timestamp) { flatbuffers::FlatBufferBuilder fbb; auto key_string = fbb.CreateString((char *) key, key_length); auto value_string = fbb.CreateString((char *) value, value_length); - auto message = CreateEventLogMessage(fbb, key_string, value_string); + auto message = + CreateEventLogMessage(fbb, key_string, value_string, timestamp); fbb.Finish(message); write_message(conn->conn, MessageType_EventLogMessage, fbb.GetSize(), fbb.GetBufferPointer()); diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 9c218c469..5ac6f6af3 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -76,13 +76,15 @@ void local_scheduler_disconnect_client(LocalSchedulerConnection *conn); * @param key_length The length of the key. * @param value The value to store. * @param value_length The length of the value. + * @param timestamp The time that the event is logged. * @return Void. */ void local_scheduler_log_event(LocalSchedulerConnection *conn, uint8_t *key, int64_t key_length, uint8_t *value, - int64_t value_length); + int64_t value_length, + double timestamp); /** * Get next task for this client. This will block until the scheduler assigns diff --git a/src/local_scheduler/local_scheduler_extension.cc b/src/local_scheduler/local_scheduler_extension.cc index 4230d0e17..67bb1dc5e 100644 --- a/src/local_scheduler/local_scheduler_extension.cc +++ b/src/local_scheduler/local_scheduler_extension.cc @@ -89,13 +89,14 @@ static PyObject *PyLocalSchedulerClient_log_event(PyObject *self, int key_length; const char *value; int value_length; - if (!PyArg_ParseTuple(args, "s#s#", &key, &key_length, &value, - &value_length)) { + double timestamp; + if (!PyArg_ParseTuple(args, "s#s#d", &key, &key_length, &value, &value_length, + ×tamp)) { return NULL; } local_scheduler_log_event( ((PyLocalSchedulerClient *) self)->local_scheduler_connection, - (uint8_t *) key, key_length, (uint8_t *) value, value_length); + (uint8_t *) key, key_length, (uint8_t *) value, value_length, timestamp); Py_RETURN_NONE; } diff --git a/test/runtest.py b/test/runtest.py index a9acede0a..d496cc779 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -733,8 +733,10 @@ class APITest(unittest.TestCase): def events(): # This is a hack for getting the event log. It is not part of the API. keys = ray.worker.global_worker.redis_client.keys("event_log:*") - return [ray.worker.global_worker.redis_client.lrange(key, 0, -1) - for key in keys] + res = [] + for key in keys: + res.extend(ray.worker.global_worker.redis_client.zrange(key, 0, -1)) + return res def wait_for_num_events(num_events, timeout=10): start_time = time.time() @@ -761,6 +763,7 @@ class APITest(unittest.TestCase): # Make sure that we can call ray.log_span in a remote function. ray.get(test_log_span.remote()) + # Wait for the events to appear in the event log. wait_for_num_events(2) self.assertEqual(len(events()), 2) @@ -1581,11 +1584,15 @@ class GlobalStateAPI(unittest.TestCase): # Make sure the event log has the correct number of events. start_time = time.time() while time.time() - start_time < 10: - profiles = ray.global_state.task_profiles() - if len(profiles) == num_calls: + profiles = ray.global_state.task_profiles(start=0, end=time.time()) + limited_profiles = ray.global_state.task_profiles(start=0, + end=time.time(), + num=1) + if len(profiles) == num_calls and len(limited_profiles) == 1: break time.sleep(0.1) self.assertEqual(len(profiles), num_calls) + self.assertEqual(len(limited_profiles), 1) # Make sure that each entry is properly formatted. for task_id, data in profiles.items():