From 3e5092d92ac97cc2606402f86bd3b49b09a9ef0b Mon Sep 17 00:00:00 2001 From: fawce Date: Tue, 28 Feb 2012 21:50:40 -0500 Subject: [PATCH] nascent order protocol implemented. --- zipline/finance/trading.py | 64 +++++++++++++++-------------- zipline/messaging.py | 71 +++++++++++++++++++++++---------- zipline/protocol.py | 43 +++++++++++++++----- zipline/sources.py | 14 ++++--- zipline/test/client.py | 19 +++++---- zipline/test/test_finance.py | 11 +++-- zipline/transforms/technical.py | 8 ++-- zipline/util.py | 23 ----------- 8 files changed, 148 insertions(+), 105 deletions(-) diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index d5c5b444..cc5ca523 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -13,6 +13,7 @@ class TradeSimulationClient(qmsg.Component): qmsg.Component.__init__(self) self.received_count = 0 self.prev_dt = None + self.event_queue = [] @property def get_id(self): @@ -33,7 +34,7 @@ class TradeSimulationClient(qmsg.Component): if len(rlist) == 0 or len(xlist) > 0: raise Exception("unexpected end of feed stream") message = rlist[0].recv() - if message == str(CONTROL_PROTOCOL.DONE): + if message == str(zp.CONTROL_PROTOCOL.DONE): self.signal_done() return #leave open orders hanging? client requests for orders? @@ -45,34 +46,34 @@ class TradeSimulationClient(qmsg.Component): def _handle_event(self, event): self.event_queue.append(event) - if event['TRADE_SIM']['ALGO_TIME'] <= event['dt']: - del(event['TRADE_SIM']) - event['dt'] = qutil.parse_date(event['dt']) + if event.ALGO_TIME <= event.dt: #event occurred in the present, send the queue to be processed self.handle_events(self.event_queue) + self.order_socket.send(str(zp.CONTROL_PROTOCOL.DONE)) def handle_events(self, event_queue): raise NotImplementedError - def order(self, sid, volume): - order = {'sid':sid, 'volume':volume} - self.order_feed.send(zp.ORDER_FRAME(order)) + def order(self, sid, amount): + self.order_socket.send(zp.ORDER_FRAME(sid, amount)) + class TradeSimulator(qmsg.BaseTransform): def __init__(self): - qmsq.BaseTransform.__init__(self, "") + qmsg.BaseTransform.__init__(self, "") self.open_orders = {} self.algo_time = None self.event_start = None self.last_event_time = None self.last_iteration_duration = None + @property def get_id(self): - return "TRADE_SIM" + return "ALGO_TIME" - def open(): + def open(self): qmsg.BaseTransform.open(self) self.order_socket = self.bind_order() @@ -95,54 +96,55 @@ class TradeSimulator(qmsg.BaseTransform): if len(rlist) == 0 or len(xlist) > 0: raise Exception("unexpected end of feed stream") message = rlist[0].recv() - if message == str(CONTROL_PROTOCOL.DONE): + if message == str(zp.CONTROL_PROTOCOL.DONE): self.signal_done() return #leave open orders hanging? client requests for orders? - event = qp.FEED_UNFRAME(message) + event = zp.FEED_UNFRAME(message) if self.last_iteration_duration != None: self.algo_time = self.last_event_time + self.last_iteration_duration + else: + self.algo_time = event.dt #base case, first event we're transporting. - self.last_event_time = qutil.parse_date(event['dt']) + self.last_event_time = event.dt if self.algo_time < self.last_event_time: #compress time, move algo's clock to the time of this event self.algo_time = self.last_event_time - fill = self.process_orders(event) - - #TODO: decide what this transform should send downstream, maybe fills? effective algo time? - self.state['value'] = {'FILL':fill, - 'ALGO_TIME':qutil.format_date(self.algo_time)} - + #self.process_orders(event) + #mark the start time for client's processing of this event. self.event_start = datetime.datetime.utcnow() - self.result_socket.send(qf.MERGE_FRAME(cur_state), self.zmq.NOBLOCK) + self.result_socket.send(zp.TRANSFORM_FRAME('ALGO_TIME', self.algo_time), self.zmq.NOBLOCK) while True: #this loop should also poll for portfolio state req/rep (rlist, wlist, xlist) = select([self.order_socket], - [], - [self.order_socket], - timeout=self.heartbeat_timeout/1000) #select timeout is in sec - + [], + [self.order_socket], + timeout=self.heartbeat_timeout/1000) #select timeout is in sec #no more orders, should this be an error condition? if len(rlist) == 0 or len(xlist) > 0: - break + continue order_msg = rlist[0].recv() - if order_msg == str(CONTROL_PROTOCOL.DONE): + if order_msg == str(zp.CONTROL_PROTOCOL.DONE): + qutil.LOGGER.info("order loop finished") break - order = qp.ORDER_UNFRAME(order_msg) - self.add_open_order(order) + sid, amount = zp.ORDER_UNFRAME(order_msg) + self.add_open_order(sid, amount) #end of order processing loop self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start - def add_open_order(self, order): - self.open_orders[order['sid']] = order - \ No newline at end of file + def add_open_order(self, sid, amount): + pass + + def process_orders(self, event): + #TODO put real fill logic here, return a list of fills + return [{'sid':133, 'amount':-100}] \ No newline at end of file diff --git a/zipline/messaging.py b/zipline/messaging.py index 7adfd802..70268774 100644 --- a/zipline/messaging.py +++ b/zipline/messaging.py @@ -20,10 +20,6 @@ class ComponentHost(Component): Component.__init__(self) self.addresses = addresses - # workaround for defect in threaded use of strptime: - # http://bugs.python.org/issue11108 - qutil.parse_date("2012/02/13-10:04:28.114") - self.components = {} self.sync_register = {} self.timeout = datetime.timedelta(seconds=5) @@ -196,10 +192,10 @@ class ParallelBuffer(Component): self.drain() self.signal_done() else: - event = zp.DATASOURCE_UNFRAME(message) - self.append(event.source_id, event) + event = self.unframe(message) + self.append(event) self.send_next() - + def __len__(self): """ Buffer's length is same as internal map holding separate @@ -207,12 +203,12 @@ class ParallelBuffer(Component): """ return len(self.data_buffer) - def append(self, source_id, value): + def append(self, event): """ Add an event to the buffer for the source specified by source_id. """ - self.data_buffer[source_id].append(value) + self.data_buffer[event.source_id].append(event) self.received_count += 1 def next(self): @@ -228,7 +224,7 @@ class ParallelBuffer(Component): if len(events) == 0: continue cur = events - if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']): + if (earliest == None) or (cur[0].dt <= earliest[0].dt): earliest = cur if earliest != None: @@ -271,9 +267,15 @@ class ParallelBuffer(Component): event = self.next() if(event != None): - self.feed_socket.send(zp.FEED_FRAME(event), self.zmq.NOBLOCK) + self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK) self.sent_count += 1 + def unframe(self, msg): + return zp.DATASOURCE_UNFRAME(msg) + + def frame(self, event): + return zp.FEED_FRAME(event) + class MergedParallelBuffer(ParallelBuffer): """ @@ -293,18 +295,35 @@ class MergedParallelBuffer(ParallelBuffer): return #get the raw event from the passthrough transform. - result = self.data_buffer["PASSTHROUGH"].pop(0)['value'] + result = self.data_buffer["PASSTHROUGH"].pop(0).PASSTHROUGH for source, events in self.data_buffer.iteritems(): if source == "PASSTHROUGH": continue if len(events) > 0: cur = events.pop(0) - result[source] = cur['value'] + result.merge(cur) return result @property def get_id(self): return "MERGE" + + def unframe(self, msg): + return zp.TRANSFORM_UNFRAME(msg) + + def frame(self, event): + return zp.MERGE_FRAME(event) + + # + def append(self, event): + """ + :param event: a namedict with one entry. key is the name of the transform, value is the transformed value. + Add an event to the buffer for the source specified by + source_id. + """ + + self.data_buffer[event.__dict__.keys()[0]].append(event) + self.received_count += 1 class BaseTransform(Component): @@ -351,13 +370,10 @@ class BaseTransform(Component): self.signal_done() return - event = qp.FEED_UNFRAME(message) + event = zp.FEED_UNFRAME(message) cur_state = self.transform(event) - #TODO: do we want to relay the datetime again? maybe drop this? - #cur_state['dt'] = event['dt'] - cur_state['id'] = self.state['name'] - - self.result_socket.send(qp.TRANSFORM_FRAME(cur_state), self.zmq.NOBLOCK) + qutil.LOGGER.info("state of transform is: {state}".format(state=cur_state)) + self.result_socket.send(zp.TRANSFORM_FRAME(cur_state['name'], cur_state['value']), self.zmq.NOBLOCK) def transform(self, event): """ @@ -380,8 +396,21 @@ class PassthroughTransform(BaseTransform): def __init__(self): BaseTransform.__init__(self, "PASSTHROUGH") - def transform(self, event): - return {'value':event} + def do_work(self): + """ + Loops until feed's DONE message is received: + - receive an event from the data feed + - call transform (subclass' method) on event + - send the transformed event + """ + socks = dict(self.poll.poll(self.heartbeat_timeout)) #timeout after 2 seconds. + if self.feed_socket in socks and socks[self.feed_socket] == self.zmq.POLLIN: + message = self.feed_socket.recv() + if message == str(CONTROL_PROTOCOL.DONE): + self.signal_done() + return + #message is already FEED_FRAMEd, send it as the value. + self.result_socket.send(zp.TRANSFORM_FRAME("PASSTHROUGH", message), self.zmq.NOBLOCK) class DataSource(Component): diff --git a/zipline/protocol.py b/zipline/protocol.py index 2b60fbc9..864f4274 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -107,6 +107,9 @@ class namedict(object): def __eq__(self, other): return self.__dict__ == other.__dict__ + + def has_attr(self, name): + return self.__dict__.has_key(name) # ================ # Control Protocol @@ -235,7 +238,7 @@ def FEED_FRAME(event): assert isinstance(event, namedict) source_id = event.source_id ds_type = event.type - pack_date(event) + PACK_DATE(event) payload = event.__dict__ return msgpack.dumps(payload) @@ -245,7 +248,7 @@ def FEED_UNFRAME(msg): #TODO: anything we can do to assert more about the content of the dict? assert isinstance(payload, dict) rval = namedict(payload) - unpack_date(rval) + UNPACK_DATE(rval) return rval except TypeError: raise INVALID_FEED_FRAME(msg) @@ -265,21 +268,39 @@ def TRANSFORM_FRAME(name, value): """ assert isinstance(name, basestring) assert value != None + + if(name == 'ALGO_TIME'): + value = PACK_ALGO_DT(value) + return msgpack.dumps(tuple([name, value])) def TRANSFORM_UNFRAME(msg): + """ + :rtype: namedict with : + """ try: name, value = msgpack.loads(msg) #TODO: anything we can do to assert more about the content of the dict? assert isinstance(name, basestring) if(name == "PASSTHROUGH"): value = FEED_UNFRAME(value) + elif(name == "ALGO_TIME"): + value = UNPACK_ALGO_DT(value) return namedict({name : value}) except TypeError: raise INVALID_TRANSFORM_FRAME(msg) except ValueError: raise INVALID_TRANSFORM_FRAME(msg) - + +def PACK_ALGO_DT(value): + value = namedict({'dt' : value}) + PACK_DATE(value) + return value.__dict__ + +def UNPACK_ALGO_DT(value): + value = namedict(value) + UNPACK_DATE(value) + return value.dt # ================== # Merge Protocol @@ -294,7 +315,9 @@ def MERGE_FRAME(event): """ assert isinstance(event, namedict) assert isinstance(event.dt, datetime.datetime) - pack_date(event) + PACK_DATE(event) + if(event.has_attr('ALGO_TIME')): + event.ALGO_TIME = PACK_ALGO_DT(event.ALGO_TIME) payload = event.__dict__ return msgpack.dumps(payload) @@ -304,9 +327,11 @@ def MERGE_UNFRAME(msg): #TODO: anything we can do to assert more about the content of the dict? assert isinstance(payload, dict) payload = namedict(payload) + if(payload.has_attr('ALGO_TIME')): + payload.ALGO_TIME = UNPACK_ALGO_DT(payload.ALGO_TIME) assert isinstance(payload.epoch, numbers.Integral) assert isinstance(payload.micros, numbers.Integral) - unpack_date(payload) + UNPACK_DATE(payload) return payload except TypeError: raise INVALID_MERGE_FRAME(msg) @@ -340,7 +365,7 @@ def TRADE_FRAME(event): assert isinstance(event.sid, int) assert isinstance(event.price, float) assert isinstance(event.volume, int) - pack_date(event) + PACK_DATE(event) qutil.LOGGER.info("event is: {event}".format(event=event.__dict__)) return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id])) @@ -355,7 +380,7 @@ def TRADE_UNFRAME(msg): assert isinstance(epoch, numbers.Integral) assert isinstance(micros, numbers.Integral) rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id}) - unpack_date(rval) + UNPACK_DATE(rval) qutil.LOGGER.info("unpacked Trade: {trade}".format(trade=rval.__dict__)) return rval except TypeError: @@ -389,7 +414,7 @@ def ORDER_UNFRAME(msg): # Date Helpers # ================= -def pack_date(event): +def PACK_DATE(event): assert isinstance(event.dt, datetime.datetime) assert event.dt.tzinfo == pytz.utc #utc only please epoch = long(event.dt.strftime('%s')) @@ -398,7 +423,7 @@ def pack_date(event): del(event.__dict__['dt']) return event -def unpack_date(payload): +def UNPACK_DATE(payload): assert isinstance(payload.epoch, numbers.Integral) assert isinstance(payload.micros, numbers.Integral) dt = datetime.datetime.fromtimestamp(payload.epoch) diff --git a/zipline/sources.py b/zipline/sources.py index 7c8e7a59..ea2edf73 100644 --- a/zipline/sources.py +++ b/zipline/sources.py @@ -3,6 +3,7 @@ Provides data handlers that can push messages to a zipline.core.DataFeed """ import datetime import random +import pytz import zipline.util as qutil import zipline.messaging as zm @@ -14,7 +15,8 @@ class TradeDataSource(zm.DataSource): """ :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME` :rtype: None """ - message = zp.TRADE_FRAME(self.get_id, event) + event.source_id = self.get_id + message = zp.DATASOURCE_FRAME(event) self.data_socket.send(message) class RandomEquityTrades(TradeDataSource): @@ -25,7 +27,7 @@ class RandomEquityTrades(TradeDataSource): self.count = count self.incr = 0 self.sid = sid - self.trade_start = datetime.datetime.now() + self.trade_start = datetime.datetime.now().replace(tzinfo=pytz.utc) self.minute = datetime.timedelta(minutes=1) self.price = random.uniform(5.0, 50.0) @@ -40,12 +42,12 @@ class RandomEquityTrades(TradeDataSource): return self.price = self.price + random.uniform(-0.05, 0.05) - self.send(self.sid, self.price, random.randrange(100,10000,100), qutil.format_date(self.trade_start + (self.minute * self.incr))) + self._send(self.sid, self.price, random.randrange(100,10000,100), self.trade_start + (self.minute * self.incr)) self.incr += 1 - def send(self, sid, price, volume, dt): - message = zp.TRADE_FRAME(self.get_id(), sid, price, volume, dt) - self.data_socket.send(message) + def _send(self, sid, price, volume, dt): + event = zp.namedict({'source_id': self.get_id, "type" : "TRADE", "sid":sid, "price":price, "volume":volume, "dt":dt}) + self.send(event) class SpecificEquityTrades(TradeDataSource): diff --git a/zipline/test/client.py b/zipline/test/client.py index 3279bc2a..c8b7674a 100644 --- a/zipline/test/client.py +++ b/zipline/test/client.py @@ -3,7 +3,7 @@ import zipline.util as qutil import zipline.messaging as qmsg from zipline.finance.trading import TradeSimulationClient -from zipline.protocol import CONTROL_PROTOCOL +import zipline.protocol as zp class TestClient(qmsg.Component): """no-op client - Just connects to the merge and counts messages. compares received message count to the expected count.""" @@ -29,7 +29,7 @@ class TestClient(qmsg.Component): if self.data_feed in socks and socks[self.data_feed] == self.zmq.POLLIN: msg = self.data_feed.recv() - if msg == str(CONTROL_PROTOCOL.DONE): + if msg == str(zp.CONTROL_PROTOCOL.DONE): qutil.LOGGER.info("Client is DONE!") self.signal_done() self.utest.assertEqual(self.expected_msg_count, self.received_count, @@ -40,20 +40,23 @@ class TestClient(qmsg.Component): self.received_count += 1 event = zp.MERGE_UNFRAME(msg) if(self.prev_dt != None): - if(not event['dt'] >= self.prev_dt): + if(not event.dt >= self.prev_dt): raise Exception("Message out of order: {date} after {prev}".format(date=event['dt'], prev=prev_dt)) - self.prev_dt = event['dt'] + self.prev_dt = event.dt if(self.received_count % 100 == 0): qutil.LOGGER.info("received {n} messages".format(n=self.received_count)) class TestTradingClient(TradeSimulationClient): - def __init__(self): + def __init__(self, count): TradeSimulationClient.__init__(self) + self.count = count + self.incr = 0 def handle_events(self, event_queue): #place an order for 100 shares of sid:133 - self.order(133,100) - - \ No newline at end of file + if(self.incr >= self.count): + self.order(133, 100) + self.incr += 1 + diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index b9a1442d..5e69fd0e 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -13,6 +13,7 @@ import zipline.protocol as zp from zipline.test.client import TestTradingClient from zipline.test.dummy import ThreadPoolExecutorMixin from zipline.sources import SpecificEquityTrades +from zipline.finance.trading import TradeSimulator class FinanceTestCase(ThreadPoolExecutorMixin, TestCase): @@ -53,7 +54,10 @@ class FinanceTestCase(ThreadPoolExecutorMixin, TestCase): self.assertEqual(zp.namedict(trade), event) def test_order_protocol(self): - raise NotImplementedError + order_msg = zp.ORDER_FRAME(133, 100) + sid, amount = zp.ORDER_UNFRAME(order_msg) + self.assertEqual(sid, 133) + self.assertEqual(amount, 100) def test_trading_calendar(self): known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y") @@ -91,9 +95,10 @@ class FinanceTestCase(ThreadPoolExecutorMixin, TestCase): [100,100,100,100], datetime.datetime.strptime("02/15/2012","%m/%d/%Y"), datetime.timedelta(days=1))) - client = TestTradingClient() + client = TestTradingClient(0) + order_sim = TradeSimulator() - sim.register_components([set1, client]) + sim.register_components([client, order_sim, set1]) sim.register_controller( con ) # Simulation diff --git a/zipline/transforms/technical.py b/zipline/transforms/technical.py index 3ad41ccb..d235a22a 100644 --- a/zipline/transforms/technical.py +++ b/zipline/transforms/technical.py @@ -24,15 +24,15 @@ class MovingAverage(BaseTransform): """Update the moving average with the latest data point.""" self.events.append(event) - self.current_total += event['price'] - event_date = qutil.parse_date(event['dt']) + self.current_total += event.price + event_date = event.dt index = 0 for cur_event in self.events: - cur_date = qutil.parse_date(cur_event['dt']) + cur_date = cur_event.dt if(cur_date - event_date) >= self.window: self.events.pop(index) - self.current_total -= cur_event['price'] + self.current_total -= cur_event.price index += 1 else: break diff --git a/zipline/util.py b/zipline/util.py index c972f58a..b064306a 100644 --- a/zipline/util.py +++ b/zipline/util.py @@ -26,26 +26,3 @@ def configure_logging(loglevel=logging.DEBUG): ) LOGGER.addHandler(handler) LOGGER.info("logging started...") - -def parse_date(dt_str): - """ - Parse strings according to the same format as generated by - format_date. - """ - if(dt_str == None): - return None - parts = dt_str.split(".") - dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace( - microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc - ) - return dt - -def format_date(dt): - """ - Format the date into a date with millesecond resolution and - string/alphabetical sorting that is equivalent to datetime sorting. - """ - if(dt == None): - return None - dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000) - return dt_str