Files
ray/python/ray/experimental/sgd/util.py
T
Eric Liang 3267676994 [Experimental] Add experimental distributed SGD API (#2858)
* check in sgd api

* idx

* foreach_worker foreach_model

* add feed_dict

* update

* yapf

* typo

* lint

* plasma op change

* fix plasma op

* still not working

* fix

* fix

* comments

* yapf

* silly flake8

* small test
2018-09-19 21:12:37 -07:00

125 lines
3.5 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import time
import tensorflow as tf
import ray
logger = logging.getLogger(__name__)
def fetch(oids):
if ray.global_state.use_raylet:
local_sched_client = ray.worker.global_worker.local_scheduler_client
for o in oids:
ray_obj_id = ray.ObjectID(o)
local_sched_client.reconstruct_objects([ray_obj_id], True)
else:
for o in oids:
plasma_id = ray.pyarrow.plasma.ObjectID(o)
ray.worker.global_worker.plasma_client.fetch([plasma_id])
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
feed_dict = feed_dict or {}
if write_timeline:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
fetches = sess.run(
ops,
options=run_options,
run_metadata=run_metadata,
feed_dict=feed_dict)
trace = Timeline(step_stats=run_metadata.step_stats)
outf = "timeline-{}-{}.json".format(name, os.getpid())
trace_file = open(outf, "w")
logger.info("wrote tf timeline to", os.path.abspath(outf))
trace_file.write(trace.generate_chrome_trace_format())
else:
fetches = sess.run(ops, feed_dict=feed_dict)
return fetches
class Timeline(object):
def __init__(self, tid):
self.events = []
self.offset = 0
self.start_time = self.time()
self.tid = tid
def patch_ray(self):
orig_log = ray.worker.log
def custom_log(event_type, kind, *args, **kwargs):
orig_log(event_type, kind, *args, **kwargs)
if kind == ray.worker.LOG_SPAN_START:
self.start(event_type)
elif kind == ray.worker.LOG_SPAN_END:
self.end(event_type)
elif kind == ray.worker.LOG_SPAN_POINT:
self.event(event_type)
ray.worker.log = custom_log
def time(self):
return time.time() + self.offset
def reset(self):
self.events = []
self.start_time = self.time()
def start(self, name):
self.events.append((self.tid, "B", name, self.time()))
def end(self, name):
self.events.append((self.tid, "E", name, self.time()))
def event(self, name):
now = self.time()
self.events.append((self.tid, "B", name, now))
self.events.append((self.tid, "E", name, now + .0001))
def merge(self, other):
if other.start_time < self.start_time:
self.start_time = other.start_time
self.events.extend(other.events)
self.events.sort(key=lambda e: e[3])
def chrome_trace_format(self, filename):
out = []
for tid, ph, name, t in self.events:
ts = int((t - self.start_time) * 1000000)
out.append({
"name": name,
"tid": tid,
"pid": tid,
"ph": ph,
"ts": ts,
})
with open(filename, "w") as f:
f.write(json.dumps(out))
logger.info("Wrote chrome timeline to", filename)
if __name__ == "__main__":
a = Timeline(1)
b = Timeline(2)
a.start("hi")
time.sleep(.1)
b.start("bye")
a.start("hi3")
time.sleep(.1)
a.end("hi3")
b.end("bye")
time.sleep(.1)
a.end("hi")
b.start("b1")
b.end("b1")
a.merge(b)
a.chrome_trace_format("test.json")