diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 636095535..06f31a97f 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -386,3 +386,130 @@ class GlobalState(object): if "function_name" in event[3]: task_info[task_id]["function_name"] = event[3]["function_name"] return task_info + + def dump_catapult_trace(self, path): + """Dump task profiling information to a file. + + This information can be viewed as a timeline of profiling information by + going to chrome://tracing in the chrome web browser and loading the + appropriate file. + + Args: + path: The filepath to dump the profiling information to. + """ + task_info = self.task_profiles() + workers = self.workers() + tasks = self.task_table() + start_time = None + for info in task_info.values(): + task_start = min(self._get_times(info)) + if not start_time or task_start < start_time: + start_time = task_start + + def micros(ts): + return int(1e6 * (ts - start_time)) + + full_trace = [] + + for task_id, info in task_info.items(): + parent_info = task_info.get(tasks[task_id]["TaskSpec"]["ParentTaskID"]) + times = self._get_times(info) + worker = workers[info["worker_id"]] + + if parent_info: + parent_worker = workers[parent_info["worker_id"]] + parent_times = self._get_times(parent_info) + parent_trace = { + "cat": "submit_task", + "pid": "Node " + str(parent_worker["node_ip_address"]), + "tid": parent_info["worker_id"], + "ts": micros(min(parent_times)), + "ph": "s", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(parent_trace) + + parent = { + "cat": "submit_task", + "pid": "Node " + str(parent_worker["node_ip_address"]), + "tid": parent_info["worker_id"], + "ts": micros(min(parent_times)), + "ph": "s", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(parent) + + task_trace = { + "cat": "submit_task", + "pid": "Node " + str(worker["node_ip_address"]), + "tid": info["worker_id"], + "ts": micros(min(times)), + "ph": "f", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(task_trace) + + task = { + "name": info["function_name"], + "cat": "ray_task", + "ph": "X", + "ts": micros(min(times)), + "dur": micros(max(times)) - micros(min(times)), + "pid": "Node " + str(worker["node_ip_address"]), + "tid": info["worker_id"], + "args": info + } + full_trace.append(task) + + with open(path, "w") as outfile: + json.dump(full_trace, outfile) + task_info + + def _get_times(self, data): + """Extract the numerical times from a task profile. + + This is a helper method for dump_catapult_trace. + + Args: + data: This must be a value in the dictionary returned by the + task_profiles function. + """ + all_times = [] + all_times.append(data["acquire_lock_start"]) + all_times.append(data["acquire_lock_end"]) + all_times.append(data["get_arguments_start"]) + all_times.append(data["get_arguments_end"]) + all_times.append(data["execute_start"]) + all_times.append(data["execute_end"]) + all_times.append(data["store_outputs_start"]) + all_times.append(data["store_outputs_end"]) + return all_times + + def workers(self): + """Get a dictionary mapping worker ID to worker information.""" + worker_keys = self.redis_client.keys("Worker*") + workers_data = dict() + + for worker_key in worker_keys: + worker_info = self.redis_client.hgetall(worker_key) + worker_id = binary_to_hex(worker_key[len("Workers:"):]) + + workers_data[worker_id] = { + "local_scheduler_socket": (worker_info[b"local_scheduler_socket"] + .decode("ascii")), + "node_ip_address": (worker_info[b"node_ip_address"] + .decode("ascii")), + "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] + .decode("ascii")), + "plasma_store_socket": (worker_info[b"plasma_store_socket"] + .decode("ascii")), + "stderr_file": worker_info[b"stderr_file"].decode("ascii"), + "stdout_file": worker_info[b"stdout_file"].decode("ascii") + } + return workers_data diff --git a/test/runtest.py b/test/runtest.py index 6efb0def6..a9acede0a 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1598,6 +1598,61 @@ class GlobalStateAPI(unittest.TestCase): ray.worker.cleanup() + def testWorkers(self): + num_workers = 3 + ray.init(redirect_output=True, num_cpus=num_workers, + num_workers=num_workers) + + @ray.remote + def f(): + return id(ray.worker.global_worker) + + # Wait until all of the workers have started. + worker_ids = set() + while len(worker_ids) != num_workers: + worker_ids = set(ray.get([f.remote() for _ in range(10)])) + + worker_info = ray.global_state.workers() + self.assertEqual(len(worker_info), num_workers) + for worker_id, info in worker_info.items(): + self.assertEqual(info["node_ip_address"], "127.0.0.1") + self.assertIn("local_scheduler_socket", info) + self.assertIn("plasma_manager_socket", info) + self.assertIn("plasma_store_socket", info) + self.assertIn("stderr_file", info) + self.assertIn("stdout_file", info) + + ray.worker.cleanup() + + def testDumpTraceFile(self): + ray.init(redirect_output=True) + + @ray.remote + def f(): + return 1 + + @ray.remote + class Foo(object): + def __init__(self): + pass + + def method(self): + pass + + ray.get([f.remote() for _ in range(10)]) + actors = [Foo.remote() for _ in range(5)] + ray.get([actor.method.remote() for actor in actors]) + ray.get([actor.method.remote() for actor in actors]) + + path = os.path.join("/tmp/ray_test_trace") + ray.global_state.dump_catapult_trace(path) + + # TODO(rkn): This test is not perfect because it does not verify that the + # visualization actually renders (e.g., the context of the dumped trace + # could be malformed). + + ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2)