diff --git a/tests/test_finance.py b/tests/test_finance.py index b9783471..f841e37f 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -150,7 +150,7 @@ class FinanceTestCase(TestCase): # TODO: for some reason the orders aren't filled without an extra # trade. - trade_count = 5 + trade_count = 5000 self.zipline_test_config['order_count'] = trade_count - 1 self.zipline_test_config['trade_count'] = trade_count self.zipline_test_config['order_amount'] = 1 diff --git a/zipline/components/tradesimulation.py b/zipline/components/tradesimulation.py index b1a4ff53..0a711c4b 100644 --- a/zipline/components/tradesimulation.py +++ b/zipline/components/tradesimulation.py @@ -44,36 +44,26 @@ class TradeSimulationClient(Component): return str(zp.FINANCE_COMPONENT.TRADING_CLIENT) def set_algorithm(self, algorithm): - """ :param algorithm: must implement the algorithm protocol. See :py:mod:`zipline.test.algorithm` """ self.algorithm = algorithm - # register the trading_client's order method with the algorithm self.algorithm.set_order(self.order) - - #TODO: re-enable initialization logging. This means we can't call set_algorithm - #until we have a context for this component. Possibly this could happen - # ask the algorithm to initialize, routing stdout to a zmq PUSH socket. - - #with self.zmq_out.threadbound(), self.stdout_capture(self.logger, 'Algo print capture'): - # self.algorithm.initialize() - #if we don't have a log socket, initialize anyway. - #else: - # self.algorithm.initialize() - - self.algorithm.initialize() # we need to provide the performance tracker with the # sids referenced in the algorithm, so portfolio can # initialize with all possible sids. - self.perf.set_sids(self.algorithm.get_sid_filter()) + + # self.algorithm.initialize() + + def open(self): self.result_feed = self.connect_result() if not self.results_socket: log.warn(" No results socket, will not broadcast sim data.") + self.algorithm.set_logger(log) else: sock = self.context.socket(zmq.PUSH) sock.connect(self.results_socket) @@ -81,12 +71,14 @@ class TradeSimulationClient(Component): self.sockets.append(sock) self.out_socket = sock - self.setup_logging(sock) self.perf.publish_to(sock) + # register the trading_client's order method with the algorithm + self.algorithm.set_logger(self.algo_log) + + self.run_logged_op(self.algorithm.initialize) - #Initialize log capture for testing purposes. def setup_logging(self, socket = None): sock = socket or self.results_socket @@ -95,6 +87,8 @@ class TradeSimulationClient(Component): ) self.logger = Logger("Print") + self.algo_log = Logger("AlgoLog") + # N.B. that this is a class, which is instantiated later # in run_algorithm. The class provides a generator. self.stdout_capture = stdout_only_pipe @@ -188,21 +182,25 @@ class TradeSimulationClient(Component): # data injection pipeline for log rerouting # any fields injected here should be added to # LOG_EXTRA_FIELDS in zipline/protocol.py - if self.zmq_out: + self.run_logged_op(self.algorithm.handle_data, data) - def inject_event_data(record): + def run_logged_op(self, callable_op, *args, **kwargs): + """ Wrap a callable operation with the zmq logbook + handler if it exits.""" + if self.zmq_out: - #Record the simulation time. + def inject_event_data(record): + # Record the simulation time. + record.extra['algo_dt'] = self.current_dt - record.extra['algo_dt'] = self.current_dt - - data_injector = Processor(inject_event_data) - log_pipeline = NestedSetup([self.zmq_out,data_injector]) - with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''): - self.algorithm.handle_data(data) + data_injector = Processor(inject_event_data) + log_pipeline = NestedSetup([self.zmq_out,data_injector]) + with log_pipeline.threadbound(), self.stdout_capture(self.logger, ''): + callable_op(*args, **kwargs) # if no log socket, just run the algo normally - else: - self.algorithm.handle_data(data) + else: + callable_op(*args, **kwargs) + #Testing utility for log capture. # TODO: remove test code from here. diff --git a/zipline/core/monitor.py b/zipline/core/monitor.py index 1767958d..efcc6314 100644 --- a/zipline/core/monitor.py +++ b/zipline/core/monitor.py @@ -99,7 +99,6 @@ class Controller(object): log.warn("Running Controller in development mode, will ONLY synchronize start.") def init_zmq(self, flavor): - assert self.zmq_flavor in ['thread', 'mp', 'green'] if flavor == 'mp': @@ -131,7 +130,6 @@ class Controller(object): Give the controller a set set of components to manage and a set of state transitions for the entire system. """ - # A freeform topology is where we heartbeat with anything # that shows up. if topology == 'freeform': @@ -323,9 +321,10 @@ class Controller(object): # if this is the first time heartbeating, break # out early if we get everything tracked no need # to hold out for the full heartbeat. - if initializing and len(self.responses) == len(self.topology): - log.info("breaking out of initial heartbeat") - break + if initializing and not self.freeform: + if len(self.responses) == len(self.topology): + log.info("breaking out of initial heartbeat") + break # ================ # Heartbeat Stats diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 56788c49..aa07fa54 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -70,6 +70,9 @@ class TestAlgorithm(): def set_order(self, order_callable): self.order = order_callable + def set_logger(self, logger): + pass + def set_portfolio(self, portfolio): self.portfolio = portfolio @@ -106,6 +109,9 @@ class HeavyBuyAlgorithm(): def set_order(self, order_callable): self.order = order_callable + def set_logger(self, logger): + pass + def set_portfolio(self, portfolio): self.portfolio = portfolio @@ -129,6 +135,9 @@ class NoopAlgorithm(object): def set_order(self, order_callable): pass + def set_logger(self, logger): + pass + def set_portfolio(self, portfolio): pass @@ -140,7 +149,8 @@ class NoopAlgorithm(object): class ExceptionAlgorithm(object): """ - Dolce fa niente. + Throw an exception from the method name specified in the + constructor. """ def __init__(self, throw_from): @@ -158,6 +168,9 @@ class ExceptionAlgorithm(object): else: pass + def set_logger(self, logger): + pass + def set_portfolio(self, portfolio): if self.throw_from == "set_portfolio": raise Exception("Algo exception in set_portfolio") @@ -174,7 +187,7 @@ class ExceptionAlgorithm(object): if self.throw_from == "get_sid_filter": raise Exception("Algo exception in get_sid_filter") else: - return None + return [1] class TestPrintAlgorithm(): @@ -187,6 +200,9 @@ class TestPrintAlgorithm(): def set_order(self, order_callable): pass + def set_logger(self, logger): + pass + def set_portfolio(self, portfolio): pass @@ -195,4 +211,27 @@ class TestPrintAlgorithm(): pass def get_sid_filter(self): - return None + return [1] + +class TestLoggingAlgorithm(): + + def __init__(self): + self.log = None + + def initialize(self): + self.log.info("Initializing...") + + def set_order(self, order_callable): + pass + + def set_logger(self, logger): + self.log = logger + + def set_portfolio(self, portfolio): + pass + + def handle_data(self, data): + self.log.info("Handling Data...") + + def get_sid_filter(self): + return [1]