first draft of refactoring is done

This commit is contained in:
fawce
2012-02-15 22:35:00 -05:00
parent 4253844432
commit fc8b942e45
4 changed files with 342 additions and 539 deletions
+9 -453
View File
@@ -12,141 +12,26 @@ import atexit
import zipline.util as qutil
import zipline.messaging as qmsg
class SimulatorBase(object):
class SimulatorBase(ComponentHost):
"""
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
"""
def __init__(self, sources, transforms, client, feed=None, merge=None):
def __init__(self, addresses):
"""
"""
self.sources = sources
self.transforms = transforms
self.client = client
self.merge = None
self.feed = None
self.context = None
self.sync_context = None
self.sync_socket = None
self.sync_register = {}
self.sync_address = "tcp://127.0.0.1:{port}".format(port=10100)
self.data_address = "tcp://127.0.0.1:{port}".format(port=10101)
self.feed_address = "tcp://127.0.0.1:{port}".format(port=10102)
self.merge_address = "tcp://127.0.0.1:{port}".format(port=10103)
self.result_address = "tcp://127.0.0.1:{port}".format(port=10104)
self.performance_address = "tcp://127.0.0.1:{port}".format(port=10105)
self.timeout = datetime.timedelta(seconds=5)
"""
self.feed = ParallelBuffer()
self.merge = MergedParallelBuffer()
#workaround for defect in threaded use of strptime: http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
if(feed == None):
self.feed = DataFeed(self.sources.keys(),
self.data_address,
self.feed_address,
self.performance_address,
qmsg.Sync(self,"DataFeed"))
else:
self.feed = feed
if(merge == None):
#connect merge to feed, set expected transforms
self.merge = TransformsMerge(self.feed_address,
self.merge_address,
self.result_address,
qmsg.Sync(self,"TransformsMerge"),
self.transforms.keys())
else:
self.merge = merge
#register the feed and the merge
self.register_components([self.feed, self.merge])
def simulate(self):
#launch the feed
self.launch_component("DataFeed", self.feed)
self.launch_component(self)
#launch the data sources
for name, data_source in self.sources.iteritems():
data_source.data_address = self.data_address
data_source.sync = qmsg.Sync(self, str(data_source.source_id))
self.launch_component(name, data_source)
qutil.LOGGER.info("datasources processes launched")
#connect all the transforms to the feed and merge, launch each
for name, transform in self.transforms.iteritems():
transform.feed_address = self.feed_address #connect transform to receive feed.
transform.merge_address = self.merge_address #connect transform to push results to merge
transform.sync = qmsg.Sync(self, name) #synchronize the transform against this simulation.
self.launch_component(name, transform) #start transforms
#launch merge
self.launch_component("transforms merge", self.merge)
qutil.LOGGER.info("transform processes launched")
#connect client to merged feed
self.client.address = self.result_address
self.client.sync = qmsg.Sync(self,"Client")
client_proc = self.launch_component("client", self.client)
qutil.LOGGER.info("client process launched")
qutil.LOGGER.info("sync register starting with {count} members: {reg}".format(count=len(self.sync_register), reg=self.sync_register))
self.sync_components()
#client_proc.join() #wait for client to complete processing
def launch_component(self, name, component):
raise NotImplementedError
def register_sync(self, sync_id):
self.sync_register[sync_id] = datetime.datetime.utcnow()
def unregister_sync(self, sync_id):
qutil.LOGGER.info("unregistering {sync_id}".format(sync_id=sync_id))
del(self.sync_register[sync_id])
def is_timed_out(self):
cur_time = datetime.datetime.utcnow()
if(len(self.sync_register) == 0):
qutil.LOGGER.info("**********Simulator sync register is empty.")
return True
for source, last_dt in self.sync_register.iteritems():
if((cur_time - last_dt) > self.timeout):
qutil.LOGGER.info("Time out for {source}. Current registery: {reg}".format(source=source, reg=self.sync_register))
return True
return False
def sync_components(self):
# Socket to receive signals
self.context = zmq.Context()
qutil.LOGGER.info("waiting for all datasources and clients to be ready")
self.sync_socket = self.context.socket(zmq.REP)
self.sync_socket.bind(self.sync_address)
#self.sync_socket.setsockopt(zmq.LINGER,0)
self.poller = zmq.Poller()
self.poller.register(self.sync_socket, zmq.POLLIN)
while not self.is_timed_out():
# wait for synchronization request
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
try:
msg = self.sync_socket.recv()
parts = msg.split(':')
sync_id = parts[0]
status = parts[1]
if(status == "DONE"):
self.unregister_sync(sync_id)
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack', zmq.NOBLOCK)
except:
qutil.LOGGER.exception("Exception in sync components loop")
self.sync_socket.close()
qutil.LOGGER.info("simulator heartbeat stopped.")
class ThreadSimulator(SimulatorBase):
def __init__(self, sources, transforms, client, feed=None, merge=None):
@@ -168,333 +53,4 @@ class ProcessSimulator(SimulatorBase):
proc = multiprocessing.Process(target=component.run)
proc.start()
return proc
class DataFeed(object):
def __init__(self, source_list, data_address, feed_address, performance_address, sync):
"""
:source_list: list of data source IDs
"""
self.feed_address = feed_address
self.performance_address = performance_address
self.data_address = data_address
self.data_buffer = qmsg.ParallelBuffer(source_list)
self.sync = sync
self.feed_socket = None
self.data_socket = None
self.perf_socket = None
self.context = None
self.poller = None
def open(self):
# Prepare our context and sockets
self.context = zmq.Context()
#create the data sink. Based on http://zguide.zeromq.org/py:tasksink2
#see: http://zguide.zeromq.org/py:taskwork2
self.data_socket = self.context.socket(zmq.PULL)
self.data_socket.bind(self.data_address)
#create the feed
self.feed_socket = self.context.socket(zmq.PUB)
self.feed_socket.bind(self.feed_address)
#self.feed_socket.setsockopt(zmq.LINGER,0)
self.data_buffer.out_socket = self.feed_socket
self.poller = zmq.Poller()
self.poller.register(self.data_socket, zmq.POLLIN)
#create the performance results push
self.perf_socket = self.context.socket(zmq.PUSH)
self.perf_socket.bind(self.performance_address)
#self.perf_socket.setsockopt(zmq.LINGER,0)
self.sync.open()
def close(self):
try:
self.data_socket.close()
self.feed_socket.close()
self.perf_socket.close()
self.sync.close()
except:
qutil.LOGGER.exception("Error closing DataFeed")
finally:
self.context.destroy()
def handle_all(self):
qutil.LOGGER.info("entering feed loop on {addr}".format(addr=self.data_address))
ds_finished_counter = 0
while self.sync.confirm():
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.data_socket in socks and socks[self.data_socket] == zmq.POLLIN:
message = self.data_socket.recv()
event = json.loads(message)
if(event["type"] == "DONE"):
ds_finished_counter += 1
if(len(self.data_buffer) == ds_finished_counter):
break
else:
self.data_buffer.append(event[u's'], event)
self.data_buffer.send_next()
#self.perf_socket.send(message, zmq.NOBLOCK)
#drain any remaining messages in the buffer
self.data_buffer.drain()
#send the DONE message
self.feed_socket.send("DONE", zmq.NOBLOCK)
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.data_buffer.received_count,
m=self.data_buffer.sent_count))
def run(self):
try:
self.open()
self.handle_all()
except:
qutil.LOGGER.exception("Exception in Feed, attempting to close.")
finally:
self.close()
class BaseTransform(object):
"""Parent class for feed transforms. Subclass and override transform
method to create a new derived value from the combined feed."""
def __init__(self, name):
"""
"""
self.feed_address = None
self.merge_address = None
self.state = {}
self.state['name'] = name
self.sync = None
self.received_count = 0
self.sent_count = 0
self.context = None
self.feed_socket = None
self.result_socket = None
def run(self):
"""Top level execution entry point for the transform::
- connects to the feed socket to subscribe to events
- connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms
- processes all messages received from feed, until DONE message received
- pushes all transforms
- sends DONE to result socket, closes all sockets and context"""
try:
self.open()
self.process_all()
except:
qutil.LOGGER.exception("Exception during transform processing, attempting to close merge.")
finally:
self.close()
def open(self):
"""
Establishes zmq connections.
"""
self.context = zmq.Context()
qutil.LOGGER.info("starting {name} transform".
format(name = self.state['name']))
#create the feed SUB.
self.feed_socket = self.context.socket(zmq.SUB)
self.feed_socket.connect(self.feed_address)
self.feed_socket.setsockopt(zmq.SUBSCRIBE,'')
self.poller = zmq.Poller()
self.poller.register(self.feed_socket, zmq.POLLIN)
#create the result PUSH
self.result_socket = self.context.socket(zmq.PUSH)
self.result_socket.connect(self.merge_address)
#self.result_socket.setsockopt(zmq.LINGER,0)
self.sync.open()
def process_all(self):
"""
Loops until feed's DONE message is received:
- receive an event from the data feed
- call transform (subclass' method) on event
- send the transformed event
"""
qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name']))
while self.sync.confirm():
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
message = self.feed_socket.recv()
if(message == "DONE"):
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
self.result_socket.send("DONE", zmq.NOBLOCK)
break
self.received_count += 1
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['name'] = self.state['name']
self.result_socket.send(json.dumps(cur_state), zmq.NOBLOCK)
self.sent_count += 1
def close(self):
"""
Shut down zmq resources.
"""
qutil.LOGGER.info("Transform {name} recieved {r} and sent {s}".format(
name=self.state['name'],
r=self.received_count,
s=self.sent_count))
try:
self.feed_socket.close()
self.result_socket.close()
self.sync.close()
except:
qutil.LOGGER.exception("Error closing Transforms")
finally:
self.context.destroy()
def transform(self, event):
""" Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
transformed value:
self.state['value'] = transformed_value
"""
return {}
class TransformsMerge(object):
""" Merge data feed and array of transform feeds into a single result vector.
PULL from feed
PULL from child transforms
PUSH merged message to client
"""
def __init__(self, feed_address, transform_address, result_address, sync, transform_list):
"""
"""
self.sync = sync
self.feed_address = feed_address
self.transform_address = transform_address
self.result_address = result_address
buffer_list = copy.copy(transform_list)
buffer_list.append("feed") #for the raw feed
self.data_buffer = qmsg.MergedParallelBuffer(buffer_list)
self.feed_socket = None
self.result_socket = None
self.poller = None
self.context = None
self.transform_socket = None
def run(self):
""""""
try:
self.open()
self.process_all()
except:
qutil.LOGGER.exception("Exception during merge processing, attempting to close merge.")
finally:
self.close()
def open(self):
"""Establish zmq context, feed socket, result socket for client, and transform
socket to receive transformed events. Create and launch transforms. Will confirm
ready with the DataFeed at the conclusion."""
self.context = zmq.Context()
qutil.LOGGER.info("starting transforms merge")
#create the feed SUB.
self.feed_socket = self.context.socket(zmq.SUB)
self.feed_socket.connect(self.feed_address)
self.feed_socket.setsockopt(zmq.SUBSCRIBE,'')
#create the result PUSH
self.result_socket = self.context.socket(zmq.PUSH)
self.result_socket.bind(self.result_address)
#self.result_socket.setsockopt(zmq.LINGER,0)
#create the transform PULL.
self.transform_socket = self.context.socket(zmq.PULL)
self.transform_socket.bind(self.transform_address)
self.data_buffer.out_socket = self.result_socket
# Initialize poll set
self.poller = zmq.Poller()
self.poller.register(self.feed_socket, zmq.POLLIN)
self.poller.register(self.transform_socket, zmq.POLLIN)
self.sync.open()
def close(self):
"""
Close all zmq sockets and context.
"""
try:
self.sync.close()
self.transform_socket.close()
self.feed_socket.close()
self.result_socket.close()
except:
qutil.LOGGER.exception("Error closing merge")
finally:
self.context.destroy()
def process_all(self):
"""
Uses a Poller to receive messages from all transforms and the feed.
All transforms corresponding to the same event are merged with each other
and the original feed event into a single message. That message is then
sent to the result socket.
"""
done_count = 0
while self.sync.confirm():
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
message = self.feed_socket.recv()
if(message == "DONE"):
qutil.LOGGER.info("finished receiving feed to merge")
done_count += 1
else:
event = json.loads(message)
self.data_buffer.append("feed", event)
if self.transform_socket in socks and socks[self.transform_socket] == zmq.POLLIN:
t_message = self.transform_socket.recv()
if(t_message == "DONE"):
qutil.LOGGER.info("finished receiving a transform to merge")
done_count += 1
else:
t_event = json.loads(t_message)
self.data_buffer.append(t_event['name'], t_event)
if(done_count >= len(self.data_buffer)):
break #done!
self.data_buffer.send_next()
qutil.LOGGER.info("about to drain {n} messages from merger's buffer".format(n=self.data_buffer.pending_messages()))
#drain any remaining messages in the buffer
self.data_buffer.drain()
#signal to client that we're done
self.result_socket.send("DONE", zmq.NOBLOCK)
+296 -64
View File
@@ -4,21 +4,243 @@ Commonly used messaging components.
import json
import uuid
from gevent_zeromq import zmq
import zipline.util as qutil
class ParallelBuffer(object):
""" holds several queues of events by key, allows retrieval in date order
or by merging"""
def __init__(self, key_list):
self.out_socket = None
self.sent_count = 0
self.received_count = 0
self.draining = False
self.data_buffer = {}
for key in key_list:
self.data_buffer[key] = []
class Component(object):
def __init__(self, addresses):
"""
:addresses: a dict of name_string -> zmq port address strings. Must have the following entries::
- sync_address: socket address used for synchronizing the start of all workers, heartbeating, and exit notification
will be used in REP/REQ sockets. Bind is always on the REP side.
- data_address: socket address used for data sources to stream their records.
will be used in PUSH/PULL sockets between data sources and a ParallelBuffer (aka the Feed). Bind
will always be on the PULL side (we always have N producers and 1 consumer)
- feed_address: socket address used to publish consolidated feed from serialization of data sources
will be used in PUB/SUB sockets between Feed and Transforms. Bind is always on the PUB side.
- merge_address: socket address used to publish transformed values.
will be used in PUSH/PULL from many transforms to one MergedParallelBuffer (aka the Merge). Bind
will always be on the PULL side (we always have N producers and 1 consumer)
- result_address: socket address used to publish merged data source feed and transforms to clients
will be used in PUB/SUB from one Merge to one or many clients. Bind is always on the PUB side.
Bind/Connect methods will return the correct socket type for each address. Any sockets on which recv is expected to be called
will also return a Poller.
"""
self.context = zmq.Context()
self.addresses = addresses
self.sockets = []
def get_id(self):
raise NotImplemented
def open(self):
raise NotImplemented
def do_work(self):
raise NotImplemented
def run(self):
try:
self.open()
self.connect_sync()
while self.confirm():
self.do_work()
#notify host we're done
self.sync_socket.send(self.sync_id + ":DONE")
#close all the sockets
for sock in self.sockets:
sock.close()
finally:
self.context.destroy()
def confirm(self):
try:
# send a synchronization request to the host
self.sync_socket.send(self.sync_id + ":RUNNING")
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
message = self.sync_socket.recv()
return True
except:
qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id))
return False
def bind_data(self):
return self.bind_pull_socket(self, self.addresses['data_address'])
def connect_data(self):
return self.connect_push_socket(self, self.addresses['data_address'])
def bind_feed(self):
return self.bind_pub_socket(self, self.addresses['feed_address'])
def connect_feed(self):
return self.bind_sub_socket(self, self.addresses['feed_address'])
def bind_merge(self):
return self.bind_pull_socket(self, self.addresses['merge_address'])
def connect_merge(self):
return self.connect_push_socket(self, self.addresses['merge_address'])
def bind_result(self):
return self.bind_pub_socket(self, self.addresses['result_address'])
def connect_result(self):
return self.bind_sub_socket(self, self.addresses['result_address'])
def bind_pull_socket(self, addr):
pull_socket = self.context.socket(zmq.PULL)
pull_socket.bind(addr)
poller = zmq.Poller()
poller.register(self.pull_socket, zmq.POLLIN)
self.sockets.append(pull_socket)
return pull_socket, poller
def connect_push_socket(self, addr):
push_socket = self.context.socket(zmq.PUSH)
push_socket.connect(self.merge_address)
#push_socket.setsockopt(zmq.LINGER,0)
self.sockets.append(push_socket)
return push_socket
def bind_pub_socket(self, addr):
pub_socket = self.context.socket(zmq.PUB)
pub_socket.bind(self.pub_address)
#pub_socket.setsockopt(zmq.LINGER,0)
poller = zmq.Poller()
poller.register(self.pub_socket, zmq.POLLIN)
self.sockets.append(pub_socket)
return pub_socket, poller
def connect_sub_socket(self, addr):
sub_socket = self.context.socket(zmq.SUB)
sub_socket.connect(self.feed_address)
sub_socket.setsockopt(zmq.SUBSCRIBE,'')
self.sockets.append(sub_socket)
return sub_socket
def bind_sync(self):
sync_socket = self.context.socket(zmq.REP)
sync_socket.bind(self.addresses['sync_address'])
poller = zmq.Poller()
poller.register(self.sync_socket, zmq.POLLIN)
self.sockets.append(sync_socket)
return sync_socket, poller
def connect_sync(self):
self.sync_socket = self.context.socket(zmq.REQ)
self.sync_socket.connect(self.addresses['sync_address'])
self.sync_socket.setsockopt(zmq.LINGER,0)
self.poller = zmq.Poller()
self.poller.register(self.sync_socket, zmq.POLLIN)
self.sockets.append(sync_socket)
class ComponentHost(Component):
"""Component that can launch multiple sub-components, synchronize their start, and then wait for all
components to be finished."""
def __init__(self, addresses):
Component.__init__(self, addresses)
self.components = {}
self.timeout = datetime.timedelta(seconds=5)
def register_components(self, component_list):
for component in component_list:
self.components[component.get_id()] = component
def unregister_component(self, component_id):
del(self.components[component_id])
def open(self):
self.sync_socket, self.poller = self.bind_sync()
for component in self.components.values():
self.launch_component(component)
def is_timed_out(self):
cur_time = datetime.datetime.utcnow()
if(len(self.sync_register) == 0):
qutil.LOGGER.info("Component register is empty.")
return True
for source, last_dt in self.sync_register.iteritems():
if((cur_time - last_dt) > self.timeout):
qutil.LOGGER.info("Time out for {source}. Current registery: {reg}".format(source=source, reg=self.sync_register))
return True
return False
def run(self):
while not self.is_timed_out():
# wait for synchronization request
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
try:
msg = self.sync_socket.recv()
parts = msg.split(':')
sync_id = parts[0]
status = parts[1]
if(status == "DONE"):
self.unregister_component(sync_id)
else:
self.sync_register[sync_id] = datetime.datetime.utcnow()
#qutil.LOGGER.info("confirmed {id}".format(id=msg))
# send synchronization reply
self.sync_socket.send('ack', zmq.NOBLOCK)
except:
qutil.LOGGER.exception("Exception in sync components loop")
def luanch_component(self, component):
raise NotImplemented
class ParallelBuffer(Component):
"""Connects to N PULL sockets, publishing all messages received to a PUB socket.
Published messages are guaranteed to be in chronological order based on message property dt.
Expects to be instantiated in one execution context (thread, process, etc) and run in another."""
def __init__(self):
self.sent_count = 0
self.received_count = 0
self.draining = False
self.data_buffer = {}
self.ds_finished_counter = 0
def get_id(self):
return "FEED"
def add_source(self, source_id):
self.data_buffer[source_id] = []
def open(self):
self.pull_socket, self.poller = self.bind_data()
self.feed_socket = self.bind_feed()
def do_work(self):
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
message = self.pull_socket.recv()
event = json.loads(message)
if(event["type"] == "DONE"):
ds_finished_counter += 1
if(len(self.data_buffer) == ds_finished_counter):
#drain any remaining messages in the buffer
self.drain()
#send the DONE message
self.feed_socket.send("DONE", zmq.NOBLOCK)
qutil.LOGGER.info("received {n} messages, sent {m} messages".format(n=self.received_count,
m=self.sent_count))
else:
self.append(event[u's'], event)
self.send_next()
def __len__(self):
"""buffer's length is same as internal map holding separate sorted arrays of events keyed by source id"""
return len(self.data_buffer)
@@ -72,7 +294,7 @@ class ParallelBuffer(object):
event = self.next()
if(event != None):
self.out_socket.send(json.dumps(event), zmq.NOBLOCK)
self.feed_socket.send(json.dumps(event), zmq.NOBLOCK)
self.sent_count += 1
@@ -81,10 +303,8 @@ class MergedParallelBuffer(ParallelBuffer):
Merges multiple streams of events into single messages.
"""
def __init__(self, keys):
ParallelBuffer.__init__(self, keys)
self.feed = []
self.data_buffer["feed"] = self.feed
def __init__(self):
ParallelBuffer.__init__(self)
def next(self):
"""Get the next merged message from the feed buffer."""
@@ -100,54 +320,66 @@ class MergedParallelBuffer(ParallelBuffer):
result[source] = cur['value']
return result
class Sync(object):
"""Sync instances register themselves with a Host. Once the Sync
is created, the Host is guaranteed to block until confirm is called on this
instance (and all others registered with the host). Components can use instances
to delay the start of the host until initial setup is complete."""
def __init__(self, host, name):
self.host = host
self.sync_id = "{name}-{id}".format(name=name, id=uuid.uuid1())
self.context = None
self.sync_socket = None
self.poller = None
self.host.register_sync(self.sync_id)
#qutil.LOGGER.info("registered {id} with host".format(id=self.sync_id))
def open(self):
self.context = zmq.Context()
#synchronize with host
self.sync_socket = self.context.socket(zmq.REQ)
self.sync_socket.connect(self.host.sync_address)
self.sync_socket.setsockopt(zmq.LINGER,0)
self.poller = zmq.Poller()
self.poller.register(self.sync_socket, zmq.POLLIN)
def confirm(self):
"""Confirm readiness with the Host."""
try:
# send a synchronization request to the host
self.sync_socket.send(self.sync_id + ":RUNNING")
# wait for synchronization reply from the host
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.sync_socket in socks and socks[self.sync_socket] == zmq.POLLIN:
message = self.sync_socket.recv()
class BaseTransform(Component):
"""Top level execution entry point for the transform::
- connects to the feed socket to subscribe to events
- connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms
- processes all messages received from feed, until DONE message received
- pushes all transforms
- sends DONE to result socket, closes all sockets and context
return True
except:
qutil.LOGGER.exception("exception in confirmation for {source}. Exiting.".format(source=self.sync_id))
return False
Parent class for feed transforms. Subclass and override transform
method to create a new derived value from the combined feed."""
def __init__(self, name):
self.state = {}
self.state['name'] = name
def get_id(self):
return self.state['name']
def open(self):
"""
Establishes zmq connections.
"""
#create the feed.
self.feed_socket, self.poller = self.connect_feed()
#create the result PUSH
self.result_socket = self.connect_merge()
def close(self):
try:
self.sync_socket.send(self.sync_id + ":DONE")
self.sync_socket.close()
except:
qutil.LOGGER.exception("Error closing Sync object")
finally:
self.context.destroy()
def do_work(self):
"""
Loops until feed's DONE message is received:
- receive an event from the data feed
- call transform (subclass' method) on event
- send the transformed event
"""
qutil.LOGGER.info("starting {name} event loop".format(name = self.state['name']))
socks = dict(self.poller.poll(2000)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == zmq.POLLIN:
message = self.feed_socket.recv()
if(message == "DONE"):
qutil.LOGGER.info("{name} received the Done message from the feed".format(name=self.state['name']))
self.result_socket.send("DONE", zmq.NOBLOCK)
return
self.received_count += 1
event = json.loads(message)
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['name'] = self.state['name']
self.result_socket.send(json.dumps(cur_state), zmq.NOBLOCK)
self.sent_count += 1
def transform(self, event):
""" Must return the transformed value as a map with {name:"name of new transform", value: "value of new field"}
Transforms run in parallel and results are merged into a single map, so transform names must be unique.
Best practice is to use the self.state object initialized from the transform configuration, and only set the
transformed value:
self.state['value'] = transformed_value
"""
return {'value'=event}
+24 -10
View File
@@ -7,26 +7,34 @@ import json
import random
import zipline.util as qutil
import zipline.messaging as qmsg
class DataSource(object):
class DataSource(qmsg.Component):
"""
Baseclass for data sources. Subclass and implement send_all - usually this
means looping through all records in a store, converting to a dict, and
calling send(map).
"""
def __init__(self, source_id):
self.source_id = source_id
self.id = source_id
self.host = host
self.data_address = None
self.sync = None
self.cur_event = None
self.context = None
self.data_socket = None
def get_id(self):
return self.id
def set_addresses(self, addresses):
self.data_address = addresses['data_address']
def open(self):
"""create zmq context and socket"""
qutil.LOGGER.info(
"starting data source:{source_id} on {addr}"
.format(source_id=self.source_id, addr=self.data_address))
"starting data source:{id} on {addr}"
.format(id=self.id, addr=self.data_address))
self.context = zmq.Context()
@@ -54,13 +62,16 @@ class DataSource(object):
def send(self, event):
"""
event is expected to be a dict
sets source_id and type properties in the dict
sets id and type properties in the dict
sends to the data_socket.
"""
event['s'] = self.source_id
event['type'] = 'event'
event['id'] = self.id
event['type'] = self.get_type()
self.data_socket.send(json.dumps(event), zmq.NOBLOCK)
def get_type(self):
raise NotImplemented
def close(self):
"""
Close the zmq context and sockets.
@@ -69,7 +80,7 @@ class DataSource(object):
try:
done_msg = {}
done_msg['type'] = 'DONE'
done_msg['s'] = self.source_id
done_msg['s'] = self.id
self.data_socket.send(json.dumps(done_msg), zmq.NOBLOCK)
qutil.LOGGER.info("closing data socket")
@@ -90,7 +101,10 @@ class RandomEquityTrades(DataSource):
DataSource.__init__(self, source_id)
self.count = count
self.sid = sid
def get_type(self):
return 'equity_trade'
def send_all(self):
trade_start = datetime.datetime.now()
minute = datetime.timedelta(minutes=1)
+13 -12
View File
@@ -22,19 +22,23 @@ class MessagingTestCase(unittest.TestCase):
def setUp(self):
"""generate some config objects for the datafeed, sources, and transforms."""
pass
self.addresses = {'sync_address' : "tcp://127.0.0.1:{port}".format(port=10100),
'data_address' : "tcp://127.0.0.1:{port}".format(port=10101),
'feed_address' : "tcp://127.0.0.1:{port}".format(port=10102),
'merge_address' : "tcp://127.0.0.1:{port}".format(port=10103),
'result_address' : "tcp://127.0.0.1:{port}".format(port=10104)
}
def get_simulator(self, sources, transforms, client, feed=None, merge=None):
return ProcessSimulator(sources, transforms, client, feed=feed, merge=merge)
def get_simulator(self):
return ProcessSimulator()
def dtest_sources_only(self):
"""streams events from two data sources, no transforms."""
sim = self.get_simulator()
ret1 = RandomEquityTrades(133, "ret1", 400)
ret2 = RandomEquityTrades(134, "ret2", 400)
sources = {"ret1":ret1, "ret2":ret2}
client = TestClient(self, expected_msg_count=800)
sim = self.get_simulator(sources, {}, client)
sim.register_components([ret1, ret2, client])
sim.simulate()
self.assertEqual(sim.feed.data_buffer.pending_messages(), 0,
@@ -47,21 +51,18 @@ class MessagingTestCase(unittest.TestCase):
2 datasources -> feed -> 2 moving average transforms -> transform merge -> testclient
verify message count at client.
"""
sim = self.get_simulator()
ret1 = RandomEquityTrades(133, "ret1", 5000)
ret2 = RandomEquityTrades(134, "ret2", 5000)
sources = {"ret1":ret1, "ret2":ret2}
mavg1 = MovingAverage("mavg1", 30)
mavg2 = MovingAverage("mavg2", 60)
transforms = {"mavg1":mavg1, "mavg2":mavg2}
client = TestClient(self, expected_msg_count=10000)
sim = self.get_simulator(sources, transforms, client)
sim.register_components[ret1, ret2, mavg1, mavg2, client]
sim.simulate()
self.assertEqual(sim.feed.data_buffer.pending_messages(), 0, "The feed should be drained of all messages.")
def test_error_in_feed(self):
def dtest_error_in_feed(self):
ret1 = RandomEquityTrades(133, "ret1", 400)
ret2 = RandomEquityTrades(134, "ret2", 400)
sources = {"ret1":ret1, "ret2":ret2}