Merge branch 'master' of github.com:quantopian/zipline

This commit is contained in:
Thomas Wiecki
2012-05-14 16:36:02 -04:00
14 changed files with 676 additions and 167 deletions
+5 -3
View File
@@ -10,6 +10,7 @@ import uuid
import time
import socket
import gevent
import traceback
import humanhash
# pyzmq
@@ -305,11 +306,11 @@ class Component(object):
self._exception = exc
exc_type, exc_value, exc_traceback = sys.exc_info()
self.stack_trace = exc_traceback
trace = '\n>>>'.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
exception_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.EXCEPTION,
str(exc)
trace
)
self.control_out.send(exception_frame)
@@ -492,7 +493,8 @@ class Component(object):
"""
The descriptive name of the component.
"""
return 'UNKNOWN COMPONENT'
# Prevents the bug that Thomas ran into
raise NotImplementedError
@property
def get_type(self):
+67
View File
@@ -0,0 +1,67 @@
from datetime import timedelta
from collections import defaultdict
from zipline.messaging import BaseTransform
class MovingAverageTransform(BaseTransform):
def init(self, daycount=3):
self.daycount = daycount
self.by_sid = defaultdict(self._create)
def transform(self, event):
cur = self.by_sid[event.sid]
cur.update(event)
self.state['value'] = cur.average
return self.state
def _create(self):
return MovingAverage(self.daycount)
class MovingAverage(object):
def __init__(self, daycount):
self.window = EventWindow(daycount)
self.total = 0.0
self.average = 0.0
def update(self, event):
self.window.update(event)
self.total += event.price
for dropped in self.window.dropped_ticks:
self.total -= dropped.price
if len(self.window.ticks) > 0:
self.average = self.total / len(self.window.ticks)
else:
self.average = 0.0
class EventWindow(object):
"""
Tracks a window of the event history. Use an instance to track the events
inside your window to efficiently calculate rolling statistics.
"""
def __init__(self, daycount):
self.ticks = []
self.dropped_ticks = []
self.delta = timedelta(days=daycount)
def update(self, event):
# add new event
self.ticks.append(event)
# determine which events are expired
last_date = event['dt']
first_date = last_date - self.delta
self.dropped_ticks = []
for tick in self.ticks:
if tick['dt'] <= first_date:
self.dropped_ticks.append(tick)
# remove the expired events
slice_index = len(self.dropped_ticks)
self.ticks = self.ticks[slice_index:]
+46
View File
@@ -0,0 +1,46 @@
import pandas
from datetime import timedelta
from collections import defaultdict
from zipline.messaging import BaseTransform
class ReturnsTransform(BaseTransform):
def init(self):
self.by_sid = defaultdict(self._create)
def transform(self, event):
cur = self.by_sid[event.sid]
cur.update(event)
self.state['value'] = cur.returns
return self.state
def _create(self):
return ReturnsFromPriorClose()
class ReturnsFromPriorClose(object):
"""
Calculates a security's returns since the previous close, using the
current price.
"""
def __init__(self):
self.last_close = None
self.last_event = None
self.returns = 0.0
def update(self, event):
next_close = None
if self.last_close:
change = event.price - self.last_close.price
self.returns = change / self.last_close.price
if self.last_event:
if self.last_event.dt.day != event.dt.day:
# the current event is from the day after
# the last event. Therefore the last event was
# the last close
self.last_close = self.last_event
# the current event is now the last_event
self.last_event = event
+6 -9
View File
@@ -1,7 +1,6 @@
import datetime
import pytz
import math
import pandas
import time
from collections import Counter
@@ -14,7 +13,7 @@ import zipline.util as qutil
import zipline.protocol as zp
import zipline.finance.performance as perf
from zipline.protocol_utils import Enum, namedict
from zipline.protocol_utils import Enum, ndict
# the simulation style enumerates the available transaction simulation
# strategies.
@@ -43,10 +42,7 @@ class TradeSimulationClient(qmsg.Component):
self.txn_sim = TransactionSimulator(sim_style)
assert self.trading_environment.frame_index != None
self.event_frame = pandas.DataFrame(
index=self.trading_environment.frame_index
)
self.event_frame = ndict()
self.perf = perf.PerformanceTracker(self.trading_environment)
@property
@@ -59,8 +55,10 @@ class TradeSimulationClient(qmsg.Component):
:py:mod:`zipline.test.algorithm`
"""
self.algorithm = algorithm
#register the trading_client's order method with the algorithm
# register the trading_client's order method with the algorithm
self.algorithm.set_order(self.order)
# ask the algorithm to initialize
self.algorithm.initialize()
def open(self):
self.result_feed = self.connect_result()
@@ -176,8 +174,7 @@ class TradeSimulationClient(qmsg.Component):
def queue_event(self, event):
if self.event_queue == None:
self.event_queue = []
series = event.as_series()
self.event_queue.append(series)
self.event_queue.append(event)
def get_frame(self):
for event in self.event_queue:
+61
View File
@@ -0,0 +1,61 @@
import pandas
from datetime import timedelta
from collections import defaultdict
from zipline.messaging import BaseTransform
from zipline.finance.movingaverage import EventWindow
class VWAPTransform(BaseTransform):
def init(self, daycount=3):
self.daycount = daycount
self.by_sid = defaultdict(self.create_vwap)
def transform(self, event):
cur = self.by_sid[event.sid]
cur.update(event)
self.state['value'] = cur.vwap
return self.state
def create_vwap(self):
return DailyVWAP(self.daycount)
class DailyVWAP:
"""A class that tracks the volume weighted average price
based on tick updates."""
def __init__(self, daycount):
self.window = EventWindow(daycount)
self.flux = 0.0
self.volume = 0
self.vwap = 0.0
self.delta = timedelta(days=daycount)
def update(self, event):
# update the event window
self.window.update(event)
# add the current event's flux and volume to the tracker
flux, volume = self.calculate_flux([event])
self.flux += flux
self.volume += volume
# subract the expired events flux and volume from the tracker
dropped = self.window.dropped_ticks
dropped_flux, dropped_volume = self.calculate_flux(dropped)
self.flux -= dropped_flux
self.volume -= dropped_volume
if(self.volume != 0):
self.vwap = self.flux / self.volume
else:
self.vwap = None
def calculate_flux(self, ticks):
flux = 0.0
volume = 0
for tick in ticks:
flux += tick['volume'] * tick['price']
volume += tick['volume']
return flux, volume
+45 -52
View File
@@ -4,9 +4,7 @@ messaging. All ziplines follow a general topology of parallel sources,
datetimestamp serialization, parallel transformations, and finally sinks.
Furthermore, many ziplines have common needs. For example, all trade
simulations require a
:py:class:`~zipline.finance.trading.TradeSimulationClient`, an
:py:class:`~zipline.finance.trading.OrderSource`, and a
:py:class:`~zipline.finance.trading.TransactionSimulator` (a transform).
:py:class:`~zipline.finance.trading.TradeSimulationClient`.
To establish best practices and minimize code replication, the lines module
provides complete zipline topologies. You can extend any zipline without
@@ -17,56 +15,49 @@ before invoking simulate.
Here is a diagram of the SimulatedTrading zipline:
+----------------------+ +------------------------+
+-->| Orders DataSource | | (DataSource added |
| | Integrates algo | | via add_source) |
| | orders into history | | |
| +--------------------+-+ +-+----------------------+
| | |
| | |
| v v
| +---------+
| | Feed |
| +-+------++
| | |
| | |
| v v
| +----------------------+ +----------------------+
| | Transaction | | |
| | Transform simulates | | (Transforms added |
| | trades based on | | via add_transform) |
| | orders from algo. | | |
| +-------------------+--+ +-+--------------------+
| | |
| | |
| v v
| +------------+
| | Merge |
| +------+-----+
| |
| |
| V
| +--------------------------------+
| | |
| | TradingSimulationClient |
| orders | tracks performance and |
+---------------+ provides API to algorithm. |
| |
+---------------------+----------+
^ |
| orders | frames
| |
| v
+---------+-----------------------+
| |
| Algorithm added via |
| __init__. |
| |
| |
| |
+---------------------------------+
+----------------------+ +------------------------+
| Trade History | | (DataSource added |
| | | via add_source) |
| | | |
+--------------------+-+ +-+----------------------+
| |
| |
v v
+---------+
| Feed | (ensures events are serialized
+-+------++ in chronological order)
| |
| |
v v
+----------------------+ +----------------------+
| (Transforms added | | (Transforms added |
| via add_transform) | | via add_transform) |
+-------------------+--+ +-+--------------------+
| |
| |
v v
+------------+
| Merge | (combines original event and
+------+-----+ transforms into one vector)
|
|
V
+---------------+ +--------------------------------+
| Risk and Perf | | |
| Tracker | | TradingSimulationClient |
+---------------+ | tracks performance and |
^ Trades and | provides API to algorithm. |
| simulated | |
| transactions +--+------------------+----------+
| | ^ |
+---------------------+ | orders | frames
| |
| v
+---------------------------------+
| Algorithm added via |
| __init__. |
+---------------------------------+
"""
import mock
@@ -152,6 +143,8 @@ class SimulatedTrading(object):
sockets[7],
logging = qutil.LOGGER
)
self.con.cancel_socket = self.allocator.lease(1)[0]
# TODO: Not freeform
self.con.manage(
+7 -7
View File
@@ -227,7 +227,7 @@ class Feed(Component):
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.done()
self.signal_done()
self.shutdown()
# -- Hard Kill --
@@ -440,14 +440,14 @@ class BaseTransform(Component):
method to create a new derived value from the combined feed.
"""
def __init__(self, name):
def __init__(self, name, **kwargs):
Component.__init__(self)
self.state = {
'name': name
}
self.init()
self.init(**kwargs)
def init(self):
pass
@@ -496,7 +496,7 @@ class BaseTransform(Component):
# -- Soft Kill --
elif event == CONTROL_PROTOCOL.SHUTDOWN:
self.done()
self.signal_done()
self.shutdown()
# -- Hard Kill --
@@ -564,11 +564,11 @@ class PassthroughTransform(BaseTransform):
"""
def __init__(self):
def __init__(self, **kwargs):
BaseTransform.__init__(self, "PASSTHROUGH")
self.init()
self.init(**kwargs)
def init(self):
def init(self, **kwargs):
pass
@property
+164 -48
View File
@@ -3,13 +3,14 @@ import gevent
import itertools
# pyzmq
import zmq
# gevent_zeromq
import gevent_zeromq
# zmq_ctypes
#import zmq_ctypes
from collections import OrderedDict
from protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \
states = CONTROL_STATES
from gpoll import _Poller as GeventPoller
@@ -103,6 +104,17 @@ from gpoll import _Poller as GeventPoller
# | 0 | | 0 | | 0 | | 0 |
# +---+ +---+ +---+ +---+
INIT, SOURCES_READY, RUNNING, TERMINATE = CONTROL_STATES
state_transitions = frozenset([
(-1 , INIT),
(INIT , SOURCES_READY),
(SOURCES_READY , RUNNING),
(INIT , TERMINATE),
(SOURCES_READY , TERMINATE),
(RUNNING , TERMINATE),
])
class UnknownChatter(Exception):
def __init__(self, name):
self.named = name
@@ -149,27 +161,26 @@ class Controller(object):
def __init__(self, pub_socket, route_socket, logging = None):
self.context = None
self.zmq = None
self.context = None
self.zmq = None
self.zmq_poller = None
polling = False
self.polling = polling
self.running = False
self.polling = False
self.tracked = set()
self.responses = set()
self.ctime = 0
self.tic = time.time()
self.ctime = 0
self.tic = time.time()
self.freeform = False
self._state = -1
self.associated = []
self.pub_socket = pub_socket
self.pub_socket = pub_socket
self.route_socket = route_socket
self.error_replay = {}
self.error_replay = OrderedDict()
if logging:
self.logging = logging
@@ -182,23 +193,23 @@ class Controller(object):
assert self.zmq_flavor in ['thread', 'mp', 'green']
if flavor == 'mp':
self.zmq = zmq
self.context = self.zmq.Context()
self.zmq = zmq
self.context = self.zmq.Context()
self.zmq_poller = self.zmq.Poller
return
if flavor == 'thread':
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = self.zmq.Poller
return
if flavor == 'green':
self.zmq = gevent_zeromq.zmq
self.context = self.zmq.Context.instance()
self.zmq = gevent_zeromq.zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = GeventPoller
return
if flavor == 'pypy':
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq = zmq
self.context = self.zmq.Context.instance()
self.zmq_poller = self.zmq.Poller
return
@@ -217,19 +228,24 @@ class Controller(object):
self.freeform = False
self.topology = frozenset(topology)
default_states = [
CONTROL_STATES.RUNNING,
CONTROL_STATES.SHUTDOWN,
CONTROL_STATES.TERMINATE,
]
self.states = states or default_states
self.polling = True
self.state = CONTROL_STATES.INIT
# Start off in RUNNING, state
self.state = self.states[0]
@property
def state(self):
return self._state
@state.setter
def state(self, new):
old, self._state = self._state, new
if (old, new) not in state_transitions:
raise RuntimeError("[Controller] Invalid State Transition : %s -> %s" %(old, new))
else:
self.logging.info("[Controller] State Transition : %s -> %s" %(old, new))
def run(self):
self.running = True
self.init_zmq(self.zmq_flavor)
try:
@@ -256,6 +272,9 @@ class Controller(object):
# -------------
def send_heart(self):
if not self.running:
return
heartbeat_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.HEARTBEAT,
str(self.ctime)
@@ -263,6 +282,9 @@ class Controller(object):
self.pub.send(heartbeat_frame)
def send_hardkill(self):
if not self.running:
return
kill_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.KILL,
''
@@ -270,6 +292,9 @@ class Controller(object):
self.pub.send(kill_frame)
def send_softkill(self):
if not self.running:
return
soft_frame = CONTROL_FRAME(
CONTROL_PROTOCOL.SHUTDOWN,
''
@@ -282,16 +307,35 @@ class Controller(object):
def _poll(self):
assert self.route_socket
assert self.pub_socket
assert self.cancel_socket
# -- Publish --
# =============
self.pub = self.context.socket(self.zmq.PUB)
self.pub.bind(self.pub_socket)
# -- Cancel --
# =============
assert isinstance(self.cancel_socket,basestring), self.cancel_socket
self.cancel = self.context.socket(self.zmq.REP)
self.cancel.connect(self.cancel_socket)
# -- Router --
# =============
self.router = self.context.socket(self.zmq.ROUTER)
self.router.bind(self.route_socket)
self.associated.extend([self.pub, self.router])
poller = self.zmq.Poller()
poller.register(self.router, self.zmq.POLLIN)
poller.register(self.cancel, self.zmq.POLLIN)
self.associated += [self.pub, self.router, self.cancel]
# TODO: actually do this
self.state = CONTROL_STATES.SOURCES_READY
buffer = []
@@ -325,11 +369,20 @@ class Controller(object):
self.logging.error('Invalid frame', rawmessage)
pass
if self.cancel in socks and socks[self.cancel] == self.zmq.POLLIN:
self.logging.info('[Controller] Received Cancellation')
rawmessage = self.cancel.recv()
self.shutdown(soft=True)
break
self.beat()
if self.zmq_flavor == 'green':
gevent.sleep(0)
if self.state is CONTROL_STATES.TERMINATE:
break
if not self.polling:
break
@@ -359,35 +412,89 @@ class Controller(object):
for component in bad:
self.fail(component)
# ------------------
# Component Handlers
# ------------------
# --------------
# Init Handlers
# --------------
def new_source(self):
if self.state is CONTROL_STATES.RUNNING:
self.state = SOURCES_READY
def new_universal(self):
pass
# The various "states of being that a component can inform us
# of
def new(self, component):
self.logging.info('[Controller] Alive "%s" ' % component)
if self.state is CONTROL_STATES.TERMINATE:
return
self.logging.info('[Controller] Now Tracking "%s" ' % component)
universal = self.new_universal
init_handlers = {
'FEED' : self.new_source,
}
if component in self.topology or self.freeform:
init_handlers.get(component, universal)()
self.tracked.add(component)
else:
# Some sort of socket collision has occured, this is
# a very bad failure mode.
raise UnknownChatter(component)
# ------------------
# Epic Fail Handling
# ------------------
def fail_universal(self):
pass
# TODO: this requires higher order functionality
#self.logging.error('[Controller] System in exception state, shutting down')
#self.shutdown(soft=True)
def fail(self, component):
self.logging.info('[Controller] Component "%s" timed out' % component)
self.tracked.remove(component)
if self.state is CONTROL_STATES.TERMINATE:
return
universal = self.fail_universal
fail_handlers = { }
if component in self.topology or self.freeform:
self.logging.info('[Controller] Component "%s" timed out' % component)
self.tracked.remove(component)
fail_handlers.get(component, universal)()
# -------------------
# Completion Handling
# -------------------
def done(self, component):
# TODO: This will be what we ship off to vbench at some
# point...
# print component finished at self.ctime
self.logging.info('[Controller] Component "%s" done.' % component)
# --------------
# Error Handling
# --------------
def exception_universal(self):
"""
Shutdown the system on failure.
"""
self.logging.error('[Controller] System in exception state, shutting down')
self.shutdown(soft=True)
def exception(self, component, failure):
self.error_replay[time.time()] = failure
self.logging.error('Component "%s" in exception state' % component)
universal = self.exception_universal
exception_handlers = { }
if component in self.topology or self.freeform:
self.error_replay[(component, time.time())] = failure
self.logging.error('[Controller] Component "%s" in exception state' % component)
exception_handlers.get(component, universal)()
else:
raise UnknownChatter(component)
# -----------------
# Protocol Handling
@@ -462,6 +569,11 @@ class Controller(object):
self.associated.append(s)
return s
def do_error_replay(self):
for (component, time), error in self.error_replay.iteritems():
self.logging.info('[Controller] Error Log for -- %s --:\n%s' %
(component, error))
def shutdown(self, hard=False, soft=True, context=None):
if not self.polling:
@@ -472,7 +584,7 @@ class Controller(object):
assert hard or soft, """ Must specify kill hard or soft """
if hard:
self.state = CONTROL_STATES.SHUTDOWN
self.state = CONTROL_STATES.TERMINATE
self.logging.info('[Controller] Hard Shutdown')
@@ -488,18 +600,22 @@ class Controller(object):
#for asoc in self.associated:
#asoc.close()
self.do_error_replay()
if __name__ == '__main__':
print 'Running on ',\
'tcp://127.0.0.1:5000', \
'tcp://127.0.0.1:5001',
print 'Running on '\
'tcp://127.0.0.1:5000 '\
'tcp://127.0.0.1:5001 '
controller = Controller(
'tcp://127.0.0.1:5000',
'tcp://127.0.0.1:5001',
)
controller.zmq_flavor = 'green'
controller.manage(
'freeform',
[]
)
controller.run('green')
controller.run()
+4 -17
View File
@@ -126,9 +126,6 @@ from collections import namedtuple
from protocol_utils import Enum, FrameExceptionFactory, namedict
from date_utils import EPOCH, UN_EPOCH
#import ujson
#import ultrajson_numpy
# -----------------------
# Control Protocol
# -----------------------
@@ -136,9 +133,10 @@ from date_utils import EPOCH, UN_EPOCH
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
CONTROL_STATES = Enum(
'INIT',
'SOURCES_READY',
'RUNNING',
'SHUTDOWN', # a soft kill
'TERMINATE', # a hard kill
'TERMINATE',
)
CONTROL_PROTOCOL = Enum(
@@ -149,6 +147,7 @@ CONTROL_PROTOCOL = Enum(
'OK' , # 3 - rep
'DONE' , # 4 - rep
'EXCEPTION' , # 5 - rep
'SIGNAL' , # 6 - rep
)
def CONTROL_FRAME(event, payload):
@@ -174,18 +173,6 @@ def CONTROL_UNFRAME(msg):
except ValueError:
raise INVALID_CONTROL_FRAME(msg)
# -----------------------
# Heartbeat Protocol
# -----------------------
# These encode the msgpack equivelant of 1 and 2. The heartbeat
# frame should only be 1 byte on the wire.
HEARTBEAT_PROTOCOL = namedict({
'REQ' : b'\x01',
'REP' : b'\x02',
})
# -----------------------
# Component State
# -----------------------
+1
View File
@@ -12,6 +12,7 @@ def Enum(*options):
"""
class cstruct(Structure):
_fields_ = [(o, c_ubyte) for o in options]
__iter__ = lambda s: iter(range(len(options)))
return cstruct(*range(len(options)))
def FrameExceptionFactory(name):
+14 -2
View File
@@ -9,6 +9,9 @@ are provided below.
The algorithm must expose methods:
- initialize: method that takes no args, no returns. Simply called to
enable the algorithm to set any internal state needed.
- get_sid_filter: method that takes no args, and returns a list
of valid sids. List must have a length between 1 and 10. If None is returned
the filter will block all events.
@@ -18,7 +21,7 @@ The algorithm must expose methods:
+-----------------+--------------+----------------+--------------------+
| | SID(133) | SID(134) | SID(135) |
+=================+==============+=====================================+
+=================+==============+================+====================+
| price | $10.10 | $22.50 | $13.37 |
+-----------------+--------------+----------------+--------------------+
| volume | 10,000 | 5,000 | 50,000 |
@@ -61,6 +64,9 @@ class TestAlgorithm():
self.order = None
self.frame_count = 0
self.portfolio = None
def initialize(self):
pass
def set_order(self, order_callable):
self.order = order_callable
@@ -94,7 +100,10 @@ class HeavyBuyAlgorithm():
self.order = None
self.frame_count = 0
self.portfolio = None
def initialize(self):
pass
def set_order(self, order_callable):
self.order = order_callable
@@ -114,6 +123,9 @@ class NoopAlgorithm(object):
"""
Dolce fa niente.
"""
def initialize(self):
pass
def set_order(self, order_callable):
pass
+72 -29
View File
@@ -368,33 +368,54 @@ class FinanceTestCase(TestCase):
# same scenario, but short sales.
params2 = {
'trade_count':100,
'trade_amount':100,
'trade_delay': timedelta(minutes=5),
'trade_interval': timedelta(days=1),
'order_count':3,
'order_amount':1000,
'order_interval': timedelta(minutes=30),
'trade_count' : 100,
'trade_amount' : 100,
'trade_delay' : timedelta(minutes=5),
'trade_interval' : timedelta(days=1),
'order_count' : 3,
'order_amount' :-1000,
'order_interval' : timedelta(minutes=30),
# because we placed an orders totaling less than 25% of one trade
# the simulator should produce just one transaction.
'expected_txn_count' : 1,
'expected_txn_volume' : 25
'expected_txn_count' : 1,
'expected_txn_volume' : -25
}
self.transaction_sim(**params2)
@timed(DEFAULT_TIMEOUT)
def test_alternating_long_short(self):
# create a scenario where we alternate buys and sells
params1 = {
'trade_count' : int(6.5 * 60 * 4),
'trade_amount' : 100,
'trade_interval' : timedelta(minutes=1),
'order_count' : 4,
'order_amount' : 10,
'order_interval' : timedelta(hours=24),
'alternate' : True,
'complete_fill' : True,
'expected_txn_count' : 4,
'expected_txn_volume' : 0 #equal buys and sells
}
self.transaction_sim(**params1)
def transaction_sim(self, **params):
trade_count = params['trade_count']
trade_amount = params['trade_amount']
trade_interval = params['trade_interval']
trade_delay = params.get('trade_delay')
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
trade_count = params['trade_count']
trade_amount = params['trade_amount']
trade_interval = params['trade_interval']
trade_delay = params.get('trade_delay')
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
expected_txn_volume = params['expected_txn_volume']
# optional parameters
# ---------------------
# if present, alternate between long and short sales
alternate = params.get('alternate')
# if present, expect transaction amounts to match orders exactly.
complete_fill = params.get('complete_fill')
trading_environment = factory.create_trading_environment()
trade_sim = TransactionSimulator()
@@ -411,17 +432,31 @@ class FinanceTestCase(TestCase):
trading_environment
)
for i in range(order_count):
if alternate:
alternator = -1
else:
alternator = 1
order_date = start_date
for i in xrange(order_count):
order = namedict(
{
'sid':sid,
'amount':order_amount,
'type':zp.DATASOURCE_TYPE.ORDER,
'dt' : start_date + i * order_interval
'sid' : sid,
'amount' : order_amount * alternator**i,
'type' : zp.DATASOURCE_TYPE.ORDER,
'dt' : order_date
})
trade_sim.add_open_order(order)
order_date = order_date + order_interval
# move after market orders to just after market next
# market open.
if order_date.hour >= 21:
if order_date.minute >= 00:
order_date = order_date + timedelta(days=1)
order_date = order_date.replace(hour=14, minute=30)
# there should now be one open order list stored under the sid
oo = trade_sim.open_orders
self.assertEqual(len(oo), 1)
@@ -429,9 +464,10 @@ class FinanceTestCase(TestCase):
order_list = oo[sid]
self.assertEqual(order_count, len(order_list))
for order in order_list:
for i in xrange(order_count):
order = order_list[i]
self.assertEqual(order.sid, sid)
self.assertEqual(order.amount, order_amount)
self.assertEqual(order.amount, order_amount * alternator**i)
tracker = PerformanceTracker(trading_environment)
@@ -450,10 +486,17 @@ class FinanceTestCase(TestCase):
trade.TRANSACTION = None
tracker.process_event(trade)
if complete_fill:
self.assertEqual(len(transactions), len(order_list))
total_volume = 0
for txn in transactions:
for i in xrange(len(transactions)):
txn = transactions[i]
total_volume += txn.amount
if complete_fill:
order = order_list[i]
self.assertEqual(order.amount, txn.amount)
self.assertEqual(total_volume, expected_txn_volume)
self.assertEqual(len(transactions), expected_txn_count)
+97
View File
@@ -0,0 +1,97 @@
from datetime import timedelta
from collections import defaultdict
from unittest2 import TestCase
import zipline.test.factory as factory
import zipline.util as qutil
from zipline.finance.vwap import DailyVWAP, VWAPTransform
from zipline.finance.returns import ReturnsFromPriorClose
from zipline.finance.movingaverage import MovingAverage
from zipline.lines import SimulatedTrading
from zipline.simulator import AddressAllocator, Simulator
allocator = AddressAllocator(1000)
class ZiplineWithTransformsTestCase(TestCase):
leased_sockets = defaultdict(list)
def setUp(self):
# skip ahead 100 spots
allocator.lease(100)
qutil.configure_logging()
self.trading_environment = factory.create_trading_environment()
self.zipline_test_config = {
'allocator':allocator,
'sid':133
}
def test_vwap_tnfm(self):
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
vwap = VWAPTransform("vwap_10", daycount=10)
zipline.add_transform(vwap)
zipline.simulate(blocking=True)
self.assertTrue(zipline.sim.ready())
self.assertFalse(zipline.sim.exception)
class FinanceTransformsTestCase(TestCase):
def setUp(self):
self.trading_environment = factory.create_trading_environment()
def test_vwap(self):
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
)
vwap = DailyVWAP(daycount=2)
for trade in trade_history:
vwap.update(trade)
self.assertEqual(vwap.vwap, 10.75)
def test_returns(self):
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
)
returns = ReturnsFromPriorClose()
for trade in trade_history:
returns.update(trade)
self.assertEqual(returns.returns, .1)
def test_moving_average(self):
trade_history = factory.create_trade_history(
133,
[10.0, 10.0, 10.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.trading_environment
)
ma = MovingAverage(daycount=2)
for trade in trade_history:
ma.update(trade)
self.assertEqual(ma.average, 10.5)
+87
View File
@@ -0,0 +1,87 @@
import types
from collections import Container, Hashable, Callable
class Any(object): pass
class Workflow(Container, Callable):
def __init__(self, states, transitions, initial_state):
self.simple = set()
self.complx = []
if isinstance(states[0], tuple):
self.groups = {b for _,b in states}
else:
self.groups = set()
matcher = lambda b: lambda f,t : t == b
for (a, b) in transitions.itervalues():
if a is Any:
self.complx.append(matcher(b))
if isinstance(a, Hashable) and isinstance(b, Hashable):
self.simple.add((a,b))
def __call__(self, **kwargs):
if 'group' in kwargs:
return self.groups
def __contains__(self, state):
if state in self.simple:
return True
for match in self.complx:
if match(*state):
return True
else:
return False
class Flowable:
@property
def state(self):
if not hasattr(self, '_state'):
self._state = self.initial_state
else:
return self._state
@state.setter
def state(self, new):
if not hasattr(self, '_state'):
self._state = self.initial_state
old = self._state
if (old, new) in self.workflow:
self._state = new
else:
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
class WorkflowMeta(type):
"""
Base metaclass component workflows.
"""
def __new__(cls, name, mro, attrs):
state = attrs.get('states', None)
transitions = attrs.get('transitions', None)
initial_state = attrs.get('initial_state', None)
if attrs.get('workflow'):
raise RuntimeError('`workflow` is a reserved attribute.')
if not state:
raise RuntimeError('Must specify states')
if not transitions:
raise RuntimeError('Must specify transitions')
if not transitions:
raise RuntimeError('Must specify initial_state')
new_class = super(WorkflowMeta, cls).__new__(
cls, name, mro+(Flowable,), attrs
)
new_class.workflow = Workflow(state, transitions, initial_state)
return new_class