added tests for the protocol for trades and for end to end including transforms and merge.

This commit is contained in:
fawce
2012-02-28 16:26:22 -05:00
parent f1a7e41d1f
commit eb4071be9b
4 changed files with 104 additions and 36 deletions
+55 -34
View File
@@ -69,10 +69,12 @@ def FrameExceptionFactory(name):
def __init__(self, got):
self.got = got
def __str__(self):
return "Invalid {framcls} Frame: {got}".format(
return "Invalid {framecls} Frame: {got}".format(
framecls = name,
got = self.got,
)
return InvalidFrame
class namedict(object):
"""
@@ -97,8 +99,14 @@ class namedict(object):
self.__dict__[key] = value
def merge(self, other_nd):
assert isinstance(namedict, other_nd)
assert isinstance(other_nd, namedict)
self.__dict__.update(other_nd.__dict__)
def __repr__(self):
return "namedict: " + str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# ================
# Control Protocol
@@ -166,9 +174,9 @@ COMPONENT_STATE = Enum(
# Datasource Protocol
# ==================
INVALID_DATASOURCE_FRAME = FrameExceptionFactory('ORDER')
INVALID_DATASOURCE_FRAME = FrameExceptionFactory('DATASOURCE')
def DATASOURCE_FRAME(ds_id, ds_type, payload):
def DATASOURCE_FRAME(event):
"""
wraps any datasource payload with id and type, so that unpacking may choose the write
UNFRAME for the payload.
@@ -178,11 +186,13 @@ def DATASOURCE_FRAME(ds_id, ds_type, payload):
(others to follow soon)
::payload:: a msgpack string carrying the payload for the frame
"""
assert isinstance(ds_id, basestring)
assert isinstance(ds_type, basestring)
assert isinstance(payload, basestring)
return msgpack.dumps(tuple([ds_id, ds_type, payload]))
assert isinstance(event.source_id, basestring)
assert isinstance(event.type, basestring)
if(event.type == "TRADE"):
return msgpack.dumps(tuple([event.type, TRADE_FRAME(event)]))
else:
raise INVALID_DATASOURCE_FRAME(str(event))
def DATASOURCE_UNFRAME(msg):
"""
extracts payload, and calls correct UNFRAME method based on the datasource type passed along
@@ -197,11 +207,12 @@ def DATASOURCE_UNFRAME(msg):
- dt - a datetime object
"""
try:
ds_id, ds_type, payload = msgpack.loads(msg)
qutil.LOGGER.info("unpacking {msg}".format(msg = msg))
ds_type, payload = msgpack.loads(msg)
qutil.LOGGER.info("unpacked a datasource frame! {ds_type} - {payload}".format(ds_type=ds_type, payload=payload) )
assert isinstance(ds_type, basestring)
if(ds_type == "TRADE"):
result = {'source_id' : ds_id, 'type' : ds_type}
result.update(TRADE_UNFRAME(payload))
return namedict(result)
return TRADE_UNFRAME(payload)
else:
raise INVALID_DATASOURCE_FRAME(msg)
@@ -225,7 +236,6 @@ def FEED_FRAME(event):
source_id = event.source_id
ds_type = event.type
pack_date(event)
del(event.__dict__['dt'])
payload = event.__dict__
return msgpack.dumps(payload)
@@ -236,11 +246,11 @@ def FEED_UNFRAME(msg):
assert isinstance(payload, dict)
rval = namedict(payload)
unpack_date(rval)
return namedict(rval)
return rval
except TypeError:
raise INVALID_TRADE_FRAME(msg)
raise INVALID_FEED_FRAME(msg)
except ValueError:
raise INVALID_TRADE_FRAME(msg)
raise INVALID_FEED_FRAME(msg)
# ==================
# Transform Protocol
@@ -262,7 +272,8 @@ def TRANSFORM_UNFRAME(msg):
name, value = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(name, basestring)
assert payload.has_key('value')
if(name == "PASSTHROUGH"):
value = FEED_UNFRAME(value)
return namedict({name : value})
except TypeError:
raise INVALID_TRANSFORM_FRAME(msg)
@@ -282,8 +293,8 @@ def MERGE_FRAME(event):
- type
"""
assert isinstance(event, namedict)
source_id = event.source_id
ds_type = event.type
assert isinstance(event.dt, datetime.datetime)
pack_date(event)
payload = event.__dict__
return msgpack.dumps(payload)
@@ -292,11 +303,15 @@ def MERGE_UNFRAME(msg):
payload = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
return namedict(payload)
payload = namedict(payload)
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
unpack_date(payload)
return payload
except TypeError:
raise INVALID_TRADE_FRAME(msg)
raise INVALID_MERGE_FRAME(msg)
except ValueError:
raise INVALID_TRADE_FRAME(msg)
raise INVALID_MERGE_FRAME(msg)
# ==================
@@ -316,26 +331,32 @@ def TRADE_FRAME(event):
- sid -- the security id
- price -- float of the price printed for the trade
- volume -- int for shares in the trade
- dt -- datetime for the trade
"""
assert isinstance(event, namedict)
assert isinstance(event.source_id, basestring)
assert event.type == "TRADE"
assert isinstance(event.sid, int)
assert isinstance(event.price, float)
assert isinstance(event.volume, int)
pack_date(event)
payload = msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros]))
return DATASOURCE_FRAME(ds_id, "TRADE", payload)
qutil.LOGGER.info("event is: {event}".format(event=event.__dict__))
return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id]))
def TRADE_UNFRAME(msg):
try:
sid, price, volume, epoch, micros = msgpack.loads(msg)
qutil.LOGGER.info("about to unpack TRADE: {trade}".format(trade=msg))
sid, price, volume, epoch, micros, source_type, source_id = msgpack.loads(msg)
assert isinstance(sid, int)
assert isinstance(price, float)
assert isinstance(volume, int)
assert isinstance(epoch, numbers.Integral)
assert isinstance(micros, numbers.Integral)
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'dt' : dt, 'epoch' : epoch, 'micros' : micros})
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id})
unpack_date(rval)
qutil.LOGGER.info("unpacked Trade: {trade}".format(trade=rval.__dict__))
return rval
except TypeError:
raise INVALID_TRADE_FRAME(msg)
@@ -371,18 +392,18 @@ def ORDER_UNFRAME(msg):
def pack_date(event):
assert isinstance(event.dt, datetime.datetime)
assert event.dt.tzinfo == pytz.utc #utc only please
epoch = long(dt.strftime('%s'))
epoch = long(event.dt.strftime('%s'))
event['epoch'] = epoch
event['micros'] = event.dt.microsecond
del(event.__dict__['dt'])
return event
def unpack_date(payload):
assert isinstance(payload['epoch'], numbers.Integral)
assert isinstance(payload['micros'], numbers.Integral)
dt = datetime.datetime.fromtimestamp(payload['epoch'])
dt.replace(microsecond = payload['micros'], tzinfo = pytz.utc)
del(payload['epoch'])
del(payload['micros'])
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
dt = datetime.datetime.fromtimestamp(payload.epoch)
dt = dt.replace(microsecond = payload.micros, tzinfo = pytz.utc)
del(payload.__dict__['epoch'])
del(payload.__dict__['micros'])
payload['dt'] = dt
return payload
+2 -2
View File
@@ -10,11 +10,11 @@ import zipline.protocol as zp
class TradeDataSource(zm.DataSource):
def send(event):
def send(self, event):
""" :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
message = zp.TRADE_FRAME(self.get_id, sid, price, volume, dt)
message = zp.TRADE_FRAME(self.get_id, event)
self.data_socket.send(message)
class RandomEquityTrades(TradeDataSource):
+2
View File
@@ -61,6 +61,8 @@ def getCodeFromFile(filename):
def create_trade(sid, price, amount, datetime):
row = {}
row['source_id'] = "test_factory"
row['type'] = "TRADE"
row['sid'] = sid
row['dt'] = datetime
row['price'] = price
+45
View File
@@ -8,6 +8,7 @@ import zipline.test.factory as factory
import zipline.util as qutil
import zipline.db as db
import zipline.finance.risk as risk
import zipline.protocol as zp
from zipline.test.client import TestTradingClient
from zipline.test.dummy import ThreadPoolExecutorMixin
@@ -16,6 +17,50 @@ from zipline.sources import SpecificEquityTrades
class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
def test_trade_protocol(self):
trades = factory.create_trade_history(133,
[10.0,10.0,10.0,10.0],
[100,100,100,100],
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
datetime.timedelta(days=1))
for trade in trades:
msg = zp.TRADE_FRAME("fake_source", zp.namedict(trade))
recovered_trade = zp.DATASOURCE_UNFRAME(msg)
self.assertTrue(recovered_trade.type == "TRADE")
self.assertTrue(recovered_trade.source_id == "fake_source")
del(recovered_trade.__dict__['type'])
del(recovered_trade.__dict__['source_id'])
self.assertEqual(zp.namedict(trade), recovered_trade)
def test_trade_feed_protocol(self):
trades = factory.create_trade_history(133,
[10.0,10.0,10.0,10.0],
[100,100,100,100],
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
datetime.timedelta(days=1))
for trade in trades:
#simulate data source sending frame
msg = zp.DATASOURCE_FRAME(zp.namedict(trade))
#feed unpacking frame
recovered_trade = zp.DATASOURCE_UNFRAME(msg)
#feed sending frame
feed_msg = zp.FEED_FRAME(recovered_trade)
#transform unframing
recovered_feed = zp.FEED_UNFRAME(feed_msg)
#do a transform
trans_msg = zp.TRANSFORM_FRAME('helloworld', 2345.6)
#simulate passthrough transform -- passthrough shouldn't even unpack the msg, just resend.
passthrough_msg = zp.TRANSFORM_FRAME('PASSTHROUGH', feed_msg)
#merge unframes transform and passthrough
trans_recovered = zp.TRANSFORM_UNFRAME(trans_msg)
pt_recovered = zp.TRANSFORM_UNFRAME(passthrough_msg)
#simulated merge
pt_recovered.PASSTHROUGH.merge(trans_recovered)
#frame the merged event
merged_msg = zp.MERGE_FRAME(pt_recovered.PASSTHROUGH)
#unframe the merge and validate values
event = zp.MERGE_UNFRAME(merged_msg)
def test_trading_calendar(self):
known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y")
known_holiday = datetime.datetime.strptime("02/20/2012", "%m/%d/%Y") #president's day