diff --git a/tests/test_exception_handling.py b/tests/test_exception_handling.py index d1561837..ac6339bc 100644 --- a/tests/test_exception_handling.py +++ b/tests/test_exception_handling.py @@ -3,7 +3,8 @@ import zmq from unittest2 import TestCase from collections import defaultdict -from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm +from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm, \ + InitializeTimeoutAlgorithm, TooMuchProcessingAlgorithm from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator from zipline.lines import SimulatedTrading @@ -143,3 +144,35 @@ class ExceptionTestCase(TestCase): # make sure our path shortening is working self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py') self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py') + + def test_initialize_timeout(self): + + self.zipline_test_config['algorithm'] = \ + InitializeTimeoutAlgorithm( + self.zipline_test_config['sid'] + ) + + zipline = SimulatedTrading.create_test_zipline( + **self.zipline_test_config + ) + output, _ = drain_zipline(self, zipline) + self.assertEqual(output[-1]['prefix'], 'EXCEPTION') + payload = output[-1]['payload'] + self.assertEqual(payload['name'],'Timeout') + self.assertEqual(payload['message'], 'Call to initialize timed out') + +# def test_heartbeat(self): + +# self.zipline_test_config['algorithm'] = \ +# TooMuchProcessingAlgorithm( +# self.zipline_test_config['sid'] +# ) +# zipline = SimulatedTrading.create_test_zipline( +# **self.zipline_test_config +# ) +# output, _ = drain_zipline(self, zipline) +# self.assertEqual(output[-1]['prefix'], 'EXCEPTION') +# payload = output[-1]['payload'] +# self.assertEqual(payload['name'],'Timeout') +# self.assertEqual(payload['message'], 'Too much time spent in handle_data call') + diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 8784a4fc..811f99ac 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -6,7 +6,7 @@ from numbers import Integral from itertools import groupby from zipline import ndict -from zipline.utils.timeout import heartbeat +from zipline.utils.timeout import timeout, heartbeat, Timeout from zipline.gens.transform import StatefulTransform from zipline.finance.trading import TransactionSimulator @@ -16,8 +16,10 @@ from zipline.gens.utils import hash_args log = Logger('Trade Simulation') -class AlgoTimeoutException(Exception): - pass +# TODO: make these arguments rather than global constants +INIT_TIMEOUT = 5 +HEARTBEAT_INTERVAL = 1 # seconds +MAX_HEARTBEAT_INTERVALS = 15 class TradeSimulationClient(object): """ @@ -145,6 +147,21 @@ class AlgorithmSimulator(object): self.algolog = Logger("AlgoLog") self.algo.set_logger(self.algolog) + # Handler for heartbeats during calls to handle_data. + def log_heartbeats(beat_count, stackframe): + t = beat_count * HEARTBEAT_INTERVAL + warning = "handle_data has been processing for %i seconds" %t + self.algolog.warn(warning) + + # Context manager that calls log_heartbeats every HEARTBEAT_INTERVAL + # seconds, raising an exception after MAX_HEARTBEATS + self.heartbeat_monitor = heartbeat( + HEARTBEAT_INTERVAL, + MAX_HEARTBEAT_INTERVALS, + frame_handler=log_heartbeats, + timeout_message="Too much time spent in handle_data call" + ) + # ============== # Snapshot Setup # ============== @@ -223,17 +240,11 @@ class AlgorithmSimulator(object): # Capture any output of this generator to stdout and pipe it # to a logbook interface. Also inject the current algo # snapshot time to any log record generated. - with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''): - #Set an alarm to go off if initialize takes more than 5 seconds. - signal.signal(signal.SIGALRM, self.handle_init_timeout) - signal.alarm(5) - # Call the user's initialize method. - self.algo.initialize() - # Deactivate the alarm. - signal.alarm(0) - signal.signal(signal.SIGALRM, signal.SIG_DFL) + # Call user's initialize method with a timeout. + with timeout(INIT_TIMEOUT, message="Call to initialize timed out"): + self.algo.initialize() # Group together events with the same dt field. This depends on the # events already being sorted. @@ -243,7 +254,7 @@ class AlgorithmSimulator(object): # This should only occur once, at the start of the test. if self.simulation_dt == None: self.simulation_dt = date - + # Done message has the risk report, so we yield before exiting. if date == 'DONE': for event in snapshot: @@ -296,8 +307,6 @@ class AlgorithmSimulator(object): for field in event.keys(): self.universe[event.sid][field] = event[field] - # Ping every 10 seconds. Timeout after 9 pings. - # @heartbeat(10, 9, self.handle_simulation_ping) def simulate_snapshot(self, date): """ Run the user's algo against our current snapshot and update @@ -308,8 +317,8 @@ class AlgorithmSimulator(object): self.snapshot_dt = date start_tic = datetime.now() - - self.algo.handle_data(self.universe) + with self.heartbeat_monitor: + self.algo.handle_data(self.universe) stop_tic = datetime.now() # How long did you take? @@ -317,17 +326,3 @@ class AlgorithmSimulator(object): # Update the simulation time. self.simulation_dt = date + delta - - def handle_init_timeout(self, signum, frame): - """ - Handler method for initialize timeout. - """ - log.error("Algorithm timed out during initialize.") - raise AlgoTimeoutException("More than 5 seconds in initialize.") - - def handle_simulation_ping(self, frame): - """ - Frame handler for decorated simulate_snapshot method. - """ - print 'foo' - diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index a7881fa8..f6bdfd4d 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -221,6 +221,56 @@ class DivByZeroAlgorithm(): def get_sid_filter(self): return [self.sid] +class InitializeTimeoutAlgorithm(): + def __init__(self, sid): + self.sid = sid + self.incr = 0 + + def initialize(self): + import time + from zipline.gens.tradesimulation import INIT_TIMEOUT + time.sleep(INIT_TIMEOUT + 1) + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + pass + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + pass + + def get_sid_filter(self): + return [self.sid] + +class TooMuchProcessingAlgorithm(): + def __init__(self, sid): + self.sid = sid + + def initialize(self): + pass + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + pass + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + # Unless we're running on some sort of + # supercomputer this will hit timeout. + for i in xrange(100000000): + self.foo = i + + def get_sid_filter(self): + return [self.sid] + class TimeoutAlgorithm(): def __init__(self, sid):