Files
catalyst/backtest/util.py
T
2012-01-19 15:59:57 -05:00

162 lines
4.9 KiB
Python

"""
Small classes to assist with db access, timezone calculations, and so on.
"""
import datetime
import pytz
import json
import logging
import uuid
import zmq
class DocWrap():
"""
Provides attribute access style on top of dictionary results from pymongo.
Allows you to access result['field'] as result.field.
Aliases result['_id'] to result.id.
"""
def __init__(self, store=None):
if(store == None):
self.store = {}
else:
self.store = store.copy()
if(self.store.has_key('_id')):
self.store['id'] = self.store['_id']
del(self.store['_id'])
def __setitem__(self,key,value):
if(key == '_id'):
self.store['id'] = value
else:
self.store[key] = value
def __getitem__(self, key):
if self.store.has_key(key):
return self.store[key]
def __getattr__(self,attrname):
if self.store.has_key(attrname):
return self.store[attrname]
else:
raise AttributeError("No attribute named {name}".format(name=attrname))
def parse_date(dt_str):
"""parse strings according to the same format as generated by format_date"""
if(dt_str == None):
return None
parts = dt_str.split(".")
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc)
return dt
def format_date(dt):
"""Format the date into a date with millesecond resolution and string/alphabetical sorting that is equivalent to datetime sorting"""
if(dt == None):
return None
dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000)
return dt_str
class ParallelBuffer(object):
""" holds several queues of events by key, allows retrieval in date order or by merging"""
def __init__(self, key_list):
self.out_socket = None
self.sent_count = 0
self.received_count = 0
self.draining = False
self.data_buffer = {}
for key in key_list:
self.data_buffer[key] = []
def __len__(self):
return len(self.data_buffer)
def append(self, key, value):
self.data_buffer[key].append(value)
self.received_count += 1
def next(self):
if(not(self.is_full() or self.draining)):
return
cur = None
earliest = None
for source, events in self.data_buffer.iteritems():
if len(events) == 0:
continue
cur = events
if(earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
earliest = cur
if(earliest != None):
return earliest.pop(0)
def is_full(self):
for source, events in self.data_buffer.iteritems():
if (len(events) == 0):
return False
return True
def pending_messages(self):
total = 0
for source, events in self.data_buffer.iteritems():
total += len(events)
return total
def drain(self):
self.draining = True
while(self.pending_messages() > 0):
self.send_next()
def send_next(self):
if(not(self.is_full() or self.draining)):
return
event = self.next()
if(event != None):
self.out_socket.send(json.dumps(event))
self.sent_count += 1
class MergedParallelBuffer(ParallelBuffer):
def __init__(self, keys):
ParallelBuffer.__init__(self, keys)
self.feed = []
self.data_buffer["feed"] = self.feed
def next(self):
if(not(self.is_full() or self.draining)):
return
result = self.feed.pop(0)
for source, events in self.data_buffer.iteritems():
if(source == "feed"):
continue
if(len(events) > 0):
cur = events.pop(0)
result[source] = cur['value']
return result
class FeedSync(object):
def __init__(self, feed, name):
self.feed = feed
self.id = "{name}-{id}".format(name=name, id=uuid.uuid1())
self.feed.register_sync(self.id)
self.logger = logging.getLogger()
#self.logger.info("registered {id} with feed".format(id=self.id))
def confirm(self):
context = zmq.Context()
#synchronize with feed
sync_socket = context.socket(zmq.REQ)
sync_socket.connect(self.feed.sync_address)
# send a synchronization request to the feed
sync_socket.send(self.id)
# wait for synchronization reply from the feed
sync_socket.recv()
sync_socket.close()
context.term()
self.logger.info("sync'd feed from {id}".format(id = self.id))