moved mongods out

This commit is contained in:
fawce
2012-07-30 17:53:58 -04:00
parent 3cff533ff9
commit ef1ca0388d
3 changed files with 37 additions and 251 deletions
+37 -41
View File
@@ -1,30 +1,26 @@
import os
import uuid
import msgpack
import pytz
from unittest2 import TestCase
from pymongo import Connection, ASCENDING
from itertools import izip, izip_longest, permutations, cycle, chain
from itertools import cycle, chain
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.sort import date_sort, ready, done, queue_is_ready,queue_is_done,\
pop_oldest
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
assert_trade_protocol, alternate
from zipline.gens.sort import \
date_sort, \
ready, \
done, \
queue_is_ready,\
queue_is_done
from zipline.gens.utils import hash_args, alternate
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
from zipline.gens.composites import date_sorted_sources
import zipline.protocol as zp
class HelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
@@ -33,12 +29,12 @@ class HelperTestCase(TestCase):
# Empty queues are neither done nor ready.
assert not queue_is_ready(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_ready(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_ready(queue)
@@ -48,13 +44,13 @@ class HelperTestCase(TestCase):
queue.pop()
assert queue_is_ready(queue)
assert queue_is_done(queue)
def test_pop_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not ready(sources)
assert not done(sources)
@@ -66,13 +62,13 @@ class HelperTestCase(TestCase):
for id in ids:
sources[id].append(to_dt("datetime"))
assert ready(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert ready(sources)
self.assertRaises(AssertionError, done, sources)
@@ -82,12 +78,12 @@ class HelperTestCase(TestCase):
assert ready(sources)
assert done(sources)
class DateSortTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
@@ -99,27 +95,27 @@ class DateSortTestCase(TestCase):
sort_gen = date_sort(events, source_ids)
l = list(sort_gen)
assert l == expected
def test_single_source(self):
source_ids = ['a']
# 100 events, increasing by a minute at a time.
type = zp.DATASOURCE_TYPE.TRADE
dates = list(date_gen(count = 100))
dates.append("DONE")
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
# Turn event_args into proper events.
events = [mock_data_unframe(*args) for args in event_args]
# We don't expected Feed to yield the last event.
expected = events[:-1]
event_gen = (e for e in events)
self.run_date_sort(event_gen, expected, source_ids)
def test_multi_source(self):
source_ids = ['a', 'b']
type = zp.DATASOURCE_TYPE.TRADE
@@ -130,8 +126,8 @@ class DateSortTestCase(TestCase):
dates_a.append("DONE")
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
events_a = [mock_data_unframe(*args) for args in events_a_args]
events_a = [mock_data_unframe(*args) for args in events_a_args]
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
delta_b = timedelta(minutes = 1)
dates_b = list(date_gen(delta = delta_b, count = 10))
@@ -139,7 +135,7 @@ class DateSortTestCase(TestCase):
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
events_b = [mock_data_unframe(*args) for args in events_b_args]
# The expected output is all non-DONE events in both a and b,
# sorted first by dt and then by source_id.
non_dones = events_a[:-1] + events_b[:-1]
@@ -155,7 +151,7 @@ class DateSortTestCase(TestCase):
self.run_date_sort(sequential, expected, source_ids)
def test_sorted_sources(self):
filter = [1,2]
#Set up source a. One hour between events.
args_a = tuple()
@@ -164,7 +160,7 @@ class DateSortTestCase(TestCase):
'delta' : timedelta(hours = 1),
'filter' : filter
}
#Set up source b. One day between events.
#Set up source b. One day between events.
args_b = tuple()
kwargs_b = {'sids' : [1,2,3,4],
'start' : datetime(2012,6,6,0),
@@ -186,19 +182,19 @@ class DateSortTestCase(TestCase):
'delta' : timedelta(minutes = 1),
'filter' : filter
}
sources = (SpecificEquityTrades,) * 4
source_args = (args_a, args_b, args_c, args_d)
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
# Generate our expected source_ids.
zip_args = zip(source_args, source_kwargs)
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
for args, kwargs in zip_args]
# Pipe our sources into sort.
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
# Read all the values from sort and assert that they arrive in
# the correct sorting with the expected hash values.
to_list = list(sort_out)
@@ -211,7 +207,7 @@ class DateSortTestCase(TestCase):
expected = sorted(copy, compare_by_dt_source_id)
assert to_list == expected
def mock_data_unframe(source_id, dt, type):
event = ndict()
event.source_id = source_id
@@ -227,11 +223,11 @@ def compare_by_dt_source_id(x,y):
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
-114
View File
@@ -1,114 +0,0 @@
import os
import uuid
import msgpack
import pytz
from unittest2 import TestCase
from pymongo import Connection, ASCENDING
from itertools import izip, izip_longest
from datetime import datetime, timedelta
from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
assert_trade_protocol, mock_raw_event
import zipline.protocol as zp
mongo_conn_args = {
'mongodb_host': 'localhost',
'mongodb_port': 27017,
}
class TempMongo(object):
def __enter__(self):
self.conn = Connection(mongo_conn_args['mongodb_host'],
mongo_conn_args['mongodb_port'])
temp_id = 'qexec_test_id'
self.db = self.conn[temp_id]
return self
def __exit__(self, type, value, traceback):
self.conn.drop_database(self.db.name)
class TestMongoDataGenerator(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_create_pymongo_iterator(self):
with TempMongo() as temp_mongo:
db = temp_mongo.db
coll = db.test
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
for i in xrange(100):
# sid = 1, dt ranging from 0 to 99
coll.insert(mock_raw_event(1, i))
start_date = 20
end_date = 50
filter = {'sid' : [1]}
args = (coll, filter, start_date, end_date)
cursor = create_pymongo_iterator(*args)
# We filter to only get dt's between 20 and 50
expected = (mock_raw_event(1, i) for i in xrange(20, 51))
# Assert that our iterator returns the expected values.
for cursor_event, expected_event in izip_longest(cursor, expected):
del cursor_event['_id']
# Easiest way to convert unicode to strings.
cursor_event = msgpack.loads(msgpack.dumps(cursor_event))
assert expected_event.keys() == cursor_event.keys()
assert expected_event.values() == cursor_event.values()
def test_MongoTradeHistoryGen(self):
with TempMongo() as temp_mongo:
db = temp_mongo.db
coll = db.test
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
start_date = datetime(year = 2012,month=6,day=5,hour=0)
delta = timedelta(hours = 1)
for i in xrange(100):
# sid = 1, dt's increasing an hour at a time from start
time = start_date + i * delta
coll.insert(mock_raw_event(1, time))
# Halfway through the events we added to db.
end_date = start_date + delta * 50
filter = {'sid' : [1]}
args = (coll, filter, start_date, end_date)
db_gen = MongoTradeHistoryGen(*args)
expected_times = (start_date + i * delta for i in xrange(51))
expected_events = (mock_raw_event(1, t) for t in expected_times)
# DB events should match the expected events for price, dt, volume,
# and sid. They should also conform to the trade frame protocol.
for db, expected in izip_longest(db_gen, expected_events):
expected['dt'] = expected['dt'].replace(tzinfo = pytz.utc)
# Check that our output meets the trade protocol.
assert_trade_protocol(db)
# Check that our output matches expectations
for field in iter(['sid', 'dt', 'price', 'volume']):
assert db[field] == expected[field]
# Expected output of hash_args:
assert db['source_id'] == \
'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc'
-96
View File
@@ -1,96 +0,0 @@
"""
Generator-style DataSource that loads from MongoDB.
"""
import pytz
import logbook
import pymongo
from pymongo import ASCENDING
from datetime import datetime
from zipline import ndict
from zipline.gens.utils import hash_args, \
assert_trade_protocol
import zipline.protocol as zp
def MongoTradeHistoryGen(collection, filter, start_date, end_date):
"""A generator that takes a pymongo Collection object, a list of
filters, a start date and an end_date and yields ndicts containing
the results of a query to its collection with the given filter,
start, and end. The output is also packaged with a unique
source_id string for downstream sorting
"""
assert isinstance(collection, pymongo.collection.Collection)
assert isinstance(filter, dict)
assert isinstance(start_date, (datetime))
assert isinstance(end_date, (datetime))
# Set up internal iterator. This outputs raw dictionaries.
iterator = create_pymongo_iterator(collection, filter, start_date, end_date)
# Create unique identifier string that can be used to break
# sorting ties deterministically
argstring = hash_args(collection, filter, start_date, end_date)
source_id = "MongoTradeHistoryGen" + argstring
# All datasources
for event in iterator:
# Construct a new event that fulfills the datasource protocol.
event['type'] = zp.DATASOURCE_TYPE.TRADE
event['dt'] = event['dt'].replace(tzinfo=pytz.utc)
event['source_id'] = source_id
payload = ndict(event)
assert_trade_protocol(payload)
yield payload
def create_pymongo_iterator(collection, filter, start_date, end_date):
"""
Returns an iterator that spits out raw objects loaded from a
MongoDB collection.
See the comments on :py:class:`zipline.messaging.DataSource`
for expected content of filter.
"""
log = logbook.Logger("MongoDBQuery")
# Object that will hold our database query.
spec = {}
# add the filters from the algorithm.
for name, value in filter.iteritems():
# Add the list of sids that we care about.
if name == 'sid':
assert isinstance(value, list)
sid_range = {'sid':{'$in':value}}
spec.update(sid_range)
# limit the data to the date range [start, end], inclusive
date_range = {'dt':{'$gte': start_date, '$lte': end_date}}
spec.update(date_range)
fields = ['sid','price','volume','dt']
# In our collection, load all objects matching spec. Of those
# objects, get only the fields matching fields, and return the
# loaded objects sorted by dt from least to greatest.
cursor = collection.find(
fields = fields,
spec = spec,
sort = [("dt",ASCENDING)],
slave_ok = True
)
# Optimize the cursor sort to query in 'dt' and 'sid' order.
cursor = cursor.hint([('dt', ASCENDING),('sid', ASCENDING)])
# Set up the iterator
iterator = iter(cursor)
log.info("MongoDataSource iterator ready")
return iterator