From 2aa4af86aa170a067c0466d12afa261e2c02bb30 Mon Sep 17 00:00:00 2001 From: fawce Date: Thu, 1 Mar 2012 23:44:35 -0500 Subject: [PATCH] merged master branch that contains control component, error handling with finance branch. --- zipline/finance/trading.py | 12 ++--- zipline/messaging.py | 92 ++++++++++++++++++++++++------------ zipline/protocol.py | 8 ++-- zipline/simulator.py | 85 +++++++++++++++++++++++++++++++++ zipline/test/client.py | 15 +++--- zipline/test/test_finance.py | 9 +--- 6 files changed, 165 insertions(+), 56 deletions(-) create mode 100644 zipline/simulator.py diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 9420eb96..2ecdc197 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -1,4 +1,3 @@ -import json import datetime import pytz import math @@ -76,8 +75,9 @@ class OrderDataSource(qmsg.DataSource): self.last_iteration_duration = datetime.timedelta(seconds=0) self.sent_count = 0 + @property def get_type(self): - return zp.FINANCE_COMPONENT.ORDER_SOURCE + return zp.DATASOURCE_TYPE.ORDER def open(self): qmsg.DataSource.open(self) @@ -123,7 +123,7 @@ class OrderDataSource(qmsg.DataSource): self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start dt = self.simulation_dt + self.last_iteration_duration - order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER}) + order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt}) self.send(order_event) count += 1 @@ -133,14 +133,10 @@ class OrderDataSource(qmsg.DataSource): if(count == 0): self.send_dummy() self.sent_count += 1 - - def send(self, order_event): - message = zp.DATASOURCE_FRAME(order_event) - self.data_socket.send(message) def send_dummy(self): dt = self.simulation_dt + self.last_iteration_duration - dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER}) + dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt}) self.send(dummy_order) diff --git a/zipline/messaging.py b/zipline/messaging.py index 5be276b8..79296b1e 100644 --- a/zipline/messaging.py +++ b/zipline/messaging.py @@ -3,10 +3,10 @@ Commonly used messaging components. """ import datetime -import ujson as json import zipline.util as qutil from zipline.component import Component +import zipline.protocol as zp from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \ COMPONENT_STATE @@ -220,20 +220,27 @@ class ParallelBuffer(Component): self.signal_done() else: try: - event = json.loads(message) - - # JSON deserialization error - except ValueError as exc: + event = self.unframe(message) + # deserialization error + except zp.InvalidFrame as exc: return self.signal_exception(exc) try: - self.append(event[u'id'], event) + self.append(event) self.send_next() # Invalid message - except KeyError as exc: + except zp.InvalidFrame as exc: return self.signal_exception(exc) + # + def unframe(self, msg): + return zp.DATASOURCE_UNFRAME(msg) + + def frame(self, event): + return zp.FEED_FRAME(event) + + # ------------- # Flow Control # ------------- @@ -254,17 +261,16 @@ class ParallelBuffer(Component): return event = self.next() - - if event != None: - self.feed_socket.send(json.dumps(event), self.zmq.NOBLOCK) + if(event != None): + self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK) self.sent_count += 1 - def append(self, source_id, value): + def append(self, event): """ Add an event to the buffer for the source specified by source_id. """ - self.data_buffer[source_id].append(value) + self.data_buffer[event.source_id].append(event) self.received_count += 1 def next(self): @@ -280,7 +286,7 @@ class ParallelBuffer(Component): if len(events) == 0: continue cur = events - if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']): + if (earliest == None) or (cur[0].dt <= earliest[0].dt): earliest = cur if earliest != None: @@ -350,15 +356,32 @@ class MergedParallelBuffer(ParallelBuffer): if(not(self.is_full() or self.draining)): return + # #get the raw event from the passthrough transform. - result = self.data_buffer["PASSTHROUGH"].pop(0)['value'] + result = self.data_buffer[zp.TRANSFORM_TYPE.PASSTHROUGH].pop(0).PASSTHROUGH for source, events in self.data_buffer.iteritems(): - if source == "PASSTHROUGH": + if source == zp.TRANSFORM_TYPE.PASSTHROUGH: continue if len(events) > 0: cur = events.pop(0) - result[source] = cur['value'] + result.merge(cur) return result + + def unframe(self, msg): + return zp.TRANSFORM_UNFRAME(msg) + + def frame(self, event): + return zp.MERGE_FRAME(event) + + def append(self, event): + """ + :param event: a namedict with one entry. key is the name of the transform, value is the transformed value. + Add an event to the buffer for the source specified by + source_id. + """ + + self.data_buffer[event.__dict__.keys()[0]].append(event) + self.received_count += 1 class BaseTransform(Component): @@ -425,14 +448,12 @@ class BaseTransform(Component): return try: - event = json.loads(message) - except ValueError as exc: + event = self.unframe(message) + except zp.InvalidFrame as exc: return self.signal_exception(exc) try: cur_state = self.transform(event) - cur_state['dt'] = event['dt'] - cur_state['id'] = self.state['name'] # This is overloaded, so it can fail in all sorts of # unknown ways. Its best to catch it in the @@ -441,12 +462,18 @@ class BaseTransform(Component): return self.signal_exception(exc) try: - json_frame = json.dumps(cur_state) - except ValueError as exc: + transform_frame = self.frame(cur_state) + except zp.InvalidFrame as exc: return self.signal_exception(exc) - self.result_socket.send(json_frame, self.zmq.NOBLOCK) - + self.result_socket.send(transform_frame, self.zmq.NOBLOCK) + + def frame(self, cur_state): + return zp.TRANSFORM_FRAME(cur_state['name'], cur_state['value']) + + def unframe(self, msg): + return zp.FEED_UNFRAME(msg) + def transform(self, event): """ Must return the transformed value as a map with:: @@ -487,7 +514,7 @@ class PassthroughTransform(BaseTransform): #TODO, could save some cycles by skipping the _UNFRAME call and just setting value to original msg string. def transform(self, event): - return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.DATASOURCE_FRAME(event) } + return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.FEED_FRAME(event) } class DataSource(Component): @@ -520,14 +547,17 @@ class DataSource(Component): """ Emit data. """ - assert isinstance(event, dict) + assert isinstance(event, zp.namedict) - event['id'] = self.id - event['type'] = self.get_type() + event.__dict__['source_id'] = self.get_id + event.__dict__['type'] = self.get_type try: - json_frame = json.dumps(event) - except ValueError as exc: + ds_frame = self.frame(event) + except zp.InvalidFrame as exc: return self.signal_exception(exc) - self.data_socket.send(json_frame) + self.data_socket.send(ds_frame) + + def frame(self, event): + return zp.DATASOURCE_FRAME(event) diff --git a/zipline/protocol.py b/zipline/protocol.py index 5e007715..24a9a243 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -61,13 +61,15 @@ def Enum(*options): _fields_ = [(o, c_ubyte) for o in options] return cstruct(*range(len(options))) +class InvalidFrame(Exception): + def __init__(self, got): + self.got = got + def FrameExceptionFactory(name): """ Exception factory with a closure around the frame class name. """ - class InvalidFrame(Exception): - def __init__(self, got): - self.got = got + class NamedInvalidFrame(InvalidFrame): def __str__(self): return "Invalid {framecls} Frame: {got}".format( framecls = name, diff --git a/zipline/simulator.py b/zipline/simulator.py new file mode 100644 index 00000000..1e3e3f5f --- /dev/null +++ b/zipline/simulator.py @@ -0,0 +1,85 @@ +""" +Simulator hosts all the components necessary to execute a simluation. See :py:method"" +""" + +import threading +import mock +from collections import defaultdict +from zipline.monitor import Controller +from zipline.messaging import ComponentHost +import zipline.util as qutil + +class AddressAllocator(object): + + def __init__(self, ns): + self.idx = 0 + self.sockets = [ + 'tcp://127.0.0.1:%s' % (10000 + n) + for n in xrange(ns) + ] + + def lease(self, n): + sockets = self.sockets[self.idx:self.idx+n] + self.idx += n + return sockets + + def reaquire(self, *conn): + pass + +# +class Simulator(ComponentHost): + + def __init__(self, addresses): + ComponentHost.__init__(self, addresses) + self.subthreads = [] + self.running = False + + def launch_controller(self): + thread = threading.Thread(target=self.controller.run) + thread.start() + + self.subthreads.append(thread) + return thread + + def simulate(self): + thread = threading.Thread(target=self.run) + thread.start() + + self.subthreads.append(thread) + self.running = True + + return thread + + def did_clean_shutdown(self): + return not any([t.isAlive() for t in self.subthreads]) + + def shutdown(self): + """ + Destroy all tracked components. + """ + + if not self.running: + return + + try: + self.controller.shutdown(context=self.context) + except: + import pdb; pdb.set_trace() + + for component in self.components.itervalues(): + component.shutdown() + + for thread in self.subthreads: + if thread.is_alive(): + thread._Thread__stop() + + self.running = False + + assert self.did_clean_shutdown() + + def launch_component(self, component): + thread = threading.Thread(target=component.run) + thread.start() + + self.subthreads.append(thread) + return thread diff --git a/zipline/test/client.py b/zipline/test/client.py index 20bb8df7..2b1ac006 100644 --- a/zipline/test/client.py +++ b/zipline/test/client.py @@ -1,6 +1,3 @@ -#import logging -import ujson as json - import zipline.util as qutil import zipline.messaging as qmsg from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE @@ -45,10 +42,10 @@ class TestClient(qmsg.Component): self.received_count += 1 try: - event = json.loads(msg) + event = self.unframe(msg) - # JSON deserialization error - except ValueError as exc: + # deserialization error + except zp.InvalidFrame as exc: return self.signal_exception(exc) if self.prev_dt != None: @@ -59,10 +56,14 @@ class TestClient(qmsg.Component): ) ) else: - self.prev_dt = event['dt'] + self.prev_dt = event.dt if self.received_count % 100 == 0: qutil.LOGGER.info("received {n} messages".format(n=self.received_count)) + + def unframe(self, msg): + return zp.MERGE_UNFRARME(msg) + class TestTradingClient(TradeSimulationClient): diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index 0a9869e0..d1f35271 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -157,13 +157,8 @@ class FinanceTestCase(TestCase): # ---------- sim.simulate().join() - # Stop Running - # ------------ - - # TODO: less abrupt later, just shove a StopIteration - # down the pipe to make it stop spinning - sim.cuc._Thread__stop() - + + # TODO: Make more assertions about the final state of the components. self.assertEqual(sim.feed.pending_messages(), 0, "The feed should be drained of all messages, found {n} remaining." .format(n=sim.feed.pending_messages())