diff --git a/zipline/gens/mongods.py b/zipline/gens/mongods.py index 417a1f70..303fd465 100644 --- a/zipline/gens/mongods.py +++ b/zipline/gens/mongods.py @@ -4,18 +4,20 @@ Generator-style DataSource that loads from MongoDB. import pytz import logbook +import pymongo from pymongo import ASCENDING +from datetime import datetime, timedelta from zipline import ndict -from zipline.gens.utils import stringify_args +from zipline.gens.utils import stringify_args, assert_datasource_protocol import zipline.protocol as zp -def create_pymongo_iterator(self, collection, filter, start_date, end_date): +def create_pymongo_iterator(collection, filter, start_date, end_date): """ See the comments on :py:class:`zipline.messaging.DataSource` for - expected content of self.filter. Spec must adhere to that definition. + expected content of filter. Spec must adhere to that definition. Returns an iterator that spits out raw objects loaded from MongoDB. """ log = logbook.Logger("MongoDBQuery") @@ -28,12 +30,12 @@ def create_pymongo_iterator(self, collection, filter, start_date, end_date): # Add the list of sids that we care about. if name == 'sid': - assert isinstance(value, sid) + assert isinstance(value, list) sid_range = {'sid':{'$in':value}} spec.update(sid_range) # limit the data to the date range [start, end], inclusive - date_range = {'dt':{'$gte':self.start, '$lte':self.end}} + date_range = {'dt':{'$gte': start_date, '$lte': end_date}} spec.update(date_range) fields = ['sid','price','volume','dt'] @@ -41,8 +43,8 @@ def create_pymongo_iterator(self, collection, filter, start_date, end_date): # In our collection, load all objects matching spec. Of those # objects, get only the fields matching fields, and return the # loaded objects sorted by dt from least to greatest. - - cursor = self.collection.find( + + cursor = collection.find( fields = fields, spec = spec, sort = [("dt",ASCENDING)], @@ -53,17 +55,46 @@ def create_pymongo_iterator(self, collection, filter, start_date, end_date): cursor = cursor.hint([('dt', ASCENDING),('sid', ASCENDING)]) # Set up the iterator - iterator = iter(self.cursor) + iterator = iter(cursor) log.info("MongoDataSource iterator ready") + return iterator def MongoTradeHistoryGen(collection, filter, start_date, end_date): + """A generator that takes a pymongo Collection object, a list of + filters, a start date and an end_date and yields ndicts containing + the results of a query to its collection with the given filter, + start, and end. The output is also packaged with a unique + source_id string for downstream sorting + """ + assert isinstance(collection, pymongo.collection.Collection) + assert isinstance(filter, dict) + assert isinstance(start_date, (datetime)) + assert isinstance(end_date, (datetime)) + + # Set up internal iterator. This outputs raw dictionaries. iterator = create_pymongo_iterator(collection, filter, start_date, end_date) - source_id = "MongoTradeHistoryGen" + stringify_args(collection, filter, start_date, end_date) + # Create unique identifier string that can be used to break + # sorting ties deterministically + argstring = stringify_args(collection, filter, start_date, end_date) + source_id = "MongoTradeHistoryGen" + argstring + + # All datasources for event in iterator: + # Construct a new event that fulfills the datasource protocol. + event['type'] = zp.DATASOURCE_TYPE.TRADE + + event['dt'] = event['dt'].replace(tzinfo=pytz.utc) + event['source_id'] = source_id + payload = ndict(event) + assert_datasource_protocol(payload) + yield payload + + + diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 0c2f468d..27454fb7 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -1,5 +1,40 @@ +import pytz +import numbers + +from hashlib import md5 +from datetime import datetime + +from zipline import ndict +from zipline.protocol import DATASOURCE_TYPE + + def stringify_args(*args, **kwargs): - """Define a unique string for any set of args.""" + """Define a unique string for any set of representable args.""" arg_string = '_'.join([str(arg) for arg in args]) - kwarg_string = '_'.join([str(key) + '=' + str(value) for key, value in kwargs]) + kwarg_string = '_'.join([str(key) + '=' + str(value) for key, value in kwargs.iteritems()]) combined = ':'.join([arg_string, kwarg_string]) + + hasher = md5() + hasher.update(combined) + return hasher.hexdigest() + + +def assert_datasource_protocol(event): + """Assert that an event meets the protocol for datasource outputs.""" + + assert isinstance(event, ndict) + assert isinstance(event.source_id, basestring) + assert isinstance(event.dt, datetime) + assert event.dt.tzinfo == pytz.utc + assert event.type in DATASOURCE_TYPE + +def assert_trade_protocol(event): + """Assert that an event meets the protocol for datasource TRADE outputs.""" + assert_datasource_protocol(event) + + assert isinstance(event, ndict) + assert event.type == DATASOURCE_TYPE.TRADE + assert isinstance(event.sid, int) + assert isinstance(event.price, numbers.Real) + assert isinstance(event.volume, numbers.Integral) + diff --git a/zipline/protocol.py b/zipline/protocol.py index 4822b658..98245981 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -231,7 +231,6 @@ def DATASOURCE_FRAME(event): - *payload* a msgpack string carrying the payload for the frame """ - assert isinstance(event.source_id, basestring) assert isinstance(event.type, int), 'Unexpected type %s' % (event.type)