mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 06:47:57 +08:00
merged master branch that contains control component, error handling with finance branch.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import datetime
|
||||
import pytz
|
||||
import math
|
||||
@@ -76,8 +75,9 @@ class OrderDataSource(qmsg.DataSource):
|
||||
self.last_iteration_duration = datetime.timedelta(seconds=0)
|
||||
self.sent_count = 0
|
||||
|
||||
@property
|
||||
def get_type(self):
|
||||
return zp.FINANCE_COMPONENT.ORDER_SOURCE
|
||||
return zp.DATASOURCE_TYPE.ORDER
|
||||
|
||||
def open(self):
|
||||
qmsg.DataSource.open(self)
|
||||
@@ -123,7 +123,7 @@ class OrderDataSource(qmsg.DataSource):
|
||||
|
||||
self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start
|
||||
dt = self.simulation_dt + self.last_iteration_duration
|
||||
order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER})
|
||||
order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt})
|
||||
|
||||
self.send(order_event)
|
||||
count += 1
|
||||
@@ -133,14 +133,10 @@ class OrderDataSource(qmsg.DataSource):
|
||||
if(count == 0):
|
||||
self.send_dummy()
|
||||
self.sent_count += 1
|
||||
|
||||
def send(self, order_event):
|
||||
message = zp.DATASOURCE_FRAME(order_event)
|
||||
self.data_socket.send(message)
|
||||
|
||||
def send_dummy(self):
|
||||
dt = self.simulation_dt + self.last_iteration_duration
|
||||
dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt, "source_id":self.get_id, "type":zp.DATASOURCE_TYPE.ORDER})
|
||||
dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt})
|
||||
self.send(dummy_order)
|
||||
|
||||
|
||||
|
||||
+61
-31
@@ -3,10 +3,10 @@ Commonly used messaging components.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import ujson as json
|
||||
|
||||
import zipline.util as qutil
|
||||
from zipline.component import Component
|
||||
import zipline.protocol as zp
|
||||
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
|
||||
COMPONENT_STATE
|
||||
|
||||
@@ -220,20 +220,27 @@ class ParallelBuffer(Component):
|
||||
self.signal_done()
|
||||
else:
|
||||
try:
|
||||
event = json.loads(message)
|
||||
|
||||
# JSON deserialization error
|
||||
except ValueError as exc:
|
||||
event = self.unframe(message)
|
||||
# deserialization error
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
try:
|
||||
self.append(event[u'id'], event)
|
||||
self.append(event)
|
||||
self.send_next()
|
||||
|
||||
# Invalid message
|
||||
except KeyError as exc:
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
#
|
||||
def unframe(self, msg):
|
||||
return zp.DATASOURCE_UNFRAME(msg)
|
||||
|
||||
def frame(self, event):
|
||||
return zp.FEED_FRAME(event)
|
||||
|
||||
|
||||
# -------------
|
||||
# Flow Control
|
||||
# -------------
|
||||
@@ -254,17 +261,16 @@ class ParallelBuffer(Component):
|
||||
return
|
||||
|
||||
event = self.next()
|
||||
|
||||
if event != None:
|
||||
self.feed_socket.send(json.dumps(event), self.zmq.NOBLOCK)
|
||||
if(event != None):
|
||||
self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK)
|
||||
self.sent_count += 1
|
||||
|
||||
def append(self, source_id, value):
|
||||
def append(self, event):
|
||||
"""
|
||||
Add an event to the buffer for the source specified by
|
||||
source_id.
|
||||
"""
|
||||
self.data_buffer[source_id].append(value)
|
||||
self.data_buffer[event.source_id].append(event)
|
||||
self.received_count += 1
|
||||
|
||||
def next(self):
|
||||
@@ -280,7 +286,7 @@ class ParallelBuffer(Component):
|
||||
if len(events) == 0:
|
||||
continue
|
||||
cur = events
|
||||
if (earliest == None) or (cur[0]['dt'] <= earliest[0]['dt']):
|
||||
if (earliest == None) or (cur[0].dt <= earliest[0].dt):
|
||||
earliest = cur
|
||||
|
||||
if earliest != None:
|
||||
@@ -350,15 +356,32 @@ class MergedParallelBuffer(ParallelBuffer):
|
||||
if(not(self.is_full() or self.draining)):
|
||||
return
|
||||
|
||||
#
|
||||
#get the raw event from the passthrough transform.
|
||||
result = self.data_buffer["PASSTHROUGH"].pop(0)['value']
|
||||
result = self.data_buffer[zp.TRANSFORM_TYPE.PASSTHROUGH].pop(0).PASSTHROUGH
|
||||
for source, events in self.data_buffer.iteritems():
|
||||
if source == "PASSTHROUGH":
|
||||
if source == zp.TRANSFORM_TYPE.PASSTHROUGH:
|
||||
continue
|
||||
if len(events) > 0:
|
||||
cur = events.pop(0)
|
||||
result[source] = cur['value']
|
||||
result.merge(cur)
|
||||
return result
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.TRANSFORM_UNFRAME(msg)
|
||||
|
||||
def frame(self, event):
|
||||
return zp.MERGE_FRAME(event)
|
||||
|
||||
def append(self, event):
|
||||
"""
|
||||
:param event: a namedict with one entry. key is the name of the transform, value is the transformed value.
|
||||
Add an event to the buffer for the source specified by
|
||||
source_id.
|
||||
"""
|
||||
|
||||
self.data_buffer[event.__dict__.keys()[0]].append(event)
|
||||
self.received_count += 1
|
||||
|
||||
|
||||
class BaseTransform(Component):
|
||||
@@ -425,14 +448,12 @@ class BaseTransform(Component):
|
||||
return
|
||||
|
||||
try:
|
||||
event = json.loads(message)
|
||||
except ValueError as exc:
|
||||
event = self.unframe(message)
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
try:
|
||||
cur_state = self.transform(event)
|
||||
cur_state['dt'] = event['dt']
|
||||
cur_state['id'] = self.state['name']
|
||||
|
||||
# This is overloaded, so it can fail in all sorts of
|
||||
# unknown ways. Its best to catch it in the
|
||||
@@ -441,12 +462,18 @@ class BaseTransform(Component):
|
||||
return self.signal_exception(exc)
|
||||
|
||||
try:
|
||||
json_frame = json.dumps(cur_state)
|
||||
except ValueError as exc:
|
||||
transform_frame = self.frame(cur_state)
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
self.result_socket.send(json_frame, self.zmq.NOBLOCK)
|
||||
|
||||
self.result_socket.send(transform_frame, self.zmq.NOBLOCK)
|
||||
|
||||
def frame(self, cur_state):
|
||||
return zp.TRANSFORM_FRAME(cur_state['name'], cur_state['value'])
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.FEED_UNFRAME(msg)
|
||||
|
||||
def transform(self, event):
|
||||
"""
|
||||
Must return the transformed value as a map with::
|
||||
@@ -487,7 +514,7 @@ class PassthroughTransform(BaseTransform):
|
||||
|
||||
#TODO, could save some cycles by skipping the _UNFRAME call and just setting value to original msg string.
|
||||
def transform(self, event):
|
||||
return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.DATASOURCE_FRAME(event) }
|
||||
return {'name':zp.TRANSFORM_TYPE.PASSTHROUGH, 'value': zp.FEED_FRAME(event) }
|
||||
|
||||
|
||||
class DataSource(Component):
|
||||
@@ -520,14 +547,17 @@ class DataSource(Component):
|
||||
"""
|
||||
Emit data.
|
||||
"""
|
||||
assert isinstance(event, dict)
|
||||
assert isinstance(event, zp.namedict)
|
||||
|
||||
event['id'] = self.id
|
||||
event['type'] = self.get_type()
|
||||
event.__dict__['source_id'] = self.get_id
|
||||
event.__dict__['type'] = self.get_type
|
||||
|
||||
try:
|
||||
json_frame = json.dumps(event)
|
||||
except ValueError as exc:
|
||||
ds_frame = self.frame(event)
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
self.data_socket.send(json_frame)
|
||||
self.data_socket.send(ds_frame)
|
||||
|
||||
def frame(self, event):
|
||||
return zp.DATASOURCE_FRAME(event)
|
||||
|
||||
+5
-3
@@ -61,13 +61,15 @@ def Enum(*options):
|
||||
_fields_ = [(o, c_ubyte) for o in options]
|
||||
return cstruct(*range(len(options)))
|
||||
|
||||
class InvalidFrame(Exception):
|
||||
def __init__(self, got):
|
||||
self.got = got
|
||||
|
||||
def FrameExceptionFactory(name):
|
||||
"""
|
||||
Exception factory with a closure around the frame class name.
|
||||
"""
|
||||
class InvalidFrame(Exception):
|
||||
def __init__(self, got):
|
||||
self.got = got
|
||||
class NamedInvalidFrame(InvalidFrame):
|
||||
def __str__(self):
|
||||
return "Invalid {framecls} Frame: {got}".format(
|
||||
framecls = name,
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Simulator hosts all the components necessary to execute a simluation. See :py:method""
|
||||
"""
|
||||
|
||||
import threading
|
||||
import mock
|
||||
from collections import defaultdict
|
||||
from zipline.monitor import Controller
|
||||
from zipline.messaging import ComponentHost
|
||||
import zipline.util as qutil
|
||||
|
||||
class AddressAllocator(object):
|
||||
|
||||
def __init__(self, ns):
|
||||
self.idx = 0
|
||||
self.sockets = [
|
||||
'tcp://127.0.0.1:%s' % (10000 + n)
|
||||
for n in xrange(ns)
|
||||
]
|
||||
|
||||
def lease(self, n):
|
||||
sockets = self.sockets[self.idx:self.idx+n]
|
||||
self.idx += n
|
||||
return sockets
|
||||
|
||||
def reaquire(self, *conn):
|
||||
pass
|
||||
|
||||
#
|
||||
class Simulator(ComponentHost):
|
||||
|
||||
def __init__(self, addresses):
|
||||
ComponentHost.__init__(self, addresses)
|
||||
self.subthreads = []
|
||||
self.running = False
|
||||
|
||||
def launch_controller(self):
|
||||
thread = threading.Thread(target=self.controller.run)
|
||||
thread.start()
|
||||
|
||||
self.subthreads.append(thread)
|
||||
return thread
|
||||
|
||||
def simulate(self):
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
|
||||
self.subthreads.append(thread)
|
||||
self.running = True
|
||||
|
||||
return thread
|
||||
|
||||
def did_clean_shutdown(self):
|
||||
return not any([t.isAlive() for t in self.subthreads])
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Destroy all tracked components.
|
||||
"""
|
||||
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
self.controller.shutdown(context=self.context)
|
||||
except:
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
for component in self.components.itervalues():
|
||||
component.shutdown()
|
||||
|
||||
for thread in self.subthreads:
|
||||
if thread.is_alive():
|
||||
thread._Thread__stop()
|
||||
|
||||
self.running = False
|
||||
|
||||
assert self.did_clean_shutdown()
|
||||
|
||||
def launch_component(self, component):
|
||||
thread = threading.Thread(target=component.run)
|
||||
thread.start()
|
||||
|
||||
self.subthreads.append(thread)
|
||||
return thread
|
||||
@@ -1,6 +1,3 @@
|
||||
#import logging
|
||||
import ujson as json
|
||||
|
||||
import zipline.util as qutil
|
||||
import zipline.messaging as qmsg
|
||||
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE
|
||||
@@ -45,10 +42,10 @@ class TestClient(qmsg.Component):
|
||||
self.received_count += 1
|
||||
|
||||
try:
|
||||
event = json.loads(msg)
|
||||
event = self.unframe(msg)
|
||||
|
||||
# JSON deserialization error
|
||||
except ValueError as exc:
|
||||
# deserialization error
|
||||
except zp.InvalidFrame as exc:
|
||||
return self.signal_exception(exc)
|
||||
|
||||
if self.prev_dt != None:
|
||||
@@ -59,10 +56,14 @@ class TestClient(qmsg.Component):
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.prev_dt = event['dt']
|
||||
self.prev_dt = event.dt
|
||||
|
||||
if self.received_count % 100 == 0:
|
||||
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
|
||||
|
||||
def unframe(self, msg):
|
||||
return zp.MERGE_UNFRARME(msg)
|
||||
|
||||
|
||||
class TestTradingClient(TradeSimulationClient):
|
||||
|
||||
|
||||
@@ -157,13 +157,8 @@ class FinanceTestCase(TestCase):
|
||||
# ----------
|
||||
sim.simulate().join()
|
||||
|
||||
# Stop Running
|
||||
# ------------
|
||||
|
||||
# TODO: less abrupt later, just shove a StopIteration
|
||||
# down the pipe to make it stop spinning
|
||||
sim.cuc._Thread__stop()
|
||||
|
||||
|
||||
# TODO: Make more assertions about the final state of the components.
|
||||
self.assertEqual(sim.feed.pending_messages(), 0,
|
||||
"The feed should be drained of all messages, found {n} remaining."
|
||||
.format(n=sim.feed.pending_messages())
|
||||
|
||||
Reference in New Issue
Block a user