diff --git a/zipline/component.py b/zipline/component.py index 4e1d6b7f..d82c8fb9 100644 --- a/zipline/component.py +++ b/zipline/component.py @@ -10,6 +10,7 @@ import uuid import time import socket import gevent +import traceback import humanhash # pyzmq @@ -305,11 +306,11 @@ class Component(object): self._exception = exc exc_type, exc_value, exc_traceback = sys.exc_info() - self.stack_trace = exc_traceback + trace = '\n>>>'.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) exception_frame = CONTROL_FRAME( CONTROL_PROTOCOL.EXCEPTION, - str(exc) + trace ) self.control_out.send(exception_frame) @@ -492,7 +493,8 @@ class Component(object): """ The descriptive name of the component. """ - return 'UNKNOWN COMPONENT' + # Prevents the bug that Thomas ran into + raise NotImplementedError @property def get_type(self): diff --git a/zipline/finance/movingaverage.py b/zipline/finance/movingaverage.py new file mode 100644 index 00000000..329b631e --- /dev/null +++ b/zipline/finance/movingaverage.py @@ -0,0 +1,67 @@ +from datetime import timedelta +from collections import defaultdict + +from zipline.messaging import BaseTransform + +class MovingAverageTransform(BaseTransform): + + def init(self, daycount=3): + self.daycount = daycount + self.by_sid = defaultdict(self._create) + + def transform(self, event): + cur = self.by_sid[event.sid] + cur.update(event) + self.state['value'] = cur.average + return self.state + + def _create(self): + return MovingAverage(self.daycount) + +class MovingAverage(object): + + def __init__(self, daycount): + self.window = EventWindow(daycount) + self.total = 0.0 + self.average = 0.0 + + def update(self, event): + self.window.update(event) + + self.total += event.price + + for dropped in self.window.dropped_ticks: + self.total -= dropped.price + + if len(self.window.ticks) > 0: + self.average = self.total / len(self.window.ticks) + else: + self.average = 0.0 + +class EventWindow(object): + """ + Tracks a window of the event history. Use an instance to track the events + inside your window to efficiently calculate rolling statistics. + """ + def __init__(self, daycount): + self.ticks = [] + self.dropped_ticks = [] + self.delta = timedelta(days=daycount) + + def update(self, event): + # add new event + self.ticks.append(event) + # determine which events are expired + last_date = event['dt'] + first_date = last_date - self.delta + + self.dropped_ticks = [] + for tick in self.ticks: + if tick['dt'] <= first_date: + self.dropped_ticks.append(tick) + + # remove the expired events + slice_index = len(self.dropped_ticks) + self.ticks = self.ticks[slice_index:] + + diff --git a/zipline/finance/returns.py b/zipline/finance/returns.py new file mode 100644 index 00000000..e8d3ce34 --- /dev/null +++ b/zipline/finance/returns.py @@ -0,0 +1,46 @@ +import pandas +from datetime import timedelta +from collections import defaultdict + +from zipline.messaging import BaseTransform + +class ReturnsTransform(BaseTransform): + + def init(self): + self.by_sid = defaultdict(self._create) + + def transform(self, event): + cur = self.by_sid[event.sid] + cur.update(event) + self.state['value'] = cur.returns + return self.state + + def _create(self): + return ReturnsFromPriorClose() + +class ReturnsFromPriorClose(object): + """ + Calculates a security's returns since the previous close, using the + current price. + """ + + def __init__(self): + self.last_close = None + self.last_event = None + self.returns = 0.0 + + def update(self, event): + next_close = None + if self.last_close: + change = event.price - self.last_close.price + self.returns = change / self.last_close.price + + if self.last_event: + if self.last_event.dt.day != event.dt.day: + # the current event is from the day after + # the last event. Therefore the last event was + # the last close + self.last_close = self.last_event + + # the current event is now the last_event + self.last_event = event \ No newline at end of file diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 03e45fa8..1910aff5 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -1,7 +1,6 @@ import datetime import pytz import math -import pandas import time from collections import Counter @@ -14,7 +13,7 @@ import zipline.util as qutil import zipline.protocol as zp import zipline.finance.performance as perf -from zipline.protocol_utils import Enum, namedict +from zipline.protocol_utils import Enum, ndict # the simulation style enumerates the available transaction simulation # strategies. @@ -43,10 +42,7 @@ class TradeSimulationClient(qmsg.Component): self.txn_sim = TransactionSimulator(sim_style) assert self.trading_environment.frame_index != None - self.event_frame = pandas.DataFrame( - index=self.trading_environment.frame_index - ) - + self.event_frame = ndict() self.perf = perf.PerformanceTracker(self.trading_environment) @property @@ -59,8 +55,10 @@ class TradeSimulationClient(qmsg.Component): :py:mod:`zipline.test.algorithm` """ self.algorithm = algorithm - #register the trading_client's order method with the algorithm + # register the trading_client's order method with the algorithm self.algorithm.set_order(self.order) + # ask the algorithm to initialize + self.algorithm.initialize() def open(self): self.result_feed = self.connect_result() @@ -176,8 +174,7 @@ class TradeSimulationClient(qmsg.Component): def queue_event(self, event): if self.event_queue == None: self.event_queue = [] - series = event.as_series() - self.event_queue.append(series) + self.event_queue.append(event) def get_frame(self): for event in self.event_queue: diff --git a/zipline/finance/vwap.py b/zipline/finance/vwap.py new file mode 100644 index 00000000..9ef07299 --- /dev/null +++ b/zipline/finance/vwap.py @@ -0,0 +1,61 @@ +import pandas +from datetime import timedelta +from collections import defaultdict + +from zipline.messaging import BaseTransform +from zipline.finance.movingaverage import EventWindow + +class VWAPTransform(BaseTransform): + + def init(self, daycount=3): + self.daycount = daycount + self.by_sid = defaultdict(self.create_vwap) + + def transform(self, event): + cur = self.by_sid[event.sid] + cur.update(event) + self.state['value'] = cur.vwap + return self.state + + def create_vwap(self): + return DailyVWAP(self.daycount) + +class DailyVWAP: + """A class that tracks the volume weighted average price + based on tick updates.""" + def __init__(self, daycount): + self.window = EventWindow(daycount) + self.flux = 0.0 + self.volume = 0 + self.vwap = 0.0 + self.delta = timedelta(days=daycount) + + def update(self, event): + + # update the event window + self.window.update(event) + + # add the current event's flux and volume to the tracker + flux, volume = self.calculate_flux([event]) + self.flux += flux + self.volume += volume + + # subract the expired events flux and volume from the tracker + dropped = self.window.dropped_ticks + dropped_flux, dropped_volume = self.calculate_flux(dropped) + + self.flux -= dropped_flux + self.volume -= dropped_volume + + if(self.volume != 0): + self.vwap = self.flux / self.volume + else: + self.vwap = None + + def calculate_flux(self, ticks): + flux = 0.0 + volume = 0 + for tick in ticks: + flux += tick['volume'] * tick['price'] + volume += tick['volume'] + return flux, volume diff --git a/zipline/lines.py b/zipline/lines.py index 56ec2b9a..932d9e2f 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -4,9 +4,7 @@ messaging. All ziplines follow a general topology of parallel sources, datetimestamp serialization, parallel transformations, and finally sinks. Furthermore, many ziplines have common needs. For example, all trade simulations require a -:py:class:`~zipline.finance.trading.TradeSimulationClient`, an -:py:class:`~zipline.finance.trading.OrderSource`, and a -:py:class:`~zipline.finance.trading.TransactionSimulator` (a transform). +:py:class:`~zipline.finance.trading.TradeSimulationClient`. To establish best practices and minimize code replication, the lines module provides complete zipline topologies. You can extend any zipline without @@ -17,56 +15,49 @@ before invoking simulate. Here is a diagram of the SimulatedTrading zipline: - - +----------------------+ +------------------------+ - +-->| Orders DataSource | | (DataSource added | - | | Integrates algo | | via add_source) | - | | orders into history | | | - | +--------------------+-+ +-+----------------------+ - | | | - | | | - | v v - | +---------+ - | | Feed | - | +-+------++ - | | | - | | | - | v v - | +----------------------+ +----------------------+ - | | Transaction | | | - | | Transform simulates | | (Transforms added | - | | trades based on | | via add_transform) | - | | orders from algo. | | | - | +-------------------+--+ +-+--------------------+ - | | | - | | | - | v v - | +------------+ - | | Merge | - | +------+-----+ - | | - | | - | V - | +--------------------------------+ - | | | - | | TradingSimulationClient | - | orders | tracks performance and | - +---------------+ provides API to algorithm. | - | | - +---------------------+----------+ - ^ | - | orders | frames - | | - | v - +---------+-----------------------+ - | | - | Algorithm added via | - | __init__. | - | | - | | - | | - +---------------------------------+ + +----------------------+ +------------------------+ + | Trade History | | (DataSource added | + | | | via add_source) | + | | | | + +--------------------+-+ +-+----------------------+ + | | + | | + v v + +---------+ + | Feed | (ensures events are serialized + +-+------++ in chronological order) + | | + | | + v v + +----------------------+ +----------------------+ + | (Transforms added | | (Transforms added | + | via add_transform) | | via add_transform) | + +-------------------+--+ +-+--------------------+ + | | + | | + v v + +------------+ + | Merge | (combines original event and + +------+-----+ transforms into one vector) + | + | + V + +---------------+ +--------------------------------+ + | Risk and Perf | | | + | Tracker | | TradingSimulationClient | + +---------------+ | tracks performance and | + ^ Trades and | provides API to algorithm. | + | simulated | | + | transactions +--+------------------+----------+ + | | ^ | + +---------------------+ | orders | frames + | | + | v + +---------------------------------+ + | Algorithm added via | + | __init__. | + +---------------------------------+ """ import mock @@ -152,6 +143,8 @@ class SimulatedTrading(object): sockets[7], logging = qutil.LOGGER ) + + self.con.cancel_socket = self.allocator.lease(1)[0] # TODO: Not freeform self.con.manage( diff --git a/zipline/messaging.py b/zipline/messaging.py index 6cd4154d..e1011071 100644 --- a/zipline/messaging.py +++ b/zipline/messaging.py @@ -227,7 +227,7 @@ class Feed(Component): # -- Soft Kill -- elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.done() + self.signal_done() self.shutdown() # -- Hard Kill -- @@ -440,14 +440,14 @@ class BaseTransform(Component): method to create a new derived value from the combined feed. """ - def __init__(self, name): + def __init__(self, name, **kwargs): Component.__init__(self) self.state = { 'name': name } - self.init() + self.init(**kwargs) def init(self): pass @@ -496,7 +496,7 @@ class BaseTransform(Component): # -- Soft Kill -- elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.done() + self.signal_done() self.shutdown() # -- Hard Kill -- @@ -564,11 +564,11 @@ class PassthroughTransform(BaseTransform): """ - def __init__(self): + def __init__(self, **kwargs): BaseTransform.__init__(self, "PASSTHROUGH") - self.init() + self.init(**kwargs) - def init(self): + def init(self, **kwargs): pass @property diff --git a/zipline/monitor.py b/zipline/monitor.py index f6b55037..6f72989b 100644 --- a/zipline/monitor.py +++ b/zipline/monitor.py @@ -3,13 +3,14 @@ import gevent import itertools # pyzmq import zmq -# gevent_zeromq import gevent_zeromq -# zmq_ctypes -#import zmq_ctypes + +from collections import OrderedDict from protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \ - CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME + CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \ + +states = CONTROL_STATES from gpoll import _Poller as GeventPoller @@ -103,6 +104,17 @@ from gpoll import _Poller as GeventPoller # | 0 | | 0 | | 0 | | 0 | # +---+ +---+ +---+ +---+ +INIT, SOURCES_READY, RUNNING, TERMINATE = CONTROL_STATES + +state_transitions = frozenset([ + (-1 , INIT), + (INIT , SOURCES_READY), + (SOURCES_READY , RUNNING), + (INIT , TERMINATE), + (SOURCES_READY , TERMINATE), + (RUNNING , TERMINATE), +]) + class UnknownChatter(Exception): def __init__(self, name): self.named = name @@ -149,27 +161,26 @@ class Controller(object): def __init__(self, pub_socket, route_socket, logging = None): - self.context = None - self.zmq = None + self.context = None + self.zmq = None self.zmq_poller = None - polling = False - - self.polling = polling - + self.running = False + self.polling = False self.tracked = set() self.responses = set() - self.ctime = 0 - self.tic = time.time() + self.ctime = 0 + self.tic = time.time() self.freeform = False + self._state = -1 self.associated = [] - self.pub_socket = pub_socket + self.pub_socket = pub_socket self.route_socket = route_socket - self.error_replay = {} + self.error_replay = OrderedDict() if logging: self.logging = logging @@ -182,23 +193,23 @@ class Controller(object): assert self.zmq_flavor in ['thread', 'mp', 'green'] if flavor == 'mp': - self.zmq = zmq - self.context = self.zmq.Context() + self.zmq = zmq + self.context = self.zmq.Context() self.zmq_poller = self.zmq.Poller return if flavor == 'thread': - self.zmq = zmq - self.context = self.zmq.Context.instance() + self.zmq = zmq + self.context = self.zmq.Context.instance() self.zmq_poller = self.zmq.Poller return if flavor == 'green': - self.zmq = gevent_zeromq.zmq - self.context = self.zmq.Context.instance() + self.zmq = gevent_zeromq.zmq + self.context = self.zmq.Context.instance() self.zmq_poller = GeventPoller return if flavor == 'pypy': - self.zmq = zmq - self.context = self.zmq.Context.instance() + self.zmq = zmq + self.context = self.zmq.Context.instance() self.zmq_poller = self.zmq.Poller return @@ -217,19 +228,24 @@ class Controller(object): self.freeform = False self.topology = frozenset(topology) - default_states = [ - CONTROL_STATES.RUNNING, - CONTROL_STATES.SHUTDOWN, - CONTROL_STATES.TERMINATE, - ] - - self.states = states or default_states self.polling = True + self.state = CONTROL_STATES.INIT - # Start off in RUNNING, state - self.state = self.states[0] + @property + def state(self): + return self._state + + @state.setter + def state(self, new): + old, self._state = self._state, new + + if (old, new) not in state_transitions: + raise RuntimeError("[Controller] Invalid State Transition : %s -> %s" %(old, new)) + else: + self.logging.info("[Controller] State Transition : %s -> %s" %(old, new)) def run(self): + self.running = True self.init_zmq(self.zmq_flavor) try: @@ -256,6 +272,9 @@ class Controller(object): # ------------- def send_heart(self): + if not self.running: + return + heartbeat_frame = CONTROL_FRAME( CONTROL_PROTOCOL.HEARTBEAT, str(self.ctime) @@ -263,6 +282,9 @@ class Controller(object): self.pub.send(heartbeat_frame) def send_hardkill(self): + if not self.running: + return + kill_frame = CONTROL_FRAME( CONTROL_PROTOCOL.KILL, '' @@ -270,6 +292,9 @@ class Controller(object): self.pub.send(kill_frame) def send_softkill(self): + if not self.running: + return + soft_frame = CONTROL_FRAME( CONTROL_PROTOCOL.SHUTDOWN, '' @@ -282,16 +307,35 @@ class Controller(object): def _poll(self): + assert self.route_socket + assert self.pub_socket + assert self.cancel_socket + + # -- Publish -- + # ============= self.pub = self.context.socket(self.zmq.PUB) self.pub.bind(self.pub_socket) + # -- Cancel -- + # ============= + assert isinstance(self.cancel_socket,basestring), self.cancel_socket + self.cancel = self.context.socket(self.zmq.REP) + self.cancel.connect(self.cancel_socket) + + # -- Router -- + # ============= self.router = self.context.socket(self.zmq.ROUTER) self.router.bind(self.route_socket) - self.associated.extend([self.pub, self.router]) poller = self.zmq.Poller() poller.register(self.router, self.zmq.POLLIN) + poller.register(self.cancel, self.zmq.POLLIN) + + self.associated += [self.pub, self.router, self.cancel] + + # TODO: actually do this + self.state = CONTROL_STATES.SOURCES_READY buffer = [] @@ -325,11 +369,20 @@ class Controller(object): self.logging.error('Invalid frame', rawmessage) pass + if self.cancel in socks and socks[self.cancel] == self.zmq.POLLIN: + self.logging.info('[Controller] Received Cancellation') + rawmessage = self.cancel.recv() + self.shutdown(soft=True) + break + self.beat() if self.zmq_flavor == 'green': gevent.sleep(0) + if self.state is CONTROL_STATES.TERMINATE: + break + if not self.polling: break @@ -359,35 +412,89 @@ class Controller(object): for component in bad: self.fail(component) - # ------------------ - # Component Handlers - # ------------------ + # -------------- + # Init Handlers + # -------------- + + def new_source(self): + if self.state is CONTROL_STATES.RUNNING: + self.state = SOURCES_READY + + def new_universal(self): + pass # The various "states of being that a component can inform us # of def new(self, component): - self.logging.info('[Controller] Alive "%s" ' % component) + if self.state is CONTROL_STATES.TERMINATE: + return + + self.logging.info('[Controller] Now Tracking "%s" ' % component) + + universal = self.new_universal + init_handlers = { + 'FEED' : self.new_source, + } if component in self.topology or self.freeform: + init_handlers.get(component, universal)() self.tracked.add(component) else: # Some sort of socket collision has occured, this is # a very bad failure mode. raise UnknownChatter(component) + # ------------------ + # Epic Fail Handling + # ------------------ + + def fail_universal(self): + pass + # TODO: this requires higher order functionality + #self.logging.error('[Controller] System in exception state, shutting down') + #self.shutdown(soft=True) + def fail(self, component): - self.logging.info('[Controller] Component "%s" timed out' % component) - self.tracked.remove(component) + if self.state is CONTROL_STATES.TERMINATE: + return + + universal = self.fail_universal + fail_handlers = { } + + if component in self.topology or self.freeform: + self.logging.info('[Controller] Component "%s" timed out' % component) + self.tracked.remove(component) + fail_handlers.get(component, universal)() + + # ------------------- + # Completion Handling + # ------------------- def done(self, component): - # TODO: This will be what we ship off to vbench at some - # point... - # print component finished at self.ctime self.logging.info('[Controller] Component "%s" done.' % component) + # -------------- + # Error Handling + # -------------- + + def exception_universal(self): + """ + Shutdown the system on failure. + """ + self.logging.error('[Controller] System in exception state, shutting down') + self.shutdown(soft=True) + def exception(self, component, failure): - self.error_replay[time.time()] = failure - self.logging.error('Component "%s" in exception state' % component) + universal = self.exception_universal + exception_handlers = { } + + if component in self.topology or self.freeform: + self.error_replay[(component, time.time())] = failure + self.logging.error('[Controller] Component "%s" in exception state' % component) + + exception_handlers.get(component, universal)() + else: + raise UnknownChatter(component) # ----------------- # Protocol Handling @@ -462,6 +569,11 @@ class Controller(object): self.associated.append(s) return s + def do_error_replay(self): + for (component, time), error in self.error_replay.iteritems(): + self.logging.info('[Controller] Error Log for -- %s --:\n%s' % + (component, error)) + def shutdown(self, hard=False, soft=True, context=None): if not self.polling: @@ -472,7 +584,7 @@ class Controller(object): assert hard or soft, """ Must specify kill hard or soft """ if hard: - self.state = CONTROL_STATES.SHUTDOWN + self.state = CONTROL_STATES.TERMINATE self.logging.info('[Controller] Hard Shutdown') @@ -488,18 +600,22 @@ class Controller(object): #for asoc in self.associated: #asoc.close() + self.do_error_replay() + if __name__ == '__main__': - print 'Running on ',\ - 'tcp://127.0.0.1:5000', \ - 'tcp://127.0.0.1:5001', + print 'Running on '\ + 'tcp://127.0.0.1:5000 '\ + 'tcp://127.0.0.1:5001 ' controller = Controller( 'tcp://127.0.0.1:5000', 'tcp://127.0.0.1:5001', ) + controller.zmq_flavor = 'green' + controller.manage( 'freeform', [] ) - controller.run('green') + controller.run() diff --git a/zipline/protocol.py b/zipline/protocol.py index c3291556..90a3184a 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -126,9 +126,6 @@ from collections import namedtuple from protocol_utils import Enum, FrameExceptionFactory, namedict from date_utils import EPOCH, UN_EPOCH -#import ujson -#import ultrajson_numpy - # ----------------------- # Control Protocol # ----------------------- @@ -136,9 +133,10 @@ from date_utils import EPOCH, UN_EPOCH INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL') CONTROL_STATES = Enum( + 'INIT', + 'SOURCES_READY', 'RUNNING', - 'SHUTDOWN', # a soft kill - 'TERMINATE', # a hard kill + 'TERMINATE', ) CONTROL_PROTOCOL = Enum( @@ -149,6 +147,7 @@ CONTROL_PROTOCOL = Enum( 'OK' , # 3 - rep 'DONE' , # 4 - rep 'EXCEPTION' , # 5 - rep + 'SIGNAL' , # 6 - rep ) def CONTROL_FRAME(event, payload): @@ -174,18 +173,6 @@ def CONTROL_UNFRAME(msg): except ValueError: raise INVALID_CONTROL_FRAME(msg) -# ----------------------- -# Heartbeat Protocol -# ----------------------- - -# These encode the msgpack equivelant of 1 and 2. The heartbeat -# frame should only be 1 byte on the wire. - -HEARTBEAT_PROTOCOL = namedict({ - 'REQ' : b'\x01', - 'REP' : b'\x02', -}) - # ----------------------- # Component State # ----------------------- diff --git a/zipline/protocol_utils.py b/zipline/protocol_utils.py index 8d21b1a0..60c90814 100644 --- a/zipline/protocol_utils.py +++ b/zipline/protocol_utils.py @@ -12,6 +12,7 @@ def Enum(*options): """ class cstruct(Structure): _fields_ = [(o, c_ubyte) for o in options] + __iter__ = lambda s: iter(range(len(options))) return cstruct(*range(len(options))) def FrameExceptionFactory(name): diff --git a/zipline/test/algorithms.py b/zipline/test/algorithms.py index 17873d27..b3743709 100644 --- a/zipline/test/algorithms.py +++ b/zipline/test/algorithms.py @@ -9,6 +9,9 @@ are provided below. The algorithm must expose methods: + - initialize: method that takes no args, no returns. Simply called to + enable the algorithm to set any internal state needed. + - get_sid_filter: method that takes no args, and returns a list of valid sids. List must have a length between 1 and 10. If None is returned the filter will block all events. @@ -18,7 +21,7 @@ The algorithm must expose methods: +-----------------+--------------+----------------+--------------------+ | | SID(133) | SID(134) | SID(135) | - +=================+==============+=====================================+ + +=================+==============+================+====================+ | price | $10.10 | $22.50 | $13.37 | +-----------------+--------------+----------------+--------------------+ | volume | 10,000 | 5,000 | 50,000 | @@ -61,6 +64,9 @@ class TestAlgorithm(): self.order = None self.frame_count = 0 self.portfolio = None + + def initialize(self): + pass def set_order(self, order_callable): self.order = order_callable @@ -94,7 +100,10 @@ class HeavyBuyAlgorithm(): self.order = None self.frame_count = 0 self.portfolio = None - + + def initialize(self): + pass + def set_order(self, order_callable): self.order = order_callable @@ -114,6 +123,9 @@ class NoopAlgorithm(object): """ Dolce fa niente. """ + + def initialize(self): + pass def set_order(self, order_callable): pass diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index 763d8024..0876e19c 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -368,33 +368,54 @@ class FinanceTestCase(TestCase): # same scenario, but short sales. params2 = { - 'trade_count':100, - 'trade_amount':100, - 'trade_delay': timedelta(minutes=5), - 'trade_interval': timedelta(days=1), - 'order_count':3, - 'order_amount':1000, - 'order_interval': timedelta(minutes=30), + 'trade_count' : 100, + 'trade_amount' : 100, + 'trade_delay' : timedelta(minutes=5), + 'trade_interval' : timedelta(days=1), + 'order_count' : 3, + 'order_amount' :-1000, + 'order_interval' : timedelta(minutes=30), # because we placed an orders totaling less than 25% of one trade # the simulator should produce just one transaction. - 'expected_txn_count' : 1, - 'expected_txn_volume' : 25 + 'expected_txn_count' : 1, + 'expected_txn_volume' : -25 } self.transaction_sim(**params2) - - + + @timed(DEFAULT_TIMEOUT) + def test_alternating_long_short(self): + # create a scenario where we alternate buys and sells + params1 = { + 'trade_count' : int(6.5 * 60 * 4), + 'trade_amount' : 100, + 'trade_interval' : timedelta(minutes=1), + 'order_count' : 4, + 'order_amount' : 10, + 'order_interval' : timedelta(hours=24), + 'alternate' : True, + 'complete_fill' : True, + 'expected_txn_count' : 4, + 'expected_txn_volume' : 0 #equal buys and sells + } + self.transaction_sim(**params1) def transaction_sim(self, **params): - trade_count = params['trade_count'] - trade_amount = params['trade_amount'] - trade_interval = params['trade_interval'] - trade_delay = params.get('trade_delay') - order_count = params['order_count'] - order_amount = params['order_amount'] - order_interval = params['order_interval'] - expected_txn_count = params['expected_txn_count'] + trade_count = params['trade_count'] + trade_amount = params['trade_amount'] + trade_interval = params['trade_interval'] + trade_delay = params.get('trade_delay') + order_count = params['order_count'] + order_amount = params['order_amount'] + order_interval = params['order_interval'] + expected_txn_count = params['expected_txn_count'] expected_txn_volume = params['expected_txn_volume'] + # optional parameters + # --------------------- + # if present, alternate between long and short sales + alternate = params.get('alternate') + # if present, expect transaction amounts to match orders exactly. + complete_fill = params.get('complete_fill') trading_environment = factory.create_trading_environment() trade_sim = TransactionSimulator() @@ -411,17 +432,31 @@ class FinanceTestCase(TestCase): trading_environment ) - for i in range(order_count): + if alternate: + alternator = -1 + else: + alternator = 1 + + order_date = start_date + for i in xrange(order_count): order = namedict( { - 'sid':sid, - 'amount':order_amount, - 'type':zp.DATASOURCE_TYPE.ORDER, - 'dt' : start_date + i * order_interval + 'sid' : sid, + 'amount' : order_amount * alternator**i, + 'type' : zp.DATASOURCE_TYPE.ORDER, + 'dt' : order_date }) trade_sim.add_open_order(order) - + + order_date = order_date + order_interval + # move after market orders to just after market next + # market open. + if order_date.hour >= 21: + if order_date.minute >= 00: + order_date = order_date + timedelta(days=1) + order_date = order_date.replace(hour=14, minute=30) + # there should now be one open order list stored under the sid oo = trade_sim.open_orders self.assertEqual(len(oo), 1) @@ -429,9 +464,10 @@ class FinanceTestCase(TestCase): order_list = oo[sid] self.assertEqual(order_count, len(order_list)) - for order in order_list: + for i in xrange(order_count): + order = order_list[i] self.assertEqual(order.sid, sid) - self.assertEqual(order.amount, order_amount) + self.assertEqual(order.amount, order_amount * alternator**i) tracker = PerformanceTracker(trading_environment) @@ -450,10 +486,17 @@ class FinanceTestCase(TestCase): trade.TRANSACTION = None tracker.process_event(trade) - + + if complete_fill: + self.assertEqual(len(transactions), len(order_list)) + total_volume = 0 - for txn in transactions: + for i in xrange(len(transactions)): + txn = transactions[i] total_volume += txn.amount + if complete_fill: + order = order_list[i] + self.assertEqual(order.amount, txn.amount) self.assertEqual(total_volume, expected_txn_volume) self.assertEqual(len(transactions), expected_txn_count) diff --git a/zipline/test/test_transforms.py b/zipline/test/test_transforms.py new file mode 100644 index 00000000..6a2bf204 --- /dev/null +++ b/zipline/test/test_transforms.py @@ -0,0 +1,97 @@ +from datetime import timedelta +from collections import defaultdict +from unittest2 import TestCase + +import zipline.test.factory as factory +import zipline.util as qutil +from zipline.finance.vwap import DailyVWAP, VWAPTransform +from zipline.finance.returns import ReturnsFromPriorClose +from zipline.finance.movingaverage import MovingAverage +from zipline.lines import SimulatedTrading +from zipline.simulator import AddressAllocator, Simulator + + +allocator = AddressAllocator(1000) + +class ZiplineWithTransformsTestCase(TestCase): + leased_sockets = defaultdict(list) + + def setUp(self): + # skip ahead 100 spots + allocator.lease(100) + qutil.configure_logging() + self.trading_environment = factory.create_trading_environment() + self.zipline_test_config = { + 'allocator':allocator, + 'sid':133 + } + + def test_vwap_tnfm(self): + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + + vwap = VWAPTransform("vwap_10", daycount=10) + zipline.add_transform(vwap) + + zipline.simulate(blocking=True) + + self.assertTrue(zipline.sim.ready()) + self.assertFalse(zipline.sim.exception) + +class FinanceTransformsTestCase(TestCase): + def setUp(self): + self.trading_environment = factory.create_trading_environment() + + def test_vwap(self): + + trade_history = factory.create_trade_history( + 133, + [10.0, 10.0, 10.0, 11.0], + [100, 100, 100, 300], + timedelta(days=1), + self.trading_environment + ) + + vwap = DailyVWAP(daycount=2) + for trade in trade_history: + vwap.update(trade) + + self.assertEqual(vwap.vwap, 10.75) + + + def test_returns(self): + trade_history = factory.create_trade_history( + 133, + [10.0, 10.0, 10.0, 11.0], + [100, 100, 100, 300], + timedelta(days=1), + self.trading_environment + ) + + returns = ReturnsFromPriorClose() + for trade in trade_history: + returns.update(trade) + + + self.assertEqual(returns.returns, .1) + + + def test_moving_average(self): + trade_history = factory.create_trade_history( + 133, + [10.0, 10.0, 10.0, 11.0], + [100, 100, 100, 300], + timedelta(days=1), + self.trading_environment + ) + + ma = MovingAverage(daycount=2) + for trade in trade_history: + ma.update(trade) + + + self.assertEqual(ma.average, 10.5) + + + \ No newline at end of file diff --git a/zipline/transitions.py b/zipline/transitions.py new file mode 100644 index 00000000..9a68926f --- /dev/null +++ b/zipline/transitions.py @@ -0,0 +1,87 @@ +import types +from collections import Container, Hashable, Callable + +class Any(object): pass + +class Workflow(Container, Callable): + + def __init__(self, states, transitions, initial_state): + self.simple = set() + self.complx = [] + + if isinstance(states[0], tuple): + self.groups = {b for _,b in states} + else: + self.groups = set() + + matcher = lambda b: lambda f,t : t == b + + for (a, b) in transitions.itervalues(): + if a is Any: + self.complx.append(matcher(b)) + if isinstance(a, Hashable) and isinstance(b, Hashable): + self.simple.add((a,b)) + + def __call__(self, **kwargs): + if 'group' in kwargs: + return self.groups + + def __contains__(self, state): + if state in self.simple: + return True + for match in self.complx: + if match(*state): + return True + else: + return False + +class Flowable: + + @property + def state(self): + if not hasattr(self, '_state'): + self._state = self.initial_state + else: + return self._state + + @state.setter + def state(self, new): + if not hasattr(self, '_state'): + self._state = self.initial_state + + old = self._state + + if (old, new) in self.workflow: + self._state = new + else: + raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) + +class WorkflowMeta(type): + """ + Base metaclass component workflows. + """ + + def __new__(cls, name, mro, attrs): + + state = attrs.get('states', None) + transitions = attrs.get('transitions', None) + initial_state = attrs.get('initial_state', None) + + if attrs.get('workflow'): + raise RuntimeError('`workflow` is a reserved attribute.') + + if not state: + raise RuntimeError('Must specify states') + + if not transitions: + raise RuntimeError('Must specify transitions') + + if not transitions: + raise RuntimeError('Must specify initial_state') + + new_class = super(WorkflowMeta, cls).__new__( + cls, name, mro+(Flowable,), attrs + ) + new_class.workflow = Workflow(state, transitions, initial_state) + + return new_class