merged master branch that contains control component, error handling with finance branch.

This commit is contained in:
fawce
2012-03-01 23:44:35 -05:00
parent 6283906434
commit 2aa4af86aa
6 changed files with 165 additions and 56 deletions
+4 -8
View File
@@ -1,4 +1,3 @@
import json
import datetime
import pytz
import math
@@ -76,8 +75,9 @@ class OrderDataSource(qmsg.DataSource):
self.last_iteration_duration = datetime.timedelta(seconds=0)
self.sent_count = 0
@property
def get_type(self):
return zp.FINANCE_COMPONENT.ORDER_SOURCE
return zp.DATASOURCE_TYPE.ORDER
def open(self):
qmsg.DataSource.open(self)
@@ -123,7 +123,7 @@ class OrderDataSource(qmsg.DataSource):
self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start
dt = self.simulation_dt + self.last_iteration_duration
order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER})
order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt})
self.send(order_event)
count += 1
@@ -133,14 +133,10 @@ class OrderDataSource(qmsg.DataSource):
if(count == 0):
self.send_dummy()
self.sent_count += 1
def send(self, order_event):
message = zp.DATASOURCE_FRAME(order_event)
self.data_socket.send(message)
def send_dummy(self):
dt = self.simulation_dt + self.last_iteration_duration
dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER})
dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt})
self.send(dummy_order)
+61 -31
View File
@@ -3,10 +3,10 @@ Commonly used messaging components.
"""
import datetime
import ujson as json
import zipline.util as qutil
from zipline.component import Component
import zipline.protocol as zp
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
COMPONENT_STATE
@@ -220,20 +220,27 @@ class ParallelBuffer(Component):
self.signal_done()
else:
try:
event = json.loads(message)
# JSON deserialization error
except ValueError as exc:
event = self.unframe(message)
# deserialization error
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
try:
self.append(event[u'id'], event)
self.append(event)
self.send_next()
# Invalid message
except KeyError as exc:
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
#
def unframe(self, msg):
return zp.DATASOURCE_UNFRAME(msg)
def frame(self, event):
return zp.FEED_FRAME(event)
# -------------
# Flow Control
# -------------
@@ -254,17 +261,16 @@ class ParallelBuffer(Component):
return
event = self.next()
if event != None:
self.feed_socket.send(json.dumps(event), self.zmq.NOBLOCK)
if(event != None):
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
self.sent_count += 1
def append(self, source_id, value):
def append(self, event):
"""
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[source_id].append(value)
self.data_buffer[event.source_id].append(event)
self.received_count += 1
def next(self):
@@ -280,7 +286,7 @@ class ParallelBuffer(Component):
if len(events) == 0:
continue
cur = events
if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
if (earliest == None) or (cur[0].dt <= earliest[0].dt):
earliest = cur
if earliest != None:
@@ -350,15 +356,32 @@ class MergedParallelBuffer(ParallelBuffer):
if(not(self.is_full() or self.draining)):
return
#
#get the raw event from the passthrough transform.
result = self.data_buffer["PASSTHROUGH"].pop(0)['value']
result = self.data_buffer[zp.TRANSFORM_TYPE.PASSTHROUGH].pop(0).PASSTHROUGH
for source, events in self.data_buffer.iteritems():
if source == "PASSTHROUGH":
if source == zp.TRANSFORM_TYPE.PASSTHROUGH:
continue
if len(events) > 0:
cur = events.pop(0)
result[source] = cur['value']
result.merge(cur)
return result
def unframe(self, msg):
return zp.TRANSFORM_UNFRAME(msg)
def frame(self, event):
return zp.MERGE_FRAME(event)
def append(self, event):
"""
:param event: a namedict with one entry. key is the name of the transform, value is the transformed value.
Add an event to the buffer for the source specified by
source_id.
"""
self.data_buffer[event.__dict__.keys()[0]].append(event)
self.received_count += 1
class BaseTransform(Component):
@@ -425,14 +448,12 @@ class BaseTransform(Component):
return
try:
event = json.loads(message)
except ValueError as exc:
event = self.unframe(message)
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
try:
cur_state = self.transform(event)
cur_state['dt'] = event['dt']
cur_state['id'] = self.state['name']
# This is overloaded, so it can fail in all sorts of
# unknown ways. Its best to catch it in the
@@ -441,12 +462,18 @@ class BaseTransform(Component):
return self.signal_exception(exc)
try:
json_frame = json.dumps(cur_state)
except ValueError as exc:
transform_frame = self.frame(cur_state)
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
self.result_socket.send(json_frame, self.zmq.NOBLOCK)
self.result_socket.send(transform_frame, self.zmq.NOBLOCK)
def frame(self, cur_state):
return zp.TRANSFORM_FRAME(cur_state['name'], cur_state['value'])
def unframe(self, msg):
return zp.FEED_UNFRAME(msg)
def transform(self, event):
"""
Must return the transformed value as a map with::
@@ -487,7 +514,7 @@ class PassthroughTransform(BaseTransform):
#TODO, could save some cycles by skipping the _UNFRAME call and just setting value to original msg string.
def transform(self, event):
return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.DATASOURCE_FRAME(event) }
return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.FEED_FRAME(event) }
class DataSource(Component):
@@ -520,14 +547,17 @@ class DataSource(Component):
"""
Emit data.
"""
assert isinstance(event, dict)
assert isinstance(event, zp.namedict)
event['id'] = self.id
event['type'] = self.get_type()
event.__dict__['source_id'] = self.get_id
event.__dict__['type'] = self.get_type
try:
json_frame = json.dumps(event)
except ValueError as exc:
ds_frame = self.frame(event)
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
self.data_socket.send(json_frame)
self.data_socket.send(ds_frame)
def frame(self, event):
return zp.DATASOURCE_FRAME(event)
+5 -3
View File
@@ -61,13 +61,15 @@ def Enum(*options):
_fields_ = [(o, c_ubyte) for o in options]
return cstruct(*range(len(options)))
class InvalidFrame(Exception):
def __init__(self, got):
self.got = got
def FrameExceptionFactory(name):
"""
Exception factory with a closure around the frame class name.
"""
class InvalidFrame(Exception):
def __init__(self, got):
self.got = got
class NamedInvalidFrame(InvalidFrame):
def __str__(self):
return "Invalid {framecls} Frame: {got}".format(
framecls = name,
+85
View File
@@ -0,0 +1,85 @@
"""
Simulator hosts all the components necessary to execute a simluation. See :py:method""
"""
import threading
import mock
from collections import defaultdict
from zipline.monitor import Controller
from zipline.messaging import ComponentHost
import zipline.util as qutil
class AddressAllocator(object):
def __init__(self, ns):
self.idx = 0
self.sockets = [
'tcp://127.0.0.1:%s' % (10000 + n)
for n in xrange(ns)
]
def lease(self, n):
sockets = self.sockets[self.idx:self.idx+n]
self.idx += n
return sockets
def reaquire(self, *conn):
pass
#
class Simulator(ComponentHost):
def __init__(self, addresses):
ComponentHost.__init__(self, addresses)
self.subthreads = []
self.running = False
def launch_controller(self):
thread = threading.Thread(target=self.controller.run)
thread.start()
self.subthreads.append(thread)
return thread
def simulate(self):
thread = threading.Thread(target=self.run)
thread.start()
self.subthreads.append(thread)
self.running = True
return thread
def did_clean_shutdown(self):
return not any([t.isAlive() for t in self.subthreads])
def shutdown(self):
"""
Destroy all tracked components.
"""
if not self.running:
return
try:
self.controller.shutdown(context=self.context)
except:
import pdb; pdb.set_trace()
for component in self.components.itervalues():
component.shutdown()
for thread in self.subthreads:
if thread.is_alive():
thread._Thread__stop()
self.running = False
assert self.did_clean_shutdown()
def launch_component(self, component):
thread = threading.Thread(target=component.run)
thread.start()
self.subthreads.append(thread)
return thread
+8 -7
View File
@@ -1,6 +1,3 @@
#import logging
import ujson as json
import zipline.util as qutil
import zipline.messaging as qmsg
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE
@@ -45,10 +42,10 @@ class TestClient(qmsg.Component):
self.received_count += 1
try:
event = json.loads(msg)
event = self.unframe(msg)
# JSON deserialization error
except ValueError as exc:
# deserialization error
except zp.InvalidFrame as exc:
return self.signal_exception(exc)
if self.prev_dt != None:
@@ -59,10 +56,14 @@ class TestClient(qmsg.Component):
)
)
else:
self.prev_dt = event['dt']
self.prev_dt = event.dt
if self.received_count % 100 == 0:
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
def unframe(self, msg):
return zp.MERGE_UNFRARME(msg)
class TestTradingClient(TradeSimulationClient):
+2 -7
View File
@@ -157,13 +157,8 @@ class FinanceTestCase(TestCase):
# ----------
sim.simulate().join()
# Stop Running
# ------------
# TODO: less abrupt later, just shove a StopIteration
# down the pipe to make it stop spinning
sim.cuc._Thread__stop()
# TODO: Make more assertions about the final state of the components.
self.assertEqual(sim.feed.pending_messages(), 0,
"The feed should be drained of all messages, found {n} remaining."
.format(n=sim.feed.pending_messages())