diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 65913faa8..636095535 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -345,21 +345,44 @@ class GlobalState(object): executions of that task. The second element is a list of profiling information for tasks where the events have no task ID. """ + task_info = dict() event_names = self.redis_client.keys("event_log*") - results = dict() - events = [] for i in range(len(event_names)): event_list = self.redis_client.lrange(event_names[i], 0, -1) for event in event_list: - event_dict = json.loads(event.decode("ascii")) + event_dict = json.loads(event) task_id = "" - for element in event_dict: - if "task_id" in element[3]: - task_id = element[3]["task_id"] - if task_id != "": - if task_id not in results: - results[task_id] = [] - results[task_id].append(event_dict) - else: - events.append(event_dict) - return results, events + for event in event_dict: + if "task_id" in event[3]: + task_id = event[3]["task_id"] + task_info[task_id] = dict() + for event in event_dict: + if event[1] == "ray:get_task" and event[2] == 1: + task_info[task_id]["get_task_start"] = event[0] + if event[1] == "ray:get_task" and event[2] == 2: + task_info[task_id]["get_task_end"] = event[0] + if event[1] == "ray:import_remote_function" and event[2] == 1: + task_info[task_id]["import_remote_start"] = event[0] + if event[1] == "ray:import_remote_function" and event[2] == 2: + task_info[task_id]["import_remote_end"] = event[0] + if event[1] == "ray:acquire_lock" and event[2] == 1: + task_info[task_id]["acquire_lock_start"] = event[0] + if event[1] == "ray:acquire_lock" and event[2] == 2: + task_info[task_id]["acquire_lock_end"] = event[0] + if event[1] == "ray:task:get_arguments" and event[2] == 1: + task_info[task_id]["get_arguments_start"] = event[0] + if event[1] == "ray:task:get_arguments" and event[2] == 2: + task_info[task_id]["get_arguments_end"] = event[0] + if event[1] == "ray:task:execute" and event[2] == 1: + task_info[task_id]["execute_start"] = event[0] + if event[1] == "ray:task:execute" and event[2] == 2: + task_info[task_id]["execute_end"] = event[0] + if event[1] == "ray:task:store_outputs" and event[2] == 1: + task_info[task_id]["store_outputs_start"] = event[0] + if event[1] == "ray:task:store_outputs" and event[2] == 2: + task_info[task_id]["store_outputs_end"] = event[0] + if "worker_id" in event[3]: + task_info[task_id]["worker_id"] = event[3]["worker_id"] + if "function_name" in event[3]: + task_info[task_id]["function_name"] = event[3]["function_name"] + return task_info diff --git a/test/runtest.py b/test/runtest.py index bf4223c0a..6efb0def6 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1581,33 +1581,20 @@ class GlobalStateAPI(unittest.TestCase): # Make sure the event log has the correct number of events. start_time = time.time() while time.time() - start_time < 10: - profiles, events = ray.global_state.task_profiles() + profiles = ray.global_state.task_profiles() if len(profiles) == num_calls: break time.sleep(0.1) self.assertEqual(len(profiles), num_calls) - self.assertEqual(len(events), 0) # Make sure that each entry is properly formatted. - for task_id in profiles: - events_list = profiles[task_id] - # Make sure that the task was not executed more than once. - self.assertEqual(len(events_list), 1) - events = events_list[0] - for event in events: - found_exec = False - found_store = False - found_get = False - for event in events: - if event[1] == "ray:task:execute": - found_exec = True - if event[1] == "ray:task:get_arguments": - found_get = True - if event[1] == "ray:task:store_outputs": - found_store = True - self.assertTrue(found_exec) - self.assertTrue(found_store) - self.assertTrue(found_get) + for task_id, data in profiles.items(): + self.assertIn("execute_start", data) + self.assertIn("execute_end", data) + self.assertIn("get_arguments_start", data) + self.assertIn("get_arguments_end", data) + self.assertIn("store_outputs_start", data) + self.assertIn("store_outputs_end", data) ray.worker.cleanup()