Unifying protocol and messaging.

This commit is contained in:
Stephen Diehl
2012-02-24 12:51:46 -05:00
parent d7a707f7b8
commit ffc2e34334
5 changed files with 183 additions and 77 deletions
+88 -43
View File
@@ -1,10 +1,13 @@
"""
Commonly used messaging components.
"""
import json
import os
import uuid
import datetime
import socket
import humanhash
import zipline.util as qutil
from zipline.protocol import CONTROL_PROTOCOL
class Component(object):
@@ -14,8 +17,6 @@ class Component(object):
- sync_address: socket address used for synchronizing the start of all workers, heartbeating, and exit notification
will be used in REP/REQ sockets. Bind is always on the REP side.
- control_address: socket address used for controlling and
monitoring the status of the simulation
- data_address: socket address used for data sources to stream their records.
will be used in PUSH/PULL sockets between data sources and a ParallelBuffer (aka the Feed). Bind
will always be on the PULL side (we always have N producers and 1 consumer)
@@ -31,15 +32,22 @@ class Component(object):
will also return a Poller.
"""
self.zmq = None
self.context = None
self.addresses = None
self.out_socket = None
self.gevent_needed = False
self.killed = False
self.zmq = None
self.context = None
self.addresses = None
self.out_socket = None
self.gevent_needed = False
self.killed = False
self.heartbeat_timeout = 2000
# TODO: could probably mkae this into a property instead of a
# method
self.guid = uuid.uuid4()
self.huid = humanhash.humanize(self.guid.hex)
# ------------
# Core Methods
# ------------
@property
def get_id(self):
raise NotImplementedError
@@ -62,17 +70,12 @@ class Component(object):
def do_work(self):
raise NotImplementedError
def run(self):
fail = None
#try:
#TODO: can't initialize these values in the __init__?
def _run(self):
self.done = False
self.sockets = []
if self.gevent_needed:
qutil.LOGGER.info("Loading gevent specific zmq for {id}".format(id=self.get_id()))
qutil.LOGGER.info("Loading gevent specific zmq for {id}".format(id=self.get_id))
import gevent_zeromq
self.zmq = gevent_zeromq.zmq
else:
@@ -89,50 +92,68 @@ class Component(object):
for sock in self.sockets:
sock.close()
#except Exception as e:
#qutil.LOGGER.exception("Unexpected error in run for {id}.".format(id=self.get_id()))
#fail = e
def run(self, catch_exceptions=False):
#finally:
fail = None
#if(self.context != None):
#self.context.destroy()
#if fail:
#raise fail
# Catching all exceptions makes this really hard to
# debug, is it with care.
if catch_exceptions:
try:
self._run()
except Exception as e:
qutil.LOGGER.exception("Unexpected error in run for {id}.".format(id=self.get_id))
fail = e
finally:
if(self.context != None):
self.context.destroy()
if fail:
raise fail
else:
self._run()
if(self.context != None):
self.context.destroy()
def loop(self):
while not self.done:
self.confirm()
self.do_work()
# -----------
# Messaging
# -----------
def signal_done(self):
#notify down stream components that we're done
if(self.out_socket != None):
self.out_socket.send("DONE")
self.out_socket.send(str(CONTROL_PROTOCOL.DONE))
#notify host we're done
self.sync_socket.send(self.get_id() + ":DONE")
# TODO: proper framing
self.sync_socket.send(self.get_id + ":" + str(CONTROL_PROTOCOL.DONE))
self.receive_sync_ack()
#notify internal work look that we're done
self.done = True
# TODO: probably don't need a method here ... or move into
# higher level framing protocol
def is_done_message(self, message):
return message == "DONE"
def confirm(self):
# send a synchronization request to the host
self.sync_socket.send(self.get_id() + ":RUN")
self.receive_sync_ack()
# TODO: proper framing
self.sync_socket.send(self.get_id + ":RUN")
self.receive_sync_ack() # blocking
def receive_sync_ack(self):
# wait for synchronization reply from the host
socks = dict(self.sync_poller.poll(2000)) #timeout after 2 seconds.
"""
Wait for synchronization reply from the host.
"""
socks = dict(self.sync_poller.poll(self.heartbeat_timeout))
if self.sync_socket in socks and socks[self.sync_socket] == self.zmq.POLLIN:
message = self.sync_socket.recv()
else:
raise Exception("Sync ack timed out on response for {id}".format(id=self.get_id()))
raise Exception("Sync ack timed out on response for {id}".format(id=self.get_id))
def bind_data(self):
return self.bind_pull_socket(self.addresses['data_address'])
@@ -164,6 +185,7 @@ class Component(object):
poller = self.zmq.Poller()
poller.register(pull_socket, self.zmq.POLLIN)
self.sockets.append(pull_socket)
return pull_socket, poller
def connect_push_socket(self, addr):
@@ -172,6 +194,7 @@ class Component(object):
#push_socket.setsockopt(self.zmq.LINGER,0)
self.sockets.append(push_socket)
self.out_socket = push_socket
return push_socket
def bind_pub_socket(self, addr):
@@ -179,15 +202,19 @@ class Component(object):
pub_socket.bind(addr)
#pub_socket.setsockopt(self.zmq.LINGER,0)
self.out_socket = pub_socket
return pub_socket
def connect_sub_socket(self, addr):
sub_socket = self.context.socket(self.zmq.SUB)
sub_socket.connect(addr)
sub_socket.setsockopt(self.zmq.SUBSCRIBE,'')
self.sockets.append(sub_socket)
poller = self.zmq.Poller()
poller.register(sub_socket, self.zmq.POLLIN)
self.sockets.append(sub_socket)
# TODO: migrate tuple unpacking to be consistent
return sub_socket, poller
def setup_control(self):
@@ -196,10 +223,10 @@ class Component(object):
overall status of the simulation and to forcefully tear
down the simulation in case of a failure.
"""
pass
assert self.controller
def setup_sync(self):
qutil.LOGGER.debug("Connecting sync client for {id}".format(id=self.get_id()))
qutil.LOGGER.debug("Connecting sync client for {id}".format(id=self.get_id))
self.sync_socket = self.context.socket(self.zmq.REQ)
self.sync_socket.connect(self.addresses['sync_address'])
@@ -208,3 +235,21 @@ class Component(object):
self.sync_poller.register(self.sync_socket, self.zmq.POLLIN)
self.sockets.append(self.sync_socket)
def debug(self):
return (
self.get_id ,
self.huid ,
socket.gethostname() ,
os.getpid() ,
hex(id(self)) ,
)
def __repr__(self):
return "<{name} {uuid} at {host} {pid} {pointer}>".format(
name = self.get_id ,
uuid = self.huid ,
host = socket.gethostname() ,
pid = os.getpid() ,
pointer = hex(id(self)) ,
)
+32 -18
View File
@@ -7,17 +7,20 @@ import datetime
import zipline.util as qutil
from zipline.component import Component
from zipline.protocol import CONTROL_PROTOCOL
class ComponentHost(Component):
"""
Components that can launch multiple sub-components, synchronize their start, and then wait for all
components to be finished.
Components that can launch multiple sub-components, synchronize their
start, and then wait for all components to be finished.
"""
def __init__(self, addresses, gevent_needed=False):
Component.__init__(self)
self.addresses = addresses
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
# workaround for defect in threaded use of strptime:
# http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
self.components = {}
@@ -47,13 +50,13 @@ class ComponentHost(Component):
if self.controller:
component.controller = self.controller
self.components[component.get_id()] = component
self.sync_register[component.get_id()] = datetime.datetime.utcnow()
self.components[component.get_id] = component
self.sync_register[component.get_id] = datetime.datetime.utcnow()
if(isinstance(component, DataSource)):
self.feed.add_source(component.get_id())
self.feed.add_source(component.get_id)
if(isinstance(component, BaseTransform)):
self.merge.add_source(component.get_id())
self.merge.add_source(component.get_id)
def unregister_component(self, component_id):
del self.components[component_id]
@@ -97,15 +100,19 @@ class ComponentHost(Component):
if self.sync_socket in socks and socks[self.sync_socket] == self.zmq.POLLIN:
msg = self.sync_socket.recv()
parts = msg.split(':')
if(len(parts) < 2):
if len(parts) != 2:
qutil.LOGGER.info("got bad confirm: {msg}".format(msg=msg))
sync_id = parts[0]
status = parts[1]
if(self.is_done_message(status)):
continue
sync_id, status = parts
if status == str(CONTROL_PROTOCOL.DONE): # TODO: other way around
qutil.LOGGER.info("{id} is DONE".format(id=sync_id))
self.unregister_component(sync_id)
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack', self.zmq.NOBLOCK)
@@ -119,9 +126,10 @@ class ComponentHost(Component):
class ParallelBuffer(Component):
"""
Connects to N PULL sockets, publishing all messages received to a PUB socket.
Published messages are guaranteed to be in chronological order based on message property dt.
Expects to be instantiated in one execution context (thread, process, etc) and run in another.
Connects to N PULL sockets, publishing all messages received to a PUB
socket. Published messages are guaranteed to be in chronological order
based on message property dt. Expects to be instantiated in one execution
context (thread, process, etc) and run in another.
"""
def __init__(self):
@@ -133,6 +141,7 @@ class ParallelBuffer(Component):
self.ds_finished_counter = 0
@property
def get_id(self):
return "FEED"
@@ -149,7 +158,7 @@ class ParallelBuffer(Component):
if self.pull_socket in socks and socks[self.pull_socket] == self.zmq.POLLIN:
message = self.pull_socket.recv()
if self.is_done_message(message):
if message == str(CONTROL_PROTOCOL.DONE):
self.ds_finished_counter += 1
if len(self.data_buffer) == self.ds_finished_counter:
#drain any remaining messages in the buffer
@@ -262,6 +271,7 @@ class MergedParallelBuffer(ParallelBuffer):
result[source] = cur['value']
return result
@property
def get_id(self):
return "MERGE"
@@ -283,6 +293,7 @@ class BaseTransform(Component):
self.state = {}
self.state['name'] = name
@property
def get_id(self):
return self.state['name']
@@ -305,7 +316,7 @@ class BaseTransform(Component):
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == self.zmq.POLLIN:
message = self.feed_socket.recv()
if self.is_done_message(message):
if message == str(CONTROL_PROTOCOL.DONE):
self.signal_done()
return
@@ -313,6 +324,7 @@ class BaseTransform(Component):
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['id'] = self.state['name']
self.result_socket.send(json.dumps(cur_state), self.zmq.NOBLOCK)
def transform(self, event):
@@ -321,8 +333,9 @@ class BaseTransform(Component):
{name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
Transforms run in parallel and results are merged into a single map, so
transform names must be unique. Best practice is to use the self.state
object initialized from the transform configuration, and only set the
transformed value::
self.state['value'] = transformed_value
@@ -350,6 +363,7 @@ class DataSource(Component):
self.id = source_id
self.cur_event = None
@property
def get_id(self):
return self.id
+29
View File
@@ -1,3 +1,32 @@
#import msgpack
#import ujson
#import ultrajson_numpy
from ctypes import Structure, c_ubyte
def Enum(*options):
"""
Fast enums are very important when we want really tight zmq
loops. These are probably going to evolve into pure C structs
anyways so might as well get going on that.
"""
class cstruct(Structure):
_fields_ = [(o, c_ubyte) for o in options]
return cstruct(*range(len(options)))
CONTROL_PROTOCOL = Enum(
'INIT' , # 0 - req
'INFO' , # 1 - req
'STATUS' , # 2 - req
'SHUTDOWN' , # 3 - req
'KILL' , # 4 - req
'OK' , # 5 - rep
'DONE' , # 6 - rep
'EXCEPTION' , # 7 - rep
)
HEARTBEAT_PROTOCOL = {
'REQ' : '\x01',
'REP' : '\x02',
}
+6 -1
View File
@@ -2,6 +2,8 @@ import json
import zipline.util as qutil
import zipline.messaging as qmsg
from zipline.protocol import CONTROL_PROTOCOL
class TestClient(qmsg.Component):
def __init__(self, utest, expected_msg_count=0):
@@ -11,6 +13,7 @@ class TestClient(qmsg.Component):
self.utest = utest
self.prev_dt = None
@property
def get_id(self):
return "TEST_CLIENT"
@@ -19,9 +22,11 @@ class TestClient(qmsg.Component):
def do_work(self):
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.data_feed in socks and socks[self.data_feed] == self.zmq.POLLIN:
msg = self.data_feed.recv()
if(self.is_done_message(msg)):
if msg == str(CONTROL_PROTOCOL.DONE):
qutil.LOGGER.info("Client is DONE!")
self.signal_done()
self.utest.assertEqual(self.expected_msg_count, self.received_count,
+28 -15
View File
@@ -7,32 +7,45 @@ import datetime
import pytz
import logging
LOGGER = logging.getLogger('QSimLogger')
def configure_logging(loglevel=logging.DEBUG):
"""
Configures zipline.util.LOGGER to write a rotating file
(10M per file, 5 files) to `` /var/log/zipline.log ``.
"""
LOGGER.setLevel(loglevel)
handler = logging.handlers.RotatingFileHandler(
"/var/log/zipline/{lfn}.log".format(lfn="zipline"),
maxBytes=10*1024*1024, backupCount=5
)
handler.setFormatter(logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s %(funcName)s - %(message)s",
"%Y-%m-%d %H:%M:%S %Z")
)
LOGGER.addHandler(handler)
LOGGER.info("logging started...")
def parse_date(dt_str):
"""parse strings according to the same format as generated by format_date"""
"""
Parse strings according to the same format as generated by
format_date.
"""
if(dt_str == None):
return None
parts = dt_str.split(".")
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc)
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(
microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc
)
return dt
def format_date(dt):
"""Format the date into a date with millesecond resolution and string/alphabetical
sorting that is equivalent to datetime sorting"""
"""
Format the date into a date with millesecond resolution and
string/alphabetical sorting that is equivalent to datetime sorting.
"""
if(dt == None):
return None
dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000)
return dt_str
def configure_logging(loglevel=logging.DEBUG):
"""configures zipline.util.LOGGER to write a rotating file (10M per file, 5 files) to /var/log/zipline.log"""
LOGGER.setLevel(loglevel)
handler = logging.handlers.RotatingFileHandler(
"/var/log/zipline/{lfn}.log".format(lfn="zipline"),
maxBytes=10*1024*1024, backupCount=5)
handler.setFormatter(logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s %(funcName)s - %(message)s","%Y-%m-%d %H:%M:%S %Z"))
LOGGER.addHandler(handler)
LOGGER.info("logging started...")