send sockets have zero linger, sync object now serves as heartbeat, components die on error, all components error on heartbeat timeout

This commit is contained in:
fawce
2012-02-14 00:35:26 -05:00
parent 21afa42f3a
commit 2aecfc0010
7 changed files with 277 additions and 139 deletions
+141 -77
View File
@@ -6,6 +6,7 @@ import zmq
import json
import copy
import threading
import datetime
import qsim.util as qutil
import qsim.messaging as qmsg
@@ -15,7 +16,7 @@ class Simulator(object):
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
"""
def __init__(self, sources, transforms, client):
def __init__(self, sources, transforms, client, feed=None, merge=None):
"""
"""
self.sources = sources
@@ -25,7 +26,7 @@ class Simulator(object):
self.feed = None
self.context = None
self.sync_context = None
self.syncservice = None
self.sync_socket = None
self.sync_register = {}
self.sync_address = "tcp://127.0.0.1:{port}".format(port=10100)
self.data_address = "tcp://127.0.0.1:{port}".format(port=10101)
@@ -33,29 +34,45 @@ class Simulator(object):
self.merge_address = "tcp://127.0.0.1:{port}".format(port=10103)
self.result_address = "tcp://127.0.0.1:{port}".format(port=10104)
self.timeout = datetime.timedelta(seconds=1)
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
if(feed == None):
self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed"))
else:
self.feed = feed
if(merge == None):
#connect merge to feed, set expected transforms
self.merge = TransformsMerge(self.feed_address,
self.merge_address,
self.result_address,
qmsg.Sync(self,"TransformsMerge"),
self.transforms.keys())
else:
self.merge = merge
def simulate(self):
self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed"))
#launch the feed
self.launch_component("DataFeed", self.feed)
#launch the data sources
for name, data_source in self.sources.iteritems():
data_source.data_address = self.data_address
data_source.sync = qmsg.Sync(self, str(data_source.source_id))
self.launch_component(name, data_source)
qutil.LOGGER.info("datasources processes launched")
#connect all the transforms to the feed and merge
#connect all the transforms to the feed and merge, launch each
for name, transform in self.transforms.iteritems():
transform.feed_address = self.feed_address #connect transform to receive feed.
transform.merge_address = self.merge_address #connect transform to push results to merge
transform.sync = qmsg.Sync(self, name) #synchronize the transform against this simulation.
self.launch_component(name, transform) #start transforms
#connect merge to feed, set expected transforms
self.merge = TransformsMerge(self.feed_address,
self.merge_address,
self.result_address,
qmsg.Sync(self,"TransformsMerge"),
self.transforms.keys())
self.launch_component(name, transform) #start transforms
#launch merge
self.launch_component("transforms merge", self.merge)
qutil.LOGGER.info("transform processes launched")
@@ -66,7 +83,7 @@ class Simulator(object):
qutil.LOGGER.info("client process launched")
self.sync_components()
client_proc.join() #wait for client to complete processing
#client_proc.join() #wait for client to complete processing
def launch_component(self, name, component):
qutil.LOGGER.info("starting {name}".format(name=name))
@@ -81,32 +98,54 @@ class Simulator(object):
return proc
def register_sync(self, sync_id):
self.sync_register[sync_id] = "UNCONFIRMED"
def registration_complete(self):
for status in self.sync_register.values():
if status == "UNCONFIRMED":
return False
return True
self.sync_register[sync_id] = datetime.datetime.utcnow()
def unregister_sync(self, sync_id):
del(self.sync_register[sync_id])
def is_timed_out(self):
cur_time = datetime.datetime.utcnow()
if(len(self.sync_register) == 0):
qutil.LOGGER.info("**********Simulator sync register is empty.")
return True
for source, last_dt in self.sync_register.iteritems():
if((cur_time - last_dt) > self.timeout):
qutil.LOGGER.info("Time out for {source}".format(source=source))
return True
return False
def sync_components(self):
# Socket to receive signals
self.context = zmq.Context()
qutil.LOGGER.info("waiting for all datasources and clients to be ready")
self.syncservice = self.context.socket(zmq.REP)
self.syncservice.bind(self.sync_address)
self.sync_socket = self.context.socket(zmq.REP)
self.sync_socket.bind(self.sync_address)
self.sync_socket.setsockopt(zmq.LINGER,0)
self.poller = zmq.Poller()
self.poller.register(self.sync_socket, zmq.POLLIN)
while not self.registration_complete():
while not self.is_timed_out():
# wait for synchronization request
msg = self.syncservice.recv()
self.sync_register[msg] = "CONFIRMED"
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.syncservice.send('CONFIRMED')
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
self.syncservice.close()
qutil.LOGGER.info("sync'd all datasources and clients")
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
try:
msg = self.sync_socket.recv()
parts = msg.split(':')
sync_id = parts[0]
status = parts[1]
if(status == "DONE"):
self.unregister_sync(sync_id)
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack')
except:
continue
self.sync_socket.close()
qutil.LOGGER.info("simulator heartbeat stopped.")
@@ -123,30 +162,43 @@ class DataFeed(object):
self.feed_socket = None
self.data_socket = None
self.context = None
self.poller = None
def run(self):
def open(self):
# Prepare our context and sockets
try:
self.context = zmq.Context()
self.context = zmq.Context()
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
#see: http://zguide.zeromq.org/py:taskwork2
self.data_socket = self.context.socket(zmq.PULL)
self.data_socket.bind(self.data_address)
#create the feed
self.feed_socket = self.context.socket(zmq.PUB)
self.feed_socket.bind(self.feed_address)
self.feed_socket.setsockopt(zmq.LINGER,0)
ds_finished_counter = 0
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
#see: http://zguide.zeromq.org/py:taskwork2
self.data_socket = self.context.socket(zmq.PULL)
self.data_socket.bind(self.data_address)
self.data_buffer.out_socket = self.feed_socket
self.poller = zmq.Poller()
self.poller.register(self.data_socket, zmq.POLLIN)
#create the feed
self.feed_socket = self.context.socket(zmq.PUB)
self.feed_socket.bind(self.feed_address)
self.sync.open()
self.data_buffer.out_socket = self.feed_socket
def close(self):
self.data_socket.close()
self.feed_socket.close()
self.sync.close()
self.context.term()
self.sync.confirm()
qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address))
while True:
def handle_all(self):
qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address))
ds_finished_counter = 0
while self.sync.confirm():
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
message = self.data_socket.recv()
event = json.loads(message)
if(event["type"] == "DONE"):
ds_finished_counter += 1
@@ -156,20 +208,23 @@ class DataFeed(object):
self.data_buffer.append(event[u's'], event)
self.data_buffer.send_next()
#drain any remaining messages in the buffer
self.data_buffer.drain()
#send the DONE message
self.feed_socket.send("DONE")
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count,
m=self.data_buffer.sent_count))
#drain any remaining messages in the buffer
self.data_buffer.drain()
#send the DONE message
self.feed_socket.send("DONE")
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count,
m=self.data_buffer.sent_count))
def run(self):
try:
self.open()
self.handle_all()
except:
qutil.LOGGER.exception("Exception in Feed, attempting to close.")
finally:
self.data_socket.close()
self.feed_socket.close()
self.context.term()
self.close()
@@ -204,7 +259,7 @@ class BaseTransform(object):
self.open()
self.process_all()
except:
qutil.LOGGER.exception("Exception during merge processing, attempting to close merge.")
qutil.LOGGER.exception("Exception during transform processing, attempting to close merge.")
finally:
self.close()
@@ -221,9 +276,15 @@ class BaseTransform(object):
self.feed_socket.connect(self.feed_address)
self.feed_socket.setsockopt(zmq.SUBSCRIBE,'')
self.poller = zmq.Poller()
self.poller.register(self.feed_socket, zmq.POLLIN)
#create the result PUSH
self.result_socket = self.context.socket(zmq.PUSH)
self.result_socket.connect(self.merge_address)
self.result_socket.setsockopt(zmq.LINGER,0)
self.sync.open()
def process_all(self):
"""
@@ -233,21 +294,22 @@ class BaseTransform(object):
- send the transformed event
"""
qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name']))
self.sync.confirm()
while True:
message = self.feed_socket.recv()
if(message == "DONE"):
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
self.result_socket.send("DONE")
break
self.received_count += 1
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['name'] = self.state['name']
self.result_socket.send(json.dumps(cur_state))
self.sent_count += 1
while self.sync.confirm():
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
message = self.feed_socket.recv()
if(message == "DONE"):
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
self.result_socket.send("DONE")
break
self.received_count += 1
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['name'] = self.state['name']
self.result_socket.send(json.dumps(cur_state))
self.sent_count += 1
def close(self):
"""
@@ -260,6 +322,7 @@ class BaseTransform(object):
self.feed_socket.close()
self.result_socket.close()
self.sync.close()
self.context.term()
def transform(self, event):
@@ -320,6 +383,7 @@ class TransformsMerge(object):
#create the result PUSH
self.result_socket = self.context.socket(zmq.PUSH)
self.result_socket.bind(self.result_address)
self.result_socket.setsockopt(zmq.LINGER,0)
#create the transform PULL.
self.transform_socket = self.context.socket(zmq.PULL)
@@ -331,7 +395,7 @@ class TransformsMerge(object):
self.poller.register(self.feed_socket, zmq.POLLIN)
self.poller.register(self.transform_socket, zmq.POLLIN)
self.sync.confirm()
self.sync.open()
def close(self):
"""
@@ -350,8 +414,8 @@ class TransformsMerge(object):
sent to the result socket.
"""
done_count = 0
while True:
socks = dict(self.poller.poll())
while self.sync.confirm():
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
message = self.feed_socket.recv()
+36 -14
View File
@@ -108,22 +108,44 @@ class Sync(object):
to delay the start of the host until initial setup is complete."""
def __init__(self, host, name):
self.host = host
self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1())
self.host = host
self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1())
self.context = None
self.sync_socket = None
self.poller = None
self.host.register_sync(self.sync_id)
#qutil.LOGGER.info("registered {id} with host".format(id=self.sync_id))
def open(self):
self.context = zmq.Context()
#synchronize with host
self.sync_socket = self.context.socket(zmq.REQ)
self.sync_socket.connect(self.host.sync_address)
self.sync_socket.setsockopt(zmq.LINGER,0)
self.poller = zmq.Poller()
self.poller.register(self.sync_socket, zmq.POLLIN)
def confirm(self):
"""Confirm readiness with the Host."""
context = zmq.Context()
#synchronize with host
sync_socket = context.socket(zmq.REQ)
sync_socket.connect(self.host.sync_address)
# send a synchronization request to the host
sync_socket.send(self.sync_id)
# wait for synchronization reply from the host
sync_socket.recv()
sync_socket.close()
context.term()
qutil.LOGGER.info("sync'd host from {id}".format(id = self.sync_id))
try:
# send a synchronization request to the host
self.sync_socket.send(self.sync_id + ":RUNNING", zmq.NOBLOCK)
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
message = self.sync_socket.recv()
return True
except:
qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id))
return False
def close(self):
try:
self.sync_socket.send(self.sync_id + ":DONE", zmq.NOBLOCK)
self.sync_socket.close()
self.context.term()
except:
pass #just don't want to error out on closing
+30 -11
View File
@@ -33,14 +33,19 @@ class DataSource(object):
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
self.data_socket = self.context.socket(zmq.PUSH)
self.data_socket.connect(self.data_address)
self.data_socket.setsockopt(zmq.LINGER,0)
self.sync.confirm()
self.sync.open()
def run(self):
"""Fully execute this datasource."""
self.open()
self.send_all()
self.close()
try:
self.open()
self.send_all()
except:
qutil.LOGGER.info("Exception running datasource.")
finally:
self.close()
def send_all(self):
"""Subclasses must implement this method."""
@@ -52,20 +57,35 @@ class DataSource(object):
sets source_id and type properties in the dict
sends to the data_socket.
"""
self.sync.confirm()
event['s'] = self.source_id
event['type'] = 'event'
self.data_socket.send(json.dumps(event))
self.data_socket.send(json.dumps(event), zmq.NOBLOCK)
def close(self):
"""
Close the zmq context and sockets.
"""
done_msg = {}
done_msg['type'] = 'DONE'
done_msg['s'] = self.source_id
self.data_socket.send(json.dumps(done_msg))
qutil.LOGGER.info("sending DONE message.")
try:
done_msg = {}
done_msg['type'] = 'DONE'
done_msg['s'] = self.source_id
self.data_socket.send(json.dumps(done_msg), zmq.NOBLOCK)
except:
qutil.LOGGER.exception("failed to send DONE message")
pass #continue with the closing.
qutil.LOGGER.info("closing data socket")
self.data_socket.close()
self.context.term()
qutil.LOGGER.info("closing sync")
self.sync.close()
qutil.LOGGER.info("closing context")
try:
self.context.term()
qutil.LOGGER.info("done")
except:
qutil.LOGGER.exception("error closing context")
qutil.LOGGER.info("finished processing data source")
class RandomEquityTrades(DataSource):
@@ -89,7 +109,6 @@ class RandomEquityTrades(DataSource):
'volume':random.randrange(100,10000,100)}
self.send(event)
+21 -15
View File
@@ -25,32 +25,38 @@ class TestClient(object):
qutil.LOGGER.info("connecting to {address}".format(address=self.address))
self.data_feed.connect(self.address)
self.sync.confirm()
self.sync.open()
self.poller = zmq.Poller()
self.poller.register(self.data_feed, zmq.POLLIN)
qutil.LOGGER.info("Starting the client loop")
prev_dt = None
while True:
msg = self.data_feed.recv()
if(msg == "DONE"):
qutil.LOGGER.info("DONE!")
break
self.received_count += 1
event = json.loads(msg)
if(prev_dt != None):
if(not event['dt'] >= prev_dt):
raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt))
while self.sync.confirm():
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.data_feed in socks and socks[self.data_feed] == zmq.POLLIN:
msg = self.data_feed.recv()
if(msg == "DONE"):
qutil.LOGGER.info("DONE!")
break
self.received_count += 1
event = json.loads(msg)
if(prev_dt != None):
if(not event['dt'] >= prev_dt):
raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt))
prev_dt = event['dt']
if(self.received_count % 100 == 0):
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
prev_dt = event['dt']
if(self.received_count % 100 == 0):
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
except:
self.error = True
qutil.LOGGER.exception("Error in test client.")
qutil.LOGGER.exception("**********************Error in test client.")
finally:
self.data_feed.close()
self.sync.close()
self.context.term()
self.utest.assertEqual(self.expected_msg_count, self.received_count,
+26 -1
View File
@@ -5,11 +5,13 @@ Test suite for the messaging infrastructure of QSim.
import unittest2 as unittest
import multiprocessing
import time
from qsim.core import Simulator
from qsim.core import Simulator, DataFeed
from qsim.transforms.technical import MovingAverage
from qsim.sources import RandomEquityTrades
import qsim.util as qutil
import qsim.messaging as qmsg
from qsim.test.client import TestClient
@@ -56,3 +58,26 @@ class MessagingTestCase(unittest.TestCase):
self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, "The feed should be drained of all messages.")
def test_zerror_in_feed(self):
ret1 = RandomEquityTrades(133, "ret1", 400)
ret2 = RandomEquityTrades(134, "ret2", 400)
sources = {"ret1":ret1, "ret2":ret2}
mavg1 = MovingAverage("mavg1", 30)
mavg2 = MovingAverage("mavg2", 60)
transforms = {"mavg1":mavg1, "mavg2":mavg2}
client = TestClient(self, expected_msg_count=0)
sim = Simulator(sources, transforms, client)
sim.feed = DataFeedErr(sources.keys(), sim.data_address, sim.feed_address, qmsg.Sync(sim, "DataFeedErrorGenerator"))
sim.simulate()
class DataFeedErr(DataFeed):
"""Helper class for testing, simulates exceptions inside the DataFeed"""
def __init__(self, source_list, data_address, feed_address, sync):
DataFeed.__init__(self, source_list, data_address, feed_address, sync)
def handle_all(self):
#time.sleep(1000)
raise Exception("simulated error in data feed from test helper")
+22 -20
View File
@@ -6,6 +6,7 @@ TODO: add trailing stop
"""
import datetime
from qsim.core import BaseTransform
import qsim.util as qutil
class MovingAverage(BaseTransform):
"""
@@ -15,31 +16,32 @@ class MovingAverage(BaseTransform):
def __init__(self, name, days):
BaseTransform.__init__(self, name)
self.events = []
self.window = datetime.timedelta(days = days)
self.events = []
self.current_total = 0
self.window = datetime.timedelta(days = days)
def transform(self, event):
"""Update the moving average with the latest data point."""
#self.events.append(event)
self.events.append(event)
self.current_total += event['price']
event_date = qutil.parse_date(event['dt'])
#filter the event list to the window length.
#self.events = [x for x in self.events if (qutil.parse_date(x['dt']) - qutil.parse_date(event['dt'])) <= self.window]
index = 0
for cur_event in self.events:
cur_date = qutil.parse_date(cur_event['dt'])
if(cur_date - event_date):
self.events.pop(index)
self.current_total -= cur_event['price']
index += 1
else:
break
if(len(self.events) == 0):
return 0.0
self.average = self.current_total/len(self.events)
#if(len(self.events) == 0):
# return 0.0
#total = 0.0
#for event in self.events:
# total += event['price']
#self.average = total/len(self.events)
#self.state['value'] = self.average
self.state['value'] = 10
self.state['value'] = self.average
return self.state
+1 -1
View File
@@ -15,7 +15,7 @@ def parse_date(dt_str):
if(dt_str == None):
return None
parts = dt_str.split(".")
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000"), tzinfo = pytz.utc)
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc)
return dt
def format_date(dt):