diff --git a/tests/test_feed.py b/tests/test_feed.py index bfff6742..21e8afb3 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -1,30 +1,26 @@ -import os - -import uuid -import msgpack -import pytz - from unittest2 import TestCase -from pymongo import Connection, ASCENDING -from itertools import izip, izip_longest, permutations, cycle, chain +from itertools import cycle, chain from datetime import datetime, timedelta from collections import deque from zipline import ndict -from zipline.gens.sort import date_sort, ready, done, queue_is_ready,queue_is_done,\ - pop_oldest -from zipline.gens.utils import hash_args, assert_datasource_protocol,\ - assert_trade_protocol, alternate +from zipline.gens.sort import \ + date_sort, \ + ready, \ + done, \ + queue_is_ready,\ + queue_is_done +from zipline.gens.utils import hash_args, alternate from zipline.gens.tradegens import date_gen, SpecificEquityTrades from zipline.gens.composites import date_sorted_sources import zipline.protocol as zp class HelperTestCase(TestCase): - + def setUp(self): pass - + def tearDown(self): pass @@ -33,12 +29,12 @@ class HelperTestCase(TestCase): # Empty queues are neither done nor ready. assert not queue_is_ready(queue) assert not queue_is_done(queue) - + queue.append(to_dt('foo')) assert queue_is_ready(queue) assert not queue_is_done(queue) - + queue.appendleft(to_dt('DONE')) assert queue_is_ready(queue) @@ -48,13 +44,13 @@ class HelperTestCase(TestCase): queue.pop() assert queue_is_ready(queue) assert queue_is_done(queue) - + def test_pop_logic(self): sources = {} ids = ['a', 'b', 'c'] for id in ids: sources[id] = deque() - + assert not ready(sources) assert not done(sources) @@ -66,13 +62,13 @@ class HelperTestCase(TestCase): for id in ids: sources[id].append(to_dt("datetime")) - + assert ready(sources) assert not done(sources) for id in ids: sources[id].appendleft(to_dt("DONE")) - + # ["DONE", message] will trip an assert in queue_is_done. assert ready(sources) self.assertRaises(AssertionError, done, sources) @@ -82,12 +78,12 @@ class HelperTestCase(TestCase): assert ready(sources) assert done(sources) - + class DateSortTestCase(TestCase): - + def setUp(self): pass - + def tearDown(self): pass @@ -99,27 +95,27 @@ class DateSortTestCase(TestCase): sort_gen = date_sort(events, source_ids) l = list(sort_gen) assert l == expected - + def test_single_source(self): source_ids = ['a'] # 100 events, increasing by a minute at a time. type = zp.DATASOURCE_TYPE.TRADE dates = list(date_gen(count = 100)) dates.append("DONE") - + # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] event_args = zip(cycle(source_ids), iter(dates), cycle([type])) - + # Turn event_args into proper events. events = [mock_data_unframe(*args) for args in event_args] - + # We don't expected Feed to yield the last event. expected = events[:-1] event_gen = (e for e in events) - + self.run_date_sort(event_gen, expected, source_ids) - + def test_multi_source(self): source_ids = ['a', 'b'] type = zp.DATASOURCE_TYPE.TRADE @@ -130,8 +126,8 @@ class DateSortTestCase(TestCase): dates_a.append("DONE") events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) - events_a = [mock_data_unframe(*args) for args in events_a_args] - + events_a = [mock_data_unframe(*args) for args in events_a_args] + # Set up source 'b'. Outputs 10 events with 1 minute deltas. delta_b = timedelta(minutes = 1) dates_b = list(date_gen(delta = delta_b, count = 10)) @@ -139,7 +135,7 @@ class DateSortTestCase(TestCase): events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) events_b = [mock_data_unframe(*args) for args in events_b_args] - + # The expected output is all non-DONE events in both a and b, # sorted first by dt and then by source_id. non_dones = events_a[:-1] + events_b[:-1] @@ -155,7 +151,7 @@ class DateSortTestCase(TestCase): self.run_date_sort(sequential, expected, source_ids) def test_sorted_sources(self): - + filter = [1,2] #Set up source a. One hour between events. args_a = tuple() @@ -164,7 +160,7 @@ class DateSortTestCase(TestCase): 'delta' : timedelta(hours = 1), 'filter' : filter } - #Set up source b. One day between events. + #Set up source b. One day between events. args_b = tuple() kwargs_b = {'sids' : [1,2,3,4], 'start' : datetime(2012,6,6,0), @@ -186,19 +182,19 @@ class DateSortTestCase(TestCase): 'delta' : timedelta(minutes = 1), 'filter' : filter } - + sources = (SpecificEquityTrades,) * 4 source_args = (args_a, args_b, args_c, args_d) source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) - + # Generate our expected source_ids. zip_args = zip(source_args, source_kwargs) expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) for args, kwargs in zip_args] - + # Pipe our sources into sort. sort_out = date_sorted_sources(sources, source_args, source_kwargs) - + # Read all the values from sort and assert that they arrive in # the correct sorting with the expected hash values. to_list = list(sort_out) @@ -211,7 +207,7 @@ class DateSortTestCase(TestCase): expected = sorted(copy, compare_by_dt_source_id) assert to_list == expected - + def mock_data_unframe(source_id, dt, type): event = ndict() event.source_id = source_id @@ -227,11 +223,11 @@ def compare_by_dt_source_id(x,y): return -1 elif x.dt > y.dt: return 1 - + elif x.source_id < y.source_id: return -1 elif x.source_id > y.source_id: return 1 - + else: return 0 diff --git a/tests/test_mongods.py b/tests/test_mongods.py deleted file mode 100644 index d9b8dbe5..00000000 --- a/tests/test_mongods.py +++ /dev/null @@ -1,114 +0,0 @@ -import os - -import uuid -import msgpack -import pytz - -from unittest2 import TestCase -from pymongo import Connection, ASCENDING -from itertools import izip, izip_longest -from datetime import datetime, timedelta - -from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen -from zipline.gens.utils import hash_args, assert_datasource_protocol,\ - assert_trade_protocol, mock_raw_event - -import zipline.protocol as zp - -mongo_conn_args = { - 'mongodb_host': 'localhost', - 'mongodb_port': 27017, -} - -class TempMongo(object): - - def __enter__(self): - self.conn = Connection(mongo_conn_args['mongodb_host'], - mongo_conn_args['mongodb_port']) - - temp_id = 'qexec_test_id' - - self.db = self.conn[temp_id] - - return self - - def __exit__(self, type, value, traceback): - self.conn.drop_database(self.db.name) - -class TestMongoDataGenerator(TestCase): - - def setUp(self): - pass - def tearDown(self): - pass - - def test_create_pymongo_iterator(self): - - with TempMongo() as temp_mongo: - db = temp_mongo.db - coll = db.test - coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)]) - - for i in xrange(100): - # sid = 1, dt ranging from 0 to 99 - coll.insert(mock_raw_event(1, i)) - - start_date = 20 - end_date = 50 - filter = {'sid' : [1]} - args = (coll, filter, start_date, end_date) - - cursor = create_pymongo_iterator(*args) - # We filter to only get dt's between 20 and 50 - expected = (mock_raw_event(1, i) for i in xrange(20, 51)) - - # Assert that our iterator returns the expected values. - for cursor_event, expected_event in izip_longest(cursor, expected): - del cursor_event['_id'] - # Easiest way to convert unicode to strings. - cursor_event = msgpack.loads(msgpack.dumps(cursor_event)) - assert expected_event.keys() == cursor_event.keys() - assert expected_event.values() == cursor_event.values() - - def test_MongoTradeHistoryGen(self): - - with TempMongo() as temp_mongo: - db = temp_mongo.db - coll = db.test - coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)]) - - start_date = datetime(year = 2012,month=6,day=5,hour=0) - delta = timedelta(hours = 1) - - for i in xrange(100): - # sid = 1, dt's increasing an hour at a time from start - time = start_date + i * delta - coll.insert(mock_raw_event(1, time)) - - # Halfway through the events we added to db. - end_date = start_date + delta * 50 - - filter = {'sid' : [1]} - args = (coll, filter, start_date, end_date) - db_gen = MongoTradeHistoryGen(*args) - - expected_times = (start_date + i * delta for i in xrange(51)) - expected_events = (mock_raw_event(1, t) for t in expected_times) - - # DB events should match the expected events for price, dt, volume, - # and sid. They should also conform to the trade frame protocol. - - for db, expected in izip_longest(db_gen, expected_events): - expected['dt'] = expected['dt'].replace(tzinfo = pytz.utc) - # Check that our output meets the trade protocol. - assert_trade_protocol(db) - - # Check that our output matches expectations - for field in iter(['sid', 'dt', 'price', 'volume']): - assert db[field] == expected[field] - - # Expected output of hash_args: - assert db['source_id'] == \ - 'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc' - - diff --git a/zipline/gens/mongods.py b/zipline/gens/mongods.py deleted file mode 100644 index aa22d4c1..00000000 --- a/zipline/gens/mongods.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Generator-style DataSource that loads from MongoDB. -""" - -import pytz -import logbook -import pymongo - -from pymongo import ASCENDING -from datetime import datetime - -from zipline import ndict -from zipline.gens.utils import hash_args, \ - assert_trade_protocol - -import zipline.protocol as zp - -def MongoTradeHistoryGen(collection, filter, start_date, end_date): - """A generator that takes a pymongo Collection object, a list of - filters, a start date and an end_date and yields ndicts containing - the results of a query to its collection with the given filter, - start, and end. The output is also packaged with a unique - source_id string for downstream sorting - """ - - assert isinstance(collection, pymongo.collection.Collection) - assert isinstance(filter, dict) - assert isinstance(start_date, (datetime)) - assert isinstance(end_date, (datetime)) - - # Set up internal iterator. This outputs raw dictionaries. - iterator = create_pymongo_iterator(collection, filter, start_date, end_date) - - # Create unique identifier string that can be used to break - # sorting ties deterministically - argstring = hash_args(collection, filter, start_date, end_date) - source_id = "MongoTradeHistoryGen" + argstring - - # All datasources - for event in iterator: - # Construct a new event that fulfills the datasource protocol. - event['type'] = zp.DATASOURCE_TYPE.TRADE - event['dt'] = event['dt'].replace(tzinfo=pytz.utc) - event['source_id'] = source_id - - payload = ndict(event) - assert_trade_protocol(payload) - yield payload - -def create_pymongo_iterator(collection, filter, start_date, end_date): - """ - Returns an iterator that spits out raw objects loaded from a - MongoDB collection. - - See the comments on :py:class:`zipline.messaging.DataSource` - for expected content of filter. - """ - log = logbook.Logger("MongoDBQuery") - - # Object that will hold our database query. - spec = {} - - # add the filters from the algorithm. - for name, value in filter.iteritems(): - - # Add the list of sids that we care about. - if name == 'sid': - assert isinstance(value, list) - sid_range = {'sid':{'$in':value}} - spec.update(sid_range) - - # limit the data to the date range [start, end], inclusive - date_range = {'dt':{'$gte': start_date, '$lte': end_date}} - spec.update(date_range) - - fields = ['sid','price','volume','dt'] - - # In our collection, load all objects matching spec. Of those - # objects, get only the fields matching fields, and return the - # loaded objects sorted by dt from least to greatest. - - cursor = collection.find( - fields = fields, - spec = spec, - sort = [("dt",ASCENDING)], - slave_ok = True - ) - - # Optimize the cursor sort to query in 'dt' and 'sid' order. - cursor = cursor.hint([('dt', ASCENDING),('sid', ASCENDING)]) - - # Set up the iterator - iterator = iter(cursor) - log.info("MongoDataSource iterator ready") - - return iterator