generator ds and feed, various generator utils

This commit is contained in:
scottsanderson
2012-07-26 18:48:51 -04:00
parent 96d32fbc86
commit 97e4ba050a
8 changed files with 560 additions and 56 deletions
View File
+130
View File
@@ -0,0 +1,130 @@
"""
Generator version of Feed.
"""
import pytz
import logbook
import pymongo
import types
from pymongo import ASCENDING
from datetime import datetime, timedelta
from collections import deque, defaultdict
from zipline import ndict
from zipline.gens.utils import stringify_args, assert_datasource_protocol, \
assert_trade_protocol, assert_datasource_unframe_protocol
import zipline.protocol as zp
def FeedGen(stream_in, source_ids):
"""
A generator that takes a generator and a list of source_ids. We
maintain an internal queue for each id in source_ids. While we
have messages pending from all sources, we pull the earliest
message and yield it.
"""
assert isinstance(stream_in, types.GeneratorType)
assert isinstance(source_ids, list)
# Set up an internal queue for each expected source.
sources = {}
for id in source_ids:
assert isinstance(id, basestring), "Bad source_id %s" % source_id
sources[id] = deque()
namestring = "FeedGen" + stringify_args(source_ids)
# Process incoming streams.
for message in stream_in:
# Incoming messages should be the output of DATASOURCE_UNFRAME.
assert_datasource_unframe_protocol(message), \
"Bad message in FeedGen: %s" % message
# Only allow messages from sources we expect.
assert message.source_id in sources, "Unexpected source: %s" % message
sources[message.source_id].append(message)
# Only pop messages when we have a pending message from
# all datasources. Stop if all sources have signalled done.
while full(sources) and not done(sources):
message = pop_oldest(sources)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue
assert queue[0].dt == "DONE", \
"Bad last message in FeedGen on exit: %s" % queue
def full(sources):
"""
Feed is full when every internal queue has at least one message. Note that
this include DONE messages, so done(sources) is True only if full(sources).
"""
assert isinstance(sources, dict)
return all( (queue_is_full(source) for source in sources.itervalues()) )
def queue_is_full(queue):
assert isinstance(queue, deque)
return len(queue) > 0
def done(sources):
"""Feed is done when all internal queues have only a "DONE" message."""
assert isinstance(sources, dict)
return all( (queue_is_done(source) for source in sources.itervalues()) )
def queue_is_done(queue):
assert isinstance(queue, deque)
if len(queue) == 0:
return False
if queue[0].dt == "DONE":
assert len(queue) == 1, "Message after DONE in FeedGen: %s" % queue
return True
else:
return False
def pop_oldest(sources):
oldest_event = None
# Iterate over the dict, checking internal queues for the oldest
# pending event.
for queue in sources.itervalues():
current_event = queue[0]
# Skip queues that are done.
if current_event.dt == "DONE":
continue
# Any event is older than nothing.
elif oldest_event == None:
oldest_event = current_event
# Keep the older event. Break ties by source_id. This will
# trip an assert if we have duplicate sources.
else:
oldest_event = older(oldest_event, current_event)
# Pop the oldest event we found from its queue and return it.
return sources[oldest_event.source_id].pop()
# Return the event with the older timestamp. Break ties by source_id.
def older(oldest, current):
assert isinstance(oldest, ndict)
assert isinstance(oldest, ndict)
# Try to compare by dt.
if oldest.dt < current.dt:
return oldest
elif oldest.dt > current.dt:
return current
# Break ties by source_id.
elif oldest.source_id < current.source_id:
return oldest
elif oldest.source_id > current.source_id:
return current
else:
assert False, "Duplicate event"
+39 -37
View File
@@ -10,15 +10,50 @@ from pymongo import ASCENDING
from datetime import datetime, timedelta
from zipline import ndict
from zipline.gens.utils import stringify_args, assert_datasource_protocol
from zipline.gens.utils import stringify_args, assert_datasource_protocol, \
assert_trade_protocol
import zipline.protocol as zp
def MongoTradeHistoryGen(collection, filter, start_date, end_date):
"""A generator that takes a pymongo Collection object, a list of
filters, a start date and an end_date and yields ndicts containing
the results of a query to its collection with the given filter,
start, and end. The output is also packaged with a unique
source_id string for downstream sorting
"""
assert isinstance(collection, pymongo.collection.Collection)
assert isinstance(filter, dict)
assert isinstance(start_date, (datetime))
assert isinstance(end_date, (datetime))
# Set up internal iterator. This outputs raw dictionaries.
iterator = create_pymongo_iterator(collection, filter, start_date, end_date)
# Create unique identifier string that can be used to break
# sorting ties deterministically
argstring = stringify_args(collection, filter, start_date, end_date)
source_id = "MongoTradeHistoryGen" + argstring
# All datasources
for event in iterator:
# Construct a new event that fulfills the datasource protocol.
event['type'] = zp.DATASOURCE_TYPE.TRADE
event['dt'] = event['dt'].replace(tzinfo=pytz.utc)
event['source_id'] = source_id
payload = ndict(event)
assert_trade_protocol(payload)
yield payload
def create_pymongo_iterator(collection, filter, start_date, end_date):
"""
See the comments on :py:class:`zipline.messaging.DataSource` for
expected content of filter. Spec must adhere to that definition.
Returns an iterator that spits out raw objects loaded from MongoDB.
Returns an iterator that spits out raw objects loaded from a
MongoDB collection.
See the comments on :py:class:`zipline.messaging.DataSource`
for expected content of filter.
"""
log = logbook.Logger("MongoDBQuery")
@@ -59,40 +94,7 @@ def create_pymongo_iterator(collection, filter, start_date, end_date):
log.info("MongoDataSource iterator ready")
return iterator
def MongoTradeHistoryGen(collection, filter, start_date, end_date):
"""A generator that takes a pymongo Collection object, a list of
filters, a start date and an end_date and yields ndicts containing
the results of a query to its collection with the given filter,
start, and end. The output is also packaged with a unique
source_id string for downstream sorting
"""
assert isinstance(collection, pymongo.collection.Collection)
assert isinstance(filter, dict)
assert isinstance(start_date, (datetime))
assert isinstance(end_date, (datetime))
# Set up internal iterator. This outputs raw dictionaries.
iterator = create_pymongo_iterator(collection, filter, start_date, end_date)
# Create unique identifier string that can be used to break
# sorting ties deterministically
argstring = stringify_args(collection, filter, start_date, end_date)
source_id = "MongoTradeHistoryGen" + argstring
# All datasources
for event in iterator:
# Construct a new event that fulfills the datasource protocol.
event['type'] = zp.DATASOURCE_TYPE.TRADE
event['dt'] = event['dt'].replace(tzinfo=pytz.utc)
event['source_id'] = source_id
payload = ndict(event)
assert_datasource_protocol(payload)
yield payload
+32 -10
View File
@@ -1,14 +1,36 @@
class gen_wrapper(object):
# Inside Client
def __init__(self, addresses, ...):
self.pull_socket = ...
def __init__(self, val):
self.val = val
self.iterator = iter(xrange(self.val))
def reset_iter(self):
self.val
self.control_socket ...
def __iter__(self):
return self.iterator
def run(self):
for message in gen_from_pull(self.pull_socket):
#Do things with messages.
heartbeat()
signal_done()
sys.exit(0)
def next():
return self.iterator.next()
# Inside Merge
def __init__(self, addresses, source_ids ...):
self.poller = ... # Poller on multiple xforms, single socket.
self.processor = ... # Generator that
self.push_socket = ... # Outbound socket
def run(self):
incoming = gen_from_poll(self.poller)# Receives messages from all xforms.
processed = self.processor(incoming, source_ids) # Maintains internal queues and merges.
for message in self.processed:
heartbeat()
self.push_socket.send(message)
# Inside
+190
View File
@@ -0,0 +1,190 @@
import os
import uuid
import msgpack
import pytz
from unittest2 import TestCase
from pymongo import Connection, ASCENDING
from itertools import izip, izip_longest, permutations, cycle
from datetime import datetime, timedelta
from collections import deque
from zipline import ndict
from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\
pop_oldest
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
assert_trade_protocol, date_gen
import zipline.protocol as zp
class FeedHelperTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_individual_queue_logic(self):
queue = deque()
# Empty queues are neither done nor full.
assert not queue_is_full(queue)
assert not queue_is_done(queue)
queue.append(to_dt('foo'))
assert queue_is_full(queue)
assert not queue_is_done(queue)
queue.appendleft(to_dt('DONE'))
assert queue_is_full(queue)
# Checking done when we have a message after done will trip an assert.
self.assertRaises(AssertionError, queue_is_done, queue)
queue.pop()
assert queue_is_full(queue)
assert queue_is_done(queue)
def test_sources_logic(self):
sources = {}
ids = ['a', 'b', 'c']
for id in ids:
sources[id] = deque()
assert not full(sources)
assert not done(sources)
# All sources must have a message to be full/done
sources['a'].append(to_dt("datetime"))
assert not full(sources)
assert not done(sources)
sources['a'].pop()
for id in ids:
sources[id].append(to_dt("datetime"))
assert full(sources)
assert not done(sources)
for id in ids:
sources[id].appendleft(to_dt("DONE"))
# ["DONE", message] will trip an assert in queue_is_done.
assert full(sources)
self.assertRaises(AssertionError, done, sources)
for id in ids:
sources[id].pop()
assert full(sources)
assert done(sources)
class FeedGenTestCase(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def run_FeedGen(self, events, expected, source_ids):
"""
Take a list of events, their source_ids, and an expected sorting.
Assert that FeedGen's output agrees with expected.
"""
feed_gen = FeedGen(events, source_ids)
assert list(feed_gen) == expected
def test_single_source(self):
source_ids = ['a']
# 100 events, increasing by a minute at a time.
type = zp.DATASOURCE_TYPE.TRADE
dates = list(date_gen(n = 1))
dates.append("DONE")
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
# Turn event_args into proper events.
events = [mock_data_unframe(*args) for args in event_args]
# We don't expected Feed to yield the last event.
expected = events[:-1]
event_gen = (e for e in events)
self.run_FeedGen(event_gen, expected, source_ids)
def test_multi_source_interleaved(self):
source_ids = ['a', 'b']
type = zp.DATASOURCE_TYPE.TRADE
# Set up source 'a'. Outputs 3 events with 2 minute deltas.
delta_a = timedelta(minutes = 2)
dates_a = list(date_gen(delta = delta_a, n = 3))
dates_a.append("DONE")
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
events_a = [mock_data_unframe(*args) for args in events_a_args]
event_gen_a = (e for e in events_a)
# Set up source 'b'. Outputs 4 events with 1 minute deltas.
delta_b = timedelta(minutes = 1)
dates_b = list(date_gen(delta = delta_b, n = 4))
dates_b.append("DONE")
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
events_b = [mock_data_unframe(*args) for args in events_b_args]
event_gen_b = (e for e in events_b)
# The expected output is all non-DONE events in both a and b,
# sorted first by dt and then by source_id.
non_dones = events_a[:-1] + events_b[:-1]
expected = sorted(non_dones, compare_by_dt_source_id)
import nose.tools; nose.tools.set_trace()
self.run_FeedGen(event_gen, expected, source_ids)
# def test_FeedGen_consistency(self):
# source_ids = ['a', 'b']
# multiplied = source_ids * 5
# perms = itertools.permutations(multiplied, 10)
# self.type = zp.DATASOURCE_TYPE.TRADE
# self.events = (mock_data_unframe(id,
def mock_data_unframe(source_id, dt, type):
event = ndict()
event.source_id = source_id
event.dt = dt
event.type = type
return event
def to_dt(val):
return ndict({'dt': val})
def compare_by_dt_source_id(x,y):
if x.dt < y.dt:
return -1
elif x.dt > y.dt:
return 1
elif x.source_id < y.source_id:
return -1
elif x.source_id > y.source_id:
return 1
else:
return 0
+123
View File
@@ -0,0 +1,123 @@
import os
import uuid
import msgpack
import pytz
from unittest2 import TestCase
from pymongo import Connection, ASCENDING
from itertools import izip, izip_longest
from datetime import datetime, timedelta
from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen
from zipline.gens.utils import stringify_args, assert_datasource_protocol,\
assert_trade_protocol, mock_raw_event
import zipline.protocol as zp
def mock_raw_event(sid, dt):
event = {
'sid' : sid,
'dt' : dt,
'price' : 1.0,
'volume' : 1
}
return event
mongo_conn_args = {
'mongodb_host': 'localhost',
'mongodb_port': 27017,
}
class TempMongo(object):
def __enter__(self):
self.conn = Connection(mongo_conn_args['mongodb_host'],
mongo_conn_args['mongodb_port'])
temp_id = 'qexec_test_id'
self.db = self.conn[temp_id]
return self
def __exit__(self, type, value, traceback):
self.conn.drop_database(self.db.name)
class TestMongoDataGenerator(TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_create_pymongo_iterator(self):
with TempMongo() as temp_mongo:
db = temp_mongo.db
coll = db.test
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
for i in xrange(100):
# sid = 1, dt ranging from 0 to 99
coll.insert(mock_raw_event(1, i))
start_date = 20
end_date = 50
filter = {'sid' : [1]}
args = (coll, filter, start_date, end_date)
cursor = create_pymongo_iterator(*args)
# We filter to only get dt's between 20 and 50
expected = (mock_raw_event(1, i) for i in xrange(20, 51))
# Assert that our iterator returns the expected values.
for cursor_event, expected_event in izip_longest(cursor, expected):
del cursor_event['_id']
# Easiest way to convert unicode to strings.
cursor_event = msgpack.loads(msgpack.dumps(cursor_event))
assert expected_event.keys() == cursor_event.keys()
assert expected_event.values() == cursor_event.values()
def test_MongoTradeHistoryGen(self):
with TempMongo() as temp_mongo:
db = temp_mongo.db
coll = db.test
coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)])
start_date = datetime(year = 2012,month=6,day=5,hour=0)
delta = timedelta(hours = 1)
for i in xrange(100):
# sid = 1, dt's increasing an hour at a time from start
time = start_date + i * delta
coll.insert(mock_raw_event(1, time))
# Halfway through the events we added to db.
end_date = start_date + delta * 50
filter = {'sid' : [1]}
args = (coll, filter, start_date, end_date)
db_gen = MongoTradeHistoryGen(*args)
expected_times = (start_date + i * delta for i in xrange(51))
expected_events = (mock_raw_event(1, t) for t in expected_times)
# DB events should match the expected events for price, dt, volume,
# and sid. They should also conform to the trade frame protocol.
for db, expected in izip_longest(db_gen, expected_events):
expected['dt'] = expected['dt'].replace(tzinfo = pytz.utc)
# Check that our output meets the trade protocol.
assert_trade_protocol(db)
# Check that our output matches expectations
for field in iter(['sid', 'dt', 'price', 'volume']):
assert db[field] == expected[field]
# Expected output of stringify_args:
assert db['source_id'] == \
'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc'
+29 -4
View File
@@ -2,11 +2,22 @@ import pytz
import numbers
from hashlib import md5
from datetime import datetime
from datetime import datetime, timedelta
from zipline import ndict
from zipline.protocol import DATASOURCE_TYPE
def mock_raw_event(sid, dt):
event = {
'sid' : sid,
'dt' : dt,
'price' : 1.0,
'volume' : 1
}
return event
def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100):
return (start + i * delta for i in xrange(n))
def stringify_args(*args, **kwargs):
"""Define a unique string for any set of representable args."""
@@ -18,16 +29,19 @@ def stringify_args(*args, **kwargs):
hasher.update(combined)
return hasher.hexdigest()
def assert_datasource_protocol(event):
"""Assert that an event meets the protocol for datasource outputs."""
assert isinstance(event, ndict)
assert isinstance(event.source_id, basestring)
assert isinstance(event.dt, datetime)
assert event.dt.tzinfo == pytz.utc
assert event.type in DATASOURCE_TYPE
# Done packets have no dt.
if not event.type == DATASOURCE_TYPE.DONE:
assert isinstance(event.dt, datetime)
assert event.dt.tzinfo == pytz.utc
def assert_trade_protocol(event):
"""Assert that an event meets the protocol for datasource TRADE outputs."""
assert_datasource_protocol(event)
@@ -37,4 +51,15 @@ def assert_trade_protocol(event):
assert isinstance(event.sid, int)
assert isinstance(event.price, numbers.Real)
assert isinstance(event.volume, numbers.Integral)
assert isinstance(event.dt, datetime)
def assert_datasource_unframe_protocol(event):
"""Assert that an event is valid output of zp.DATASOURCE_UNFRAME."""
assert isinstance(event, ndict)
assert isinstance(event.source_id, basestring)
assert event.type in DATASOURCE_TYPE
assert event.has_key('dt')
def assert_feed_protocol(event):
pass
+17 -5
View File
@@ -227,6 +227,7 @@ def DATASOURCE_FRAME(event):
- *ds_type* a string denoting the datasource type. Must be on of:
- TRADE
- DONE
- (others to follow soon)
- *payload* a msgpack string carrying the payload for the frame
@@ -235,19 +236,26 @@ def DATASOURCE_FRAME(event):
assert isinstance(event.type, int), 'Unexpected type %s' % (event.type)
#datasources will send sometimes send empty msgs to feel gaps
if len(event.keys()) == 2:
if (event.type == DATASOURCE_TYPE.EMPTY):
return msgpack.dumps(tuple([
event.type,
event.source_id,
DATASOURCE_TYPE.EMPTY
"EMPTY"
]))
if(event.type == DATASOURCE_TYPE.TRADE):
elif(event.type == DATASOURCE_TYPE.TRADE):
return msgpack.dumps(tuple([
event.type,
event.source_id,
TRADE_FRAME(event)
]))
elif(event.type == DATASOURCE_TYPE.DONE):
return msgpack.dumps(tuple([
event.type,
event.source_id,
"DONE"
]))
else:
raise INVALID_DATASOURCE_FRAME(str(event))
@@ -259,8 +267,9 @@ def DATASOURCE_UNFRAME(msg):
Returns a dict containing at least:
- source_id
- type
- source_id: instance-unique string
- type: datasource type
- dt: None, 'DONE' or a datetime object
other properties are added based on the datasource type:
@@ -282,6 +291,8 @@ def DATASOURCE_UNFRAME(msg):
child_value = ndict({'dt':None})
elif(ds_type == DATASOURCE_TYPE.TRADE):
child_value = TRADE_UNFRAME(payload)
elif(ds_type == DATASOURCE_TYPE.DONE):
child_value = ndict({'dt' : 'DONE'})
else:
raise INVALID_DATASOURCE_FRAME(msg)
@@ -593,6 +604,7 @@ def tuple_to_date(date_tuple):
DATASOURCE_TYPE = Enum(
'TRADE',
'EMPTY',
'DONE'
)