From fc8b942e450d4195728909c413a3dcfd2f9eb7f2 Mon Sep 17 00:00:00 2001 From: fawce Date: Wed, 15 Feb 2012 22:35:00 -0500 Subject: [PATCH] first draft of refactoring is done --- zipline/core.py | 462 +-------------------------------- zipline/messaging.py | 360 ++++++++++++++++++++----- zipline/sources.py | 34 ++- zipline/test/test_messaging.py | 25 +- 4 files changed, 342 insertions(+), 539 deletions(-) diff --git a/zipline/core.py b/zipline/core.py index 4e035e97..7cd87030 100644 --- a/zipline/core.py +++ b/zipline/core.py @@ -12,141 +12,26 @@ import atexit import zipline.util as qutil import zipline.messaging as qmsg -class SimulatorBase(object): +class SimulatorBase(ComponentHost): """ Simulator coordinates the launch and communication of source, feed, transform, and merge components. """ - def __init__(self, sources, transforms, client, feed=None, merge=None): + def __init__(self, addresses): """ - """ - self.sources = sources - self.transforms = transforms - self.client = client - self.merge = None - self.feed = None - self.context = None - self.sync_context = None - self.sync_socket = None - self.sync_register = {} - self.sync_address = "tcp://127.0.0.1:{port}".format(port=10100) - self.data_address = "tcp://127.0.0.1:{port}".format(port=10101) - self.feed_address = "tcp://127.0.0.1:{port}".format(port=10102) - self.merge_address = "tcp://127.0.0.1:{port}".format(port=10103) - self.result_address = "tcp://127.0.0.1:{port}".format(port=10104) - - self.performance_address = "tcp://127.0.0.1:{port}".format(port=10105) - - self.timeout = datetime.timedelta(seconds=5) + """ + self.feed = ParallelBuffer() + self.merge = MergedParallelBuffer() #workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108 qutil.parse_date("2012/02/13-10:04:28.114") - if(feed == None): - self.feed = DataFeed(self.sources.keys(), - self.data_address, - self.feed_address, - self.performance_address, - qmsg.Sync(self,"DataFeed")) - else: - self.feed = feed - - if(merge == None): - #connect merge to feed, set expected transforms - self.merge = TransformsMerge(self.feed_address, - self.merge_address, - self.result_address, - qmsg.Sync(self,"TransformsMerge"), - self.transforms.keys()) - else: - self.merge = merge - + #register the feed and the merge + self.register_components([self.feed, self.merge]) + def simulate(self): - #launch the feed - self.launch_component("DataFeed", self.feed) + self.launch_component(self) - #launch the data sources - for name, data_source in self.sources.iteritems(): - data_source.data_address = self.data_address - data_source.sync = qmsg.Sync(self, str(data_source.source_id)) - self.launch_component(name, data_source) - qutil.LOGGER.info("datasources processes launched") - - #connect all the transforms to the feed and merge, launch each - for name, transform in self.transforms.iteritems(): - transform.feed_address = self.feed_address #connect transform to receive feed. - transform.merge_address = self.merge_address #connect transform to push results to merge - transform.sync = qmsg.Sync(self, name) #synchronize the transform against this simulation. - self.launch_component(name, transform) #start transforms - - #launch merge - self.launch_component("transforms merge", self.merge) - qutil.LOGGER.info("transform processes launched") - - #connect client to merged feed - self.client.address = self.result_address - self.client.sync = qmsg.Sync(self,"Client") - client_proc = self.launch_component("client", self.client) - qutil.LOGGER.info("client process launched") - - qutil.LOGGER.info("sync register starting with {count} members: {reg}".format(count=len(self.sync_register), reg=self.sync_register)) - self.sync_components() - #client_proc.join() #wait for client to complete processing - - def launch_component(self, name, component): - raise NotImplementedError - - def register_sync(self, sync_id): - self.sync_register[sync_id] = datetime.datetime.utcnow() - - def unregister_sync(self, sync_id): - qutil.LOGGER.info("unregistering {sync_id}".format(sync_id=sync_id)) - del(self.sync_register[sync_id]) - - def is_timed_out(self): - cur_time = datetime.datetime.utcnow() - if(len(self.sync_register) == 0): - qutil.LOGGER.info("**********Simulator sync register is empty.") - return True - for source, last_dt in self.sync_register.iteritems(): - if((cur_time - last_dt) > self.timeout): - qutil.LOGGER.info("Time out for {source}. Current registery: {reg}".format(source=source, reg=self.sync_register)) - return True - return False - - def sync_components(self): - # Socket to receive signals - self.context = zmq.Context() - qutil.LOGGER.info("waiting for all datasources and clients to be ready") - self.sync_socket = self.context.socket(zmq.REP) - self.sync_socket.bind(self.sync_address) - #self.sync_socket.setsockopt(zmq.LINGER,0) - self.poller = zmq.Poller() - self.poller.register(self.sync_socket, zmq.POLLIN) - - while not self.is_timed_out(): - # wait for synchronization request - socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - - if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: - try: - msg = self.sync_socket.recv() - parts = msg.split(':') - sync_id = parts[0] - status = parts[1] - if(status == "DONE"): - self.unregister_sync(sync_id) - else: - self.sync_register[sync_id] = datetime.datetime.utcnow() - #qutil.LOGGER.info("confirmed {id}".format(id=msg)) - # send synchronization reply - self.sync_socket.send('ack', zmq.NOBLOCK) - except: - qutil.LOGGER.exception("Exception in sync components loop") - - self.sync_socket.close() - qutil.LOGGER.info("simulator heartbeat stopped.") - class ThreadSimulator(SimulatorBase): def __init__(self, sources, transforms, client, feed=None, merge=None): @@ -168,333 +53,4 @@ class ProcessSimulator(SimulatorBase): proc = multiprocessing.Process(target=component.run) proc.start() return proc - - -class DataFeed(object): - - def __init__(self, source_list, data_address, feed_address, performance_address, sync): - """ - :source_list: list of data source IDs - """ - self.feed_address = feed_address - self.performance_address = performance_address - self.data_address = data_address - self.data_buffer = qmsg.ParallelBuffer(source_list) - self.sync = sync - self.feed_socket = None - self.data_socket = None - self.perf_socket = None - self.context = None - self.poller = None - - - def open(self): - # Prepare our context and sockets - self.context = zmq.Context() - #create the data sink. Based on http://zguide.zeromq.org/py:tasksink2 - #see: http://zguide.zeromq.org/py:taskwork2 - self.data_socket = self.context.socket(zmq.PULL) - self.data_socket.bind(self.data_address) - - #create the feed - self.feed_socket = self.context.socket(zmq.PUB) - self.feed_socket.bind(self.feed_address) - #self.feed_socket.setsockopt(zmq.LINGER,0) - - self.data_buffer.out_socket = self.feed_socket - self.poller = zmq.Poller() - self.poller.register(self.data_socket, zmq.POLLIN) - - #create the performance results push - self.perf_socket = self.context.socket(zmq.PUSH) - self.perf_socket.bind(self.performance_address) - #self.perf_socket.setsockopt(zmq.LINGER,0) - - self.sync.open() - - def close(self): - try: - self.data_socket.close() - self.feed_socket.close() - self.perf_socket.close() - self.sync.close() - except: - qutil.LOGGER.exception("Error closing DataFeed") - finally: - self.context.destroy() - - def handle_all(self): - qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address)) - ds_finished_counter = 0 - while self.sync.confirm(): - # wait for synchronization reply from the host - socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - - if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN: - message = self.data_socket.recv() - - event = json.loads(message) - if(event["type"] == "DONE"): - ds_finished_counter += 1 - if(len(self.data_buffer) == ds_finished_counter): - break - else: - self.data_buffer.append(event[u's'], event) - self.data_buffer.send_next() - - #self.perf_socket.send(message, zmq.NOBLOCK) - - #drain any remaining messages in the buffer - self.data_buffer.drain() - - #send the DONE message - self.feed_socket.send("DONE", zmq.NOBLOCK) - qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count, - m=self.data_buffer.sent_count)) - - - def run(self): - try: - self.open() - self.handle_all() - except: - qutil.LOGGER.exception("Exception in Feed, attempting to close.") - finally: - self.close() - - - -class BaseTransform(object): - """Parent class for feed transforms. Subclass and override transform - method to create a new derived value from the combined feed.""" - - def __init__(self, name): - """ - """ - - self.feed_address = None - self.merge_address = None - self.state = {} - self.state['name'] = name - self.sync = None - self.received_count = 0 - self.sent_count = 0 - self.context = None - self.feed_socket = None - self.result_socket = None - - def run(self): - """Top level execution entry point for the transform:: - - - connects to the feed socket to subscribe to events - - connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms - - processes all messages received from feed, until DONE message received - - pushes all transforms - - sends DONE to result socket, closes all sockets and context""" - try: - self.open() - self.process_all() - except: - qutil.LOGGER.exception("Exception during transform processing, attempting to close merge.") - finally: - self.close() - - def open(self): - """ - Establishes zmq connections. - """ - self.context = zmq.Context() - - qutil.LOGGER.info("starting {name} transform". - format(name = self.state['name'])) - #create the feed SUB. - self.feed_socket = self.context.socket(zmq.SUB) - self.feed_socket.connect(self.feed_address) - self.feed_socket.setsockopt(zmq.SUBSCRIBE,'') - - self.poller = zmq.Poller() - self.poller.register(self.feed_socket, zmq.POLLIN) - - #create the result PUSH - self.result_socket = self.context.socket(zmq.PUSH) - self.result_socket.connect(self.merge_address) - #self.result_socket.setsockopt(zmq.LINGER,0) - - self.sync.open() - - def process_all(self): - """ - Loops until feed's DONE message is received: - - receive an event from the data feed - - call transform (subclass' method) on event - - send the transformed event - """ - qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name'])) - - while self.sync.confirm(): - socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN: - message = self.feed_socket.recv() - if(message == "DONE"): - qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name'])) - self.result_socket.send("DONE", zmq.NOBLOCK) - break - self.received_count += 1 - event = json.loads(message) - cur_state = self.transform(event) - cur_state['dt'] = event['dt'] - cur_state['name'] = self.state['name'] - self.result_socket.send(json.dumps(cur_state), zmq.NOBLOCK) - self.sent_count += 1 - - def close(self): - """ - Shut down zmq resources. - """ - qutil.LOGGER.info("Transform {name} recieved {r} and sent {s}".format( - name=self.state['name'], - r=self.received_count, - s=self.sent_count)) - - try: - self.feed_socket.close() - self.result_socket.close() - self.sync.close() - except: - qutil.LOGGER.exception("Error closing Transforms") - finally: - self.context.destroy() - - def transform(self, event): - """ Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"} - Transforms run in parallel and results are merged into a single map, so transform names must be unique. - Best practice is to use the self.state object initialized from the transform configuration, and only set the - transformed value: - self.state['value'] = transformed_value - """ - return {} - -class TransformsMerge(object): - """ Merge data feed and array of transform feeds into a single result vector. - PULL from feed - PULL from child transforms - PUSH merged message to client - - """ - - def __init__(self, feed_address, transform_address, result_address, sync, transform_list): - """ - """ - self.sync = sync - self.feed_address = feed_address - self.transform_address = transform_address - self.result_address = result_address - buffer_list = copy.copy(transform_list) - buffer_list.append("feed") #for the raw feed - self.data_buffer = qmsg.MergedParallelBuffer(buffer_list) - self.feed_socket = None - self.result_socket = None - self.poller = None - self.context = None - self.transform_socket = None - - def run(self): - """""" - try: - self.open() - self.process_all() - except: - qutil.LOGGER.exception("Exception during merge processing, attempting to close merge.") - finally: - self.close() - - def open(self): - """Establish zmq context, feed socket, result socket for client, and transform - socket to receive transformed events. Create and launch transforms. Will confirm - ready with the DataFeed at the conclusion.""" - self.context = zmq.Context() - - qutil.LOGGER.info("starting transforms merge") - #create the feed SUB. - self.feed_socket = self.context.socket(zmq.SUB) - self.feed_socket.connect(self.feed_address) - self.feed_socket.setsockopt(zmq.SUBSCRIBE,'') - - #create the result PUSH - self.result_socket = self.context.socket(zmq.PUSH) - self.result_socket.bind(self.result_address) - #self.result_socket.setsockopt(zmq.LINGER,0) - - #create the transform PULL. - self.transform_socket = self.context.socket(zmq.PULL) - self.transform_socket.bind(self.transform_address) - self.data_buffer.out_socket = self.result_socket - - # Initialize poll set - self.poller = zmq.Poller() - self.poller.register(self.feed_socket, zmq.POLLIN) - self.poller.register(self.transform_socket, zmq.POLLIN) - - self.sync.open() - - def close(self): - """ - Close all zmq sockets and context. - """ - try: - self.sync.close() - self.transform_socket.close() - self.feed_socket.close() - self.result_socket.close() - except: - qutil.LOGGER.exception("Error closing merge") - finally: - self.context.destroy() - - def process_all(self): - """ - Uses a Poller to receive messages from all transforms and the feed. - All transforms corresponding to the same event are merged with each other - and the original feed event into a single message. That message is then - sent to the result socket. - """ - done_count = 0 - while self.sync.confirm(): - socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - - if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN: - message = self.feed_socket.recv() - if(message == "DONE"): - qutil.LOGGER.info("finished receiving feed to merge") - done_count += 1 - else: - event = json.loads(message) - self.data_buffer.append("feed", event) - - if self.transform_socket in socks and socks[self.transform_socket] == zmq.POLLIN: - t_message = self.transform_socket.recv() - if(t_message == "DONE"): - qutil.LOGGER.info("finished receiving a transform to merge") - done_count += 1 - else: - t_event = json.loads(t_message) - self.data_buffer.append(t_event['name'], t_event) - - if(done_count >= len(self.data_buffer)): - break #done! - - self.data_buffer.send_next() - - qutil.LOGGER.info("about to drain {n} messages from merger's buffer".format(n=self.data_buffer.pending_messages())) - - #drain any remaining messages in the buffer - self.data_buffer.drain() - - #signal to client that we're done - self.result_socket.send("DONE", zmq.NOBLOCK) - - - - - diff --git a/zipline/messaging.py b/zipline/messaging.py index 7d9ab7d7..1e3fd3b2 100644 --- a/zipline/messaging.py +++ b/zipline/messaging.py @@ -4,21 +4,243 @@ Commonly used messaging components. import json import uuid from gevent_zeromq import zmq - import zipline.util as qutil -class ParallelBuffer(object): - """ holds several queues of events by key, allows retrieval in date order - or by merging""" - def __init__(self, key_list): - self.out_socket = None - self.sent_count = 0 - self.received_count = 0 - self.draining = False - self.data_buffer = {} - for key in key_list: - self.data_buffer[key] = [] +class Component(object): + def __init__(self, addresses): + """ + :addresses: a dict of name_string -> zmq port address strings. Must have the following entries:: + + - sync_address: socket address used for synchronizing the start of all workers, heartbeating, and exit notification + will be used in REP/REQ sockets. Bind is always on the REP side. + - data_address: socket address used for data sources to stream their records. + will be used in PUSH/PULL sockets between data sources and a ParallelBuffer (aka the Feed). Bind + will always be on the PULL side (we always have N producers and 1 consumer) + - feed_address: socket address used to publish consolidated feed from serialization of data sources + will be used in PUB/SUB sockets between Feed and Transforms. Bind is always on the PUB side. + - merge_address: socket address used to publish transformed values. + will be used in PUSH/PULL from many transforms to one MergedParallelBuffer (aka the Merge). Bind + will always be on the PULL side (we always have N producers and 1 consumer) + - result_address: socket address used to publish merged data source feed and transforms to clients + will be used in PUB/SUB from one Merge to one or many clients. Bind is always on the PUB side. + + Bind/Connect methods will return the correct socket type for each address. Any sockets on which recv is expected to be called + will also return a Poller. + + """ + self.context = zmq.Context() + self.addresses = addresses + self.sockets = [] + + def get_id(self): + raise NotImplemented + + def open(self): + raise NotImplemented + + def do_work(self): + raise NotImplemented + + def run(self): + try: + self.open() + self.connect_sync() + while self.confirm(): + self.do_work() + #notify host we're done + self.sync_socket.send(self.sync_id + ":DONE") + #close all the sockets + for sock in self.sockets: + sock.close() + finally: + self.context.destroy() + + def confirm(self): + try: + # send a synchronization request to the host + self.sync_socket.send(self.sync_id + ":RUNNING") + # wait for synchronization reply from the host + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + + if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: + message = self.sync_socket.recv() + + return True + except: + qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id)) + return False + + def bind_data(self): + return self.bind_pull_socket(self, self.addresses['data_address']) + + def connect_data(self): + return self.connect_push_socket(self, self.addresses['data_address']) + + def bind_feed(self): + return self.bind_pub_socket(self, self.addresses['feed_address']) + + def connect_feed(self): + return self.bind_sub_socket(self, self.addresses['feed_address']) + + def bind_merge(self): + return self.bind_pull_socket(self, self.addresses['merge_address']) + + def connect_merge(self): + return self.connect_push_socket(self, self.addresses['merge_address']) + + def bind_result(self): + return self.bind_pub_socket(self, self.addresses['result_address']) + + def connect_result(self): + return self.bind_sub_socket(self, self.addresses['result_address']) + + def bind_pull_socket(self, addr): + pull_socket = self.context.socket(zmq.PULL) + pull_socket.bind(addr) + poller = zmq.Poller() + poller.register(self.pull_socket, zmq.POLLIN) + self.sockets.append(pull_socket) + return pull_socket, poller + + def connect_push_socket(self, addr): + push_socket = self.context.socket(zmq.PUSH) + push_socket.connect(self.merge_address) + #push_socket.setsockopt(zmq.LINGER,0) + self.sockets.append(push_socket) + return push_socket + + def bind_pub_socket(self, addr): + pub_socket = self.context.socket(zmq.PUB) + pub_socket.bind(self.pub_address) + #pub_socket.setsockopt(zmq.LINGER,0) + poller = zmq.Poller() + poller.register(self.pub_socket, zmq.POLLIN) + self.sockets.append(pub_socket) + return pub_socket, poller + + def connect_sub_socket(self, addr): + sub_socket = self.context.socket(zmq.SUB) + sub_socket.connect(self.feed_address) + sub_socket.setsockopt(zmq.SUBSCRIBE,'') + self.sockets.append(sub_socket) + return sub_socket + + def bind_sync(self): + sync_socket = self.context.socket(zmq.REP) + sync_socket.bind(self.addresses['sync_address']) + poller = zmq.Poller() + poller.register(self.sync_socket, zmq.POLLIN) + self.sockets.append(sync_socket) + return sync_socket, poller + + def connect_sync(self): + self.sync_socket = self.context.socket(zmq.REQ) + self.sync_socket.connect(self.addresses['sync_address']) + self.sync_socket.setsockopt(zmq.LINGER,0) + self.poller = zmq.Poller() + self.poller.register(self.sync_socket, zmq.POLLIN) + self.sockets.append(sync_socket) + +class ComponentHost(Component): + """Component that can launch multiple sub-components, synchronize their start, and then wait for all + components to be finished.""" + def __init__(self, addresses): + Component.__init__(self, addresses) + self.components = {} + self.timeout = datetime.timedelta(seconds=5) + + def register_components(self, component_list): + for component in component_list: + self.components[component.get_id()] = component + + def unregister_component(self, component_id): + del(self.components[component_id]) + + def open(self): + self.sync_socket, self.poller = self.bind_sync() + for component in self.components.values(): + self.launch_component(component) + + def is_timed_out(self): + cur_time = datetime.datetime.utcnow() + if(len(self.sync_register) == 0): + qutil.LOGGER.info("Component register is empty.") + return True + for source, last_dt in self.sync_register.iteritems(): + if((cur_time - last_dt) > self.timeout): + qutil.LOGGER.info("Time out for {source}. Current registery: {reg}".format(source=source, reg=self.sync_register)) + return True + return False + + def run(self): + while not self.is_timed_out(): + # wait for synchronization request + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + + if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: + try: + msg = self.sync_socket.recv() + parts = msg.split(':') + sync_id = parts[0] + status = parts[1] + if(status == "DONE"): + self.unregister_component(sync_id) + else: + self.sync_register[sync_id] = datetime.datetime.utcnow() + #qutil.LOGGER.info("confirmed {id}".format(id=msg)) + # send synchronization reply + self.sync_socket.send('ack', zmq.NOBLOCK) + except: + qutil.LOGGER.exception("Exception in sync components loop") + + def luanch_component(self, component): + raise NotImplemented + +class ParallelBuffer(Component): + """Connects to N PULL sockets, publishing all messages received to a PUB socket. + Published messages are guaranteed to be in chronological order based on message property dt. + Expects to be instantiated in one execution context (thread, process, etc) and run in another.""" + + def __init__(self): + self.sent_count = 0 + self.received_count = 0 + self.draining = False + self.data_buffer = {} + self.ds_finished_counter = 0 + + + def get_id(self): + return "FEED" + + def add_source(self, source_id): + self.data_buffer[source_id] = [] + + def open(self): + self.pull_socket, self.poller = self.bind_data() + self.feed_socket = self.bind_feed() + + def do_work(self): + # wait for synchronization reply from the host + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + + if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN: + message = self.pull_socket.recv() + + event = json.loads(message) + if(event["type"] == "DONE"): + ds_finished_counter += 1 + if(len(self.data_buffer) == ds_finished_counter): + #drain any remaining messages in the buffer + self.drain() + #send the DONE message + self.feed_socket.send("DONE", zmq.NOBLOCK) + qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.received_count, + m=self.sent_count)) + else: + self.append(event[u's'], event) + self.send_next() + def __len__(self): """buffer's length is same as internal map holding separate sorted arrays of events keyed by source id""" return len(self.data_buffer) @@ -72,7 +294,7 @@ class ParallelBuffer(object): event = self.next() if(event != None): - self.out_socket.send(json.dumps(event), zmq.NOBLOCK) + self.feed_socket.send(json.dumps(event), zmq.NOBLOCK) self.sent_count += 1 @@ -81,10 +303,8 @@ class MergedParallelBuffer(ParallelBuffer): Merges multiple streams of events into single messages. """ - def __init__(self, keys): - ParallelBuffer.__init__(self, keys) - self.feed = [] - self.data_buffer["feed"] = self.feed + def __init__(self): + ParallelBuffer.__init__(self) def next(self): """Get the next merged message from the feed buffer.""" @@ -100,54 +320,66 @@ class MergedParallelBuffer(ParallelBuffer): result[source] = cur['value'] return result - -class Sync(object): - """Sync instances register themselves with a Host. Once the Sync - is created, the Host is guaranteed to block until confirm is called on this - instance (and all others registered with the host). Components can use instances - to delay the start of the host until initial setup is complete.""" - - def __init__(self, host, name): - self.host = host - self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1()) - self.context = None - self.sync_socket = None - self.poller = None - self.host.register_sync(self.sync_id) - - #qutil.LOGGER.info("registered {id} with host".format(id=self.sync_id)) - - def open(self): - self.context = zmq.Context() - #synchronize with host - self.sync_socket = self.context.socket(zmq.REQ) - self.sync_socket.connect(self.host.sync_address) - self.sync_socket.setsockopt(zmq.LINGER,0) - self.poller = zmq.Poller() - self.poller.register(self.sync_socket, zmq.POLLIN) - - def confirm(self): - """Confirm readiness with the Host.""" - try: - # send a synchronization request to the host - self.sync_socket.send(self.sync_id + ":RUNNING") - # wait for synchronization reply from the host - socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. - if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN: - message = self.sync_socket.recv() +class BaseTransform(Component): + """Top level execution entry point for the transform:: + + - connects to the feed socket to subscribe to events + - connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms + - processes all messages received from feed, until DONE message received + - pushes all transforms + - sends DONE to result socket, closes all sockets and context - return True - except: - qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id)) - return False + Parent class for feed transforms. Subclass and override transform + method to create a new derived value from the combined feed.""" + + def __init__(self, name): + self.state = {} + self.state['name'] = name + + def get_id(self): + return self.state['name'] + + def open(self): + """ + Establishes zmq connections. + """ + #create the feed. + self.feed_socket, self.poller = self.connect_feed() + + #create the result PUSH + self.result_socket = self.connect_merge() - def close(self): - try: - self.sync_socket.send(self.sync_id + ":DONE") - self.sync_socket.close() - except: - qutil.LOGGER.exception("Error closing Sync object") - finally: - self.context.destroy() + + def do_work(self): + """ + Loops until feed's DONE message is received: + - receive an event from the data feed + - call transform (subclass' method) on event + - send the transformed event + """ + qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name'])) + socks = dict(self.poller.poll(2000)) #timeout after 2 seconds. + if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN: + message = self.feed_socket.recv() + if(message == "DONE"): + qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name'])) + self.result_socket.send("DONE", zmq.NOBLOCK) + return + self.received_count += 1 + event = json.loads(message) + cur_state = self.transform(event) + cur_state['dt'] = event['dt'] + cur_state['name'] = self.state['name'] + self.result_socket.send(json.dumps(cur_state), zmq.NOBLOCK) + self.sent_count += 1 + + def transform(self, event): + """ Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"} + Transforms run in parallel and results are merged into a single map, so transform names must be unique. + Best practice is to use the self.state object initialized from the transform configuration, and only set the + transformed value: + self.state['value'] = transformed_value + """ + return {'value'=event} \ No newline at end of file diff --git a/zipline/sources.py b/zipline/sources.py index 145682d0..5666db00 100644 --- a/zipline/sources.py +++ b/zipline/sources.py @@ -7,26 +7,34 @@ import json import random import zipline.util as qutil +import zipline.messaging as qmsg -class DataSource(object): +class DataSource(qmsg.Component): """ Baseclass for data sources. Subclass and implement send_all - usually this means looping through all records in a store, converting to a dict, and calling send(map). """ def __init__(self, source_id): - self.source_id = source_id + self.id = source_id + self.host = host self.data_address = None self.sync = None self.cur_event = None self.context = None self.data_socket = None + + def get_id(self): + return self.id + def set_addresses(self, addresses): + self.data_address = addresses['data_address'] + def open(self): """create zmq context and socket""" qutil.LOGGER.info( - "starting data source:{source_id} on {addr}" - .format(source_id=self.source_id, addr=self.data_address)) + "starting data source:{id} on {addr}" + .format(id=self.id, addr=self.data_address)) self.context = zmq.Context() @@ -54,13 +62,16 @@ class DataSource(object): def send(self, event): """ event is expected to be a dict - sets source_id and type properties in the dict + sets id and type properties in the dict sends to the data_socket. """ - event['s'] = self.source_id - event['type'] = 'event' + event['id'] = self.id + event['type'] = self.get_type() self.data_socket.send(json.dumps(event), zmq.NOBLOCK) - + + def get_type(self): + raise NotImplemented + def close(self): """ Close the zmq context and sockets. @@ -69,7 +80,7 @@ class DataSource(object): try: done_msg = {} done_msg['type'] = 'DONE' - done_msg['s'] = self.source_id + done_msg['s'] = self.id self.data_socket.send(json.dumps(done_msg), zmq.NOBLOCK) qutil.LOGGER.info("closing data socket") @@ -90,7 +101,10 @@ class RandomEquityTrades(DataSource): DataSource.__init__(self, source_id) self.count = count self.sid = sid - + + def get_type(self): + return 'equity_trade' + def send_all(self): trade_start = datetime.datetime.now() minute = datetime.timedelta(minutes=1) diff --git a/zipline/test/test_messaging.py b/zipline/test/test_messaging.py index a9565ce2..f0450655 100644 --- a/zipline/test/test_messaging.py +++ b/zipline/test/test_messaging.py @@ -22,19 +22,23 @@ class MessagingTestCase(unittest.TestCase): def setUp(self): """generate some config objects for the datafeed, sources, and transforms.""" - pass + self.addresses = {'sync_address' : "tcp://127.0.0.1:{port}".format(port=10100), + 'data_address' : "tcp://127.0.0.1:{port}".format(port=10101), + 'feed_address' : "tcp://127.0.0.1:{port}".format(port=10102), + 'merge_address' : "tcp://127.0.0.1:{port}".format(port=10103), + 'result_address' : "tcp://127.0.0.1:{port}".format(port=10104) + } - def get_simulator(self, sources, transforms, client, feed=None, merge=None): - return ProcessSimulator(sources, transforms, client, feed=feed, merge=merge) + def get_simulator(self): + return ProcessSimulator() def dtest_sources_only(self): """streams events from two data sources, no transforms.""" - + sim = self.get_simulator() ret1 = RandomEquityTrades(133, "ret1", 400) ret2 = RandomEquityTrades(134, "ret2", 400) - sources = {"ret1":ret1, "ret2":ret2} client = TestClient(self, expected_msg_count=800) - sim = self.get_simulator(sources, {}, client) + sim.register_components([ret1, ret2, client]) sim.simulate() self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, @@ -47,21 +51,18 @@ class MessagingTestCase(unittest.TestCase): 2 datasources -> feed -> 2 moving average transforms -> transform merge -> testclient verify message count at client. """ - + sim = self.get_simulator() ret1 = RandomEquityTrades(133, "ret1", 5000) ret2 = RandomEquityTrades(134, "ret2", 5000) - sources = {"ret1":ret1, "ret2":ret2} mavg1 = MovingAverage("mavg1", 30) mavg2 = MovingAverage("mavg2", 60) - transforms = {"mavg1":mavg1, "mavg2":mavg2} client = TestClient(self, expected_msg_count=10000) - sim = self.get_simulator(sources, transforms, client) + sim.register_components[ret1, ret2, mavg1, mavg2, client] sim.simulate() - self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, "The feed should be drained of all messages.") - def test_error_in_feed(self): + def dtest_error_in_feed(self): ret1 = RandomEquityTrades(133, "ret1", 400) ret2 = RandomEquityTrades(134, "ret2", 400) sources = {"ret1":ret1, "ret2":ret2}