diff --git a/dev/cli.py b/dev/cli.py deleted file mode 100644 index 3fd8a4b3..00000000 --- a/dev/cli.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: move qexec console here diff --git a/dev/topos.py b/dev/topos.py deleted file mode 100644 index 23e37ff7..00000000 --- a/dev/topos.py +++ /dev/null @@ -1,189 +0,0 @@ -import uuid -import copy -import atexit -import pickle - -from datetime import datetime -from collections import defaultdict - -from UserDict import DictMixin - -class Snapshot(object, DictMixin): - """ - A snapshot in time of a history container. - """ - - def __init__(self, state, version, ts): - self.version = version - self.timestamp = ts - self._state = state - - def keys(self): - return self._state.keys() - - def values(self): - return self._state.values() - - def items(self): - return self._state.items() - - def __getitem__(self, key): - return self._state.__getitem__(key) - - def has_key(self, key): - return self._state.has_key(key) - - def copy(self): - return copy.copy(self._state) - -class History(object, DictMixin): - """ - A duck-typed dictionary that tracks its time evolution. - - Worth noting this not a particuarly high-performance - data structure due to the copious amount of copying going on. - """ - - def __init__(self, default=None): - if default: - initial = defaultdict(default) - else: - initial = {} - - self.version = 0 - self.changeset = [('CREATE', None)] - self.current = Snapshot(initial, version=self.version, ts=datetime.now()) - self._history = [self.current] - - def items(self, version=-1): - return self._history[version].items() - - def keys(self, version=-1): - return self._history[version].keys() - - def rollback(self, version): - pass - - def event(self, tup): - self.changeset.append(tup) - - def __getitem__(self, key, version=-1): - return self._history[version].__getitem__(key) - - def __setitem__(self, key, val): - if self.current.has_key(key): - self.changeset.append(('CHANGE', key)) - else: - self.changeset.append(('ADD', key)) - - state = self.current.copy() - state[key] = val - - self.version += 1 - self.current = Snapshot(state, self.version, datetime.now()) - self._history.append(self.current) - - def __delitem__(self, key): - self.changeset.append(('REMOVE', key)) - - state = self.current.copy() - del state[key] - - self.version += 1 - self.current = Snapshot(state, self.version, datetime.now()) - self._history.append(self.current) - - def history(self): - for change in self.changeset: - print change - - def __repr__(self): - return ':'.join(['historical', self.current._state.__repr__()]) - -SocketHistory = History() -ContextHistory = History() - -def patch_zmq(_zmq=None): - """ - Monkey patch zeromq to allow for socket tracking. - """ - if _zmq: - zmq = _zmq - else: - import zmq - - _Context = zmq.Context - _Socket = zmq.Socket - - class TrackedSocket(zmq.Socket): - - def __init__(self, context, socket_type): - self.context = context - self.uuid = str(uuid.uuid4()) - SocketHistory[self.uuid] = self - _Socket.__init__(self, context, socket_type) - - def connect(self, address): - SocketHistory.event(('CONNECT', self.uuid, address)) - _Socket.connect(self, address) - - def bind(self, address): - SocketHistory.event(('BIND', self.uuid, address)) - _Socket.bind(self, address) - - def close(self, *args, **kwargs): - del SocketHistory[self.uuid] - _Socket.close(self, *args, **kwargs) - - def setsockopt(self, option, optval): - if option == zmq.IDENTITY: - old = SocketHistory[self.uuid] - SocketHistory[optval] = old - del SocketHistory[self.uuid] - self.uuid = optval - - _Socket.setsockopt(self, option, optval) - - class TrackedContext(zmq.Context): - - def __init__(self, *args, **kwargs): - self.sockets = {} - _Context.__init__(self, *args, **kwargs) - self.uuid = str(uuid.uuid4()) - ContextHistory[self.uuid] = self - - def socket(self, socket_type): - sock = TrackedSocket(self, socket_type) - ContextHistory.event(('EMBED', self.uuid, sock.uuid)) - self.sockets[sock.uuid] = sock - return sock - - def name(self, name): - """ - Name the context. Is a superset of the vanilla pyzmq - API. - """ - old = ContextHistory[self.context.uuid] - ContextHistory[name] = old - del ContextHistory[self.context.uuid] - self.uuid = name - - def term(self, *args, **kwargs): - for uid, sock in self.sockets.iteritems(): - if not sock.closed: - del SocketHistory[sock.uuid] - del ContextHistory[self.uuid] - _Context.term(self, *args, **kwargs) - - def destroy(self, *args, **kwargs): - ContextHistory.event(('DESTROY', self.uuid)) - _Context.destroy(self, *args, **kwargs) - - zmq.Context = TrackedContext - zmq.Socket = TrackedSocket - return TrackedContext, TrackedSocket - -def track_to_file(f): - def write_track(): - pickle.dump(SocketHistory.changeset, file(f, 'wb+')) - atexit.register(write_track) diff --git a/etc/requirements_sci.txt b/etc/requirements_sci.txt index b74b3e57..a0104713 100644 --- a/etc/requirements_sci.txt +++ b/etc/requirements_sci.txt @@ -4,7 +4,7 @@ python-dateutil==1.5 # Core scientific python numpy>=1.6.1 -pandas>=0.7.0rc1 +pandas==0.8.0 scipy>=0.10.0 matplotlib==1.1.0 @@ -12,8 +12,9 @@ matplotlib==1.1.0 numexpr==2.0.1 Cython==0.15.1 -#tables>=2.3.1 -#scikits.statsmodels>=0.3.1 +patsy==0.1.0 +statsmodels==0.5.0-tutorial-beta + # ZeroMQ pyzmq==2.1.11 diff --git a/tests/test_delayed_signals.py b/tests/test_delayed_signals.py new file mode 100644 index 00000000..d840bb0a --- /dev/null +++ b/tests/test_delayed_signals.py @@ -0,0 +1,53 @@ +import os +from signal import signal, SIGHUP, SIGINT +import time +from types import FrameType +import unittest + +from zipline.utils.delayed_signals import delayed_signals + +class DelayedSignals(unittest.TestCase): + def handler(self, signum, frame): + print "Got signal " + str(signum) + self.got[signum] = time.time() + self.assertTrue(isinstance(frame, FrameType)) + + def setUp(self): + signal(SIGHUP, self.handler) + signal(SIGINT, self.handler) + + def reset(self): + self.got = {} + + def test_delayed_signals(self): + self.reset() + with delayed_signals([SIGHUP]): + os.kill(os.getpid(), SIGHUP) + time.sleep(2) + self.assertTrue(self.got[SIGHUP]) + self.assertTrue(time.time() - self.got[SIGHUP] < 2) + + def test_immediate_signals(self): + self.reset() + os.kill(os.getpid(), SIGHUP) + time.sleep(2) + self.assertTrue(self.got[SIGHUP]) + self.assertTrue(time.time() - self.got[SIGHUP] > 1) + + def test_multiple_signals(self): + self.reset() + with delayed_signals([SIGHUP, SIGINT]): + os.kill(os.getpid(), SIGINT) + self.assertFalse(SIGHUP in self.got) + self.assertTrue(SIGINT in self.got) + + @delayed_signals([SIGHUP]) + def kill_and_sleep(self): + os.kill(os.getpid(), SIGHUP) + time.sleep(2) + + def test_decorator(self): + self.reset() + self.kill_and_sleep() + self.assertTrue(SIGHUP in self.got) + self.assertTrue(time.time() - self.got[SIGHUP] < 2) diff --git a/tests/test_exception_handling.py b/tests/test_exception_handling.py index d1561837..f16111ee 100644 --- a/tests/test_exception_handling.py +++ b/tests/test_exception_handling.py @@ -3,11 +3,14 @@ import zmq from unittest2 import TestCase from collections import defaultdict -from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm +from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm, \ + InitializeTimeoutAlgorithm, TooMuchProcessingAlgorithm from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator from zipline.lines import SimulatedTrading from zipline.gens.transform import StatefulTransform +from zipline.gens.tradesimulation import HEARTBEAT_INTERVAL, \ + MAX_HEARTBEAT_INTERVALS from zipline.utils.test_utils import \ drain_zipline, \ @@ -143,3 +146,46 @@ class ExceptionTestCase(TestCase): # make sure our path shortening is working self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py') self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py') + + def test_initialize_timeout(self): + + self.zipline_test_config['algorithm'] = \ + InitializeTimeoutAlgorithm( + self.zipline_test_config['sid'] + ) + + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + output, _ = drain_zipline(self, zipline) + self.assertEqual(output[-1]['prefix'], 'EXCEPTION') + payload = output[-1]['payload'] + self.assertEqual(payload['name'],'Timeout') + self.assertEqual(payload['message'], 'Call to initialize timed out') + + def test_heartbeat(self): + + self.zipline_test_config['algorithm'] = \ + TooMuchProcessingAlgorithm( + self.zipline_test_config['sid'] + ) + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + output, _ = drain_zipline(self, zipline) + + # There should be a message for each hearbeat, plus a message + # for the final timeout. + assert len(output) == MAX_HEARTBEAT_INTERVALS + 1 + + # Assert that everything but the last message is a heartbeat log. + for message in output[0:-1]: + assert message['prefix'] == 'LOG' + assert message['payload']['func_name'] == 'log_heartbeats' + + # Assert that the last message is a timeout exception. + self.assertEqual(output[-1]['prefix'], 'EXCEPTION') + payload = output[-1]['payload'] + self.assertEqual(payload['name'],'Timeout') + self.assertEqual(payload['message'], 'Too much time spent in handle_data call') + diff --git a/tests/test_finance.py b/tests/test_finance.py index e5f26240..5562b671 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -20,7 +20,6 @@ from zipline.finance.performance import PerformanceTracker from zipline.utils.protocol_utils import ndict from zipline.finance.trading import TransactionSimulator from zipline.utils.test_utils import \ - drain_zipline, \ setup_logger, \ teardown_logger,\ assert_single_position @@ -121,36 +120,6 @@ class FinanceTestCase(TestCase): zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config) assert_single_position(self, zipline) - #@timed(DEFAULT_TIMEOUT) - def test_sid_filter(self): - # Ensure the algorithm's filter prevents events from arriving. - # create a test algorithm whose filter will not match any of the - # trade events sourced inside the zipline. - order_amount = 100 - order_count = 100 - no_match_sid = 222 - test_algo = TestAlgorithm( - no_match_sid, - order_amount, - order_count - ) - - self.zipline_test_config['trade_count'] = 200 - self.zipline_test_config['algorithm'] = test_algo - - zipline = SimulatedTrading.create_test_zipline( - **self.zipline_test_config - ) - output, transaction_count = drain_zipline(self, zipline) - - #check that the algorithm received no events - self.assertEqual( - 0, - transaction_count, - "The algorithm should not receive any events due to filtering." - ) - - # TODO: write tests for short sales # TODO: write a test to do massive buying or shorting. diff --git a/tests/test_logger.py b/tests/test_logger.py index 9c7eb685..5a9b8e31 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,7 +1,15 @@ import logging +import logbook import uuid +import zmq + +from zipline import ndict from zipline.utils.logger import configure_logging, tail +from zipline.utils.log_utils import ZeroMQLogHandler + +from zipline.utils.test_utils import create_receiver, drain_receiver + from unittest2 import TestCase @@ -20,3 +28,35 @@ class LoggerTestCase(TestCase): last_line = tail(logfile, window=1) logged_msg = last_line.split(" - ")[1] self.assertEqual(test_msg, logged_msg) + + + def test_zmq_handler(self): + socket_addr = 'tcp://127.0.0.1:10000' + ctx = zmq.Context() + socket_push = ctx.socket(zmq.PUSH) + socket_push.connect(socket_addr) + recv = create_receiver(socket_addr, ctx) + zmq_out = ZeroMQLogHandler( + socket = socket_push, + filter = lambda r, h: r.channel in ['test zmq logger'], + context=ctx, + #bubble=False + ) + + log = logbook.Logger('test zmq logger') + x = ndict({}) + x.a = 1 + ex = example(133) + with zmq_out.threadbound(): + log.info(ex.num) + + + output, _ = drain_receiver(recv, count=1) + self.assertEqual(output[-1]['prefix'], 'LOG') + self.assertTrue(isinstance(output[-1]['payload']['msg'], basestring)) + + +class example(object): + + def __init__(self, num): + self.num = num diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 491df1b5..e515d725 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,4 +1,5 @@ import pytz +import numpy from datetime import timedelta, datetime from collections import defaultdict @@ -15,6 +16,7 @@ from zipline.gens.tradegens import SpecificEquityTrades from zipline.gens.transform import StatefulTransform, EventWindow from zipline.gens.vwap import VWAP from zipline.gens.mavg import MovingAverage +from zipline.gens.stddev import MovingStandardDev from zipline.gens.returns import Returns import zipline.utils.factory as factory @@ -70,6 +72,7 @@ class EventWindowTestCase(TestCase): delta = timedelta(minutes = 5), days = None ) + now = utcnow() # 15 dates, increasing in 1 minute increments. @@ -99,6 +102,7 @@ class EventWindowTestCase(TestCase): delta = None, days = 1 ) + dates = ([self.pre_open]*3) dates += ([self.mid_day]*3) dates += ([self.post_close]*3) @@ -239,11 +243,12 @@ class FinanceTransformsTestCase(TestCase): fields = ['price', 'volume'], delta = timedelta(days = 2), ) + transformed = list(mavg.transform(self.source)) # Output values. tnfm_prices = [message.tnfm_value.price for message in transformed] tnfm_volumes = [message.tnfm_value.volume for message in transformed] - + # "Hand-calculated" values expected_prices = [ ((10.0) / 1.0), @@ -264,3 +269,46 @@ class FinanceTransformsTestCase(TestCase): assert tnfm_prices == expected_prices assert tnfm_volumes == expected_volumes + + def test_moving_stddev(self): + trade_history = factory.create_trade_history( + 133, + [10.0, 15.0, 13.0, 12.0], + [100, 100, 100, 100], + timedelta(hours = 1), + self.trading_environment + ) + + stddev = StatefulTransform( + MovingStandardDev, + market_aware = False, + delta = timedelta(minutes = 150), + ) + self.source = SpecificEquityTrades(event_list=trade_history) + + transformed = list(stddev.transform(self.source)) + + vals = [message.tnfm_value for message in transformed] + + expected = [ + None, + numpy.std([10.0, 15.0], ddof = 1), + numpy.std([10.0, 15.0, 13.0], ddof = 1), + numpy.std([15.0, 13.0, 12.0], ddof = 1), + ] + + # numpy has odd rounding behavior, cf. + # http://docs.scipy.org/doc/numpy/reference/generated/numpy.std.html + for v1, v2 in zip(vals, expected): + + if v1 == None: + assert v2 == None + continue + assert round(v1, 5) == round(v2, 5) + + + + + + + diff --git a/zipline/finance/performance.py b/zipline/finance/performance.py index a1fbf65c..278d8189 100644 --- a/zipline/finance/performance.py +++ b/zipline/finance/performance.py @@ -166,6 +166,7 @@ class PerformanceTracker(object): self.event_count = 0 self.last_dict = None self.exceeded_max_loss = False + self.no_more_updates = False self.compute_risk_metrics = True @@ -205,9 +206,12 @@ class PerformanceTracker(object): self.todays_performance.positions[sid] = Position(sid) def update(self, event): - if event.dt == "DONE": + if self.no_more_updates: + return zp.ndict({'dt':0}) + elif event.dt == "DONE": event.perf_message = self.handle_simulation_end() del event['TRANSACTION'] + self.no_more_updates = True return event elif self.exceeded_max_loss: # in case of max_loss, signal to downstream @@ -215,6 +219,7 @@ class PerformanceTracker(object): event.dt = "DONE" event.perf_message = self.handle_simulation_end() del event['TRANSACTION'] + self.no_more_updates = True return event else: event.perf_message = self.process_event(event) diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index 8e52ad3d..66825eba 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -307,24 +307,22 @@ class RiskMetrics(): for i in xrange(7): if(self.treasury_curves.has_key(self.end_date + i * one_day)): curve = self.treasury_curves[self.end_date + i * one_day] - break + self.treasury_curve = curve + rate = self.treasury_curve[self.treasury_duration] + #1month note data begins in 8/2001, so we can use 3month instead. + if rate == None and self.treasury_duration == '1month': + rate = self.treasury_curve['3month'] - if curve: - self.treasury_curve = curve - rate = self.treasury_curve[self.treasury_duration] - #1month note data begins in 8/2001, so we can use 3month instead. - if rate == None and self.treasury_duration == '1month': - rate = self.treasury_curve['3month'] - if rate != None: - return rate * (td.days + 1) / 365 + if rate != None: + return rate * (td.days + 1) / 365 - message = "no rate for end date = {dt} and term = {term}. Check \ - that date doesn't exceed treasury history range." - message = message.format( - dt=self.end_date, - term=self.treasury_duration - ) - raise Exception(message) + message = "no rate for end date = {dt} and term = {term}. Check \ + that date doesn't exceed treasury history range." + message = message.format( + dt=self.end_date, + term=self.treasury_duration + ) + raise Exception(message) diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 1d82bc66..2fa99d9b 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -184,6 +184,8 @@ class TradingEnvironment(object): self.first_open = self.calculate_first_open() self.last_close = self.calculate_last_close() + self.prior_day_open = self.calculate_prior_day_open() + def calculate_first_open(self): """ Finds the first trading day on or after self.period_start. @@ -197,6 +199,24 @@ class TradingEnvironment(object): first_open = self.set_NYSE_time(first_open, 9, 30) return first_open + def calculate_prior_day_open(self): + """ + Finds the first trading day open that falls at least a day + before period_start. + """ + one_day = datetime.timedelta(days=1) + first_open = self.period_start - one_day + + if first_open <= self.trading_days[0]: + log.warn("Cannot calculate prior day open.") + return self.period_start + + while not self.is_trading_day(first_open): + first_open = first_open - one_day + + first_open = self.set_NYSE_time(first_open, 9, 30) + return first_open + def calculate_last_close(self): """ Finds the last trading day on or before self.period_end diff --git a/zipline/gens/sort.py b/zipline/gens/sort.py index 9755da74..d8ebd173 100644 --- a/zipline/gens/sort.py +++ b/zipline/gens/sort.py @@ -1,12 +1,16 @@ """ -Generator version of Feed. +Sorting generator. """ +import logbook + from collections import deque from zipline import ndict from zipline.gens.utils import \ assert_datasource_unframe_protocol, \ assert_sort_protocol +log = logbook.Logger('Sorting') + def date_sort(stream_in, source_ids): """ A generator that takes a generator and a list of source_ids. We @@ -27,7 +31,7 @@ def date_sort(stream_in, source_ids): # Incoming messages should be the output of DATASOURCE_UNFRAME. assert_datasource_unframe_protocol(message), \ "Bad message in date_sort: %s" % message - + # Only allow messages from sources we expect. assert message.source_id in sources, "Unexpected source: %s" % message @@ -40,7 +44,7 @@ def date_sort(stream_in, source_ids): message = pop_oldest(sources) assert_sort_protocol(message) yield message - + # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue diff --git a/zipline/gens/stddev.py b/zipline/gens/stddev.py new file mode 100644 index 00000000..1f46429a --- /dev/null +++ b/zipline/gens/stddev.py @@ -0,0 +1,100 @@ +from numbers import Number +from datetime import datetime, timedelta +from collections import defaultdict +from math import sqrt + +from zipline import ndict +from zipline.gens.transform import EventWindow + +class MovingStandardDev(object): + """ + Class that maintains a dicitonary from sids to + MovingStandardDevWindows. For each sid, we maintain a the + standard deviation of all events falling within the specified + window. + """ + + def __init__(self, market_aware, days = None, delta = None): + + self.market_aware = market_aware + + self.delta = delta + self.days = days + + # Market-aware mode only works with full-day windows. + if self.market_aware: + assert self.days and not self.delta,\ + "Market-aware mode only works with full-day windows." + + # Non-market-aware mode requires a timedelta. + else: + assert self.delta and not self.days, \ + "Non-market-aware mode requires a timedelta." + + # No way to pass arguments to the defaultdict factory, so we + # need to define a method to generate the correct EventWindows. + self.sid_windows = defaultdict(self.create_window) + + def create_window(self): + """ + Factory method for self.sid_windows. + """ + return MovingStandardDevWindow( + self.market_aware, + self.days, + self.delta + ) + + def update(self, event): + """ + Update the event window for this event's sid. Return an ndict + from tracked fields to moving averages. + """ + # This will create a new EventWindow if this is the first + # message for this sid. + window = self.sid_windows[event.sid] + window.update(event) + return window.get_stddev() + +class MovingStandardDevWindow(EventWindow): + """ + Iteratively calculates standard deviation for a particular sid + over a given time window. The expected functionality of this + class is to be instantiated inside a MovingStandardDev. + """ + + def __init__(self, market_aware, days, delta): + + # Call the superclass constructor to set up base EventWindow + # infrastructure. + EventWindow.__init__(self, market_aware, days, delta) + + self.sum = 0.0 + self.sum_sqr = 0.0 + + def handle_add(self, event): + assert event.has_key('price') + assert isinstance(event.price, Number) + + self.sum += event.price + self.sum_sqr += event.price ** 2 + + def handle_remove(self, event): + assert event.has_key('price') + assert isinstance(event.price, Number) + + self.sum -= event.price + self.sum_sqr -= event.price ** 2 + + def get_stddev(self): + + # Sample standard deviation is undefined for a single event or + # no events. + if len(self) <= 1: + return None + + else: + average = self.sum /len(self) + s_squared = (self.sum_sqr - self.sum*average) / (len(self) - 1) + stddev = sqrt(s_squared) + return stddev diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 5c22ef93..30d575aa 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -4,24 +4,23 @@ and zipline development """ import random import pytz -from copy import copy -import pandas as pd -from zipline import ndict -from zipline.protocol import DATASOURCE_TYPE - -from itertools import chain, cycle, ifilter, izip +from itertools import chain, cycle, ifilter, izip, repeat from datetime import datetime, timedelta from zipline.gens.utils import hash_args, create_trade def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc), delta = timedelta(minutes = 1), - count = 100): + count = 100, + repeats = None): """ Utility to generate a stream of dates. """ - return (start + (i * delta) for i in xrange(count)) + if repeats: + return (start + (i * delta) for i in xrange(count) for n in xrange(repeats)) + else: + return (start + (i * delta) for i in xrange(count)) def mock_prices(count, rand = False): """ @@ -101,20 +100,41 @@ class SpecificEquityTrades(object): def get_hash(self): return self.__class__.__name__ + "-" + self.arg_string + def update_source_id(self, gen): + for event in gen: + event.source_id = self.get_hash() + yield event + def create_fresh_generator(self): + if self.event_list: - for event in self.event_list: - event['source_id'] = self.get_hash() - unfiltered = (event for event in self.event_list) + event_gen = (event for event in self.event_list) + unfiltered = self.update_source_id(event_gen) # Set up iterators for each expected field. else: - dates = date_gen(count=self.count, - start=self.start, - delta=self.delta - ) + if self.concurrent: + # in this context the count is the number of + # trades per sid, not the total. + dates = date_gen( + count=self.count, + start=self.start, + delta=self.delta, + repeats=len(self.sids), + ) + + + else: + + dates = date_gen( + count=self.count, + start=self.start, + delta=self.delta + ) + prices = mock_prices(self.count) volumes = mock_volumes(self.count) + sids = cycle(self.sids) # Combine the iterators into a single iterator of arguments diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index a049a5b8..8c650dbc 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -1,9 +1,12 @@ +import signal from logbook import Logger, Processor from datetime import datetime, timedelta from numbers import Integral +from itertools import groupby from zipline import ndict +from zipline.utils.timeout import timeout, heartbeat, Timeout from zipline.gens.transform import StatefulTransform from zipline.finance.trading import TransactionSimulator @@ -13,10 +16,15 @@ from zipline.gens.utils import hash_args log = Logger('Trade Simulation') +# TODO: make these arguments rather than global constants +INIT_TIMEOUT = 5 +HEARTBEAT_INTERVAL = 1 # seconds +MAX_HEARTBEAT_INTERVALS = 15 #count + class TradeSimulationClient(object): """ - Generator that takes the expected output of a merge, a user - algorithm, a trading environment, and a simulator style as + Generator-style class that takes the expected output of a merge, a + user algorithm, a trading environment, and a simulator style as arguments. Pipes the merge stream through a TransactionSimulator and a PerformanceTracker, which keep track of the current state of our algorithm's simulated universe. Results are fed to the user's @@ -24,7 +32,7 @@ class TradeSimulationClient(object): TransactionSimulator's order book. TransactionSimulator maintains a dictionary from sids to the - unfulfilled orders placed by the user's algorithm. As trade + as-yet unfilled orders placed by the user's algorithm. As trade events arrive, if the algorithm has open orders against the trade's sid, the simulator will fill orders up to 25% of market cap. Applied transactions are added to a txn field on the event @@ -40,9 +48,9 @@ class TradeSimulationClient(object): performance report, which is appended to event's perf_report field. - Fully processed events are run through a batcher generator, which - batches together events with the same dt field into a single event - to be fed to the algo. The portfolio object is repeatedly + Fully processed events are fed to AlgorithmSimulator, which + batches together events with the same dt field into a single + snapshot to be fed to the algo. The portfolio object is repeatedly overwritten so that only the most recent snapshot of the universe is sent to the algo. """ @@ -55,9 +63,13 @@ class TradeSimulationClient(object): self.style = sim_style self.algo_sim = None + self.warmup_start = self.environment.prior_day_open + self.algo_start = self.environment.first_open + def get_hash(self): """ - There should only ever be one TSC in the system. + There should only ever be one TSC in the system, so + we don't bother passing args into the hash. """ return self.__class__.__name__ + hash_args() @@ -89,25 +101,31 @@ class TradeSimulationClient(object): with_portfolio = perf_tracker.transform(with_filled_orders) # Pass the messages from perf along with the trading client's - # state into the algorithm for simulation. We provide the - # trading client so that the algorithm can place new orders - # into the client's order book. + # state into the algorithm for simulation. We provide a + # pointer to the ordering client's internal state so that the + # algorithm can place new orders into the client's order book. self.algo_sim = AlgorithmSimulator( with_portfolio, ordering_client.state, self.algo, + self.algo_start ) # The algorithm will yield a daily_results message (as # calculated by the performance tracker) at the end of each # day. It will also yield a risk report at the end of the # simulation. + for message in self.algo_sim: yield message class AlgorithmSimulator(object): - def __init__(self, stream_in, order_book, algo): + def __init__(self, + stream_in, + order_book, + algo, + algo_start): self.stream_in = stream_in @@ -121,12 +139,28 @@ class AlgorithmSimulator(object): self.algo = algo self.sids = algo.get_sid_filter() + self.algo_start = algo_start # Monkey patch the user algorithm to place orders in the - # TransactionSimulator's order book. + # TransactionSimulator's order book and use our logger. self.algo.set_order(self.order) - self.algo.set_logger(Logger("AlgoLog")) + self.algolog = Logger("AlgoLog") + self.algo.set_logger(self.algolog) + # Handler for heartbeats during calls to handle_data. + def log_heartbeats(beat_count, stackframe): + t = beat_count * HEARTBEAT_INTERVAL + warning = "handle_data has been processing for %i seconds" %t + self.algolog.warn(warning) + + # Context manager that calls log_heartbeats every HEARTBEAT_INTERVAL + # seconds, raising an exception after MAX_HEARTBEATS + self.heartbeat_monitor = heartbeat( + HEARTBEAT_INTERVAL, + MAX_HEARTBEAT_INTERVALS, + frame_handler=log_heartbeats, + timeout_message="Too much time spent in handle_data call" + ) # ============== # Snapshot Setup @@ -142,7 +176,7 @@ class AlgorithmSimulator(object): # We don't have a datetime for the current snapshot until we # receive a message. self.simulation_dt = None - self.this_snapshot_dt = None + self.snapshot_dt = None # ============= # Logging Setup @@ -151,7 +185,7 @@ class AlgorithmSimulator(object): # Processor function for injecting the algo_dt into # user prints/logs. def inject_algo_dt(record): - record.extra['algo_dt'] = self.this_snapshot_dt + record.extra['algo_dt'] = self.snapshot_dt self.processor = Processor(inject_algo_dt) # This is a class, which is instantiated later @@ -206,94 +240,62 @@ class AlgorithmSimulator(object): # Capture any output of this generator to stdout and pipe it # to a logbook interface. Also inject the current algo # snapshot time to any log record generated. - with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - # Call the user's initialize method. - self.algo.initialize() - for event in self.stream_in: - # Yield any perf messages received to be relayed back to - # the browser. + # Call user's initialize method with a timeout. + with timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() - if event.perf_message: - yield event.perf_message - del event['perf_message'] + # Group together events with the same dt field. This depends on the + # events already being sorted. + for date, snapshot in groupby(self.stream_in, lambda e: e.dt): - if event.dt == "DONE": - if self.this_snapshot_dt: - # stop iteration happened - # mid-snapshot, so we have a universe - # snapshot that is not yet processed - # by the algorithm. - self.simulate_current_snapshot() - break - - # This should only happen for the first event we run. + # Set the simulation date to be the first event we see. + # This should only occur once, at the start of the test. if self.simulation_dt == None: - self.simulation_dt = event.dt + self.simulation_dt = date - # ====================== - # Time Compression Logic - # ====================== + # Done message has the risk report, so we yield before exiting. + if date == 'DONE': + for event in snapshot: + yield event.perf_message + break - if self.this_snapshot_dt != None: - self.update_current_snapshot(event) + # We're still in the warmup period. Use the event to + # update our universe, but don't yield any perf messages, + # and don't send a snapshot to handle_data. + elif date < self.algo_start: + for event in snapshot: + del event['perf_message'] + self.update_universe(event) - # The algorithm has been missing events because it took - # too long processing. Update the universe with data from - # this event, then check if enough time has passed that we - # can start a new snapshot. + # The algo has taken so long to process events that + # its simulated time is later than the event time. + # Update the universe and yield any perf messages + # encountered, but don't call handle_data. + elif date < self.simulation_dt: + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: + yield event.perf_message + # Delete the message before updating so we don't send it + # to the user. + del event['perf_message'] + self.update_universe(event) + + # Regular snapshot. Update the universe and send a snapshot + # to handle data. else: - self.update_universe(event) - if event.dt >= self.simulation_dt: - self.this_snapshot_dt = event.dt + for event in snapshot: + # Only yield if we have something interesting to say. + if event.perf_message != None: + yield event.perf_message + del event['perf_message'] + self.update_universe(event) - - def update_current_snapshot(self, event): - """ - Update our current snapshot of the universe. Call handle_data if - """ - # The new event matches our snapshot dt. Just update the - # universe and move on. - if event.dt == self.this_snapshot_dt: - self.update_universe(event) - - # The new event does not match our snapshot. - else: - self.simulate_current_snapshot() - - # Once we've finished simulating the old snapshot, - # we can update the universe with the new event. - self.update_universe(event) - - # The current event is later than the simulation time, - # which means the algorithm finished quickly enough to - # receive the new event. Start a new snapshot with this - # event's dt. - if event.dt >= self.simulation_dt: - self.this_snapshot_dt = event.dt - - # The algorithm spent enough time processing that it - # missed the new event. Wait to start a new snapshot until - # the events catch up to the algo's simulated dt. - else: - self.this_snapshot_dt = None - - def simulate_current_snapshot(self): - """ - Run the user's algo against our current snapshot and update the algo's - simulated time. - """ - start_tic = datetime.now() - self.algo.handle_data(self.universe) - stop_tic = datetime.now() - - # How long did you take? - delta = stop_tic - start_tic - - # Update the simulation time. - self.simulation_dt = self.this_snapshot_dt + delta + # Send the current state of the universe to the user's algo. + self.simulate_snapshot(date) def update_universe(self, event): """ @@ -305,3 +307,23 @@ class AlgorithmSimulator(object): # Update our knowledge of this event's sid for field in event.keys(): self.universe[event.sid][field] = event[field] + + def simulate_snapshot(self, date): + """ + Run the user's algo against our current snapshot and update + the algo's simulated time. + """ + # Needs to be set so that we inject the proper date into algo + # log/print lines. + self.snapshot_dt = date + + start_tic = datetime.now() + with self.heartbeat_monitor: + self.algo.handle_data(self.universe) + stop_tic = datetime.now() + + # How long did you take? + delta = stop_tic - start_tic + + # Update the simulation time. + self.simulation_dt = date + delta diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 099fc1eb..275a07ec 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -3,6 +3,7 @@ Generator versions of transforms. """ import types import pytz +import logbook from copy import deepcopy from datetime import datetime, timedelta @@ -15,6 +16,8 @@ from zipline.utils.tradingcalendar import trading_days_between from zipline.gens.utils import assert_sort_unframe_protocol, \ assert_transform_protocol, hash_args +log = logbook.Logger('Transform') + class Passthrough(object): FORWARDER = True """ @@ -72,6 +75,7 @@ class StatefulTransform(object): # Create the string associated with this generator's output. self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) + log.info('StatefulTransform [%s] initialized' % self.namestring) def get_hash(self): return self.namestring @@ -82,7 +86,7 @@ class StatefulTransform(object): def _gen(self, stream_in): # IMPORTANT: Messages may contain pointers that are shared with # other streams, so we only manipulate copies. - + log.info('Running StatefulTransform [%s]' % self.get_hash()) for message in stream_in: # allow upstream generators to yield None to avoid @@ -101,7 +105,8 @@ class StatefulTransform(object): # FORWARDER flag means we want to keep all original # values, plus append tnfm_id and tnfm_value. Used for # preserving the original event fields when our output - # will be fed into a merge. + # will be fed into a merge. Currently only Passthrough + # uses this flag. if self.forward_all: out_message = message_copy out_message.tnfm_id = self.namestring @@ -143,6 +148,7 @@ class StatefulTransform(object): out_message.dt = message_copy.dt yield out_message + log.info('Finished StatefulTransform [%s]' % self.get_hash()) class EventWindow: """ Abstract base class for transform classes that calculate iterative @@ -153,8 +159,9 @@ class EventWindow: from the window. Subclass these methods along with init(*args, **kwargs) to calculate metrics over the window. - The market_aware flag is used to toggle whether the eventwindow - calculates + If the market_aware flag is True, the EventWindow drops old events + based on the number of elapsed trading days between newest and oldest. + Otherwise old events are dropped based on a raw timedelta. See zipline/gens/mavg.py and zipline/gens/vwap.py for example implementations of moving average and volume-weighted average diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 1ac85df6..b8ee6ac4 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -67,12 +67,17 @@ def hash_args(*args, **kwargs): return hasher.hexdigest() def create_trade(sid, price, amount, datetime, source_id = "test_factory"): + row = ndict({ 'source_id' : source_id, 'type' : DATASOURCE_TYPE.TRADE, 'sid' : sid, 'dt' : datetime, 'price' : price, + 'close' : price, + 'open' : price, + 'low' : price * .95, + 'high' : price * 1.05, 'volume' : amount }) return row diff --git a/zipline/lines.py b/zipline/lines.py index c05c5c1c..b1463e38 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -63,30 +63,20 @@ import sys import zmq import os from signal import SIGHUP, SIGINT -import datetime -import pytz -import pandas as pd -import numpy as np - import multiprocessing from setproctitle import setproctitle from zipline.test_algorithms import TestAlgorithm from zipline.finance.trading import SIMULATION_STYLE -from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe +from zipline.utils.log_utils import ZeroMQLogHandler from zipline.utils import factory -from zipline.utils.factory import create_trading_environment -from zipline.gens.tradegens import SpecificEquityTrades -from zipline import ndict -from zipline.protocol import DATASOURCE_TYPE -from zipline.test_algorithms import TestAlgorithm - -from zipline.gens.composites import \ - date_sorted_sources, merged_transforms, sequential_transforms -from zipline.gens.transform import Passthrough, StatefulTransform +from zipline.gens.composites import ( + date_sorted_sources, + sequential_transforms +) from zipline.gens.tradesimulation import TradeSimulationClient as tsc -from logbook import Logger, NestedSetup, Processor +from logbook import Logger import zipline.protocol as zp @@ -181,6 +171,8 @@ class SimulatedTrading(object): def close(self): log.info("Closing Simulation: {id}".format(id=self.sim_id)) + if self.results_socket: + self.results_socket.close() if self.proc and self.send_sighup: ppid = os.getppid() if self.success: @@ -253,17 +245,10 @@ class SimulatedTrading(object): else: return [] - def __iter__(self): - return self - - def next(self): - return self.gen.next() - @staticmethod def create_test_zipline(**config): """ - :param config: A configuration object that is a dict with - (all optional): + :param config: A configuration object that is a dict with: - environment - a \ :py:class:`zipline.finance.trading.TradingEnvironment` @@ -285,7 +270,12 @@ class SimulatedTrading(object): of StatefulTransform objects. """ assert isinstance(config, dict) - sid = config.get('sid', 133) + sid_list = config.get('sid_list') + if not sid_list: + sid = config.get('sid') + sid_list = [sid] + + concurrent_trades = config.get('concurrent_trades', False) #-------------------- # Trading Environment @@ -329,11 +319,13 @@ class SimulatedTrading(object): trade_source = config['trade_source'] else: trade_source = factory.create_daily_trade_source( - sids, + sid_list, trade_count, - trading_environment + trading_environment, + concurrent=concurrent_trades ) + #------------------- # Transforms #------------------- @@ -367,98 +359,3 @@ class SimulatedTrading(object): #------------------- return sim - - -def create_sp_source(start_dt=None, end_dt=None): - if start_dt is None: - start_dt = datetime.datetime(2002, 1, 1, tzinfo=pytz.utc) - if end_dt is None: - end_dt = datetime.datetime(2008, 1, 1, tzinfo=pytz.utc) - - sp_events, _ = factory.load_market_data() - sp_transformed = [] - for event in sp_events: - transformed = ndict(event.to_dict()) - if (transformed.dt < start_dt) or (transformed.dt > end_dt): - continue - transformed['sid'] = 0 - transformed['price'] = transformed['returns'] - transformed['type'] = DATASOURCE_TYPE.TRADE - sp_transformed.append(transformed) - - source = SpecificEquityTrades(event_list=sp_transformed) - - return source - -class Zipline(object): - def __init__(self, **kwargs): - algorithm = kwargs.get('algorithm', TestAlgorithm) - source_descrs = kwargs.get('sources', ['S&P']) - if isinstance(source_descrs, str): - source_descrs = [source_descrs] - - sources = [] - for source_descr in source_descrs: - if isinstance(source_descr, str): - if source_descr == 'S&P': - source = create_sp_source() - else: - raise NotImplementedError, "Source with name {source_descr} not known.".format(source_descr=source_descr) - else: - source = source_descr - - sources.append(source) - - environment = kwargs.get('environment', create_trading_environment()) - - try: - transform_descrs = kwargs.get('transforms', algorithm.registered_transforms) - except: - print "Couldn't load any registered_transforms." - transform_descrs = {} - - # Create transforms by wrapping them into StatefulTransforms - transforms = [] - for namestring, trans_descr in transform_descrs.iteritems(): - sf = StatefulTransform( - trans_descr['class'], - *trans_descr['args'], - **trans_descr['kwargs'] - ) - sf.namestring = namestring - - transforms.append(sf) - - results_socket_uri = None - context = None - sim_id = None - style = SIMULATION_STYLE.FIXED_SLIPPAGE - - self.simulated_trading = SimulatedTrading( - sources, - transforms, - algorithm, - environment, - style, - results_socket_uri, - context, - sim_id) - - - def run(self): - # drain simulated_trading - perfs = [perf for perf in self.simulated_trading] - - # create daily stats dataframe - daily_perfs = [] - cum_perfs = [] - for perf in perfs: - if 'daily_perf' in perf: - daily_perfs.append(perf['daily_perf']) - else: - cum_perfs.append(perf) - - daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs] - daily_stats = pd.DataFrame(daily_perfs, index=daily_dts) - - return daily_stats diff --git a/zipline/optimize/example.py b/zipline/optimize/example.py index 7c0e58b4..b2f9c5d0 100644 --- a/zipline/optimize/example.py +++ b/zipline/optimize/example.py @@ -1,5 +1,8 @@ from zipline.lines import Zipline import pandas as pd +import pandas.io.data as dt +from pandas.io.data import DataReader + import numpy as np #from mpl_toolkits.mplot3d import Axes3D @@ -9,24 +12,28 @@ from zipline.gens.mavg import MovingAverage from zipline.optimize.algorithms import TradingAlgorithm from datetime import timedelta -from mpi4py_map import map +#from mpi4py_map import map # Inherits from Algorithm base class class DMA(TradingAlgorithm): """Dual Moving Average algorithm. """ - def __init__(self, sid, amount=100, short_window=20, long_window=40): - self.sids = [sid] + def __init__(self, sids, amount=100, short_window=20, long_window=40): + self.sids = sids self.amount = amount self.done = False self.order = None self.frame_count = 0 self.portfolio = None self.orders = [] - self.market_entered = False + self.prices = [] self.events = 0 + self.invested = {} + for sid in self.sids: + self.invested[sid] = False + self.add_transform(MovingAverage, 'short_mavg', ['price'], market_aware=False, delta=timedelta(days=int(short_window))) @@ -37,36 +44,139 @@ class DMA(TradingAlgorithm): def handle_data(self, data): self.events += 1 - sid = self.sids[0] - # access transforms via their user-defined tag - if (data[sid].short_mavg > data[sid].long_mavg) and not self.market_entered: - self.order(sid, 100) - self.market_entered = True - elif (data[sid].short_mavg < data[sid].long_mavg) and self.market_entered: - self.order(sid, -100) - self.market_entered = False + + for sid in self.sids: + # access transforms via their user-defined tag + if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]: + self.order(sid, self.amount) + self.invested[sid] = True + elif (data[sid].short_mavg['price'] < data[sid].long_mavg['price']) and self.invested[sid]: + self.order(sid, -self.amount) + self.invested[sid] = False + + +class DanVWAP(TradingAlgorithm): + """Dual Moving Average algorithm. + """ + def __init__(self, sids, amount=100, short_window=20, long_window=40): + self.sids = sids + self.amount = amount + self.done = False + self.order = None + self.frame_count = 0 + self.portfolio = None + self.orders = [] + + self.prices = [] + self.port = 0 + + self.add_transform(MovingAverage, 'short_mavg', ['price'], + market_aware=False, + delta=timedelta(days=int(short_window))) + + self.add_transform(MovingAverage, 'long_mavg', ['price'], + market_aware=False, + delta=timedelta(days=int(long_window))) + + def handle_data(self, data): + for sid in self.sids: + average=data[sid].vwap(5) + price=data[sid].price + + if price>average*1.05: + self.order(sid, self.amount) + + +def load_close_px(indexes=None, stocks=None): + if indexes is None: + indexes = {'SPX' : '^GSPC'} + if stocks is None: + stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP'] + + start = pd.datetime(1990, 1, 1) + end = pd.datetime.today() + + data = {} + for stock in stocks: + print stock + stkd = DataReader(stock, 'yahoo', start, end).sort_index() + data[stock] = stkd + + for name, ticker in indexes.iteritems(): + print name + stkd = DataReader(ticker, 'yahoo', start, end).sort_index() + data[name] = stkd + + df = pd.DataFrame({key: d['Close'] for key, d in data.iteritems()}) + + return df def run((short_window, long_window)): data = pd.DataFrame.from_csv('SP500.csv') - myalgo = DMA(sid=0, amount=100, short_window=short_window, long_window=long_window) + myalgo = DMA([0], amount=100, short_window=short_window, long_window=long_window) stats = myalgo.run(data, compute_risk_metrics=False) stats['sw'] = short_window stats['lw'] = long_window return stats -sws, lws = np.mgrid[50:80:5, 100:140:5] +def explore_params(): + sws, lws = np.mgrid[10:20:5, 10:20:5] -stats_all = map(run, zip(sws.flatten(), lws.flatten())) + stats_all = map(run, zip(sws.flatten(), lws.flatten())) + stats = pd.concat(stats_all) + returns = stats.groupby(['sw', 'lw']).sum() -# for sw, lw in zip(sws.flatten(), lws.flatten()): -# stats = run(short_window=sw, long_window=lw) -# stats_all.append(stats) + plt.contourf(sws, lws, returns.returns.reshape(sws.shape)) + plt.xlabel('Short window length') + plt.ylabel('Long window length') + plt.savefig('DMA_contour.png') + plt.show() -stats = pd.concat(stats_all) -returns = stats.groupby(['sw', 'lw']).sum() -plt.contourf(sws, lws, returns.returns.reshape(sws.shape)) -plt.xlabel('Short window length') -plt.ylabel('Long window length') -plt.savefig('DMA_contour.png') -plt.show() \ No newline at end of file +#stats = run((10, 50)) + +def get_opt_holdings_qp(univ_rets, track_rets): + from cvxopt import matrix + from cvxopt.solvers import qp + # set up the QP for CVXOPT + # .5 x' P x + q'x + # P = 2 * R'R + # q = - 2 * bmk'R + R = univ_rets.values + b = track_rets.values + P = matrix(2 * np.dot(R.T, R)) + q = matrix(-2 * np.dot(R.T, b)) + result = qp(P, q) + if result['status'] != 'optimal': + raise Exception('optimum not reached by QP') + return pd.Series(np.array(result['x']).ravel(), index=univ_rets.columns) + +def opt_portfolio(cov, budget, min_return): + from cvxopt import matrix + from cvxopt.solvers import qp + n = len(cov) + cov = matrix(2 * cov) + q = matrix(np.zeros(n)) + + h = matrix(budget) # G*x < h + # coneqp + result = qp(cov, q, h=h) + if result['status'] != 'optimal': + raise Exception('optimum not reached by QP') + + return pd.Series(np.array(result['x']).ravel()) + +def calc_te(weights, univ_rets, track_rets): + port_rets = (univ_rets * weights).sum(1) + return (port_rets - track_rets).std() + +def plot_returns(port_returns, bmk_returns): + plt.figure() + cum_port = ((1 + port_returns).cumprod() - 1) + cum_bmk = ((1 + bmk_returns).cumprod() - 1) + # cum_port = port_returns.cumsum() + # cum_bmk = bmk_returns.cumsum() + cum_port.plot(label='Portfolio returns') + cum_bmk.plot(label='Benchmark') + plt.title('Portfolio performance') + plt.legend(loc='best') diff --git a/zipline/optimize/factory.py b/zipline/optimize/factory.py index 3a7fdfff..ff91dd35 100644 --- a/zipline/optimize/factory.py +++ b/zipline/optimize/factory.py @@ -133,6 +133,6 @@ def create_predictable_zipline(config, offset=0, simulate=True): zipline = SimulatedTrading.create_test_zipline(**config) if simulate: - zipline.simulate(blocking=True) + zipline.drain_zipline(blocking=True) return zipline, config diff --git a/zipline/protocol.py b/zipline/protocol.py index dd46bd60..f9c1326b 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -132,6 +132,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now # ----------------------- PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG'] +PRICE_FIELDS = ['price', 'open', 'close', 'high', 'low'] INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL') @@ -428,21 +429,26 @@ def TRADE_FRAME(event): assert isinstance(event, ndict) assert event.type == DATASOURCE_TYPE.TRADE assert isinstance(event.sid, int) - assert isinstance(event.price, numbers.Real) + for field in PRICE_FIELDS: + assert isinstance(event[field], numbers.Real) assert isinstance(event.volume, numbers.Integral) PACK_DATE(event) return msgpack.dumps(tuple([ event.sid, event.price, + event.open, + event.close, + event.high, + event.low, event.volume, event.dt, - event.type, + event.type ])) def TRADE_UNFRAME(msg): try: packed = msgpack.loads(msg) - sid, price, volume, dt, source_type = packed + sid, price, open, close, high, low, volume, dt, source_type = packed assert isinstance(sid, int) assert isinstance(price, numbers.Real) @@ -450,6 +456,10 @@ def TRADE_UNFRAME(msg): rval = ndict({ 'sid' : sid, 'price' : price, + 'open' : open, + 'close' : close, + 'high' : high, + 'low' : low, 'volume' : volume, 'dt' : dt, 'type' : source_type @@ -654,7 +664,13 @@ def tuple_to_date(date_tuple): dt = dt.replace(microsecond = micros, tzinfo = pytz.utc) return dt +# Datasource type should completely determine the other fields of a +# message with its type. DATASOURCE_TYPE = Enum( + 'AS_TRADED_EQUITY', + 'MERGER', + 'SPLIT', + 'DIVIDEND', 'TRADE', 'EMPTY', 'DONE' @@ -720,6 +736,9 @@ def LOG_FRAME(payload): assert payload.has_key('msg'),\ "LOG_FRAME with no message" + # truncation will only work with strings and msgpack will + # preserve primitives. + payload['msg'] = str(payload['msg']) return BT_UPDATE_FRAME('LOG', payload) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 723c5fb9..0e5a0273 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -237,6 +237,56 @@ class DivByZeroAlgorithm(): def get_sid_filter(self): return [self.sid] +class InitializeTimeoutAlgorithm(): + def __init__(self, sid): + self.sid = sid + self.incr = 0 + + def initialize(self): + import time + from zipline.gens.tradesimulation import INIT_TIMEOUT + time.sleep(INIT_TIMEOUT + 1) + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + pass + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + pass + + def get_sid_filter(self): + return [self.sid] + +class TooMuchProcessingAlgorithm(): + def __init__(self, sid): + self.sid = sid + + def initialize(self): + pass + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + pass + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + # Unless we're running on some sort of + # supercomputer this will hit timeout. + for i in xrange(1000000000): + self.foo = i + + def get_sid_filter(self): + return [self.sid] + class TimeoutAlgorithm(): def __init__(self, sid): diff --git a/zipline/utils/delayed_signals.py b/zipline/utils/delayed_signals.py new file mode 100644 index 00000000..9ca8c811 --- /dev/null +++ b/zipline/utils/delayed_signals.py @@ -0,0 +1,40 @@ +from functools import wraps +from signal import signal + +class delayed_signals(object): + """ + Utility to temporary intercept one or more signals while a function or code + block is executed, restore their signal handlers at the end of execution, + and invoke them if the signals were in fact received during execution. + + Can be used either as a decorator or a context manager. + + Pass in an iterable of signals to intercept. + """ + + def handler(self, signum, frame=None): + self.got.append({'signum': signum, 'frame': frame}) + + def __init__(self, signals): + self.signals = signals + self.handlers = {} + self.got = [] + + def __enter__(self): + for signum in self.signals: + # signal() returns the old signal handler + self.handlers[signum] = signal(signum, self.handler) + + def __exit__(self, time, value, traceback): + for signum, handler in self.handlers.items(): + signal(signum, handler) + for signum, frame in ((i['signum'], i['frame']) for i in self.got): + self.handlers[signum](signum, frame) + + def __call__(self, fn): + @wraps(fn) + def call_fn(*args, **kwargs): + with self: + outval = fn(*args, **kwargs) + return outval + return call_fn diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index e3d92443..cf2168fb 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -174,7 +174,7 @@ def create_random_trade_source(sid, trade_count, trading_environment): return source -def create_daily_trade_source(sids, trade_count, trading_environment): +def create_daily_trade_source(sids, trade_count, trading_environment, concurrent=False): """ creates trade_count trades for each sid in sids list. @@ -189,11 +189,12 @@ def create_daily_trade_source(sids, trade_count, trading_environment): sids, trade_count, timedelta(days=1), - trading_environment + trading_environment, + concurrent=concurrent ) -def create_minutely_trade_source(sids, trade_count, trading_environment): +def create_minutely_trade_source(sids, trade_count, trading_environment, concurrent=False): """ creates trade_count trades for each sid in sids list. @@ -208,10 +209,11 @@ def create_minutely_trade_source(sids, trade_count, trading_environment): sids, trade_count, timedelta(minutes=1), - trading_environment + trading_environment, + concurrent=concurrent ) -def create_trade_source(sids, trade_count, trade_time_increment, trading_environment): +def create_trade_source(sids, trade_count, trade_time_increment, trading_environment, concurrent=False): args = tuple() kwargs = { @@ -219,7 +221,8 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ 'sids' : sids, 'start' : trading_environment.first_open, 'delta' : trade_time_increment, - 'filter' : sids + 'filter' : sids, + 'concurrent' : concurrent } source = SpecificEquityTrades(*args, **kwargs) diff --git a/zipline/utils/log_utils.py b/zipline/utils/log_utils.py index f9fbc57c..ea1abf18 100644 --- a/zipline/utils/log_utils.py +++ b/zipline/utils/log_utils.py @@ -89,6 +89,7 @@ class ZeroMQLogHandler(Handler): def __init__(self, socket=None, level=NOTSET, filter=None, bubble=False, context=None, fds = LOG_FIELDS, extra_fds = LOG_EXTRA_FIELDS): Handler.__init__(self, level, filter, bubble) + try: import zmq except ImportError: diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 3a9858a3..3a9a0906 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -15,6 +15,7 @@ def setup_logger(test, path='/var/log/zipline/zipline.log'): def teardown_logger(test): test.log_handler.pop_application() + test.log_handler.close() def check_list(test, a, b, label): test.assertTrue(isinstance(a, (list, blist.blist))) @@ -91,11 +92,13 @@ def create_receiver(socket_addr, ctx): return receiver -def drain_receiver(receiver): +def drain_receiver(receiver, count=None): output = [] transaction_count = 0 + msg_counter = 0 while True: msg = receiver.recv() + msg_counter += 1 update = zp.BT_UPDATE_UNFRAME(msg) output.append(update) if update['prefix'] == 'PERF': @@ -106,14 +109,18 @@ def drain_receiver(receiver): elif update['prefix'] == 'DONE': break + if count and msg_counter >= count: + break + receiver.close() del receiver return output, transaction_count -def assert_single_position(test, zipline): - output, transaction_count = drain_zipline(test, zipline) +def assert_single_position(test, zipline, blocking=False): + output, transaction_count = drain_zipline(test, zipline, p_blocking=blocking) + test.assertEqual(output[-1]['prefix'], 'DONE') test.assertEqual( test.zipline_test_config['order_count'], diff --git a/zipline/utils/timeout.py b/zipline/utils/timeout.py new file mode 100644 index 00000000..2985222e --- /dev/null +++ b/zipline/utils/timeout.py @@ -0,0 +1,128 @@ +import signal + +from functools import wraps + +from pprint import pprint as pp +from numbers import Number +from logbook import Logger + +class Timeout(Exception): + + def __init__(self, frame, message=''): + self.frame = frame + self.message = message + +class timeout(object): + """ + Utility to make a function raise TimeoutException if it spends + more than a specified number of seconds executing. Can be used + as a decorator to apply a static timeout to a function, or as + a context manager to dynamically add a timeout to a code block. + """ + + def __init__(self, seconds, message=''): + self.seconds = seconds + self.message = message + assert isinstance(seconds, Number), "Failed to specify a timeout." + assert seconds > 0, "Timeout must be greater than 0" + + def handler(self, signum, frame): + raise Timeout(frame, self.message) + + def __call__(self, fn): + + @wraps(fn) + def call_fn_with_timeout(*args, **kwargs): + # Set the alarm. + signal.signal(signal.SIGALRM, self.handler) + signal.setitimer(signal.ITIMER_REAL, self.seconds, 0) + try: + outval = fn(*args, **kwargs) + + # Deactivate the alarm once we're done so that the + # decorator doesn't have unexpected side-effects later. + # Note that this will still raise Timeout if the + # call to fn takes too long. + finally: + signal.setitimer(signal.ITIMER_REAL, 0, 0) + signal.signal(signal.SIGALRM, signal.SIG_DFL) + + # Return the value of fn if it finished before the alarm. This + # won't execute if the Timeout was raised. + return outval + return call_fn_with_timeout + + def __enter__(self): + # Set the alarm on entrance. + signal.signal(signal.SIGALRM, self.handler) + signal.setitimer(signal.ITIMER_REAL, self.seconds, 0) + + def __exit__(self, type, value, traceback): + # Deactivate the alarm on exit. This will re-raise + # any exceptions raised inside the with block. + signal.signal(signal.SIGALRM, self.handler) + signal.setitimer(signal.ITIMER_REAL, 0, 0) + +class heartbeat(object): + """ + Utility to perform pseudo-heartbeat checks on a single-threaded + function. Calls frame_handler on the current stack frame of the + wrapped function every ``interval`` seconds. After ``max_interval`` + intervals, raises Timeout. Can be used either as a decorator or + a context manager. + """ + def __init__(self, + interval, + max_intervals, + frame_handler=None, + timeout_message=''): + + self.interval = interval + self.max_intervals = max_intervals + self.frame_handler = frame_handler + self.timeout_message = timeout_message + self.count = 0 + + def handler(self, signum, frame): + self.count += 1 + if self.frame_handler: + self.frame_handler(self.count, frame) + + if self.count >= self.max_intervals: + raise Timeout(frame, self.timeout_message) + + def __call__(self, fn): + + @wraps(fn) + def call_fn_with_heartbeat(*args, **kwargs): + # Set a timer to call our handler every ``interval`` seconds. + signal.signal(signal.SIGALRM, self.handler) + signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval) + try: + outval = fn(*args, **kwargs) + + finally: + # Deactivate the timer once we're done so that the + # decorator doesn't have unexpected side-effects later. + signal.setitimer(signal.ITIMER_REAL, 0, 0) + signal.signal(signal.SIGALRM, signal.SIG_DFL) + self.count = 0 + + # Return the value of fn if it finished without tripping + # an exception. This won't execute if the Timeout or any + # other exception was raised by self.handle. + return outval + return call_fn_with_heartbeat + + def __enter__(self): + # Set a timer to call our handler every N seconds. + self.count = 0 + signal.signal(signal.SIGALRM, self.handler) + signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval) + + def __exit__(self, type, value, traceback): + # Turn off the timer on exit. This will re-raise any exception raised + # during execution of the with-block + self.count = 0 + signal.setitimer(signal.ITIMER_REAL, 0, 0) + signal.signal(signal.SIGALRM, signal.SIG_DFL)