diff --git a/zipline/gens/__init__.py b/zipline/gens/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/zipline/gens/feed.py b/zipline/gens/feed.py new file mode 100644 index 00000000..cdc2f7b7 --- /dev/null +++ b/zipline/gens/feed.py @@ -0,0 +1,130 @@ +""" +Generator version of Feed. +""" + +import pytz +import logbook +import pymongo +import types + +from pymongo import ASCENDING +from datetime import datetime, timedelta +from collections import deque, defaultdict + +from zipline import ndict +from zipline.gens.utils import stringify_args, assert_datasource_protocol, \ + assert_trade_protocol, assert_datasource_unframe_protocol + +import zipline.protocol as zp + +def FeedGen(stream_in, source_ids): + """ + A generator that takes a generator and a list of source_ids. We + maintain an internal queue for each id in source_ids. While we + have messages pending from all sources, we pull the earliest + message and yield it. + """ + + assert isinstance(stream_in, types.GeneratorType) + assert isinstance(source_ids, list) + + # Set up an internal queue for each expected source. + sources = {} + for id in source_ids: + assert isinstance(id, basestring), "Bad source_id %s" % source_id + sources[id] = deque() + + namestring = "FeedGen" + stringify_args(source_ids) + + # Process incoming streams. + + for message in stream_in: + # Incoming messages should be the output of DATASOURCE_UNFRAME. + assert_datasource_unframe_protocol(message), \ + "Bad message in FeedGen: %s" % message + + # Only allow messages from sources we expect. + assert message.source_id in sources, "Unexpected source: %s" % message + + sources[message.source_id].append(message) + + # Only pop messages when we have a pending message from + # all datasources. Stop if all sources have signalled done. + + while full(sources) and not done(sources): + message = pop_oldest(sources) + yield message + + # We should have only a done message left in each queue. + for queue in sources.itervalues(): + assert len(queue) == 1, "Bad queue in FeedGen on exit: %s" % queue + assert queue[0].dt == "DONE", \ + "Bad last message in FeedGen on exit: %s" % queue + +def full(sources): + """ + Feed is full when every internal queue has at least one message. Note that + this include DONE messages, so done(sources) is True only if full(sources). + """ + assert isinstance(sources, dict) + return all( (queue_is_full(source) for source in sources.itervalues()) ) + +def queue_is_full(queue): + assert isinstance(queue, deque) + return len(queue) > 0 + +def done(sources): + """Feed is done when all internal queues have only a "DONE" message.""" + assert isinstance(sources, dict) + return all( (queue_is_done(source) for source in sources.itervalues()) ) + +def queue_is_done(queue): + assert isinstance(queue, deque) + if len(queue) == 0: + return False + if queue[0].dt == "DONE": + assert len(queue) == 1, "Message after DONE in FeedGen: %s" % queue + return True + else: + return False + +def pop_oldest(sources): + + oldest_event = None + + # Iterate over the dict, checking internal queues for the oldest + # pending event. + + for queue in sources.itervalues(): + current_event = queue[0] + # Skip queues that are done. + if current_event.dt == "DONE": + continue + # Any event is older than nothing. + elif oldest_event == None: + oldest_event = current_event + # Keep the older event. Break ties by source_id. This will + # trip an assert if we have duplicate sources. + else: + oldest_event = older(oldest_event, current_event) + + # Pop the oldest event we found from its queue and return it. + return sources[oldest_event.source_id].pop() + +# Return the event with the older timestamp. Break ties by source_id. +def older(oldest, current): + assert isinstance(oldest, ndict) + assert isinstance(oldest, ndict) + + # Try to compare by dt. + if oldest.dt < current.dt: + return oldest + elif oldest.dt > current.dt: + return current + # Break ties by source_id. + elif oldest.source_id < current.source_id: + return oldest + elif oldest.source_id > current.source_id: + return current + else: + assert False, "Duplicate event" diff --git a/zipline/gens/mongods.py b/zipline/gens/mongods.py index 303fd465..83ca5571 100644 --- a/zipline/gens/mongods.py +++ b/zipline/gens/mongods.py @@ -10,15 +10,50 @@ from pymongo import ASCENDING from datetime import datetime, timedelta from zipline import ndict -from zipline.gens.utils import stringify_args, assert_datasource_protocol +from zipline.gens.utils import stringify_args, assert_datasource_protocol, \ + assert_trade_protocol import zipline.protocol as zp +def MongoTradeHistoryGen(collection, filter, start_date, end_date): + """A generator that takes a pymongo Collection object, a list of + filters, a start date and an end_date and yields ndicts containing + the results of a query to its collection with the given filter, + start, and end. The output is also packaged with a unique + source_id string for downstream sorting + """ + + assert isinstance(collection, pymongo.collection.Collection) + assert isinstance(filter, dict) + assert isinstance(start_date, (datetime)) + assert isinstance(end_date, (datetime)) + + # Set up internal iterator. This outputs raw dictionaries. + iterator = create_pymongo_iterator(collection, filter, start_date, end_date) + + # Create unique identifier string that can be used to break + # sorting ties deterministically + argstring = stringify_args(collection, filter, start_date, end_date) + source_id = "MongoTradeHistoryGen" + argstring + + # All datasources + for event in iterator: + # Construct a new event that fulfills the datasource protocol. + event['type'] = zp.DATASOURCE_TYPE.TRADE + event['dt'] = event['dt'].replace(tzinfo=pytz.utc) + event['source_id'] = source_id + + payload = ndict(event) + assert_trade_protocol(payload) + yield payload + def create_pymongo_iterator(collection, filter, start_date, end_date): """ - See the comments on :py:class:`zipline.messaging.DataSource` for - expected content of filter. Spec must adhere to that definition. - Returns an iterator that spits out raw objects loaded from MongoDB. + Returns an iterator that spits out raw objects loaded from a + MongoDB collection. + + See the comments on :py:class:`zipline.messaging.DataSource` + for expected content of filter. """ log = logbook.Logger("MongoDBQuery") @@ -59,40 +94,7 @@ def create_pymongo_iterator(collection, filter, start_date, end_date): log.info("MongoDataSource iterator ready") return iterator - - -def MongoTradeHistoryGen(collection, filter, start_date, end_date): - """A generator that takes a pymongo Collection object, a list of - filters, a start date and an end_date and yields ndicts containing - the results of a query to its collection with the given filter, - start, and end. The output is also packaged with a unique - source_id string for downstream sorting - """ - - assert isinstance(collection, pymongo.collection.Collection) - assert isinstance(filter, dict) - assert isinstance(start_date, (datetime)) - assert isinstance(end_date, (datetime)) - - # Set up internal iterator. This outputs raw dictionaries. - iterator = create_pymongo_iterator(collection, filter, start_date, end_date) - - # Create unique identifier string that can be used to break - # sorting ties deterministically - argstring = stringify_args(collection, filter, start_date, end_date) - source_id = "MongoTradeHistoryGen" + argstring - - # All datasources - for event in iterator: - # Construct a new event that fulfills the datasource protocol. - event['type'] = zp.DATASOURCE_TYPE.TRADE - event['dt'] = event['dt'].replace(tzinfo=pytz.utc) - event['source_id'] = source_id - payload = ndict(event) - assert_datasource_protocol(payload) - yield payload - diff --git a/zipline/gens/scratch.py b/zipline/gens/scratch.py index 931ddec9..2de87e44 100644 --- a/zipline/gens/scratch.py +++ b/zipline/gens/scratch.py @@ -1,14 +1,36 @@ -class gen_wrapper(object): +# Inside Client +def __init__(self, addresses, ...): + self.pull_socket = ... - def __init__(self, val): - self.val = val - self.iterator = iter(xrange(self.val)) - def reset_iter(self): - self.val + self.control_socket ... - def __iter__(self): - return self.iterator +def run(self): + for message in gen_from_pull(self.pull_socket): + #Do things with messages. + heartbeat() + + signal_done() + sys.exit(0) - def next(): - return self.iterator.next() +# Inside Merge +def __init__(self, addresses, source_ids ...): + self.poller = ... # Poller on multiple xforms, single socket. + + self.processor = ... # Generator that + + self.push_socket = ... # Outbound socket + +def run(self): + + incoming = gen_from_poll(self.poller)# Receives messages from all xforms. + + processed = self.processor(incoming, source_ids) # Maintains internal queues and merges. + + for message in self.processed: + heartbeat() + self.push_socket.send(message) + +# Inside + + diff --git a/zipline/gens/test_feed.py b/zipline/gens/test_feed.py new file mode 100644 index 00000000..1605b6aa --- /dev/null +++ b/zipline/gens/test_feed.py @@ -0,0 +1,190 @@ +import os + +import uuid +import msgpack +import pytz + +from unittest2 import TestCase +from pymongo import Connection, ASCENDING +from itertools import izip, izip_longest, permutations, cycle +from datetime import datetime, timedelta +from collections import deque + +from zipline import ndict +from zipline.gens.feed import FeedGen, full, done, queue_is_full,queue_is_done,\ + pop_oldest +from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ + assert_trade_protocol, date_gen + +import zipline.protocol as zp + +class FeedHelperTestCase(TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_individual_queue_logic(self): + queue = deque() + # Empty queues are neither done nor full. + assert not queue_is_full(queue) + assert not queue_is_done(queue) + + queue.append(to_dt('foo')) + assert queue_is_full(queue) + assert not queue_is_done(queue) + + + queue.appendleft(to_dt('DONE')) + assert queue_is_full(queue) + + # Checking done when we have a message after done will trip an assert. + self.assertRaises(AssertionError, queue_is_done, queue) + + queue.pop() + assert queue_is_full(queue) + assert queue_is_done(queue) + + def test_sources_logic(self): + sources = {} + ids = ['a', 'b', 'c'] + for id in ids: + sources[id] = deque() + + assert not full(sources) + assert not done(sources) + + # All sources must have a message to be full/done + sources['a'].append(to_dt("datetime")) + assert not full(sources) + assert not done(sources) + sources['a'].pop() + + for id in ids: + sources[id].append(to_dt("datetime")) + + assert full(sources) + assert not done(sources) + + for id in ids: + sources[id].appendleft(to_dt("DONE")) + + # ["DONE", message] will trip an assert in queue_is_done. + assert full(sources) + self.assertRaises(AssertionError, done, sources) + + for id in ids: + sources[id].pop() + + assert full(sources) + assert done(sources) + +class FeedGenTestCase(TestCase): + + def setUp(self): + pass + + + def tearDown(self): + pass + + def run_FeedGen(self, events, expected, source_ids): + """ + Take a list of events, their source_ids, and an expected sorting. + Assert that FeedGen's output agrees with expected. + """ + feed_gen = FeedGen(events, source_ids) + assert list(feed_gen) == expected + + + def test_single_source(self): + source_ids = ['a'] + # 100 events, increasing by a minute at a time. + type = zp.DATASOURCE_TYPE.TRADE + dates = list(date_gen(n = 1)) + dates.append("DONE") + + # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] + event_args = zip(cycle(source_ids), iter(dates), cycle([type])) + + # Turn event_args into proper events. + events = [mock_data_unframe(*args) for args in event_args] + + # We don't expected Feed to yield the last event. + expected = events[:-1] + + event_gen = (e for e in events) + + self.run_FeedGen(event_gen, expected, source_ids) + + def test_multi_source_interleaved(self): + source_ids = ['a', 'b'] + type = zp.DATASOURCE_TYPE.TRADE + + # Set up source 'a'. Outputs 3 events with 2 minute deltas. + delta_a = timedelta(minutes = 2) + dates_a = list(date_gen(delta = delta_a, n = 3)) + dates_a.append("DONE") + + events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) + events_a = [mock_data_unframe(*args) for args in events_a_args] + event_gen_a = (e for e in events_a) + + # Set up source 'b'. Outputs 4 events with 1 minute deltas. + delta_b = timedelta(minutes = 1) + dates_b = list(date_gen(delta = delta_b, n = 4)) + dates_b.append("DONE") + + events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) + events_b = [mock_data_unframe(*args) for args in events_b_args] + event_gen_b = (e for e in events_b) + + + # The expected output is all non-DONE events in both a and b, + # sorted first by dt and then by source_id. + non_dones = events_a[:-1] + events_b[:-1] + expected = sorted(non_dones, compare_by_dt_source_id) + + import nose.tools; nose.tools.set_trace() + self.run_FeedGen(event_gen, expected, source_ids) + + +# def test_FeedGen_consistency(self): + +# source_ids = ['a', 'b'] +# multiplied = source_ids * 5 +# perms = itertools.permutations(multiplied, 10) +# self.type = zp.DATASOURCE_TYPE.TRADE + +# self.events = (mock_data_unframe(id, + + +def mock_data_unframe(source_id, dt, type): + event = ndict() + event.source_id = source_id + event.dt = dt + event.type = type + return event + +def to_dt(val): + return ndict({'dt': val}) + +def compare_by_dt_source_id(x,y): + if x.dt < y.dt: + return -1 + elif x.dt > y.dt: + return 1 + + elif x.source_id < y.source_id: + return -1 + elif x.source_id > y.source_id: + return 1 + + else: + return 0 + + + + diff --git a/zipline/gens/test_mongods.py b/zipline/gens/test_mongods.py new file mode 100644 index 00000000..7260fd7d --- /dev/null +++ b/zipline/gens/test_mongods.py @@ -0,0 +1,123 @@ +import os + +import uuid +import msgpack +import pytz + +from unittest2 import TestCase +from pymongo import Connection, ASCENDING +from itertools import izip, izip_longest +from datetime import datetime, timedelta + +from zipline.gens.mongods import create_pymongo_iterator, MongoTradeHistoryGen +from zipline.gens.utils import stringify_args, assert_datasource_protocol,\ + assert_trade_protocol, mock_raw_event + +import zipline.protocol as zp + +def mock_raw_event(sid, dt): + event = { + 'sid' : sid, + 'dt' : dt, + 'price' : 1.0, + 'volume' : 1 + } + return event + +mongo_conn_args = { + 'mongodb_host': 'localhost', + 'mongodb_port': 27017, +} + +class TempMongo(object): + + def __enter__(self): + self.conn = Connection(mongo_conn_args['mongodb_host'], + mongo_conn_args['mongodb_port']) + + temp_id = 'qexec_test_id' + + self.db = self.conn[temp_id] + + return self + + def __exit__(self, type, value, traceback): + self.conn.drop_database(self.db.name) + +class TestMongoDataGenerator(TestCase): + + def setUp(self): + pass + def tearDown(self): + pass + + def test_create_pymongo_iterator(self): + + with TempMongo() as temp_mongo: + db = temp_mongo.db + coll = db.test + coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)]) + + for i in xrange(100): + # sid = 1, dt ranging from 0 to 99 + coll.insert(mock_raw_event(1, i)) + + start_date = 20 + end_date = 50 + filter = {'sid' : [1]} + args = (coll, filter, start_date, end_date) + + cursor = create_pymongo_iterator(*args) + # We filter to only get dt's between 20 and 50 + expected = (mock_raw_event(1, i) for i in xrange(20, 51)) + + # Assert that our iterator returns the expected values. + for cursor_event, expected_event in izip_longest(cursor, expected): + del cursor_event['_id'] + # Easiest way to convert unicode to strings. + cursor_event = msgpack.loads(msgpack.dumps(cursor_event)) + assert expected_event.keys() == cursor_event.keys() + assert expected_event.values() == cursor_event.values() + + def test_MongoTradeHistoryGen(self): + + with TempMongo() as temp_mongo: + db = temp_mongo.db + coll = db.test + coll.ensure_index([('dt', ASCENDING), ('sid', ASCENDING)]) + + start_date = datetime(year = 2012,month=6,day=5,hour=0) + delta = timedelta(hours = 1) + + for i in xrange(100): + # sid = 1, dt's increasing an hour at a time from start + time = start_date + i * delta + coll.insert(mock_raw_event(1, time)) + + # Halfway through the events we added to db. + end_date = start_date + delta * 50 + + filter = {'sid' : [1]} + args = (coll, filter, start_date, end_date) + db_gen = MongoTradeHistoryGen(*args) + + expected_times = (start_date + i * delta for i in xrange(51)) + expected_events = (mock_raw_event(1, t) for t in expected_times) + + # DB events should match the expected events for price, dt, volume, + # and sid. They should also conform to the trade frame protocol. + + for db, expected in izip_longest(db_gen, expected_events): + expected['dt'] = expected['dt'].replace(tzinfo = pytz.utc) + # Check that our output meets the trade protocol. + assert_trade_protocol(db) + + # Check that our output matches expectations + for field in iter(['sid', 'dt', 'price', 'volume']): + assert db[field] == expected[field] + + # Expected output of stringify_args: + assert db['source_id'] == \ + 'MongoTradeHistoryGen983a27fd0710414239a5cde71ef5a8fc' + + diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 27454fb7..fd1c8464 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -2,11 +2,22 @@ import pytz import numbers from hashlib import md5 -from datetime import datetime +from datetime import datetime, timedelta from zipline import ndict from zipline.protocol import DATASOURCE_TYPE +def mock_raw_event(sid, dt): + event = { + 'sid' : sid, + 'dt' : dt, + 'price' : 1.0, + 'volume' : 1 + } + return event + +def date_gen(start = datetime(2012, 6, 6, 0), delta = timedelta(minutes = 1), n = 100): + return (start + i * delta for i in xrange(n)) def stringify_args(*args, **kwargs): """Define a unique string for any set of representable args.""" @@ -18,16 +29,19 @@ def stringify_args(*args, **kwargs): hasher.update(combined) return hasher.hexdigest() - def assert_datasource_protocol(event): """Assert that an event meets the protocol for datasource outputs.""" assert isinstance(event, ndict) assert isinstance(event.source_id, basestring) - assert isinstance(event.dt, datetime) - assert event.dt.tzinfo == pytz.utc assert event.type in DATASOURCE_TYPE + # Done packets have no dt. + if not event.type == DATASOURCE_TYPE.DONE: + assert isinstance(event.dt, datetime) + assert event.dt.tzinfo == pytz.utc + + def assert_trade_protocol(event): """Assert that an event meets the protocol for datasource TRADE outputs.""" assert_datasource_protocol(event) @@ -37,4 +51,15 @@ def assert_trade_protocol(event): assert isinstance(event.sid, int) assert isinstance(event.price, numbers.Real) assert isinstance(event.volume, numbers.Integral) + assert isinstance(event.dt, datetime) + +def assert_datasource_unframe_protocol(event): + """Assert that an event is valid output of zp.DATASOURCE_UNFRAME.""" + assert isinstance(event, ndict) + assert isinstance(event.source_id, basestring) + assert event.type in DATASOURCE_TYPE + assert event.has_key('dt') + +def assert_feed_protocol(event): + pass diff --git a/zipline/protocol.py b/zipline/protocol.py index 98245981..cf1aa8fd 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -227,6 +227,7 @@ def DATASOURCE_FRAME(event): - *ds_type* a string denoting the datasource type. Must be on of: - TRADE + - DONE - (others to follow soon) - *payload* a msgpack string carrying the payload for the frame @@ -235,19 +236,26 @@ def DATASOURCE_FRAME(event): assert isinstance(event.type, int), 'Unexpected type %s' % (event.type) #datasources will send sometimes send empty msgs to feel gaps - if len(event.keys()) == 2: + if (event.type == DATASOURCE_TYPE.EMPTY): return msgpack.dumps(tuple([ event.type, event.source_id, - DATASOURCE_TYPE.EMPTY + "EMPTY" ])) - if(event.type == DATASOURCE_TYPE.TRADE): + elif(event.type == DATASOURCE_TYPE.TRADE): return msgpack.dumps(tuple([ event.type, event.source_id, TRADE_FRAME(event) ])) + + elif(event.type == DATASOURCE_TYPE.DONE): + return msgpack.dumps(tuple([ + event.type, + event.source_id, + "DONE" + ])) else: raise INVALID_DATASOURCE_FRAME(str(event)) @@ -259,8 +267,9 @@ def DATASOURCE_UNFRAME(msg): Returns a dict containing at least: - - source_id - - type + - source_id: instance-unique string + - type: datasource type + - dt: None, 'DONE' or a datetime object other properties are added based on the datasource type: @@ -282,6 +291,8 @@ def DATASOURCE_UNFRAME(msg): child_value = ndict({'dt':None}) elif(ds_type == DATASOURCE_TYPE.TRADE): child_value = TRADE_UNFRAME(payload) + elif(ds_type == DATASOURCE_TYPE.DONE): + child_value = ndict({'dt' : 'DONE'}) else: raise INVALID_DATASOURCE_FRAME(msg) @@ -593,6 +604,7 @@ def tuple_to_date(date_tuple): DATASOURCE_TYPE = Enum( 'TRADE', 'EMPTY', + 'DONE' )