mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 07:08:37 +08:00
added tests for the protocol for trades and for end to end including transforms and merge.
This commit is contained in:
+55
-34
@@ -69,10 +69,12 @@ def FrameExceptionFactory(name):
|
||||
def __init__(self, got):
|
||||
self.got = got
|
||||
def __str__(self):
|
||||
return "Invalid {framcls} Frame: {got}".format(
|
||||
return "Invalid {framecls} Frame: {got}".format(
|
||||
framecls = name,
|
||||
got = self.got,
|
||||
)
|
||||
|
||||
return InvalidFrame
|
||||
|
||||
class namedict(object):
|
||||
"""
|
||||
@@ -97,8 +99,14 @@ class namedict(object):
|
||||
self.__dict__[key] = value
|
||||
|
||||
def merge(self, other_nd):
|
||||
assert isinstance(namedict, other_nd)
|
||||
assert isinstance(other_nd, namedict)
|
||||
self.__dict__.update(other_nd.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return "namedict: " + str(self.__dict__)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
# ================
|
||||
# Control Protocol
|
||||
@@ -166,9 +174,9 @@ COMPONENT_STATE = Enum(
|
||||
# Datasource Protocol
|
||||
# ==================
|
||||
|
||||
INVALID_DATASOURCE_FRAME = FrameExceptionFactory('ORDER')
|
||||
INVALID_DATASOURCE_FRAME = FrameExceptionFactory('DATASOURCE')
|
||||
|
||||
def DATASOURCE_FRAME(ds_id, ds_type, payload):
|
||||
def DATASOURCE_FRAME(event):
|
||||
"""
|
||||
wraps any datasource payload with id and type, so that unpacking may choose the write
|
||||
UNFRAME for the payload.
|
||||
@@ -178,11 +186,13 @@ def DATASOURCE_FRAME(ds_id, ds_type, payload):
|
||||
(others to follow soon)
|
||||
::payload:: a msgpack string carrying the payload for the frame
|
||||
"""
|
||||
assert isinstance(ds_id, basestring)
|
||||
assert isinstance(ds_type, basestring)
|
||||
assert isinstance(payload, basestring)
|
||||
return msgpack.dumps(tuple([ds_id, ds_type, payload]))
|
||||
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert isinstance(event.type, basestring)
|
||||
if(event.type == "TRADE"):
|
||||
return msgpack.dumps(tuple([event.type, TRADE_FRAME(event)]))
|
||||
else:
|
||||
raise INVALID_DATASOURCE_FRAME(str(event))
|
||||
|
||||
def DATASOURCE_UNFRAME(msg):
|
||||
"""
|
||||
extracts payload, and calls correct UNFRAME method based on the datasource type passed along
|
||||
@@ -197,11 +207,12 @@ def DATASOURCE_UNFRAME(msg):
|
||||
- dt - a datetime object
|
||||
"""
|
||||
try:
|
||||
ds_id, ds_type, payload = msgpack.loads(msg)
|
||||
qutil.LOGGER.info("unpacking {msg}".format(msg = msg))
|
||||
ds_type, payload = msgpack.loads(msg)
|
||||
qutil.LOGGER.info("unpacked a datasource frame! {ds_type} - {payload}".format(ds_type=ds_type, payload=payload) )
|
||||
assert isinstance(ds_type, basestring)
|
||||
if(ds_type == "TRADE"):
|
||||
result = {'source_id' : ds_id, 'type' : ds_type}
|
||||
result.update(TRADE_UNFRAME(payload))
|
||||
return namedict(result)
|
||||
return TRADE_UNFRAME(payload)
|
||||
else:
|
||||
raise INVALID_DATASOURCE_FRAME(msg)
|
||||
|
||||
@@ -225,7 +236,6 @@ def FEED_FRAME(event):
|
||||
source_id = event.source_id
|
||||
ds_type = event.type
|
||||
pack_date(event)
|
||||
del(event.__dict__['dt'])
|
||||
payload = event.__dict__
|
||||
return msgpack.dumps(payload)
|
||||
|
||||
@@ -236,11 +246,11 @@ def FEED_UNFRAME(msg):
|
||||
assert isinstance(payload, dict)
|
||||
rval = namedict(payload)
|
||||
unpack_date(rval)
|
||||
return namedict(rval)
|
||||
return rval
|
||||
except TypeError:
|
||||
raise INVALID_TRADE_FRAME(msg)
|
||||
raise INVALID_FEED_FRAME(msg)
|
||||
except ValueError:
|
||||
raise INVALID_TRADE_FRAME(msg)
|
||||
raise INVALID_FEED_FRAME(msg)
|
||||
|
||||
# ==================
|
||||
# Transform Protocol
|
||||
@@ -262,7 +272,8 @@ def TRANSFORM_UNFRAME(msg):
|
||||
name, value = msgpack.loads(msg)
|
||||
#TODO: anything we can do to assert more about the content of the dict?
|
||||
assert isinstance(name, basestring)
|
||||
assert payload.has_key('value')
|
||||
if(name == "PASSTHROUGH"):
|
||||
value = FEED_UNFRAME(value)
|
||||
return namedict({name : value})
|
||||
except TypeError:
|
||||
raise INVALID_TRANSFORM_FRAME(msg)
|
||||
@@ -282,8 +293,8 @@ def MERGE_FRAME(event):
|
||||
- type
|
||||
"""
|
||||
assert isinstance(event, namedict)
|
||||
source_id = event.source_id
|
||||
ds_type = event.type
|
||||
assert isinstance(event.dt, datetime.datetime)
|
||||
pack_date(event)
|
||||
payload = event.__dict__
|
||||
return msgpack.dumps(payload)
|
||||
|
||||
@@ -292,11 +303,15 @@ def MERGE_UNFRAME(msg):
|
||||
payload = msgpack.loads(msg)
|
||||
#TODO: anything we can do to assert more about the content of the dict?
|
||||
assert isinstance(payload, dict)
|
||||
return namedict(payload)
|
||||
payload = namedict(payload)
|
||||
assert isinstance(payload.epoch, numbers.Integral)
|
||||
assert isinstance(payload.micros, numbers.Integral)
|
||||
unpack_date(payload)
|
||||
return payload
|
||||
except TypeError:
|
||||
raise INVALID_TRADE_FRAME(msg)
|
||||
raise INVALID_MERGE_FRAME(msg)
|
||||
except ValueError:
|
||||
raise INVALID_TRADE_FRAME(msg)
|
||||
raise INVALID_MERGE_FRAME(msg)
|
||||
|
||||
|
||||
# ==================
|
||||
@@ -316,26 +331,32 @@ def TRADE_FRAME(event):
|
||||
- sid -- the security id
|
||||
- price -- float of the price printed for the trade
|
||||
- volume -- int for shares in the trade
|
||||
- dt -- datetime for the trade
|
||||
|
||||
"""
|
||||
assert isinstance(event, namedict)
|
||||
assert isinstance(event.source_id, basestring)
|
||||
assert event.type == "TRADE"
|
||||
assert isinstance(event.sid, int)
|
||||
assert isinstance(event.price, float)
|
||||
assert isinstance(event.volume, int)
|
||||
pack_date(event)
|
||||
payload = msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros]))
|
||||
return DATASOURCE_FRAME(ds_id, "TRADE", payload)
|
||||
qutil.LOGGER.info("event is: {event}".format(event=event.__dict__))
|
||||
return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id]))
|
||||
|
||||
def TRADE_UNFRAME(msg):
|
||||
try:
|
||||
sid, price, volume, epoch, micros = msgpack.loads(msg)
|
||||
qutil.LOGGER.info("about to unpack TRADE: {trade}".format(trade=msg))
|
||||
sid, price, volume, epoch, micros, source_type, source_id = msgpack.loads(msg)
|
||||
|
||||
assert isinstance(sid, int)
|
||||
assert isinstance(price, float)
|
||||
assert isinstance(volume, int)
|
||||
assert isinstance(epoch, numbers.Integral)
|
||||
assert isinstance(micros, numbers.Integral)
|
||||
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'dt' : dt, 'epoch' : epoch, 'micros' : micros})
|
||||
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id})
|
||||
unpack_date(rval)
|
||||
qutil.LOGGER.info("unpacked Trade: {trade}".format(trade=rval.__dict__))
|
||||
return rval
|
||||
except TypeError:
|
||||
raise INVALID_TRADE_FRAME(msg)
|
||||
@@ -371,18 +392,18 @@ def ORDER_UNFRAME(msg):
|
||||
def pack_date(event):
|
||||
assert isinstance(event.dt, datetime.datetime)
|
||||
assert event.dt.tzinfo == pytz.utc #utc only please
|
||||
epoch = long(dt.strftime('%s'))
|
||||
epoch = long(event.dt.strftime('%s'))
|
||||
event['epoch'] = epoch
|
||||
event['micros'] = event.dt.microsecond
|
||||
del(event.__dict__['dt'])
|
||||
return event
|
||||
|
||||
def unpack_date(payload):
|
||||
assert isinstance(payload['epoch'], numbers.Integral)
|
||||
assert isinstance(payload['micros'], numbers.Integral)
|
||||
dt = datetime.datetime.fromtimestamp(payload['epoch'])
|
||||
dt.replace(microsecond = payload['micros'], tzinfo = pytz.utc)
|
||||
del(payload['epoch'])
|
||||
del(payload['micros'])
|
||||
assert isinstance(payload.epoch, numbers.Integral)
|
||||
assert isinstance(payload.micros, numbers.Integral)
|
||||
dt = datetime.datetime.fromtimestamp(payload.epoch)
|
||||
dt = dt.replace(microsecond = payload.micros, tzinfo = pytz.utc)
|
||||
del(payload.__dict__['epoch'])
|
||||
del(payload.__dict__['micros'])
|
||||
payload['dt'] = dt
|
||||
return payload
|
||||
+2
-2
@@ -10,11 +10,11 @@ import zipline.protocol as zp
|
||||
|
||||
class TradeDataSource(zm.DataSource):
|
||||
|
||||
def send(event):
|
||||
def send(self, event):
|
||||
""" :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME`
|
||||
:rtype: None
|
||||
"""
|
||||
message = zp.TRADE_FRAME(self.get_id, sid, price, volume, dt)
|
||||
message = zp.TRADE_FRAME(self.get_id, event)
|
||||
self.data_socket.send(message)
|
||||
|
||||
class RandomEquityTrades(TradeDataSource):
|
||||
|
||||
@@ -61,6 +61,8 @@ def getCodeFromFile(filename):
|
||||
|
||||
def create_trade(sid, price, amount, datetime):
|
||||
row = {}
|
||||
row['source_id'] = "test_factory"
|
||||
row['type'] = "TRADE"
|
||||
row['sid'] = sid
|
||||
row['dt'] = datetime
|
||||
row['price'] = price
|
||||
|
||||
@@ -8,6 +8,7 @@ import zipline.test.factory as factory
|
||||
import zipline.util as qutil
|
||||
import zipline.db as db
|
||||
import zipline.finance.risk as risk
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.test.client import TestTradingClient
|
||||
from zipline.test.dummy import ThreadPoolExecutorMixin
|
||||
@@ -16,6 +17,50 @@ from zipline.sources import SpecificEquityTrades
|
||||
|
||||
class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
|
||||
|
||||
def test_trade_protocol(self):
|
||||
trades = factory.create_trade_history(133,
|
||||
[10.0,10.0,10.0,10.0],
|
||||
[100,100,100,100],
|
||||
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
|
||||
datetime.timedelta(days=1))
|
||||
for trade in trades:
|
||||
msg = zp.TRADE_FRAME("fake_source", zp.namedict(trade))
|
||||
recovered_trade = zp.DATASOURCE_UNFRAME(msg)
|
||||
self.assertTrue(recovered_trade.type == "TRADE")
|
||||
self.assertTrue(recovered_trade.source_id == "fake_source")
|
||||
del(recovered_trade.__dict__['type'])
|
||||
del(recovered_trade.__dict__['source_id'])
|
||||
self.assertEqual(zp.namedict(trade), recovered_trade)
|
||||
|
||||
def test_trade_feed_protocol(self):
|
||||
trades = factory.create_trade_history(133,
|
||||
[10.0,10.0,10.0,10.0],
|
||||
[100,100,100,100],
|
||||
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
|
||||
datetime.timedelta(days=1))
|
||||
for trade in trades:
|
||||
#simulate data source sending frame
|
||||
msg = zp.DATASOURCE_FRAME(zp.namedict(trade))
|
||||
#feed unpacking frame
|
||||
recovered_trade = zp.DATASOURCE_UNFRAME(msg)
|
||||
#feed sending frame
|
||||
feed_msg = zp.FEED_FRAME(recovered_trade)
|
||||
#transform unframing
|
||||
recovered_feed = zp.FEED_UNFRAME(feed_msg)
|
||||
#do a transform
|
||||
trans_msg = zp.TRANSFORM_FRAME('helloworld', 2345.6)
|
||||
#simulate passthrough transform -- passthrough shouldn't even unpack the msg, just resend.
|
||||
passthrough_msg = zp.TRANSFORM_FRAME('PASSTHROUGH', feed_msg)
|
||||
#merge unframes transform and passthrough
|
||||
trans_recovered = zp.TRANSFORM_UNFRAME(trans_msg)
|
||||
pt_recovered = zp.TRANSFORM_UNFRAME(passthrough_msg)
|
||||
#simulated merge
|
||||
pt_recovered.PASSTHROUGH.merge(trans_recovered)
|
||||
#frame the merged event
|
||||
merged_msg = zp.MERGE_FRAME(pt_recovered.PASSTHROUGH)
|
||||
#unframe the merge and validate values
|
||||
event = zp.MERGE_UNFRAME(merged_msg)
|
||||
|
||||
def test_trading_calendar(self):
|
||||
known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y")
|
||||
known_holiday = datetime.datetime.strptime("02/20/2012", "%m/%d/%Y") #president's day
|
||||
|
||||
Reference in New Issue
Block a user