mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 02:30:34 +08:00
3267676994
* check in sgd api * idx * foreach_worker foreach_model * add feed_dict * update * yapf * typo * lint * plasma op change * fix plasma op * still not working * fix * fix * comments * yapf * silly flake8 * small test
125 lines
3.5 KiB
Python
125 lines
3.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import tensorflow as tf
|
|
|
|
import ray
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def fetch(oids):
|
|
if ray.global_state.use_raylet:
|
|
local_sched_client = ray.worker.global_worker.local_scheduler_client
|
|
for o in oids:
|
|
ray_obj_id = ray.ObjectID(o)
|
|
local_sched_client.reconstruct_objects([ray_obj_id], True)
|
|
else:
|
|
for o in oids:
|
|
plasma_id = ray.pyarrow.plasma.ObjectID(o)
|
|
ray.worker.global_worker.plasma_client.fetch([plasma_id])
|
|
|
|
|
|
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
|
|
feed_dict = feed_dict or {}
|
|
if write_timeline:
|
|
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
|
run_metadata = tf.RunMetadata()
|
|
fetches = sess.run(
|
|
ops,
|
|
options=run_options,
|
|
run_metadata=run_metadata,
|
|
feed_dict=feed_dict)
|
|
trace = Timeline(step_stats=run_metadata.step_stats)
|
|
outf = "timeline-{}-{}.json".format(name, os.getpid())
|
|
trace_file = open(outf, "w")
|
|
logger.info("wrote tf timeline to", os.path.abspath(outf))
|
|
trace_file.write(trace.generate_chrome_trace_format())
|
|
else:
|
|
fetches = sess.run(ops, feed_dict=feed_dict)
|
|
return fetches
|
|
|
|
|
|
class Timeline(object):
|
|
def __init__(self, tid):
|
|
self.events = []
|
|
self.offset = 0
|
|
self.start_time = self.time()
|
|
self.tid = tid
|
|
|
|
def patch_ray(self):
|
|
orig_log = ray.worker.log
|
|
|
|
def custom_log(event_type, kind, *args, **kwargs):
|
|
orig_log(event_type, kind, *args, **kwargs)
|
|
if kind == ray.worker.LOG_SPAN_START:
|
|
self.start(event_type)
|
|
elif kind == ray.worker.LOG_SPAN_END:
|
|
self.end(event_type)
|
|
elif kind == ray.worker.LOG_SPAN_POINT:
|
|
self.event(event_type)
|
|
|
|
ray.worker.log = custom_log
|
|
|
|
def time(self):
|
|
return time.time() + self.offset
|
|
|
|
def reset(self):
|
|
self.events = []
|
|
self.start_time = self.time()
|
|
|
|
def start(self, name):
|
|
self.events.append((self.tid, "B", name, self.time()))
|
|
|
|
def end(self, name):
|
|
self.events.append((self.tid, "E", name, self.time()))
|
|
|
|
def event(self, name):
|
|
now = self.time()
|
|
self.events.append((self.tid, "B", name, now))
|
|
self.events.append((self.tid, "E", name, now + .0001))
|
|
|
|
def merge(self, other):
|
|
if other.start_time < self.start_time:
|
|
self.start_time = other.start_time
|
|
self.events.extend(other.events)
|
|
self.events.sort(key=lambda e: e[3])
|
|
|
|
def chrome_trace_format(self, filename):
|
|
out = []
|
|
for tid, ph, name, t in self.events:
|
|
ts = int((t - self.start_time) * 1000000)
|
|
out.append({
|
|
"name": name,
|
|
"tid": tid,
|
|
"pid": tid,
|
|
"ph": ph,
|
|
"ts": ts,
|
|
})
|
|
with open(filename, "w") as f:
|
|
f.write(json.dumps(out))
|
|
logger.info("Wrote chrome timeline to", filename)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
a = Timeline(1)
|
|
b = Timeline(2)
|
|
a.start("hi")
|
|
time.sleep(.1)
|
|
b.start("bye")
|
|
a.start("hi3")
|
|
time.sleep(.1)
|
|
a.end("hi3")
|
|
b.end("bye")
|
|
time.sleep(.1)
|
|
a.end("hi")
|
|
b.start("b1")
|
|
b.end("b1")
|
|
a.merge(b)
|
|
a.chrome_trace_format("test.json")
|