From 2aecfc00100c64a4759d443ae855e18731191b22 Mon Sep 17 00:00:00 2001 From: fawce Date: Tue, 14 Feb 2012 00:35:26 -0500 Subject: [PATCH] send sockets have zero linger, sync object now serves as heartbeat, components die on error, all components error on heartbeat timeout --- qsim/core.py | 218 ++++++++++++++++++++++------------- qsim/messaging.py | 50 +++++--- qsim/sources.py | 41 +++++-- qsim/test/client.py | 36 +++--- qsim/test/test_messaging.py | 27 ++++- qsim/transforms/technical.py | 42 +++---- qsim/util.py | 2 +- 7 files changed, 277 insertions(+), 139 deletions(-) diff --git a/qsim/core.py b/qsim/core.py index 452b27ed..c2b4d240 100644 --- a/qsim/core.py +++ b/qsim/core.py @@ -6,6 +6,7 @@ import zmq import json import copy import threading +import datetime import qsim.util as qutil import qsim.messaging as qmsg @@ -15,7 +16,7 @@ class Simulator(object): Simulator coordinates the launch and communication of source, feed, transform, and merge components. """ - def __init__(self, sources, transforms, client): + def __init__(self, sources, transforms, client, feed=None, merge=None): """ """ self.sources = sources @@ -25,7 +26,7 @@ class Simulator(object): self.feed = None self.context = None self.sync_context = None - self.syncservice = None + self.sync_socket = None self.sync_register = {} self.sync_address = "tcp://127.0.0.1:{port}".format(port=10100) self.data_address = "tcp://127.0.0.1:{port}".format(port=10101) @@ -33,29 +34,45 @@ class Simulator(object): self.merge_address = "tcp://127.0.0.1:{port}".format(port=10103) self.result_address = "tcp://127.0.0.1:{port}".format(port=10104) + self.timeout = datetime.timedelta(seconds=1) + + #workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108 + qutil.parse_date("2012/02/13-10:04:28.114") + + if(feed == None): + self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed")) + else: + self.feed = feed + + if(merge == None): + #connect merge to feed, set expected transforms + self.merge = TransformsMerge(self.feed_address, + self.merge_address, + self.result_address, + qmsg.Sync(self,"TransformsMerge"), + self.transforms.keys()) + else: + self.merge = merge + def simulate(self): - self.feed = DataFeed(self.sources.keys(), self.data_address, self.feed_address, qmsg.Sync(self,"DataFeed")) + #launch the feed self.launch_component("DataFeed", self.feed) + + #launch the data sources for name, data_source in self.sources.iteritems(): data_source.data_address = self.data_address data_source.sync = qmsg.Sync(self, str(data_source.source_id)) self.launch_component(name, data_source) qutil.LOGGER.info("datasources processes launched") - #connect all the transforms to the feed and merge + #connect all the transforms to the feed and merge, launch each for name, transform in self.transforms.iteritems(): transform.feed_address = self.feed_address #connect transform to receive feed. transform.merge_address = self.merge_address #connect transform to push results to merge transform.sync = qmsg.Sync(self, name) #synchronize the transform against this simulation. - self.launch_component(name, transform) #start transforms - - #connect merge to feed, set expected transforms - self.merge = TransformsMerge(self.feed_address, - self.merge_address, - self.result_address, - qmsg.Sync(self,"TransformsMerge"), - self.transforms.keys()) + self.launch_component(name, transform) #start transforms + #launch merge self.launch_component("transforms merge", self.merge) qutil.LOGGER.info("transform processes launched") @@ -66,7 +83,7 @@ class Simulator(object): qutil.LOGGER.info("client process launched") self.sync_components() - client_proc.join() #wait for client to complete processing + #client_proc.join() #wait for client to complete processing def launch_component(self, name, component): qutil.LOGGER.info("starting {name}".format(name=name)) @@ -81,32 +98,54 @@ class Simulator(object): return proc def register_sync(self, sync_id): - self.sync_register[sync_id] = "UNCONFIRMED" - - def registration_complete(self): - for status in self.sync_register.values(): - if status == "UNCONFIRMED": - return False - - return True + self.sync_register[sync_id] = datetime.datetime.utcnow() + + def unregister_sync(self, sync_id): + del(self.sync_register[sync_id]) + + def is_timed_out(self): + cur_time = datetime.datetime.utcnow() + if(len(self.sync_register) == 0): + qutil.LOGGER.info("**********Simulator sync register is empty.") + return True + for source, last_dt in self.sync_register.iteritems(): + if((cur_time - last_dt) > self.timeout): + qutil.LOGGER.info("Time out for {source}".format(source=source)) + return True + return False def sync_components(self): # Socket to receive signals self.context = zmq.Context() qutil.LOGGER.info("waiting for all datasources and clients to be ready") - self.syncservice = self.context.socket(zmq.REP) - self.syncservice.bind(self.sync_address) + self.sync_socket = self.context.socket(zmq.REP) + self.sync_socket.bind(self.sync_address) + self.sync_socket.setsockopt(zmq.LINGER,0) + self.poller = zmq.Poller() + self.poller.register(self.sync_socket, zmq.POLLIN) - while not self.registration_complete(): + while not self.is_timed_out(): # wait for synchronization request - msg = self.syncservice.recv() - self.sync_register[msg] = "CONFIRMED" - #qutil.LOGGER.info("confirmed {id}".format(id=msg)) - # send synchronization reply - self.syncservice.send('CONFIRMED') + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - self.syncservice.close() - qutil.LOGGER.info("sync'd all datasources and clients") + if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: + try: + msg = self.sync_socket.recv() + parts = msg.split(':') + sync_id = parts[0] + status = parts[1] + if(status == "DONE"): + self.unregister_sync(sync_id) + else: + self.sync_register[sync_id] = datetime.datetime.utcnow() + #qutil.LOGGER.info("confirmed {id}".format(id=msg)) + # send synchronization reply + self.sync_socket.send('ack') + except: + continue + + self.sync_socket.close() + qutil.LOGGER.info("simulator heartbeat stopped.") @@ -123,30 +162,43 @@ class DataFeed(object): self.feed_socket = None self.data_socket = None self.context = None + self.poller = None - def run(self): + def open(self): # Prepare our context and sockets - try: - self.context = zmq.Context() + self.context = zmq.Context() + #create the data sink. Based on http://zguide.zeromq.org/py:tasksink2 + #see: http://zguide.zeromq.org/py:taskwork2 + self.data_socket = self.context.socket(zmq.PULL) + self.data_socket.bind(self.data_address) + + #create the feed + self.feed_socket = self.context.socket(zmq.PUB) + self.feed_socket.bind(self.feed_address) + self.feed_socket.setsockopt(zmq.LINGER,0) - ds_finished_counter = 0 - - #create the data sink. Based on http://zguide.zeromq.org/py:tasksink2 - #see: http://zguide.zeromq.org/py:taskwork2 - self.data_socket = self.context.socket(zmq.PULL) - self.data_socket.bind(self.data_address) + self.data_buffer.out_socket = self.feed_socket + self.poller = zmq.Poller() + self.poller.register(self.data_socket, zmq.POLLIN) - #create the feed - self.feed_socket = self.context.socket(zmq.PUB) - self.feed_socket.bind(self.feed_address) + self.sync.open() - self.data_buffer.out_socket = self.feed_socket + def close(self): + self.data_socket.close() + self.feed_socket.close() + self.sync.close() + self.context.term() - self.sync.confirm() - qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address)) - - while True: + def handle_all(self): + qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address)) + ds_finished_counter = 0 + while self.sync.confirm(): + # wait for synchronization reply from the host + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + + if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: message = self.data_socket.recv() + event = json.loads(message) if(event["type"] == "DONE"): ds_finished_counter += 1 @@ -156,20 +208,23 @@ class DataFeed(object): self.data_buffer.append(event[u's'], event) self.data_buffer.send_next() - - #drain any remaining messages in the buffer - self.data_buffer.drain() - - #send the DONE message - self.feed_socket.send("DONE") - qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count, - m=self.data_buffer.sent_count)) + #drain any remaining messages in the buffer + self.data_buffer.drain() + + #send the DONE message + self.feed_socket.send("DONE") + qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count, + m=self.data_buffer.sent_count)) + + + def run(self): + try: + self.open() + self.handle_all() except: qutil.LOGGER.exception("Exception in Feed, attempting to close.") finally: - self.data_socket.close() - self.feed_socket.close() - self.context.term() + self.close() @@ -204,7 +259,7 @@ class BaseTransform(object): self.open() self.process_all() except: - qutil.LOGGER.exception("Exception during merge processing, attempting to close merge.") + qutil.LOGGER.exception("Exception during transform processing, attempting to close merge.") finally: self.close() @@ -221,9 +276,15 @@ class BaseTransform(object): self.feed_socket.connect(self.feed_address) self.feed_socket.setsockopt(zmq.SUBSCRIBE,'') + self.poller = zmq.Poller() + self.poller.register(self.feed_socket, zmq.POLLIN) + #create the result PUSH self.result_socket = self.context.socket(zmq.PUSH) self.result_socket.connect(self.merge_address) + self.result_socket.setsockopt(zmq.LINGER,0) + + self.sync.open() def process_all(self): """ @@ -233,21 +294,22 @@ class BaseTransform(object): - send the transformed event """ qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name'])) - self.sync.confirm() - while True: - message = self.feed_socket.recv() - if(message == "DONE"): - qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name'])) - self.result_socket.send("DONE") - break - self.received_count += 1 - event = json.loads(message) - cur_state = self.transform(event) - cur_state['dt'] = event['dt'] - cur_state['name'] = self.state['name'] - self.result_socket.send(json.dumps(cur_state)) - self.sent_count += 1 + while self.sync.confirm(): + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN: + message = self.feed_socket.recv() + if(message == "DONE"): + qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name'])) + self.result_socket.send("DONE") + break + self.received_count += 1 + event = json.loads(message) + cur_state = self.transform(event) + cur_state['dt'] = event['dt'] + cur_state['name'] = self.state['name'] + self.result_socket.send(json.dumps(cur_state)) + self.sent_count += 1 def close(self): """ @@ -260,6 +322,7 @@ class BaseTransform(object): self.feed_socket.close() self.result_socket.close() + self.sync.close() self.context.term() def transform(self, event): @@ -320,6 +383,7 @@ class TransformsMerge(object): #create the result PUSH self.result_socket = self.context.socket(zmq.PUSH) self.result_socket.bind(self.result_address) + self.result_socket.setsockopt(zmq.LINGER,0) #create the transform PULL. self.transform_socket = self.context.socket(zmq.PULL) @@ -331,7 +395,7 @@ class TransformsMerge(object): self.poller.register(self.feed_socket, zmq.POLLIN) self.poller.register(self.transform_socket, zmq.POLLIN) - self.sync.confirm() + self.sync.open() def close(self): """ @@ -350,8 +414,8 @@ class TransformsMerge(object): sent to the result socket. """ done_count = 0 - while True: - socks = dict(self.poller.poll()) + while self.sync.confirm(): + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN: message = self.feed_socket.recv() diff --git a/qsim/messaging.py b/qsim/messaging.py index aff0668c..f6051ece 100644 --- a/qsim/messaging.py +++ b/qsim/messaging.py @@ -108,22 +108,44 @@ class Sync(object): to delay the start of the host until initial setup is complete.""" def __init__(self, host, name): - self.host = host - self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1()) + self.host = host + self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1()) + self.context = None + self.sync_socket = None + self.poller = None self.host.register_sync(self.sync_id) + #qutil.LOGGER.info("registered {id} with host".format(id=self.sync_id)) + def open(self): + self.context = zmq.Context() + #synchronize with host + self.sync_socket = self.context.socket(zmq.REQ) + self.sync_socket.connect(self.host.sync_address) + self.sync_socket.setsockopt(zmq.LINGER,0) + self.poller = zmq.Poller() + self.poller.register(self.sync_socket, zmq.POLLIN) + def confirm(self): """Confirm readiness with the Host.""" - context = zmq.Context() - #synchronize with host - sync_socket = context.socket(zmq.REQ) - sync_socket.connect(self.host.sync_address) - # send a synchronization request to the host - sync_socket.send(self.sync_id) - # wait for synchronization reply from the host - sync_socket.recv() - sync_socket.close() - context.term() - qutil.LOGGER.info("sync'd host from {id}".format(id = self.sync_id)) - \ No newline at end of file + try: + # send a synchronization request to the host + self.sync_socket.send(self.sync_id + ":RUNNING", zmq.NOBLOCK) + # wait for synchronization reply from the host + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + + if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: + message = self.sync_socket.recv() + return True + except: + qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id)) + return False + + def close(self): + try: + self.sync_socket.send(self.sync_id + ":DONE", zmq.NOBLOCK) + self.sync_socket.close() + self.context.term() + except: + pass #just don't want to error out on closing + \ No newline at end of file diff --git a/qsim/sources.py b/qsim/sources.py index 7090bd77..d50414b7 100644 --- a/qsim/sources.py +++ b/qsim/sources.py @@ -33,14 +33,19 @@ class DataSource(object): #create the data sink. Based on http://zguide.zeromq.org/py:tasksink2 self.data_socket = self.context.socket(zmq.PUSH) self.data_socket.connect(self.data_address) + self.data_socket.setsockopt(zmq.LINGER,0) - self.sync.confirm() + self.sync.open() def run(self): """Fully execute this datasource.""" - self.open() - self.send_all() - self.close() + try: + self.open() + self.send_all() + except: + qutil.LOGGER.info("Exception running datasource.") + finally: + self.close() def send_all(self): """Subclasses must implement this method.""" @@ -52,20 +57,35 @@ class DataSource(object): sets source_id and type properties in the dict sends to the data_socket. """ + self.sync.confirm() event['s'] = self.source_id event['type'] = 'event' - self.data_socket.send(json.dumps(event)) + self.data_socket.send(json.dumps(event), zmq.NOBLOCK) def close(self): """ Close the zmq context and sockets. """ - done_msg = {} - done_msg['type'] = 'DONE' - done_msg['s'] = self.source_id - self.data_socket.send(json.dumps(done_msg)) + qutil.LOGGER.info("sending DONE message.") + try: + done_msg = {} + done_msg['type'] = 'DONE' + done_msg['s'] = self.source_id + self.data_socket.send(json.dumps(done_msg), zmq.NOBLOCK) + except: + qutil.LOGGER.exception("failed to send DONE message") + pass #continue with the closing. + + qutil.LOGGER.info("closing data socket") self.data_socket.close() - self.context.term() + qutil.LOGGER.info("closing sync") + self.sync.close() + qutil.LOGGER.info("closing context") + try: + self.context.term() + qutil.LOGGER.info("done") + except: + qutil.LOGGER.exception("error closing context") qutil.LOGGER.info("finished processing data source") class RandomEquityTrades(DataSource): @@ -89,7 +109,6 @@ class RandomEquityTrades(DataSource): 'volume':random.randrange(100,10000,100)} self.send(event) - diff --git a/qsim/test/client.py b/qsim/test/client.py index bd03e0d5..58d0cf7c 100644 --- a/qsim/test/client.py +++ b/qsim/test/client.py @@ -25,32 +25,38 @@ class TestClient(object): qutil.LOGGER.info("connecting to {address}".format(address=self.address)) self.data_feed.connect(self.address) - self.sync.confirm() + self.sync.open() + self.poller = zmq.Poller() + self.poller.register(self.data_feed, zmq.POLLIN) + qutil.LOGGER.info("Starting the client loop") prev_dt = None - while True: - msg = self.data_feed.recv() - if(msg == "DONE"): - qutil.LOGGER.info("DONE!") - break - self.received_count += 1 - event = json.loads(msg) - if(prev_dt != None): - if(not event['dt'] >= prev_dt): - raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt)) + while self.sync.confirm(): + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + if self.data_feed in socks and socks[self.data_feed] == zmq.POLLIN: + msg = self.data_feed.recv() + if(msg == "DONE"): + qutil.LOGGER.info("DONE!") + break + self.received_count += 1 + event = json.loads(msg) + if(prev_dt != None): + if(not event['dt'] >= prev_dt): + raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt)) - prev_dt = event['dt'] - if(self.received_count % 100 == 0): - qutil.LOGGER.info("received {n} messages".format(n=self.received_count)) + prev_dt = event['dt'] + if(self.received_count % 100 == 0): + qutil.LOGGER.info("received {n} messages".format(n=self.received_count)) qutil.LOGGER.info("received {n} messages".format(n=self.received_count)) except: self.error = True - qutil.LOGGER.exception("Error in test client.") + qutil.LOGGER.exception("**********************Error in test client.") finally: self.data_feed.close() + self.sync.close() self.context.term() self.utest.assertEqual(self.expected_msg_count, self.received_count, diff --git a/qsim/test/test_messaging.py b/qsim/test/test_messaging.py index 22ab6a6f..ca03619b 100644 --- a/qsim/test/test_messaging.py +++ b/qsim/test/test_messaging.py @@ -5,11 +5,13 @@ Test suite for the messaging infrastructure of QSim. import unittest2 as unittest import multiprocessing +import time -from qsim.core import Simulator +from qsim.core import Simulator, DataFeed from qsim.transforms.technical import MovingAverage from qsim.sources import RandomEquityTrades import qsim.util as qutil +import qsim.messaging as qmsg from qsim.test.client import TestClient @@ -56,3 +58,26 @@ class MessagingTestCase(unittest.TestCase): self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, "The feed should be drained of all messages.") + def test_zerror_in_feed(self): + ret1 = RandomEquityTrades(133, "ret1", 400) + ret2 = RandomEquityTrades(134, "ret2", 400) + sources = {"ret1":ret1, "ret2":ret2} + mavg1 = MovingAverage("mavg1", 30) + mavg2 = MovingAverage("mavg2", 60) + transforms = {"mavg1":mavg1, "mavg2":mavg2} + client = TestClient(self, expected_msg_count=0) + sim = Simulator(sources, transforms, client) + sim.feed = DataFeedErr(sources.keys(), sim.data_address, sim.feed_address, qmsg.Sync(sim, "DataFeedErrorGenerator")) + sim.simulate() + +class DataFeedErr(DataFeed): + """Helper class for testing, simulates exceptions inside the DataFeed""" + + def __init__(self, source_list, data_address, feed_address, sync): + DataFeed.__init__(self, source_list, data_address, feed_address, sync) + + def handle_all(self): + #time.sleep(1000) + raise Exception("simulated error in data feed from test helper") + + diff --git a/qsim/transforms/technical.py b/qsim/transforms/technical.py index f9166f38..63d7a164 100644 --- a/qsim/transforms/technical.py +++ b/qsim/transforms/technical.py @@ -6,6 +6,7 @@ TODO: add trailing stop """ import datetime from qsim.core import BaseTransform +import qsim.util as qutil class MovingAverage(BaseTransform): """ @@ -15,31 +16,32 @@ class MovingAverage(BaseTransform): def __init__(self, name, days): BaseTransform.__init__(self, name) - self.events = [] - - self.window = datetime.timedelta(days = days) - - - + self.events = [] + self.current_total = 0 + self.window = datetime.timedelta(days = days) def transform(self, event): """Update the moving average with the latest data point.""" - #self.events.append(event) + self.events.append(event) + self.current_total += event['price'] + event_date = qutil.parse_date(event['dt']) - #filter the event list to the window length. - #self.events = [x for x in self.events if (qutil.parse_date(x['dt']) - qutil.parse_date(event['dt'])) <= self.window] + index = 0 + for cur_event in self.events: + cur_date = qutil.parse_date(cur_event['dt']) + if(cur_date - event_date): + self.events.pop(index) + self.current_total -= cur_event['price'] + index += 1 + else: + break + + if(len(self.events) == 0): + return 0.0 + + self.average = self.current_total/len(self.events) - #if(len(self.events) == 0): - # return 0.0 - - #total = 0.0 - #for event in self.events: - # total += event['price'] - - #self.average = total/len(self.events) - - #self.state['value'] = self.average - self.state['value'] = 10 + self.state['value'] = self.average return self.state \ No newline at end of file diff --git a/qsim/util.py b/qsim/util.py index 9b94203f..fc796b56 100644 --- a/qsim/util.py +++ b/qsim/util.py @@ -15,7 +15,7 @@ def parse_date(dt_str): if(dt_str == None): return None parts = dt_str.split(".") - dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000"), tzinfo = pytz.utc) + dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc) return dt def format_date(dt):