mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:50:55 +08:00
Visualize computation graph
This commit is contained in:
@@ -8,6 +8,6 @@ PYTHON_MODE = 3
|
||||
|
||||
import libraylib as lib
|
||||
import serialization
|
||||
from worker import scheduler_info, dump_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, connect, disconnect, get, put, remote, kill_workers
|
||||
from libraylib import ObjRef
|
||||
import internal
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# Utilities to deal with computation graphs
|
||||
|
||||
import graphviz
|
||||
|
||||
def graph_to_graphviz(computation_graph):
|
||||
"""
|
||||
Convert the computation graph to graphviz format.
|
||||
|
||||
Args:
|
||||
computation_graph [graph_pb2.CompGraph]: protocol buffer description of
|
||||
the computation graph
|
||||
|
||||
Returns:
|
||||
Graphviz description of the computation graph
|
||||
"""
|
||||
dot = graphviz.Digraph(format="pdf")
|
||||
dot.node("op-root", shape="box")
|
||||
for (i, op) in enumerate(computation_graph.operation):
|
||||
if op.HasField("task"):
|
||||
dot.node("op" + str(i), shape="box", label=str(i) + "\n" + op.task.name.split(".")[-1])
|
||||
for res in op.task.result:
|
||||
dot.edge("op" + str(i), str(res))
|
||||
elif op.HasField("put"):
|
||||
dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "put")
|
||||
dot.edge("op" + str(i), str(op.put.objref))
|
||||
elif op.HasField("get"):
|
||||
dot.node("op" + str(i), shape="box", label=str(i) + "\n" + "get")
|
||||
creator_operationid = op.creator_operationid if op.creator_operationid != 2 ** 64 - 1 else "-root"
|
||||
dot.edge("op" + str(creator_operationid), "op" + str(i), style="dotted", constraint="false")
|
||||
for arg in op.task.arg:
|
||||
if not arg.HasField("obj"):
|
||||
dot.node(str(arg.ref))
|
||||
dot.edge(str(arg.ref), "op" + str(i))
|
||||
return dot
|
||||
@@ -12,6 +12,8 @@ import copy
|
||||
import ray
|
||||
from ray.config import LOG_DIRECTORY, LOG_TIMESTAMP
|
||||
import serialization
|
||||
import ray.internal.graph_pb2
|
||||
import ray.graph
|
||||
|
||||
class RayFailedObject(object):
|
||||
"""If a task throws an exception during execution, a RayFailedObject is stored in the object store for each of the tasks outputs."""
|
||||
@@ -143,8 +145,42 @@ def print_task_info(task_data, mode):
|
||||
def scheduler_info(worker=global_worker):
|
||||
return ray.lib.scheduler_info(worker.handle);
|
||||
|
||||
def dump_computation_graph(file_name, worker=global_worker):
|
||||
ray.lib.dump_computation_graph(worker.handle, file_name)
|
||||
def visualize_computation_graph(file_path=None, view=False, worker=global_worker):
|
||||
"""
|
||||
Write the computation graph to a pdf file.
|
||||
|
||||
Args:
|
||||
file_path: A .pdf file that the rendered computation graph will be written to
|
||||
|
||||
view: If true, the result the python graphviz package will try to open the
|
||||
result in a viewer
|
||||
|
||||
Example:
|
||||
In ray/scripts, call "python shell.py" and paste in the following code.
|
||||
|
||||
x = da.zeros([20, 20])
|
||||
y = da.zeros([20, 20])
|
||||
z = da.dot(x, y)
|
||||
|
||||
ray.visualize_computation_graph("computation_graph.pdf")
|
||||
"""
|
||||
|
||||
if file_path is None:
|
||||
file_path = os.path.join(ray.config.LOG_DIRECTORY, (ray.config.LOG_TIMESTAMP + "-computation-graph.pdf").format(datetime.datetime.now()))
|
||||
|
||||
base_path, extension = os.path.splitext(file_path)
|
||||
if extension != ".pdf":
|
||||
raise Exception("File path must be a .pdf file")
|
||||
proto_path = base_path + ".binaryproto"
|
||||
|
||||
ray.lib.dump_computation_graph(worker.handle, proto_path)
|
||||
graph = ray.internal.graph_pb2.CompGraph()
|
||||
graph.ParseFromString(open(proto_path).read())
|
||||
ray.graph.graph_to_graphviz(graph).render(base_path, view=view)
|
||||
|
||||
print "Wrote graph dot description to file {}".format(base_path)
|
||||
print "Wrote graph protocol buffer description to file {}".format(proto_path)
|
||||
print "Wrote computation graph to file {}.pdf".format(base_path)
|
||||
|
||||
def task_info(worker=global_worker):
|
||||
"""Tell the scheduler to return task information. Currently includes a list of all failed tasks since the start of the cluster."""
|
||||
|
||||
Reference in New Issue
Block a user