mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 09:43:37 +08:00
FSM for Feed
This commit is contained in:
@@ -3,13 +3,27 @@ from collections import Counter
|
||||
|
||||
from zipline.core.component import Component
|
||||
from zipline.components.aggregator import Aggregate
|
||||
from zipline.utils.protocol_utils import Enum
|
||||
import zipline.protocol as zp
|
||||
from zipline.transitions import WorkflowMeta
|
||||
|
||||
from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \
|
||||
CONTROL_FRAME, CONTROL_UNFRAME
|
||||
|
||||
LOGGER = logging.getLogger('ZiplineLogger')
|
||||
|
||||
# FSM
|
||||
# ---
|
||||
|
||||
INIT, READY, DRAINING = FEED_STATES = \
|
||||
Enum( 'INIT', 'READY', 'DRAINING')
|
||||
|
||||
state_transitions = dict(
|
||||
do_start = (-1 , INIT) ,
|
||||
do_run = (INIT , READY) ,
|
||||
do_drain = (READY , DRAINING) ,
|
||||
)
|
||||
|
||||
class Feed(Aggregate):
|
||||
"""
|
||||
Connects to N PULL sockets, publishing all messages received to a
|
||||
@@ -18,20 +32,26 @@ class Feed(Aggregate):
|
||||
one execution context (thread, process, etc) and run in another.
|
||||
"""
|
||||
|
||||
__metaclass__ = WorkflowMeta
|
||||
|
||||
states = list(FEED_STATES)
|
||||
transitions = state_transitions
|
||||
initial_state = -1
|
||||
|
||||
def init(self):
|
||||
self.sent_count = 0
|
||||
self.received_count = 0
|
||||
self.draining = False
|
||||
self.ds_finished_counter = 0
|
||||
|
||||
# Depending on the size of this, might want to use a data
|
||||
# structure with better asymptotics.
|
||||
self.data_buffer = {}
|
||||
|
||||
# source_id -> integer count
|
||||
self.sent_counters = Counter()
|
||||
self.recv_counters = Counter()
|
||||
|
||||
self.state = INIT
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return "FEED"
|
||||
@@ -71,6 +91,8 @@ class Feed(Aggregate):
|
||||
"""
|
||||
Get the next message in chronological order.
|
||||
"""
|
||||
|
||||
# is_full and draining defined in aggregator
|
||||
if not(self.is_full() or self.draining):
|
||||
return
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ class PassthroughTransform(BaseTransform):
|
||||
"""
|
||||
|
||||
def init(self):
|
||||
self.state = { 'name': 'PASSTHROUGH' }
|
||||
self.props = { 'name': 'PASSTHROUGH' }
|
||||
|
||||
#TODO, could save some cycles by skipping the _UNFRAME call
|
||||
# and just setting value to original msg string.
|
||||
|
||||
@@ -568,3 +568,23 @@ class Component(object):
|
||||
pid = os.getpid() ,
|
||||
pointer = hex(id(self)) ,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if not hasattr(self, '_state'):
|
||||
self._state = self.initial_state
|
||||
else:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, new):
|
||||
if not hasattr(self, '_state'):
|
||||
self._state = self.initial_state
|
||||
|
||||
old = self._state
|
||||
|
||||
if (old, new) in self.workflow:
|
||||
self._state = new
|
||||
else:
|
||||
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
|
||||
|
||||
@@ -7,20 +7,20 @@ from zipline.finance.movingaverage import EventWindow
|
||||
class VWAPTransform(BaseTransform):
|
||||
|
||||
def init(self, name, daycount=3):
|
||||
self.state = {}
|
||||
self.state['name'] = name
|
||||
self.props = {}
|
||||
self.props['name'] = name
|
||||
self.daycount = daycount
|
||||
self.by_sid = defaultdict(self.create_vwap)
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
return self.props['name']
|
||||
|
||||
def transform(self, event):
|
||||
cur = self.by_sid[event.sid]
|
||||
cur.update(event)
|
||||
self.state['value'] = cur.vwap
|
||||
return self.state
|
||||
self.props['value'] = cur.vwap
|
||||
return self.props
|
||||
|
||||
def create_vwap(self):
|
||||
return DailyVWAP(self.daycount)
|
||||
|
||||
@@ -25,7 +25,7 @@ class BaseTransform(Component):
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
return self.props['name']
|
||||
|
||||
@property
|
||||
def get_type(self):
|
||||
@@ -116,9 +116,9 @@ class BaseTransform(Component):
|
||||
|
||||
Transforms run in parallel and results are merged into a
|
||||
single map, so transform names must be unique. Best practice
|
||||
is to use the self.state object initialized from the transform
|
||||
is to use the self.props object initialized from the transform
|
||||
configuration, and only set the transformed value::
|
||||
|
||||
self.state['value'] = transformed_value
|
||||
self.props['value'] = transformed_value
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -39,24 +39,6 @@ class WorkflowMeta(type):
|
||||
"""
|
||||
Base metaclass component workflows.
|
||||
"""
|
||||
@property
|
||||
def state(self):
|
||||
if not hasattr(self, '_state'):
|
||||
self._state = self.initial_state
|
||||
else:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, new):
|
||||
if not hasattr(self, '_state'):
|
||||
self._state = self.initial_state
|
||||
|
||||
old = self._state
|
||||
|
||||
if (old, new) in self.workflow:
|
||||
self._state = new
|
||||
else:
|
||||
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
|
||||
|
||||
def __new__(cls, name, mro, attrs):
|
||||
base = 'Component'
|
||||
@@ -71,7 +53,6 @@ class WorkflowMeta(type):
|
||||
raise RuntimeError('`workflow` is a reserved attribute.')
|
||||
|
||||
if not state:
|
||||
import pdb; pdb.set_trace()
|
||||
raise RuntimeError('Must specify states')
|
||||
|
||||
if not transitions:
|
||||
|
||||
Reference in New Issue
Block a user