nascent order protocol implemented.

This commit is contained in:
fawce
2012-02-28 21:50:40 -05:00
parent 50f5c4ab20
commit 3e5092d92a
8 changed files with 148 additions and 105 deletions
+33 -31
View File
@@ -13,6 +13,7 @@ class TradeSimulationClient(qmsg.Component):
qmsg.Component.__init__(self)
self.received_count = 0
self.prev_dt = None
self.event_queue = []
@property
def get_id(self):
@@ -33,7 +34,7 @@ class TradeSimulationClient(qmsg.Component):
if len(rlist) == 0 or len(xlist) > 0:
raise Exception("unexpected end of feed stream")
message = rlist[0].recv()
if message == str(CONTROL_PROTOCOL.DONE):
if message == str(zp.CONTROL_PROTOCOL.DONE):
self.signal_done()
return #leave open orders hanging? client requests for orders?
@@ -45,34 +46,34 @@ class TradeSimulationClient(qmsg.Component):
def _handle_event(self, event):
self.event_queue.append(event)
if event['TRADE_SIM']['ALGO_TIME'] <= event['dt']:
del(event['TRADE_SIM'])
event['dt'] = qutil.parse_date(event['dt'])
if event.ALGO_TIME <= event.dt:
#event occurred in the present, send the queue to be processed
self.handle_events(self.event_queue)
self.order_socket.send(str(zp.CONTROL_PROTOCOL.DONE))
def handle_events(self, event_queue):
raise NotImplementedError
def order(self, sid, volume):
order = {'sid':sid, 'volume':volume}
self.order_feed.send(zp.ORDER_FRAME(order))
def order(self, sid, amount):
self.order_socket.send(zp.ORDER_FRAME(sid, amount))
class TradeSimulator(qmsg.BaseTransform):
def __init__(self):
qmsq.BaseTransform.__init__(self, "")
qmsg.BaseTransform.__init__(self, "")
self.open_orders = {}
self.algo_time = None
self.event_start = None
self.last_event_time = None
self.last_iteration_duration = None
@property
def get_id(self):
return "TRADE_SIM"
return "ALGO_TIME"
def open():
def open(self):
qmsg.BaseTransform.open(self)
self.order_socket = self.bind_order()
@@ -95,54 +96,55 @@ class TradeSimulator(qmsg.BaseTransform):
if len(rlist) == 0 or len(xlist) > 0:
raise Exception("unexpected end of feed stream")
message = rlist[0].recv()
if message == str(CONTROL_PROTOCOL.DONE):
if message == str(zp.CONTROL_PROTOCOL.DONE):
self.signal_done()
return #leave open orders hanging? client requests for orders?
event = qp.FEED_UNFRAME(message)
event = zp.FEED_UNFRAME(message)
if self.last_iteration_duration != None:
self.algo_time = self.last_event_time + self.last_iteration_duration
else:
self.algo_time = event.dt #base case, first event we're transporting.
self.last_event_time = qutil.parse_date(event['dt'])
self.last_event_time = event.dt
if self.algo_time < self.last_event_time:
#compress time, move algo's clock to the time of this event
self.algo_time = self.last_event_time
fill = self.process_orders(event)
#TODO: decide what this transform should send downstream, maybe fills? effective algo time?
self.state['value'] = {'FILL':fill,
'ALGO_TIME':qutil.format_date(self.algo_time)}
#self.process_orders(event)
#mark the start time for client's processing of this event.
self.event_start = datetime.datetime.utcnow()
self.result_socket.send(qf.MERGE_FRAME(cur_state), self.zmq.NOBLOCK)
self.result_socket.send(zp.TRANSFORM_FRAME('ALGO_TIME', self.algo_time), self.zmq.NOBLOCK)
while True: #this loop should also poll for portfolio state req/rep
(rlist, wlist, xlist) = select([self.order_socket],
[],
[self.order_socket],
timeout=self.heartbeat_timeout/1000) #select timeout is in sec
[],
[self.order_socket],
timeout=self.heartbeat_timeout/1000) #select timeout is in sec
#no more orders, should this be an error condition?
if len(rlist) == 0 or len(xlist) > 0:
break
continue
order_msg = rlist[0].recv()
if order_msg == str(CONTROL_PROTOCOL.DONE):
if order_msg == str(zp.CONTROL_PROTOCOL.DONE):
qutil.LOGGER.info("order loop finished")
break
order = qp.ORDER_UNFRAME(order_msg)
self.add_open_order(order)
sid, amount = zp.ORDER_UNFRAME(order_msg)
self.add_open_order(sid, amount)
#end of order processing loop
self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start
def add_open_order(self, order):
self.open_orders[order['sid']] = order
def add_open_order(self, sid, amount):
pass
def process_orders(self, event):
#TODO put real fill logic here, return a list of fills
return [{'sid':133, 'amount':-100}]
+50 -21
View File
@@ -20,10 +20,6 @@ class ComponentHost(Component):
Component.__init__(self)
self.addresses = addresses
# workaround for defect in threaded use of strptime:
# http://bugs.python.org/issue11108
qutil.parse_date("2012/02/13-10:04:28.114")
self.components = {}
self.sync_register = {}
self.timeout = datetime.timedelta(seconds=5)
@@ -196,10 +192,10 @@ class ParallelBuffer(Component):
self.drain()
self.signal_done()
else:
event = zp.DATASOURCE_UNFRAME(message)
self.append(event.source_id, event)
event = self.unframe(message)
self.append(event)
self.send_next()
def __len__(self):
"""
Buffer's length is same as internal map holding separate
@@ -207,12 +203,12 @@ class ParallelBuffer(Component):
"""
return len(self.data_buffer)
def append(self, source_id, value):
def append(self, event):
"""
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[source_id].append(value)
self.data_buffer[event.source_id].append(event)
self.received_count += 1
def next(self):
@@ -228,7 +224,7 @@ class ParallelBuffer(Component):
if len(events) == 0:
continue
cur = events
if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
if (earliest == None) or (cur[0].dt <= earliest[0].dt):
earliest = cur
if earliest != None:
@@ -271,9 +267,15 @@ class ParallelBuffer(Component):
event = self.next()
if(event != None):
self.feed_socket.send(zp.FEED_FRAME(event), self.zmq.NOBLOCK)
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_count += 1
def unframe(self, msg):
return zp.DATASOURCE_UNFRAME(msg)
def frame(self, event):
return zp.FEED_FRAME(event)
class MergedParallelBuffer(ParallelBuffer):
"""
@@ -293,18 +295,35 @@ class MergedParallelBuffer(ParallelBuffer):
return
#get the raw event from the passthrough transform.
result = self.data_buffer["PASSTHROUGH"].pop(0)['value']
result = self.data_buffer["PASSTHROUGH"].pop(0).PASSTHROUGH
for source, events in self.data_buffer.iteritems():
if source == "PASSTHROUGH":
continue
if len(events) > 0:
cur = events.pop(0)
result[source] = cur['value']
result.merge(cur)
return result
@property
def get_id(self):
return "MERGE"
def unframe(self, msg):
return zp.TRANSFORM_UNFRAME(msg)
def frame(self, event):
return zp.MERGE_FRAME(event)
#
def append(self, event):
"""
:param event: a namedict with one entry. key is the name of the transform, value is the transformed value.
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[event.__dict__.keys()[0]].append(event)
self.received_count += 1
class BaseTransform(Component):
@@ -351,13 +370,10 @@ class BaseTransform(Component):
self.signal_done()
return
event = qp.FEED_UNFRAME(message)
event = zp.FEED_UNFRAME(message)
cur_state = self.transform(event)
#TODO: do we want to relay the datetime again? maybe drop this?
#cur_state['dt'] = event['dt']
cur_state['id'] = self.state['name']
self.result_socket.send(qp.TRANSFORM_FRAME(cur_state), self.zmq.NOBLOCK)
qutil.LOGGER.info("state of transform is: {state}".format(state=cur_state))
self.result_socket.send(zp.TRANSFORM_FRAME(cur_state['name'], cur_state['value']), self.zmq.NOBLOCK)
def transform(self, event):
"""
@@ -380,8 +396,21 @@ class PassthroughTransform(BaseTransform):
def __init__(self):
BaseTransform.__init__(self, "PASSTHROUGH")
def transform(self, event):
return {'value':event}
def do_work(self):
"""
Loops until feed's DONE message is received:
- receive an event from the data feed
- call transform (subclass' method) on event
- send the transformed event
"""
socks = dict(self.poll.poll(self.heartbeat_timeout)) #timeout after 2 seconds.
if self.feed_socket in socks and socks[self.feed_socket] == self.zmq.POLLIN:
message = self.feed_socket.recv()
if message == str(CONTROL_PROTOCOL.DONE):
self.signal_done()
return
#message is already FEED_FRAMEd, send it as the value.
self.result_socket.send(zp.TRANSFORM_FRAME("PASSTHROUGH", message), self.zmq.NOBLOCK)
class DataSource(Component):
+34 -9
View File
@@ -107,6 +107,9 @@ class namedict(object):
def __eq__(self, other):
return self.__dict__ == other.__dict__
def has_attr(self, name):
return self.__dict__.has_key(name)
# ================
# Control Protocol
@@ -235,7 +238,7 @@ def FEED_FRAME(event):
assert isinstance(event, namedict)
source_id = event.source_id
ds_type = event.type
pack_date(event)
PACK_DATE(event)
payload = event.__dict__
return msgpack.dumps(payload)
@@ -245,7 +248,7 @@ def FEED_UNFRAME(msg):
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
rval = namedict(payload)
unpack_date(rval)
UNPACK_DATE(rval)
return rval
except TypeError:
raise INVALID_FEED_FRAME(msg)
@@ -265,21 +268,39 @@ def TRANSFORM_FRAME(name, value):
"""
assert isinstance(name, basestring)
assert value != None
if(name == 'ALGO_TIME'):
value = PACK_ALGO_DT(value)
return msgpack.dumps(tuple([name, value]))
def TRANSFORM_UNFRAME(msg):
"""
:rtype: namedict with <transform_name>:<transform_value>
"""
try:
name, value = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(name, basestring)
if(name == "PASSTHROUGH"):
value = FEED_UNFRAME(value)
elif(name == "ALGO_TIME"):
value = UNPACK_ALGO_DT(value)
return namedict({name : value})
except TypeError:
raise INVALID_TRANSFORM_FRAME(msg)
except ValueError:
raise INVALID_TRANSFORM_FRAME(msg)
def PACK_ALGO_DT(value):
value = namedict({'dt' : value})
PACK_DATE(value)
return value.__dict__
def UNPACK_ALGO_DT(value):
value = namedict(value)
UNPACK_DATE(value)
return value.dt
# ==================
# Merge Protocol
@@ -294,7 +315,9 @@ def MERGE_FRAME(event):
"""
assert isinstance(event, namedict)
assert isinstance(event.dt, datetime.datetime)
pack_date(event)
PACK_DATE(event)
if(event.has_attr('ALGO_TIME')):
event.ALGO_TIME = PACK_ALGO_DT(event.ALGO_TIME)
payload = event.__dict__
return msgpack.dumps(payload)
@@ -304,9 +327,11 @@ def MERGE_UNFRAME(msg):
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
payload = namedict(payload)
if(payload.has_attr('ALGO_TIME')):
payload.ALGO_TIME = UNPACK_ALGO_DT(payload.ALGO_TIME)
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
unpack_date(payload)
UNPACK_DATE(payload)
return payload
except TypeError:
raise INVALID_MERGE_FRAME(msg)
@@ -340,7 +365,7 @@ def TRADE_FRAME(event):
assert isinstance(event.sid, int)
assert isinstance(event.price, float)
assert isinstance(event.volume, int)
pack_date(event)
PACK_DATE(event)
qutil.LOGGER.info("event is: {event}".format(event=event.__dict__))
return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id]))
@@ -355,7 +380,7 @@ def TRADE_UNFRAME(msg):
assert isinstance(epoch, numbers.Integral)
assert isinstance(micros, numbers.Integral)
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id})
unpack_date(rval)
UNPACK_DATE(rval)
qutil.LOGGER.info("unpacked Trade: {trade}".format(trade=rval.__dict__))
return rval
except TypeError:
@@ -389,7 +414,7 @@ def ORDER_UNFRAME(msg):
# Date Helpers
# =================
def pack_date(event):
def PACK_DATE(event):
assert isinstance(event.dt, datetime.datetime)
assert event.dt.tzinfo == pytz.utc #utc only please
epoch = long(event.dt.strftime('%s'))
@@ -398,7 +423,7 @@ def pack_date(event):
del(event.__dict__['dt'])
return event
def unpack_date(payload):
def UNPACK_DATE(payload):
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
dt = datetime.datetime.fromtimestamp(payload.epoch)
+8 -6
View File
@@ -3,6 +3,7 @@ Provides data handlers that can push messages to a zipline.core.DataFeed
"""
import datetime
import random
import pytz
import zipline.util as qutil
import zipline.messaging as zm
@@ -14,7 +15,8 @@ class TradeDataSource(zm.DataSource):
""" :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
message = zp.TRADE_FRAME(self.get_id, event)
event.source_id = self.get_id
message = zp.DATASOURCE_FRAME(event)
self.data_socket.send(message)
class RandomEquityTrades(TradeDataSource):
@@ -25,7 +27,7 @@ class RandomEquityTrades(TradeDataSource):
self.count = count
self.incr = 0
self.sid = sid
self.trade_start = datetime.datetime.now()
self.trade_start = datetime.datetime.now().replace(tzinfo=pytz.utc)
self.minute = datetime.timedelta(minutes=1)
self.price = random.uniform(5.0, 50.0)
@@ -40,12 +42,12 @@ class RandomEquityTrades(TradeDataSource):
return
self.price = self.price + random.uniform(-0.05, 0.05)
self.send(self.sid, self.price, random.randrange(100,10000,100), qutil.format_date(self.trade_start + (self.minute * self.incr)))
self._send(self.sid, self.price, random.randrange(100,10000,100), self.trade_start + (self.minute * self.incr))
self.incr += 1
def send(self, sid, price, volume, dt):
message = zp.TRADE_FRAME(self.get_id(), sid, price, volume, dt)
self.data_socket.send(message)
def _send(self, sid, price, volume, dt):
event = zp.namedict({'source_id': self.get_id, "type" : "TRADE", "sid":sid, "price":price, "volume":volume, "dt":dt})
self.send(event)
class SpecificEquityTrades(TradeDataSource):
+11 -8
View File
@@ -3,7 +3,7 @@ import zipline.util as qutil
import zipline.messaging as qmsg
from zipline.finance.trading import TradeSimulationClient
from zipline.protocol import CONTROL_PROTOCOL
import zipline.protocol as zp
class TestClient(qmsg.Component):
"""no-op client - Just connects to the merge and counts messages. compares received message count to the expected count."""
@@ -29,7 +29,7 @@ class TestClient(qmsg.Component):
if self.data_feed in socks and socks[self.data_feed] == self.zmq.POLLIN:
msg = self.data_feed.recv()
if msg == str(CONTROL_PROTOCOL.DONE):
if msg == str(zp.CONTROL_PROTOCOL.DONE):
qutil.LOGGER.info("Client is DONE!")
self.signal_done()
self.utest.assertEqual(self.expected_msg_count, self.received_count,
@@ -40,20 +40,23 @@ class TestClient(qmsg.Component):
self.received_count += 1
event = zp.MERGE_UNFRAME(msg)
if(self.prev_dt != None):
if(not event['dt'] >= self.prev_dt):
if(not event.dt >= self.prev_dt):
raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt))
self.prev_dt = event['dt']
self.prev_dt = event.dt
if(self.received_count % 100 == 0):
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
class TestTradingClient(TradeSimulationClient):
def __init__(self):
def __init__(self, count):
TradeSimulationClient.__init__(self)
self.count = count
self.incr = 0
def handle_events(self, event_queue):
#place an order for 100 shares of sid:133
self.order(133,100)
if(self.incr >= self.count):
self.order(133, 100)
self.incr += 1
+8 -3
View File
@@ -13,6 +13,7 @@ import zipline.protocol as zp
from zipline.test.client import TestTradingClient
from zipline.test.dummy import ThreadPoolExecutorMixin
from zipline.sources import SpecificEquityTrades
from zipline.finance.trading import TradeSimulator
class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
@@ -53,7 +54,10 @@ class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
self.assertEqual(zp.namedict(trade), event)
def test_order_protocol(self):
raise NotImplementedError
order_msg = zp.ORDER_FRAME(133, 100)
sid, amount = zp.ORDER_UNFRAME(order_msg)
self.assertEqual(sid, 133)
self.assertEqual(amount, 100)
def test_trading_calendar(self):
known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y")
@@ -91,9 +95,10 @@ class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
[100,100,100,100],
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
datetime.timedelta(days=1)))
client = TestTradingClient()
client = TestTradingClient(0)
order_sim = TradeSimulator()
sim.register_components([set1, client])
sim.register_components([client, order_sim, set1])
sim.register_controller( con )
# Simulation
+4 -4
View File
@@ -24,15 +24,15 @@ class MovingAverage(BaseTransform):
"""Update the moving average with the latest data point."""
self.events.append(event)
self.current_total += event['price']
event_date = qutil.parse_date(event['dt'])
self.current_total += event.price
event_date = event.dt
index = 0
for cur_event in self.events:
cur_date = qutil.parse_date(cur_event['dt'])
cur_date = cur_event.dt
if(cur_date - event_date) >= self.window:
self.events.pop(index)
self.current_total -= cur_event['price']
self.current_total -= cur_event.price
index += 1
else:
break
-23
View File
@@ -26,26 +26,3 @@ def configure_logging(loglevel=logging.DEBUG):
)
LOGGER.addHandler(handler)
LOGGER.info("logging started...")
def parse_date(dt_str):
"""
Parse strings according to the same format as generated by
format_date.
"""
if(dt_str == None):
return None
parts = dt_str.split(".")
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(
microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc
)
return dt
def format_date(dt):
"""
Format the date into a date with millesecond resolution and
string/alphabetical sorting that is equivalent to datetime sorting.
"""
if(dt == None):
return None
dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000)
return dt_str