Attempt to fix protocol assertion error.

This commit is contained in:
Stephen Diehl
2012-03-03 21:30:45 -05:00
parent 6c581bef39
commit 58eba2d100
2 changed files with 46 additions and 27 deletions
+40 -23
View File
@@ -5,23 +5,28 @@ import datetime
import random
import pytz
import zipline.util as qutil
import zipline.messaging as zm
import zipline.protocol as zp
class TradeDataSource(zm.DataSource):
def send(self, event):
""" :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
:param dict event: is a trade event with data as per
:py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
event.source_id = self.get_id
message = zp.DATASOURCE_FRAME(event)
self.data_socket.send(message)
class RandomEquityTrades(TradeDataSource):
"""Generates a random stream of trades for testing."""
"""
Generates a random stream of trades for testing.
"""
def __init__(self, sid, source_id, count):
zm.DataSource.__init__(self, source_id)
self.count = count
@@ -30,32 +35,44 @@ class RandomEquityTrades(TradeDataSource):
self.trade_start = datetime.datetime.now().replace(tzinfo=pytz.utc)
self.minute = datetime.timedelta(minutes=1)
self.price = random.uniform(5.0, 50.0)
def get_type(self):
return 'EQUITY_TRADE'
zp.COMPONENT_TYPE.SOURCE
def do_work(self):
if(self.incr == self.count):
self.signal_done()
return
self.price = self.price + random.uniform(-0.05, 0.05)
self._send(self.sid, self.price, random.randrange(100,10000,100), self.trade_start + (self.minute * self.incr))
self.incr += 1
def _send(self, sid, price, volume, dt):
event = zp.namedict({'source_id': self.get_id, "type" : "TRADE", "sid":sid, "price":price, "volume":volume, "dt":dt})
self.send(event)
self.price = self.price + random.uniform(-0.05, 0.05)
volume = random.randrange(100,10000,100)
event = zp.namedict({
'source_id' : self.get_id,
"type" : zp.DATASOURCE_TYPE.ORDER,
"sid" : self.sid,
"price" : self.price,
"volume" : volume,
"dt" : self.trade_start + (self.minute * self.incr),
})
message = zp.DATASOURCE_FRAME(event)
self.send(message)
self.incr += 1
class SpecificEquityTrades(TradeDataSource):
"""Generates a random stream of trades for testing."""
"""
Generates a random stream of trades for testing.
"""
def __init__(self, source_id, event_list):
"""
:event_list: should be a chronologically ordered list of dictionaries in the following form:
:event_list: should be a chronologically ordered list of dictionaries
in the following form:
event = {
'sid' : an integer for security id,
'dt' : datetime object,
@@ -67,14 +84,14 @@ class SpecificEquityTrades(TradeDataSource):
self.event_list = event_list
def get_type(self):
return 'EQUITY_TRADE'
zp.COMPONENT_TYPE.SOURCE
def do_work(self):
if(len(self.event_list) == 0):
self.signal_done()
return
event = self.event_list.pop(0)
self.send(zp.namedict(event))
+6 -4
View File
@@ -11,6 +11,7 @@ from zipline.sources import RandomEquityTrades
from zipline.test.client import TestClient
from zipline.test.transform import DivideByZeroTransform
from nose.tools import timed
# Should not inherit form TestCase since test runners will pick
# it up as a test. Its a Mixin of sorts at this point.
@@ -73,6 +74,7 @@ class SimulatorTestCase(object):
# Cases
# -------
@timed(2)
def test_simple(self):
# Simple test just to make sure that the archiecture is
@@ -100,7 +102,7 @@ class SimulatorTestCase(object):
ret1 = RandomEquityTrades(133, "ret1", 1)
ret2 = RandomEquityTrades(134, "ret2", 1)
client = TestClient(expected_msg_count=(ret1.count + ret2.count))
client = TestClient()
sim.register_controller( con )
sim.register_components([ret1, ret2, client])
@@ -149,7 +151,7 @@ class SimulatorTestCase(object):
ret1 = RandomEquityTrades(133, "ret1", 1)
ret2 = RandomEquityTrades(134, "ret2", 1)
fail_transform = DivideByZeroTransform("fail")
client = TestClient(self, expected_msg_count=ret1.count + ret2.count)
client = TestClient()
sim.register_controller( con )
sim.register_components([ret1, ret2, fail_transform, client])
@@ -194,7 +196,7 @@ class SimulatorTestCase(object):
ret1 = RandomEquityTrades(133, "ret1", 400)
ret2 = RandomEquityTrades(134, "ret2", 400)
client = TestClient(expected_msg_count=ret1.count + ret2.count)
client = TestClient()
sim.register_controller( con )
sim.register_components([ret1, ret2, client])
@@ -240,7 +242,7 @@ class SimulatorTestCase(object):
ret2 = RandomEquityTrades(134, "ret2", 5000)
mavg1 = MovingAverage("mavg1", 30)
mavg2 = MovingAverage("mavg2", 60)
client = TestClient(expected_msg_count=10000)
client = TestClient()
sim.register_components([ret1, ret2, mavg1, mavg2, client])
sim.register_controller( con )