diff --git a/zipline/protocol.py b/zipline/protocol.py index 01b71c59..2b60fbc9 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -69,10 +69,12 @@ def FrameExceptionFactory(name): def __init__(self, got): self.got = got def __str__(self): - return "Invalid {framcls} Frame: {got}".format( + return "Invalid {framecls} Frame: {got}".format( framecls = name, got = self.got, ) + + return InvalidFrame class namedict(object): """ @@ -97,8 +99,14 @@ class namedict(object): self.__dict__[key] = value def merge(self, other_nd): - assert isinstance(namedict, other_nd) + assert isinstance(other_nd, namedict) self.__dict__.update(other_nd.__dict__) + + def __repr__(self): + return "namedict: " + str(self.__dict__) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ # ================ # Control Protocol @@ -166,9 +174,9 @@ COMPONENT_STATE = Enum( # Datasource Protocol # ================== -INVALID_DATASOURCE_FRAME = FrameExceptionFactory('ORDER') +INVALID_DATASOURCE_FRAME = FrameExceptionFactory('DATASOURCE') -def DATASOURCE_FRAME(ds_id, ds_type, payload): +def DATASOURCE_FRAME(event): """ wraps any datasource payload with id and type, so that unpacking may choose the write UNFRAME for the payload. @@ -178,11 +186,13 @@ def DATASOURCE_FRAME(ds_id, ds_type, payload): (others to follow soon) ::payload:: a msgpack string carrying the payload for the frame """ - assert isinstance(ds_id, basestring) - assert isinstance(ds_type, basestring) - assert isinstance(payload, basestring) - return msgpack.dumps(tuple([ds_id, ds_type, payload])) - + assert isinstance(event.source_id, basestring) + assert isinstance(event.type, basestring) + if(event.type == "TRADE"): + return msgpack.dumps(tuple([event.type, TRADE_FRAME(event)])) + else: + raise INVALID_DATASOURCE_FRAME(str(event)) + def DATASOURCE_UNFRAME(msg): """ extracts payload, and calls correct UNFRAME method based on the datasource type passed along @@ -197,11 +207,12 @@ def DATASOURCE_UNFRAME(msg): - dt - a datetime object """ try: - ds_id, ds_type, payload = msgpack.loads(msg) + qutil.LOGGER.info("unpacking {msg}".format(msg = msg)) + ds_type, payload = msgpack.loads(msg) + qutil.LOGGER.info("unpacked a datasource frame! {ds_type} - {payload}".format(ds_type=ds_type, payload=payload) ) + assert isinstance(ds_type, basestring) if(ds_type == "TRADE"): - result = {'source_id' : ds_id, 'type' : ds_type} - result.update(TRADE_UNFRAME(payload)) - return namedict(result) + return TRADE_UNFRAME(payload) else: raise INVALID_DATASOURCE_FRAME(msg) @@ -225,7 +236,6 @@ def FEED_FRAME(event): source_id = event.source_id ds_type = event.type pack_date(event) - del(event.__dict__['dt']) payload = event.__dict__ return msgpack.dumps(payload) @@ -236,11 +246,11 @@ def FEED_UNFRAME(msg): assert isinstance(payload, dict) rval = namedict(payload) unpack_date(rval) - return namedict(rval) + return rval except TypeError: - raise INVALID_TRADE_FRAME(msg) + raise INVALID_FEED_FRAME(msg) except ValueError: - raise INVALID_TRADE_FRAME(msg) + raise INVALID_FEED_FRAME(msg) # ================== # Transform Protocol @@ -262,7 +272,8 @@ def TRANSFORM_UNFRAME(msg): name, value = msgpack.loads(msg) #TODO: anything we can do to assert more about the content of the dict? assert isinstance(name, basestring) - assert payload.has_key('value') + if(name == "PASSTHROUGH"): + value = FEED_UNFRAME(value) return namedict({name : value}) except TypeError: raise INVALID_TRANSFORM_FRAME(msg) @@ -282,8 +293,8 @@ def MERGE_FRAME(event): - type """ assert isinstance(event, namedict) - source_id = event.source_id - ds_type = event.type + assert isinstance(event.dt, datetime.datetime) + pack_date(event) payload = event.__dict__ return msgpack.dumps(payload) @@ -292,11 +303,15 @@ def MERGE_UNFRAME(msg): payload = msgpack.loads(msg) #TODO: anything we can do to assert more about the content of the dict? assert isinstance(payload, dict) - return namedict(payload) + payload = namedict(payload) + assert isinstance(payload.epoch, numbers.Integral) + assert isinstance(payload.micros, numbers.Integral) + unpack_date(payload) + return payload except TypeError: - raise INVALID_TRADE_FRAME(msg) + raise INVALID_MERGE_FRAME(msg) except ValueError: - raise INVALID_TRADE_FRAME(msg) + raise INVALID_MERGE_FRAME(msg) # ================== @@ -316,26 +331,32 @@ def TRADE_FRAME(event): - sid -- the security id - price -- float of the price printed for the trade - volume -- int for shares in the trade + - dt -- datetime for the trade """ assert isinstance(event, namedict) + assert isinstance(event.source_id, basestring) + assert event.type == "TRADE" assert isinstance(event.sid, int) assert isinstance(event.price, float) assert isinstance(event.volume, int) pack_date(event) - payload = msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros])) - return DATASOURCE_FRAME(ds_id, "TRADE", payload) + qutil.LOGGER.info("event is: {event}".format(event=event.__dict__)) + return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id])) def TRADE_UNFRAME(msg): try: - sid, price, volume, epoch, micros = msgpack.loads(msg) + qutil.LOGGER.info("about to unpack TRADE: {trade}".format(trade=msg)) + sid, price, volume, epoch, micros, source_type, source_id = msgpack.loads(msg) + assert isinstance(sid, int) assert isinstance(price, float) assert isinstance(volume, int) assert isinstance(epoch, numbers.Integral) assert isinstance(micros, numbers.Integral) - rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'dt' : dt, 'epoch' : epoch, 'micros' : micros}) + rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id}) unpack_date(rval) + qutil.LOGGER.info("unpacked Trade: {trade}".format(trade=rval.__dict__)) return rval except TypeError: raise INVALID_TRADE_FRAME(msg) @@ -371,18 +392,18 @@ def ORDER_UNFRAME(msg): def pack_date(event): assert isinstance(event.dt, datetime.datetime) assert event.dt.tzinfo == pytz.utc #utc only please - epoch = long(dt.strftime('%s')) + epoch = long(event.dt.strftime('%s')) event['epoch'] = epoch event['micros'] = event.dt.microsecond del(event.__dict__['dt']) return event def unpack_date(payload): - assert isinstance(payload['epoch'], numbers.Integral) - assert isinstance(payload['micros'], numbers.Integral) - dt = datetime.datetime.fromtimestamp(payload['epoch']) - dt.replace(microsecond = payload['micros'], tzinfo = pytz.utc) - del(payload['epoch']) - del(payload['micros']) + assert isinstance(payload.epoch, numbers.Integral) + assert isinstance(payload.micros, numbers.Integral) + dt = datetime.datetime.fromtimestamp(payload.epoch) + dt = dt.replace(microsecond = payload.micros, tzinfo = pytz.utc) + del(payload.__dict__['epoch']) + del(payload.__dict__['micros']) payload['dt'] = dt return payload \ No newline at end of file diff --git a/zipline/sources.py b/zipline/sources.py index 55587b3a..7c8e7a59 100644 --- a/zipline/sources.py +++ b/zipline/sources.py @@ -10,11 +10,11 @@ import zipline.protocol as zp class TradeDataSource(zm.DataSource): - def send(event): + def send(self, event): """ :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME` :rtype: None """ - message = zp.TRADE_FRAME(self.get_id, sid, price, volume, dt) + message = zp.TRADE_FRAME(self.get_id, event) self.data_socket.send(message) class RandomEquityTrades(TradeDataSource): diff --git a/zipline/test/factory.py b/zipline/test/factory.py index 691f909b..a30979e5 100644 --- a/zipline/test/factory.py +++ b/zipline/test/factory.py @@ -61,6 +61,8 @@ def getCodeFromFile(filename): def create_trade(sid, price, amount, datetime): row = {} + row['source_id'] = "test_factory" + row['type'] = "TRADE" row['sid'] = sid row['dt'] = datetime row['price'] = price diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index f156d68f..e0ee46c4 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -8,6 +8,7 @@ import zipline.test.factory as factory import zipline.util as qutil import zipline.db as db import zipline.finance.risk as risk +import zipline.protocol as zp from zipline.test.client import TestTradingClient from zipline.test.dummy import ThreadPoolExecutorMixin @@ -16,6 +17,50 @@ from zipline.sources import SpecificEquityTrades class FinanceTestCase(ThreadPoolExecutorMixin, TestCase): + def test_trade_protocol(self): + trades = factory.create_trade_history(133, + [10.0,10.0,10.0,10.0], + [100,100,100,100], + datetime.datetime.strptime("02/15/2012","%m/%d/%Y"), + datetime.timedelta(days=1)) + for trade in trades: + msg = zp.TRADE_FRAME("fake_source", zp.namedict(trade)) + recovered_trade = zp.DATASOURCE_UNFRAME(msg) + self.assertTrue(recovered_trade.type == "TRADE") + self.assertTrue(recovered_trade.source_id == "fake_source") + del(recovered_trade.__dict__['type']) + del(recovered_trade.__dict__['source_id']) + self.assertEqual(zp.namedict(trade), recovered_trade) + + def test_trade_feed_protocol(self): + trades = factory.create_trade_history(133, + [10.0,10.0,10.0,10.0], + [100,100,100,100], + datetime.datetime.strptime("02/15/2012","%m/%d/%Y"), + datetime.timedelta(days=1)) + for trade in trades: + #simulate data source sending frame + msg = zp.DATASOURCE_FRAME(zp.namedict(trade)) + #feed unpacking frame + recovered_trade = zp.DATASOURCE_UNFRAME(msg) + #feed sending frame + feed_msg = zp.FEED_FRAME(recovered_trade) + #transform unframing + recovered_feed = zp.FEED_UNFRAME(feed_msg) + #do a transform + trans_msg = zp.TRANSFORM_FRAME('helloworld', 2345.6) + #simulate passthrough transform -- passthrough shouldn't even unpack the msg, just resend. + passthrough_msg = zp.TRANSFORM_FRAME('PASSTHROUGH', feed_msg) + #merge unframes transform and passthrough + trans_recovered = zp.TRANSFORM_UNFRAME(trans_msg) + pt_recovered = zp.TRANSFORM_UNFRAME(passthrough_msg) + #simulated merge + pt_recovered.PASSTHROUGH.merge(trans_recovered) + #frame the merged event + merged_msg = zp.MERGE_FRAME(pt_recovered.PASSTHROUGH) + #unframe the merge and validate values + event = zp.MERGE_UNFRAME(merged_msg) + def test_trading_calendar(self): known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y") known_holiday = datetime.datetime.strptime("02/20/2012", "%m/%d/%Y") #president's day