Merge branch 'master', remote-tracking branch 'origin' into refactor

This commit is contained in:
Stephen Diehl
2012-05-09 09:15:07 -04:00
12 changed files with 495 additions and 184 deletions
+5 -3
View File
@@ -10,6 +10,7 @@ import uuid
import time
import socket
import gevent
import traceback
import humanhash
# pyzmq
@@ -305,11 +306,11 @@ class Component(object):
self._exception = exc
exc_type, exc_value, exc_traceback = sys.exc_info()
self.stack_trace = exc_traceback
trace = '\n>>>'.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
exception_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.EXCEPTION,
str(exc)
trace
)
self.control_out.send(exception_frame)
@@ -492,7 +493,8 @@ class Component(object):
"""
The descriptive name of the component.
"""
return 'UNKNOWN COMPONENT'
# Prevents the bug that Thomas ran into
raise NotImplementedError
@property
def get_type(self):
+15 -13
View File
@@ -224,6 +224,10 @@ class PerformanceTracker():
self.order_log.append(order)
def process_event(self, event):
if self.exceeded_max_loss:
return
assert isinstance(event, zp.namedict)
self.event_count += 1
@@ -265,7 +269,13 @@ class PerformanceTracker():
self.day_count += 1.0
# calculate progress of test
self.progress = self.day_count / self.total_days
# Output results
if self.result_stream:
msg = zp.PERF_FRAME(self.to_dict())
self.result_stream.send(msg)
#
if self.trading_environment.max_drawdown:
returns = self.todays_performance.returns
max_dd = -1 * self.trading_environment.max_drawdown
@@ -276,16 +286,8 @@ class PerformanceTracker():
# so it shows up in the update, but don't end the test
# here. Let the update go out before stopping
self.exceeded_max_loss = True
# Output results
if self.result_stream:
msg = zp.PERF_FRAME(self.to_dict())
self.result_stream.send(msg)
return
if self.exceeded_max_loss:
# now that we've sent the day's update, kill this test
self.handle_simulation_end(skip_close=True)
return
#move the market day markers forward
self.market_open = self.market_open + self.calendar_day
@@ -307,7 +309,7 @@ class PerformanceTracker():
keep_transactions = True
)
def handle_simulation_end(self, skip_close=False):
def handle_simulation_end(self):
"""
When the simulation is complete, run the full period risk report
and send it out on the result_stream.
@@ -319,8 +321,8 @@ class PerformanceTracker():
# the stream will end on the last trading day, but will not trigger
# an end of day, so we trigger the final market close here.
# In the case of errors, we needn't close again.
if not skip_close:
# In the case of max drawdown, we needn't close again.
if not self.exceeded_max_loss:
self.handle_market_close()
self.risk_report = risk.RiskReport(
+16 -13
View File
@@ -59,8 +59,10 @@ class TradeSimulationClient(qmsg.Component):
:py:mod:`zipline.test.algorithm`
"""
self.algorithm = algorithm
#register the trading_client's order method with the algorithm
# register the trading_client's order method with the algorithm
self.algorithm.set_order(self.order)
# ask the algorithm to initialize
self.algorithm.initialize()
def open(self):
self.result_feed = self.connect_result()
@@ -80,14 +82,7 @@ class TradeSimulationClient(qmsg.Component):
# if the feed is done, shut 'er down
if msg == str(zp.CONTROL_PROTOCOL.DONE):
qutil.LOGGER.info("Client is DONE!")
# signal the performance tracker that the simulation has
# ended. Perf will internally calculate the full risk report.
self.perf.handle_simulation_end()
# signal Simulator, our ComponentHost, that this component is
# done and Simulator needn't block exit on this component.
self.signal_done()
self.finish_simulation()
return
# result_feed is a merge component, so unframe accordingly
@@ -95,14 +90,22 @@ class TradeSimulationClient(qmsg.Component):
self.received_count += 1
# update performance and relay the event to the algorithm
self.process_event(event)
if self.perf.exceeded_max_loss:
self.finish_simulation()
def finish_simulation(self):
qutil.LOGGER.info("Client is DONE!")
# signal the performance tracker that the simulation has
# ended. Perf will internally calculate the full risk report.
self.perf.handle_simulation_end()
# signal Simulator, our ComponentHost, that this component is
# done and Simulator needn't block exit on this component.
self.signal_done()
def process_event(self, event):
if self.perf.exceeded_max_loss:
self.control_out.send(str(zp.CONTROL_PROTOCOL.SHUTDOWN))
return
# generate transactions, if applicable
txn = self.txn_sim.apply_trade_to_open_orders(event)
if txn:
+67
View File
@@ -0,0 +1,67 @@
from datetime import timedelta
from itertools import ifilter
from collections import defaultdict
from zipline.messaging import BaseTransform
class VWAPTransform(BaseTransform):
def init(self, daycount=3):
self.daycount = daycount
self.by_sid = defaultdict(DailyVWAP)
def transform(self, event):
cur = self.by_sid(event.sid)
cur.update(event)
self.state['value'] = cur.vwap
return self.state
class DailyVWAP:
"""A class that tracks the volume weighted average price
based on tick updates."""
def __init__(self, daycount=3):
self.ticks = []
self.dropped_ticks = []
self.flux = 0.0
self.volume = 0
self.lastTick = None
self.vwap = 0.0
self.delta = timedelta(days=daycount)
def update(self, event):
self.ticks.append(event)
flux, volume = self.calculate_flux([event])
self.flux += flux
self.volume += volume
self.last_date = event['dt']
self.first_date = self.last_date - self.delta
#use a list comprehension to filter the ticks to those within
#desired day range. The dt properties are full datetime objects
#and provide overloads for arithmetic operations.
self.dropped_ticks = []
for tick in self.ticks:
if tick['dt'] < self.first_date:
self.dropped_ticks.append(tick)
slice_index = len(self.dropped_ticks)
self.ticks = self.ticks[slice_index:]
dropped_flux, dropped_volume = self.calculate_flux(self.dropped_ticks)
self.flux -= dropped_flux
self.volume -= dropped_volume
if(self.volume != 0):
self.vwap = self.flux / self.volume
else:
self.vwap = None
def calculate_flux(self, ticks):
flux = 0.0
volume = 0
for tick in ticks:
flux += tick['volume'] * tick['price']
volume += tick['volume']
return flux, volume
+43 -52
View File
@@ -4,9 +4,7 @@ messaging. All ziplines follow a general topology of parallel sources,
datetimestamp serialization, parallel transformations, and finally sinks.
Furthermore, many ziplines have common needs. For example, all trade
simulations require a
:py:class:`~zipline.finance.trading.TradeSimulationClient`, an
:py:class:`~zipline.finance.trading.OrderSource`, and a
:py:class:`~zipline.finance.trading.TransactionSimulator` (a transform).
:py:class:`~zipline.finance.trading.TradeSimulationClient`.
To establish best practices and minimize code replication, the lines module
provides complete zipline topologies. You can extend any zipline without
@@ -17,56 +15,49 @@ before invoking simulate.
Here is a diagram of the SimulatedTrading zipline:
+----------------------+ +------------------------+
+-->| Orders DataSource | | (DataSource added |
| | Integrates algo | | via add_source) |
| | orders into history | | |
| +--------------------+-+ +-+----------------------+
| | |
| | |
| v v
| +---------+
| | Feed |
| +-+------++
| | |
| | |
| v v
| +----------------------+ +----------------------+
| | Transaction | | |
| | Transform simulates | | (Transforms added |
| | trades based on | | via add_transform) |
| | orders from algo. | | |
| +-------------------+--+ +-+--------------------+
| | |
| | |
| v v
| +------------+
| | Merge |
| +------+-----+
| |
| |
| V
| +--------------------------------+
| | |
| | TradingSimulationClient |
| orders | tracks performance and |
+---------------+ provides API to algorithm. |
| |
+---------------------+----------+
^ |
| orders | frames
| |
| v
+---------+-----------------------+
| |
| Algorithm added via |
| __init__. |
| |
| |
| |
+---------------------------------+
+----------------------+ +------------------------+
| Trade History | | (DataSource added |
| | | via add_source) |
| | | |
+--------------------+-+ +-+----------------------+
| |
| |
v v
+---------+
| Feed | (ensures events are serialized
+-+------++ in chronological order)
| |
| |
v v
+----------------------+ +----------------------+
| (Transforms added | | (Transforms added |
| via add_transform) | | via add_transform) |
+-------------------+--+ +-+--------------------+
| |
| |
v v
+------------+
| Merge | (combines original event and
+------+-----+ transforms into one vector)
|
|
V
+---------------+ +--------------------------------+
| Risk and Perf | | |
| Tracker | | TradingSimulationClient |
+---------------+ | tracks performance and |
^ Trades and | provides API to algorithm. |
| simulated | |
| transactions +--+------------------+----------+
| | ^ |
+---------------------+ | orders | frames
| |
| v
+---------------------------------+
| Algorithm added via |
| __init__. |
+---------------------------------+
"""
import mock
+7 -7
View File
@@ -227,7 +227,7 @@ class Feed(Component):
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.done()
self.signal_done()
self.shutdown()
# -- Hard Kill --
@@ -440,14 +440,14 @@ class BaseTransform(Component):
method to create a new derived value from the combined feed.
"""
def __init__(self, name):
def __init__(self, name, **kwargs):
Component.__init__(self)
self.state = {
'name': name
}
self.init()
self.init(**kwargs)
def init(self):
pass
@@ -496,7 +496,7 @@ class BaseTransform(Component):
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.done()
self.signal_done()
self.shutdown()
# -- Hard Kill --
@@ -564,11 +564,11 @@ class PassthroughTransform(BaseTransform):
"""
def __init__(self):
def __init__(self, **kwargs):
BaseTransform.__init__(self, "PASSTHROUGH")
self.init()
self.init(**kwargs)
def init(self):
def init(self, **kwargs):
pass
@property
+164 -48
View File
@@ -3,13 +3,14 @@ import gevent
import itertools
# pyzmq
import zmq
# gevent_zeromq
import gevent_zeromq
# zmq_ctypes
#import zmq_ctypes
from collections import OrderedDict
from protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \
states = CONTROL_STATES
from gpoll import _Poller as GeventPoller
@@ -103,6 +104,17 @@ from gpoll import _Poller as GeventPoller
# | 0 | | 0 | | 0 | | 0 |
# +---+ +---+ +---+ +---+
INIT, SOURCES_READY, RUNNING, TERMINATE = CONTROL_STATES
state_transitions = frozenset([
(-1 , INIT),
(INIT , SOURCES_READY),
(SOURCES_READY , RUNNING),
(INIT , TERMINATE),
(SOURCES_READY , TERMINATE),
(RUNNING , TERMINATE),
])
class UnknownChatter(Exception):
def __init__(self, name):
self.named = name
@@ -149,27 +161,26 @@ class Controller(object):
def __init__(self, pub_socket, route_socket, logging = None):
self.context = None
self.zmq = None
self.context = None
self.zmq = None
self.zmq_poller = None
polling = False
self.polling = polling
self.running = False
self.polling = False
self.tracked = set()
self.responses = set()
self.ctime = 0
self.tic = time.time()
self.ctime = 0
self.tic = time.time()
self.freeform = False
self._state = -1
self.associated = []
self.pub_socket = pub_socket
self.pub_socket = pub_socket
self.route_socket = route_socket
self.error_replay = {}
self.error_replay = OrderedDict()
if logging:
self.logging = logging
@@ -182,23 +193,23 @@ class Controller(object):
assert self.zmq_flavor in ['thread', 'mp', 'green']
if flavor == 'mp':
self.zmq = zmq
self.context = self.zmq.Context()
self.zmq = zmq
self.context = self.zmq.Context()
self.zmq_poller = self.zmq.Poller
return
if flavor == 'thread':
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = self.zmq.Poller
return
if flavor == 'green':
self.zmq = gevent_zeromq.zmq
self.context = self.zmq.Context.instance()
self.zmq = gevent_zeromq.zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = GeventPoller
return
if flavor == 'pypy':
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = self.zmq.Poller
return
@@ -217,19 +228,24 @@ class Controller(object):
self.freeform = False
self.topology = frozenset(topology)
default_states = [
CONTROL_STATES.RUNNING,
CONTROL_STATES.SHUTDOWN,
CONTROL_STATES.TERMINATE,
]
self.states = states or default_states
self.polling = True
self.state = CONTROL_STATES.INIT
# Start off in RUNNING, state
self.state = self.states[0]
@property
def state(self):
return self._state
@state.setter
def state(self, new):
old, self._state = self._state, new
if (old, new) not in state_transitions:
raise RuntimeError("[Controller] Invalid State Transition : %s -> %s" %(old, new))
else:
self.logging.info("[Controller] State Transition : %s -> %s" %(old, new))
def run(self):
self.running = True
self.init_zmq(self.zmq_flavor)
try:
@@ -256,6 +272,9 @@ class Controller(object):
# -------------
def send_heart(self):
if not self.running:
return
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.HEARTBEAT,
str(self.ctime)
@@ -263,6 +282,9 @@ class Controller(object):
self.pub.send(heartbeat_frame)
def send_hardkill(self):
if not self.running:
return
kill_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.KILL,
''
@@ -270,6 +292,9 @@ class Controller(object):
self.pub.send(kill_frame)
def send_softkill(self):
if not self.running:
return
soft_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.SHUTDOWN,
''
@@ -282,16 +307,35 @@ class Controller(object):
def _poll(self):
assert self.route_socket
assert self.pub_socket
assert self.cancel_socket
# -- Publish --
# =============
self.pub = self.context.socket(self.zmq.PUB)
self.pub.bind(self.pub_socket)
# -- Cancel --
# =============
assert isinstance(self.cancel_socket,basestring), self.cancel_socket
self.cancel = self.context.socket(self.zmq.REP)
self.cancel.connect(self.cancel_socket)
# -- Router --
# =============
self.router = self.context.socket(self.zmq.ROUTER)
self.router.bind(self.route_socket)
self.associated.extend([self.pub, self.router])
poller = self.zmq.Poller()
poller.register(self.router, self.zmq.POLLIN)
poller.register(self.cancel, self.zmq.POLLIN)
self.associated += [self.pub, self.router, self.cancel]
# TODO: actually do this
self.state = CONTROL_STATES.SOURCES_READY
buffer = []
@@ -325,11 +369,20 @@ class Controller(object):
self.logging.error('Invalid frame', rawmessage)
pass
if self.cancel in socks and socks[self.cancel] == self.zmq.POLLIN:
self.logging.info('[Controller] Received Cancellation')
rawmessage = self.cancel.recv()
self.shutdown(soft=True)
break
self.beat()
if self.zmq_flavor == 'green':
gevent.sleep(0)
if self.state is CONTROL_STATES.TERMINATE:
break
if not self.polling:
break
@@ -359,35 +412,89 @@ class Controller(object):
for component in bad:
self.fail(component)
# ------------------
# Component Handlers
# ------------------
# --------------
# Init Handlers
# --------------
def new_source(self):
if self.state is CONTROL_STATES.RUNNING:
self.state = SOURCES_READY
def new_universal(self):
pass
# The various "states of being that a component can inform us
# of
def new(self, component):
self.logging.info('[Controller] Alive "%s" ' % component)
if self.state is CONTROL_STATES.TERMINATE:
return
self.logging.info('[Controller] Now Tracking "%s" ' % component)
universal = self.new_universal
init_handlers = {
'FEED' : self.new_source,
}
if component in self.topology or self.freeform:
init_handlers.get(component, universal)()
self.tracked.add(component)
else:
# Some sort of socket collision has occured, this is
# a very bad failure mode.
raise UnknownChatter(component)
# ------------------
# Epic Fail Handling
# ------------------
def fail_universal(self):
pass
# TODO: this requires higher order functionality
#self.logging.error('[Controller] System in exception state, shutting down')
#self.shutdown(soft=True)
def fail(self, component):
self.logging.info('[Controller] Component "%s" timed out' % component)
self.tracked.remove(component)
if self.state is CONTROL_STATES.TERMINATE:
return
universal = self.fail_universal
fail_handlers = { }
if component in self.topology or self.freeform:
self.logging.info('[Controller] Component "%s" timed out' % component)
self.tracked.remove(component)
fail_handlers.get(component, universal)()
# -------------------
# Completion Handling
# -------------------
def done(self, component):
# TODO: This will be what we ship off to vbench at some
# point...
# print component finished at self.ctime
self.logging.info('[Controller] Component "%s" done.' % component)
# --------------
# Error Handling
# --------------
def exception_universal(self):
"""
Shutdown the system on failure.
"""
self.logging.error('[Controller] System in exception state, shutting down')
self.shutdown(soft=True)
def exception(self, component, failure):
self.error_replay[time.time()] = failure
self.logging.error('Component "%s" in exception state' % component)
universal = self.exception_universal
exception_handlers = { }
if component in self.topology or self.freeform:
self.error_replay[(component, time.time())] = failure
self.logging.error('[Controller] Component "%s" in exception state' % component)
exception_handlers.get(component, universal)()
else:
raise UnknownChatter(component)
# -----------------
# Protocol Handling
@@ -462,6 +569,11 @@ class Controller(object):
self.associated.append(s)
return s
def do_error_replay(self):
for (component, time), error in self.error_replay.iteritems():
self.logging.info('[Controller] Error Log for -- %s --:\n%s' %
(component, error))
def shutdown(self, hard=False, soft=True, context=None):
if not self.polling:
@@ -472,7 +584,7 @@ class Controller(object):
assert hard or soft, """ Must specify kill hard or soft """
if hard:
self.state = CONTROL_STATES.SHUTDOWN
self.state = CONTROL_STATES.TERMINATE
self.logging.info('[Controller] Hard Shutdown')
@@ -488,18 +600,22 @@ class Controller(object):
#for asoc in self.associated:
#asoc.close()
self.do_error_replay()
if __name__ == '__main__':
print 'Running on ',\
'tcp://127.0.0.1:5000', \
'tcp://127.0.0.1:5001',
print 'Running on '\
'tcp://127.0.0.1:5000 '\
'tcp://127.0.0.1:5001 '
controller = Controller(
'tcp://127.0.0.1:5000',
'tcp://127.0.0.1:5001',
)
controller.zmq_flavor = 'green'
controller.manage(
'freeform',
[]
)
controller.run('green')
controller.run()
+4 -17
View File
@@ -126,9 +126,6 @@ from collections import namedtuple
from protocol_utils import Enum, FrameExceptionFactory, namedict
from date_utils import EPOCH, UN_EPOCH
#import ujson
#import ultrajson_numpy
# -----------------------
# Control Protocol
# -----------------------
@@ -136,9 +133,10 @@ from date_utils import EPOCH, UN_EPOCH
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
CONTROL_STATES = Enum(
'INIT',
'SOURCES_READY',
'RUNNING',
'SHUTDOWN', # a soft kill
'TERMINATE', # a hard kill
'TERMINATE',
)
CONTROL_PROTOCOL = Enum(
@@ -149,6 +147,7 @@ CONTROL_PROTOCOL = Enum(
'OK' , # 3 - rep
'DONE' , # 4 - rep
'EXCEPTION' , # 5 - rep
'SIGNAL' , # 6 - rep
)
def CONTROL_FRAME(event, payload):
@@ -174,18 +173,6 @@ def CONTROL_UNFRAME(msg):
except ValueError:
raise INVALID_CONTROL_FRAME(msg)
# -----------------------
# Heartbeat Protocol
# -----------------------
# These encode the msgpack equivelant of 1 and 2. The heartbeat
# frame should only be 1 byte on the wire.
HEARTBEAT_PROTOCOL = namedict({
'REQ' : b'\x01',
'REP' : b'\x02',
})
# -----------------------
# Component State
# -----------------------
+1
View File
@@ -12,6 +12,7 @@ def Enum(*options):
"""
class cstruct(Structure):
_fields_ = [(o, c_ubyte) for o in options]
__iter__ = lambda s: iter(range(len(options)))
return cstruct(*range(len(options)))
def FrameExceptionFactory(name):
+14 -2
View File
@@ -9,6 +9,9 @@ are provided below.
The algorithm must expose methods:
- initialize: method that takes no args, no returns. Simply called to
enable the algorithm to set any internal state needed.
- get_sid_filter: method that takes no args, and returns a list
of valid sids. List must have a length between 1 and 10. If None is returned
the filter will block all events.
@@ -18,7 +21,7 @@ The algorithm must expose methods:
+-----------------+--------------+----------------+--------------------+
| | SID(133) | SID(134) | SID(135) |
+=================+==============+=====================================+
+=================+==============+================+====================+
| price | $10.10 | $22.50 | $13.37 |
+-----------------+--------------+----------------+--------------------+
| volume | 10,000 | 5,000 | 50,000 |
@@ -61,6 +64,9 @@ class TestAlgorithm():
self.order = None
self.frame_count = 0
self.portfolio = None
def initialize(self):
pass
def set_order(self, order_callable):
self.order = order_callable
@@ -94,7 +100,10 @@ class HeavyBuyAlgorithm():
self.order = None
self.frame_count = 0
self.portfolio = None
def initialize(self):
pass
def set_order(self, order_callable):
self.order = order_callable
@@ -114,6 +123,9 @@ class NoopAlgorithm(object):
"""
Dolce fa niente.
"""
def initialize(self):
pass
def set_order(self, order_callable):
pass
+72 -29
View File
@@ -368,33 +368,54 @@ class FinanceTestCase(TestCase):
# same scenario, but short sales.
params2 = {
'trade_count':100,
'trade_amount':100,
'trade_delay': timedelta(minutes=5),
'trade_interval': timedelta(days=1),
'order_count':3,
'order_amount':1000,
'order_interval': timedelta(minutes=30),
'trade_count' : 100,
'trade_amount' : 100,
'trade_delay' : timedelta(minutes=5),
'trade_interval' : timedelta(days=1),
'order_count' : 3,
'order_amount' :-1000,
'order_interval' : timedelta(minutes=30),
# because we placed an orders totaling less than 25% of one trade
# the simulator should produce just one transaction.
'expected_txn_count' : 1,
'expected_txn_volume' : 25
'expected_txn_count' : 1,
'expected_txn_volume' : -25
}
self.transaction_sim(**params2)
@timed(DEFAULT_TIMEOUT)
def test_alternating_long_short(self):
# create a scenario where we alternate buys and sells
params1 = {
'trade_count' : int(6.5 * 60 * 4),
'trade_amount' : 100,
'trade_interval' : timedelta(minutes=1),
'order_count' : 4,
'order_amount' : 10,
'order_interval' : timedelta(hours=24),
'alternate' : True,
'complete_fill' : True,
'expected_txn_count' : 4,
'expected_txn_volume' : 0 #equal buys and sells
}
self.transaction_sim(**params1)
def transaction_sim(self, **params):
trade_count = params['trade_count']
trade_amount = params['trade_amount']
trade_interval = params['trade_interval']
trade_delay = params.get('trade_delay')
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
trade_count = params['trade_count']
trade_amount = params['trade_amount']
trade_interval = params['trade_interval']
trade_delay = params.get('trade_delay')
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
expected_txn_volume = params['expected_txn_volume']
# optional parameters
# ---------------------
# if present, alternate between long and short sales
alternate = params.get('alternate')
# if present, expect transaction amounts to match orders exactly.
complete_fill = params.get('complete_fill')
trading_environment = factory.create_trading_environment()
trade_sim = TransactionSimulator()
@@ -411,17 +432,31 @@ class FinanceTestCase(TestCase):
trading_environment
)
for i in range(order_count):
if alternate:
alternator = -1
else:
alternator = 1
order_date = start_date
for i in xrange(order_count):
order = namedict(
{
'sid':sid,
'amount':order_amount,
'type':zp.DATASOURCE_TYPE.ORDER,
'dt' : start_date + i * order_interval
'sid' : sid,
'amount' : order_amount * alternator**i,
'type' : zp.DATASOURCE_TYPE.ORDER,
'dt' : order_date
})
trade_sim.add_open_order(order)
order_date = order_date + order_interval
# move after market orders to just after market next
# market open.
if order_date.hour >= 21:
if order_date.minute >= 00:
order_date = order_date + timedelta(days=1)
order_date = order_date.replace(hour=14, minute=30)
# there should now be one open order list stored under the sid
oo = trade_sim.open_orders
self.assertEqual(len(oo), 1)
@@ -429,9 +464,10 @@ class FinanceTestCase(TestCase):
order_list = oo[sid]
self.assertEqual(order_count, len(order_list))
for order in order_list:
for i in xrange(order_count):
order = order_list[i]
self.assertEqual(order.sid, sid)
self.assertEqual(order.amount, order_amount)
self.assertEqual(order.amount, order_amount * alternator**i)
tracker = PerformanceTracker(trading_environment)
@@ -450,10 +486,17 @@ class FinanceTestCase(TestCase):
trade.TRANSACTION = None
tracker.process_event(trade)
if complete_fill:
self.assertEqual(len(transactions), len(order_list))
total_volume = 0
for txn in transactions:
for i in xrange(len(transactions)):
txn = transactions[i]
total_volume += txn.amount
if complete_fill:
order = order_list[i]
self.assertEqual(order.amount, txn.amount)
self.assertEqual(total_volume, expected_txn_volume)
self.assertEqual(len(transactions), expected_txn_count)
+87
View File
@@ -0,0 +1,87 @@
import types
from collections import Container, Hashable, Callable
class Any(object): pass
class Workflow(Container, Callable):
def __init__(self, states, transitions, initial_state):
self.simple = set()
self.complx = []
if isinstance(states[0], tuple):
self.groups = {b for _,b in states}
else:
self.groups = set()
matcher = lambda b: lambda f,t : t == b
for (a, b) in transitions.itervalues():
if a is Any:
self.complx.append(matcher(b))
if isinstance(a, Hashable) and isinstance(b, Hashable):
self.simple.add((a,b))
def __call__(self, **kwargs):
if 'group' in kwargs:
return self.groups
def __contains__(self, state):
if state in self.simple:
return True
for match in self.complx:
if match(*state):
return True
else:
return False
class Flowable:
@property
def state(self):
if not hasattr(self, '_state'):
self._state = self.initial_state
else:
return self._state
@state.setter
def state(self, new):
if not hasattr(self, '_state'):
self._state = self.initial_state
old = self._state
if (old, new) in self.workflow:
self._state = new
else:
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
class WorkflowMeta(type):
"""
Base metaclass component workflows.
"""
def __new__(cls, name, mro, attrs):
state = attrs.get('states', None)
transitions = attrs.get('transitions', None)
initial_state = attrs.get('initial_state', None)
if attrs.get('workflow'):
raise RuntimeError('`workflow` is a reserved attribute.')
if not state:
raise RuntimeError('Must specify states')
if not transitions:
raise RuntimeError('Must specify transitions')
if not transitions:
raise RuntimeError('Must specify initial_state')
new_class = super(WorkflowMeta, cls).__new__(
cls, name, mro+(Flowable,), attrs
)
new_class.workflow = Workflow(state, transitions, initial_state)
return new_class