mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 20:00:28 +08:00
moved mongods out
This commit is contained in:
+37
-41
@@ -1,30 +1,26 @@
|
||||
import os
|
||||
|
||||
import uuid
|
||||
import msgpack
|
||||
import pytz
|
||||
|
||||
from unittest2 import TestCase
|
||||
from pymongo import Connection, ASCENDING
|
||||
from itertools import izip, izip_longest, permutations, cycle, chain
|
||||
from itertools import cycle, chain
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.sort import date_sort, ready, done, queue_is_ready,queue_is_done,\
|
||||
pop_oldest
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
|
||||
assert_trade_protocol, alternate
|
||||
from zipline.gens.sort import \
|
||||
date_sort, \
|
||||
ready, \
|
||||
done, \
|
||||
queue_is_ready,\
|
||||
queue_is_done
|
||||
from zipline.gens.utils import hash_args, alternate
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
class HelperTestCase(TestCase):
|
||||
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@@ -33,12 +29,12 @@ class HelperTestCase(TestCase):
|
||||
# Empty queues are neither done nor ready.
|
||||
assert not queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
queue.append(to_dt('foo'))
|
||||
assert queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
|
||||
queue.appendleft(to_dt('DONE'))
|
||||
assert queue_is_ready(queue)
|
||||
|
||||
@@ -48,13 +44,13 @@ class HelperTestCase(TestCase):
|
||||
queue.pop()
|
||||
assert queue_is_ready(queue)
|
||||
assert queue_is_done(queue)
|
||||
|
||||
|
||||
def test_pop_logic(self):
|
||||
sources = {}
|
||||
ids = ['a', 'b', 'c']
|
||||
for id in ids:
|
||||
sources[id] = deque()
|
||||
|
||||
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
@@ -66,13 +62,13 @@ class HelperTestCase(TestCase):
|
||||
|
||||
for id in ids:
|
||||
sources[id].append(to_dt("datetime"))
|
||||
|
||||
|
||||
assert ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].appendleft(to_dt("DONE"))
|
||||
|
||||
|
||||
# ["DONE", message] will trip an assert in queue_is_done.
|
||||
assert ready(sources)
|
||||
self.assertRaises(AssertionError, done, sources)
|
||||
@@ -82,12 +78,12 @@ class HelperTestCase(TestCase):
|
||||
|
||||
assert ready(sources)
|
||||
assert done(sources)
|
||||
|
||||
|
||||
class DateSortTestCase(TestCase):
|
||||
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@@ -99,27 +95,27 @@ class DateSortTestCase(TestCase):
|
||||
sort_gen = date_sort(events, source_ids)
|
||||
l = list(sort_gen)
|
||||
assert l == expected
|
||||
|
||||
|
||||
def test_single_source(self):
|
||||
source_ids = ['a']
|
||||
# 100 events, increasing by a minute at a time.
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
dates = list(date_gen(count = 100))
|
||||
dates.append("DONE")
|
||||
|
||||
|
||||
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
|
||||
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
|
||||
|
||||
|
||||
# Turn event_args into proper events.
|
||||
events = [mock_data_unframe(*args) for args in event_args]
|
||||
|
||||
|
||||
# We don't expected Feed to yield the last event.
|
||||
expected = events[:-1]
|
||||
|
||||
event_gen = (e for e in events)
|
||||
|
||||
|
||||
self.run_date_sort(event_gen, expected, source_ids)
|
||||
|
||||
|
||||
def test_multi_source(self):
|
||||
source_ids = ['a', 'b']
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
@@ -130,8 +126,8 @@ class DateSortTestCase(TestCase):
|
||||
dates_a.append("DONE")
|
||||
|
||||
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
|
||||
events_a = [mock_data_unframe(*args) for args in events_a_args]
|
||||
|
||||
events_a = [mock_data_unframe(*args) for args in events_a_args]
|
||||
|
||||
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
|
||||
delta_b = timedelta(minutes = 1)
|
||||
dates_b = list(date_gen(delta = delta_b, count = 10))
|
||||
@@ -139,7 +135,7 @@ class DateSortTestCase(TestCase):
|
||||
|
||||
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
|
||||
events_b = [mock_data_unframe(*args) for args in events_b_args]
|
||||
|
||||
|
||||
# The expected output is all non-DONE events in both a and b,
|
||||
# sorted first by dt and then by source_id.
|
||||
non_dones = events_a[:-1] + events_b[:-1]
|
||||
@@ -155,7 +151,7 @@ class DateSortTestCase(TestCase):
|
||||
self.run_date_sort(sequential, expected, source_ids)
|
||||
|
||||
def test_sorted_sources(self):
|
||||
|
||||
|
||||
filter = [1,2]
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
@@ -164,7 +160,7 @@ class DateSortTestCase(TestCase):
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b. One day between events.
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
@@ -186,19 +182,19 @@ class DateSortTestCase(TestCase):
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
|
||||
|
||||
sources = (SpecificEquityTrades,) * 4
|
||||
source_args = (args_a, args_b, args_c, args_d)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
|
||||
|
||||
|
||||
# Generate our expected source_ids.
|
||||
zip_args = zip(source_args, source_kwargs)
|
||||
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
|
||||
for args, kwargs in zip_args]
|
||||
|
||||
|
||||
# Pipe our sources into sort.
|
||||
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
|
||||
|
||||
|
||||
# Read all the values from sort and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(sort_out)
|
||||
@@ -211,7 +207,7 @@ class DateSortTestCase(TestCase):
|
||||
|
||||
expected = sorted(copy, compare_by_dt_source_id)
|
||||
assert to_list == expected
|
||||
|
||||
|
||||
def mock_data_unframe(source_id, dt, type):
|
||||
event = ndict()
|
||||
event.source_id = source_id
|
||||
@@ -227,11 +223,11 @@ def compare_by_dt_source_id(x,y):
|
||||
return -1
|
||||
elif x.dt > y.dt:
|
||||
return 1
|
||||
|
||||
|
||||
elif x.source_id < y.source_id:
|
||||
return -1
|
||||
elif x.source_id > y.source_id:
|
||||
return 1
|
||||
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import os
|
||||
|
||||
import uuid
|
||||
import msgpack
|
||||
import pytz
|
||||
|
||||
from unittest2 import TestCase
|
||||
from pymongo import Connection, ASCENDING
|
||||
from itertools import izip, izip_longest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen
|
||||
from zipline.gens.utils import hash_args, assert_datasource_protocol,\
|
||||
assert_trade_protocol, mock_raw_event
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
mongo_conn_args = {
|
||||
'mongodb_host': 'localhost',
|
||||
'mongodb_port': 27017,
|
||||
}
|
||||
|
||||
class TempMongo(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.conn = Connection(mongo_conn_args['mongodb_host'],
|
||||
mongo_conn_args['mongodb_port'])
|
||||
|
||||
temp_id = 'qexec_test_id'
|
||||
|
||||
self.db = self.conn[temp_id]
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.conn.drop_database(self.db.name)
|
||||
|
||||
class TestMongoDataGenerator(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_create_pymongo_iterator(self):
|
||||
|
||||
with TempMongo() as temp_mongo:
|
||||
db = temp_mongo.db
|
||||
coll = db.test
|
||||
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
|
||||
|
||||
for i in xrange(100):
|
||||
# sid = 1, dt ranging from 0 to 99
|
||||
coll.insert(mock_raw_event(1, i))
|
||||
|
||||
start_date = 20
|
||||
end_date = 50
|
||||
filter = {'sid' : [1]}
|
||||
args = (coll, filter, start_date, end_date)
|
||||
|
||||
cursor = create_pymongo_iterator(*args)
|
||||
# We filter to only get dt's between 20 and 50
|
||||
expected = (mock_raw_event(1, i) for i in xrange(20, 51))
|
||||
|
||||
# Assert that our iterator returns the expected values.
|
||||
for cursor_event, expected_event in izip_longest(cursor, expected):
|
||||
del cursor_event['_id']
|
||||
# Easiest way to convert unicode to strings.
|
||||
cursor_event = msgpack.loads(msgpack.dumps(cursor_event))
|
||||
assert expected_event.keys() == cursor_event.keys()
|
||||
assert expected_event.values() == cursor_event.values()
|
||||
|
||||
def test_MongoTradeHistoryGen(self):
|
||||
|
||||
with TempMongo() as temp_mongo:
|
||||
db = temp_mongo.db
|
||||
coll = db.test
|
||||
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
|
||||
|
||||
start_date = datetime(year = 2012,month=6,day=5,hour=0)
|
||||
delta = timedelta(hours = 1)
|
||||
|
||||
for i in xrange(100):
|
||||
# sid = 1, dt's increasing an hour at a time from start
|
||||
time = start_date + i * delta
|
||||
coll.insert(mock_raw_event(1, time))
|
||||
|
||||
# Halfway through the events we added to db.
|
||||
end_date = start_date + delta * 50
|
||||
|
||||
filter = {'sid' : [1]}
|
||||
args = (coll, filter, start_date, end_date)
|
||||
db_gen = MongoTradeHistoryGen(*args)
|
||||
|
||||
expected_times = (start_date + i * delta for i in xrange(51))
|
||||
expected_events = (mock_raw_event(1, t) for t in expected_times)
|
||||
|
||||
# DB events should match the expected events for price, dt, volume,
|
||||
# and sid. They should also conform to the trade frame protocol.
|
||||
|
||||
for db, expected in izip_longest(db_gen, expected_events):
|
||||
expected['dt'] = expected['dt'].replace(tzinfo = pytz.utc)
|
||||
# Check that our output meets the trade protocol.
|
||||
assert_trade_protocol(db)
|
||||
|
||||
# Check that our output matches expectations
|
||||
for field in iter(['sid', 'dt', 'price', 'volume']):
|
||||
assert db[field] == expected[field]
|
||||
|
||||
# Expected output of hash_args:
|
||||
assert db['source_id'] == \
|
||||
'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc'
|
||||
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
Generator-style DataSource that loads from MongoDB.
|
||||
"""
|
||||
|
||||
import pytz
|
||||
import logbook
|
||||
import pymongo
|
||||
|
||||
from pymongo import ASCENDING
|
||||
from datetime import datetime
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, \
|
||||
assert_trade_protocol
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def MongoTradeHistoryGen(collection, filter, start_date, end_date):
|
||||
"""A generator that takes a pymongo Collection object, a list of
|
||||
filters, a start date and an end_date and yields ndicts containing
|
||||
the results of a query to its collection with the given filter,
|
||||
start, and end. The output is also packaged with a unique
|
||||
source_id string for downstream sorting
|
||||
"""
|
||||
|
||||
assert isinstance(collection, pymongo.collection.Collection)
|
||||
assert isinstance(filter, dict)
|
||||
assert isinstance(start_date, (datetime))
|
||||
assert isinstance(end_date, (datetime))
|
||||
|
||||
# Set up internal iterator. This outputs raw dictionaries.
|
||||
iterator = create_pymongo_iterator(collection, filter, start_date, end_date)
|
||||
|
||||
# Create unique identifier string that can be used to break
|
||||
# sorting ties deterministically
|
||||
argstring = hash_args(collection, filter, start_date, end_date)
|
||||
source_id = "MongoTradeHistoryGen" + argstring
|
||||
|
||||
# All datasources
|
||||
for event in iterator:
|
||||
# Construct a new event that fulfills the datasource protocol.
|
||||
event['type'] = zp.DATASOURCE_TYPE.TRADE
|
||||
event['dt'] = event['dt'].replace(tzinfo=pytz.utc)
|
||||
event['source_id'] = source_id
|
||||
|
||||
payload = ndict(event)
|
||||
assert_trade_protocol(payload)
|
||||
yield payload
|
||||
|
||||
def create_pymongo_iterator(collection, filter, start_date, end_date):
|
||||
"""
|
||||
Returns an iterator that spits out raw objects loaded from a
|
||||
MongoDB collection.
|
||||
|
||||
See the comments on :py:class:`zipline.messaging.DataSource`
|
||||
for expected content of filter.
|
||||
"""
|
||||
log = logbook.Logger("MongoDBQuery")
|
||||
|
||||
# Object that will hold our database query.
|
||||
spec = {}
|
||||
|
||||
# add the filters from the algorithm.
|
||||
for name, value in filter.iteritems():
|
||||
|
||||
# Add the list of sids that we care about.
|
||||
if name == 'sid':
|
||||
assert isinstance(value, list)
|
||||
sid_range = {'sid':{'$in':value}}
|
||||
spec.update(sid_range)
|
||||
|
||||
# limit the data to the date range [start, end], inclusive
|
||||
date_range = {'dt':{'$gte': start_date, '$lte': end_date}}
|
||||
spec.update(date_range)
|
||||
|
||||
fields = ['sid','price','volume','dt']
|
||||
|
||||
# In our collection, load all objects matching spec. Of those
|
||||
# objects, get only the fields matching fields, and return the
|
||||
# loaded objects sorted by dt from least to greatest.
|
||||
|
||||
cursor = collection.find(
|
||||
fields = fields,
|
||||
spec = spec,
|
||||
sort = [("dt",ASCENDING)],
|
||||
slave_ok = True
|
||||
)
|
||||
|
||||
# Optimize the cursor sort to query in 'dt' and 'sid' order.
|
||||
cursor = cursor.hint([('dt', ASCENDING),('sid', ASCENDING)])
|
||||
|
||||
# Set up the iterator
|
||||
iterator = iter(cursor)
|
||||
log.info("MongoDataSource iterator ready")
|
||||
|
||||
return iterator
|
||||
Reference in New Issue
Block a user