diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 0705787f6..b9fe9a563 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -394,16 +394,16 @@ class GlobalState(object): return ip_filename_file - def task_profiles(self, start=None, end=None, num_tasks=None, fwd=True): + def task_profiles(self, num_tasks, start=None, end=None, fwd=True): """Fetch and return a list of task profiles. Args: + num_tasks: A limit on the number of tasks that task_profiles will + return. start: The start point of the time window that is queried for tasks. end: The end point in time of the time window that is queried for tasks. - num_tasks: A limit on the number of tasks that task_profiles will - return. fwd: If True, means that zrange will be used. If False, zrevrange. This argument is only meaningful in conjunction with the num_tasks argument. This controls whether the tasks returned @@ -424,12 +424,9 @@ class GlobalState(object): # function parameter num. The key is the start time of the "get_task" # component of each task. Calling heappop will result in the taks with # the earliest "get_task_start" to be removed from the heap. - - # Don't maintain the heap if we're not slicing some number - if num_tasks is not None: - heap = [] - heapq.heapify(heap) - heap_size = 0 + heap = [] + heapq.heapify(heap) + heap_size = 0 # Set up a param dict to pass the redis command params = {"withscores": True} @@ -443,12 +440,11 @@ class GlobalState(object): elif start is not None: params["max"] = time.time() - if num_tasks is not None: - if start is None and end is None: - params["end"] = num_tasks - 1 - else: - params["num"] = num_tasks - params["start"] = 0 + if start is None and end is None: + params["end"] = num_tasks - 1 + else: + params["num"] = num_tasks + params["start"] = 0 # Parse through event logs to determine task start and end points. for event_log_set in event_log_sets: @@ -481,9 +477,8 @@ class GlobalState(object): task_info[task_id]["score"] = score # Add task to (min/max) heap by its start point. # if fwd, we want to delete the largest elements, so -score - if num_tasks is not None: - heapq.heappush(heap, (-score if fwd else score, task_id)) - heap_size += 1 + heapq.heappush(heap, (-score if fwd else score, task_id)) + heap_size += 1 for event in event_dict: if event[1] == "ray:get_task" and event[2] == 1: @@ -518,7 +513,7 @@ class GlobalState(object): task_info[task_id]["function_name"] = ( event[3]["function_name"]) - if num_tasks is not None and heap_size > num_tasks: + if heap_size > num_tasks: min_task, task_id_hex = heapq.heappop(heap) del task_info[task_id_hex] heap_size -= 1 @@ -565,8 +560,10 @@ class GlobalState(object): def micros_rel(ts): return micros(ts - start_time) - task_profiles = self.task_profiles(start=0, end=time.time()) - task_table = self.task_table() + task_table = {} + # TODO(ekl) reduce the number of RPCs here with MGET + for task_id, _ in task_info.items(): + task_table[task_id] = self.task_table(task_id) seen_obj = {} full_trace = [] @@ -652,14 +649,16 @@ class GlobalState(object): if parent_info: parent_worker = workers[parent_info["worker_id"]] parent_times = self._get_times(parent_info) + parent_profile = task_info.get( + task_table[task_id]["TaskSpec"]["ParentTaskID"]) parent = { "cat": "submit_task", "pid": "Node " + parent_worker["node_ip_address"], "tid": parent_info["worker_id"], - "ts": micros_rel(task_profiles[task_table[task_id] - ["TaskSpec"] - ["ParentTaskID"]] - ["get_arguments_start"]), + "ts": micros_rel( + parent_profile and + parent_profile["get_arguments_start"] or + start_time), "ph": "s", "name": "SubmitTask", "args": {}, @@ -702,14 +701,16 @@ class GlobalState(object): if parent_info: parent_worker = workers[parent_info["worker_id"]] parent_times = self._get_times(parent_info) + parent_profile = task_info.get( + task_table[task_id]["TaskSpec"]["ParentTaskID"]) parent = { "cat": "submit_task", "pid": "Node " + parent_worker["node_ip_address"], "tid": parent_info["worker_id"], - "ts": micros_rel(task_profiles[task_table[task_id] - ["TaskSpec"] - ["ParentTaskID"]] - ["get_arguments_start"]), + "ts": micros_rel( + parent_profile and + parent_profile["get_arguments_start"] or + start_time), "ph": "s", "name": "SubmitTask", "args": {}, @@ -745,35 +746,36 @@ class GlobalState(object): seen_obj[arg] = 0 seen_obj[arg] += 1 owner_task = self._object_table(arg)["TaskID"] - owner_worker = (workers[ - task_profiles[owner_task]["worker_id"]]) - # Adding/subtracting 2 to the time associated with - # the beginning/ending of the flow event is - # necessary to make the flow events show up - # reliably. When these times are exact, this is - # presumably an edge case, and catapult doesn't - # recognize that there is a duration event at that - # exact point in time that the flow event should be - # bound to. This issue is solved by adding the 2 ms - # to the start/end time of the flow event, which - # guarantees overlap with the duration event that - # it's associated with, and the flow event - # therefore always gets drawn. - owner = { - "cat": "obj_dependency", - "pid": ("Node " + - owner_worker["node_ip_address"]), - "tid": task_profiles[owner_task]["worker_id"], - "ts": micros_rel(task_profiles[ - owner_task]["store_outputs_end"]) - 2, - "ph": "s", - "name": "ObjectDependency", - "args": {}, - "bp": "e", - "cname": "cq_build_attempt_failed", - "id": "obj" + str(arg) + str(seen_obj[arg]) - } - full_trace.append(owner) + if owner_task in task_info: + owner_worker = (workers[ + task_info[owner_task]["worker_id"]]) + # Adding/subtracting 2 to the time associated + # with the beginning/ending of the flow event + # is necessary to make the flow events show up + # reliably. When these times are exact, this is + # presumably an edge case, and catapult doesn't + # recognize that there is a duration event at + # that exact point in time that the flow event + # should be bound to. This issue is solved by + # adding the 2 ms to the start/end time of the + # flow event, which guarantees overlap with the + # duration event that it's associated with, and + # the flow event therefore always gets drawn. + owner = { + "cat": "obj_dependency", + "pid": ("Node " + + owner_worker["node_ip_address"]), + "tid": task_info[owner_task]["worker_id"], + "ts": micros_rel(task_info[ + owner_task]["store_outputs_end"]) - 2, + "ph": "s", + "name": "ObjectDependency", + "args": {}, + "bp": "e", + "cname": "cq_build_attempt_failed", + "id": "obj" + str(arg) + str(seen_obj[arg]) + } + full_trace.append(owner) dependent = { "cat": "obj_dependency", diff --git a/python/ray/experimental/ui.py b/python/ray/experimental/ui.py index d1da9702c..860f0ba4d 100644 --- a/python/ray/experimental/ui.py +++ b/python/ray/experimental/ui.py @@ -205,24 +205,23 @@ def get_sliders(update): # box values. # (Querying based on the % total amount of time.) if breakdown_opt.value == total_time_value: - tasks = ray.global_state.task_profiles(start=(smallest + - diff * low), - end=(smallest + - diff * high)) + tasks = _truncated_task_profiles(start=(smallest + + diff * low), + end=(smallest + + diff * high)) # (Querying based on % of total number of tasks that were # run.) elif breakdown_opt.value == total_tasks_value: if range_slider.value[0] == 0: - tasks = ray.global_state.task_profiles(num_tasks=(int( - num_tasks * - high)), - fwd=True) + tasks = _truncated_task_profiles(num_tasks=(int( + num_tasks * high)), + fwd=True) else: - tasks = ray.global_state.task_profiles(num_tasks=(int( - num_tasks * - (high - low))), - fwd=False) + tasks = _truncated_task_profiles(num_tasks=(int( + num_tasks * + (high - low))), + fwd=False) update(smallest, largest, num_tasks, tasks) @@ -277,6 +276,26 @@ def task_search_bar(): task_search.on_submit(handle_submit) +# Hard limit on the number of tasks to return to the UI client at once +MAX_TASKS_TO_VISUALIZE = 10000 + + +# Wrapper that enforces a limit on the number of tasks to visualize +def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True): + if num_tasks is None: + num_tasks = MAX_TASKS_TO_VISUALIZE + print( + "Warning: at most {} tasks will be fetched within this " + "time range.".format(MAX_TASKS_TO_VISUALIZE)) + elif num_tasks > MAX_TASKS_TO_VISUALIZE: + print( + "Warning: too many tasks to visualize, " + "fetching only the first {} of {}.".format( + MAX_TASKS_TO_VISUALIZE, num_tasks)) + num_tasks = MAX_TASKS_TO_VISUALIZE + return ray.global_state.task_profiles(num_tasks, start, end, fwd) + + # Helper function that guarantees unique and writeable temp files. # Prevents clashes in task trace files when multiple notebooks are running. def _get_temp_file_path(**kwargs): @@ -347,30 +366,30 @@ def task_timeline(): diff = largest - smallest if time_opt.value == total_time_value: - tasks = ray.global_state.task_profiles(start=smallest + diff * low, - end=smallest + diff * high) + tasks = _truncated_task_profiles(start=smallest + diff * low, + end=smallest + diff * high) elif time_opt.value == total_tasks_value: if range_slider.value[0] == 0: - tasks = ray.global_state.task_profiles(num_tasks=int( - num_tasks * high), - fwd=True) + tasks = _truncated_task_profiles(num_tasks=int( + num_tasks * high), + fwd=True) else: - tasks = ray.global_state.task_profiles(num_tasks=int( - num_tasks * - (high - low)), - fwd=False) + tasks = _truncated_task_profiles(num_tasks=int( + num_tasks * (high - low)), + fwd=False) else: raise ValueError("Unexpected time value '{}'".format( time_opt.value)) # Write trace to a JSON file - print("{} tasks to trace".format(len(tasks))) - print("Dumping task profiling data to " + json_tmp) + print("Collected profiles for {} tasks.".format(len(tasks))) + print( + "Dumping task profile data to {}, " + "this might take a while...".format(json_tmp)) ray.global_state.dump_catapult_trace(json_tmp, tasks, breakdowns=breakdown, obj_dep=obj_dep.value, task_dep=task_dep.value) - print("Opening html file in browser...") trace_viewer_path = os.path.join( diff --git a/test/runtest.py b/test/runtest.py index 673bd3a41..361416e60 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1661,10 +1661,10 @@ class GlobalStateAPI(unittest.TestCase): # Make sure the event log has the correct number of events. start_time = time.time() while time.time() - start_time < 10: - profiles = ray.global_state.task_profiles(start=0, end=time.time()) - limited_profiles = ray.global_state.task_profiles(start=0, - end=time.time(), - num_tasks=1) + profiles = ray.global_state.task_profiles( + 100, start=0, end=time.time()) + limited_profiles = ray.global_state.task_profiles(1, start=0, + end=time.time()) if len(profiles) == num_calls and len(limited_profiles) == 1: break time.sleep(0.1) @@ -1729,7 +1729,8 @@ class GlobalStateAPI(unittest.TestCase): ray.get([actor.method.remote() for actor in actors]) path = os.path.join("/tmp/ray_test_trace") - task_info = ray.global_state.task_profiles(start=0, end=time.time()) + task_info = ray.global_state.task_profiles( + 100, start=0, end=time.time()) ray.global_state.dump_catapult_trace(path, task_info) # TODO(rkn): This test is not perfect because it does not verify that