diff --git a/tests/test_finance.py b/tests/test_finance.py index bbceac25..b9783471 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -6,6 +6,7 @@ import pytz from unittest2 import TestCase from datetime import datetime, timedelta from collections import defaultdict +from logbook.compat import LoggingHandler from nose.tools import timed @@ -25,16 +26,22 @@ EXTENDED_TIMEOUT = 90 allocator = AddressAllocator(1000) + class FinanceTestCase(TestCase): leased_sockets = defaultdict(list) def setUp(self): - #qutil.configure_logging() self.zipline_test_config = { - 'allocator':allocator, - 'sid':133 + 'allocator' : allocator, + 'sid' : 133, + 'devel' : True } + self.log_handler = LoggingHandler() + self.log_handler.push_application() + + def tearDown(self): + self.log_handler.pop_application() @timed(DEFAULT_TIMEOUT) def test_factory_daily(self): @@ -106,7 +113,7 @@ class FinanceTestCase(TestCase): # non blocking. HUNCH: The trades are streaming through before the orders # are placed. - @timed(EXTENDED_TIMEOUT) + #@timed(EXTENDED_TIMEOUT) def test_orders(self): # Simulation @@ -135,7 +142,7 @@ class FinanceTestCase(TestCase): ) - @timed(DEFAULT_TIMEOUT) + #@timed(DEFAULT_TIMEOUT) def test_aggressive_buying(self): # Simulation @@ -232,7 +239,7 @@ class FinanceTestCase(TestCase): ) self.assertEqual( - zipline.sources['flat'].count, + zipline.sources['SpecificEquityTrades'].count, self.zipline_test_config['trade_count'], "The simulated trade source should send all trades." ) @@ -243,7 +250,7 @@ class FinanceTestCase(TestCase): "The algorithm should receive all trades." ) - @timed(DEFAULT_TIMEOUT) + #@timed(DEFAULT_TIMEOUT) def test_sid_filter(self): """Ensure the algorithm's filter prevents events from arriving.""" # create a test algorithm whose filter will not match any of the diff --git a/tests/test_monitor.py b/tests/test_monitor.py new file mode 100644 index 00000000..5da84b2a --- /dev/null +++ b/tests/test_monitor.py @@ -0,0 +1,43 @@ +import gevent +from logbook.compat import LoggingHandler +from unittest2 import TestCase, skip + +from zipline.core.monitor import Controller + + +class TestMonitor(TestCase): + def setUp(self): + self.log_handler = LoggingHandler() + self.log_handler.push_application() + + def tearDown(self): + self.log_handler.pop_application() + + def test_init(self): + pub_socket = 'tcp://127.0.0.1:5000' + route_socket = 'tcp://127.0.0.1:5001' + + con = Controller(pub_socket, route_socket) + con.manage([]) + + def test_init_topology(self): + pub_socket = 'tcp://127.0.0.1:5000' + route_socket = 'tcp://127.0.0.1:5001' + + con = Controller(pub_socket, route_socket, ) + con.manage([ 'a', 'b', 'c', 'd' ]) + + @skip + def test_poll(self): + from mock_zmq import zmq_synthetic + pub_socket = 'tcp://127.0.0.1:5000' + route_socket = 'tcp://127.0.0.1:5001' + cancel_socket = 'tcp://127.0.0.1:5002' + + con = Controller(pub_socket, route_socket, cancel_socket) + con.manage([ 'a', 'b', 'c', 'd' ]) + con.zmq = zmq_synthetic + con.zmq_flavor = 'green' + + con.period = 0.00001 + gevent.spawn(con.run).join(timeout=con.period) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index d979473d..1a77818c 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -16,11 +16,13 @@ class PerformanceTestCase(unittest.TestCase): self.benchmark_returns, self.treasury_curves = \ factory.load_market_data() - random_index = random.randint( - 0, - len(self.treasury_curves) - ) for n in range(100): + + random_index = random.randint( + 0, + len(self.treasury_curves) + ) + self.dt = self.treasury_curves.keys()[random_index] self.end_dt = self.dt + datetime.timedelta(days=365) @@ -29,6 +31,10 @@ class PerformanceTestCase(unittest.TestCase): if self.end_dt <= now: break + assert self.end_dt <= now, """ +failed to find a date suitable daterange after 100 attempts. please double +check treasury and benchmark data in findb, and re-run the test.""" + self.trading_environment = TradingEnvironment( self.benchmark_returns, self.treasury_curves, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 0d267e60..1fec3e7a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -2,6 +2,8 @@ from datetime import timedelta from collections import defaultdict from unittest2 import TestCase +from logbook.compat import LoggingHandler + import zipline.utils.factory as factory from zipline.finance.vwap import DailyVWAP, VWAPTransform from zipline.finance.returns import ReturnsFromPriorClose @@ -19,9 +21,15 @@ class ZiplineWithTransformsTestCase(TestCase): allocator.lease(100) self.trading_environment = factory.create_trading_environment() self.zipline_test_config = { - 'allocator':allocator, - 'sid':133 + 'allocator' : allocator, + 'sid' : 133, + 'devel' : True } + self.log_handler = LoggingHandler() + self.log_handler.push_application() + + def tearDown(self): + self.log_handler.pop_application() def test_vwap_tnfm(self): zipline = SimulatedTrading.create_test_zipline( @@ -36,8 +44,14 @@ class ZiplineWithTransformsTestCase(TestCase): self.assertFalse(zipline.sim.exception) class FinanceTransformsTestCase(TestCase): + def setUp(self): self.trading_environment = factory.create_trading_environment() + self.log_handler = LoggingHandler() + self.log_handler.push_application() + + def tearDown(self): + self.log_handler.pop_application() def test_vwap(self): diff --git a/zipline/components/aggregator.py b/zipline/components/aggregator.py index a7bd6c82..451fc28a 100644 --- a/zipline/components/aggregator.py +++ b/zipline/components/aggregator.py @@ -12,12 +12,12 @@ Abstract base class for Feed and Merge. import logbook import zipline.protocol as zp - from zipline.core.component import Component -from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \ - CONTROL_FRAME, CONTROL_UNFRAME -from zipline.utils.protocol_utils import Enum +from zipline.core.controlled import do_handle_control_events +from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE from zipline.transitions import WorkflowMeta +from zipline.utils.protocol_utils import Enum + log = logbook.Logger('Aggregate') @@ -34,6 +34,10 @@ AGGREGATE_TRANSITIONS = dict( do_drain = (READY , DRAINING) , ) +# ========= +# Component +# ========= + class Aggregate(Component): """ Abstract superclass to Merge & Feed. Acts on two sockets @@ -41,7 +45,7 @@ class Aggregate(Component): - pull_socket - feed_socket - Both use ``data_buffer`` for buffering. + Both use ``sources`` for buffering. Feed and Merge define these differently. """ @@ -53,6 +57,9 @@ class Aggregate(Component): def get_type(self): return COMPONENT_TYPE.CONDUIT + def add_source(self, source_id): + self.sources[source_id] = [] + # ------------- # Core Methods # ------------- @@ -61,39 +68,26 @@ class Aggregate(Component): # wait for synchronization reply from the host socks = dict(self.poll.poll(self.heartbeat_timeout)) - # TODO: Abstract this out, maybe on base component - if socks.get(self.control_in) == self.zmq.POLLIN: - msg = self.control_in.recv() - event, payload = CONTROL_UNFRAME(msg) - - # -- Heartbeat -- - if event == CONTROL_PROTOCOL.HEARTBEAT: - # Heart outgoing - heartbeat_frame = CONTROL_FRAME( - CONTROL_PROTOCOL.OK, - payload - ) - self.control_out.send(heartbeat_frame) - - # -- Soft Kill -- - elif event == CONTROL_PROTOCOL.SHUTDOWN: - self.signal_done() - self.shutdown() - - # -- Hard Kill -- - elif event == CONTROL_PROTOCOL.KILL: - self.kill() - + # ---------------- + # Control Dispatch + # ---------------- + #do_handle_control_events(self, socks) + # ------------- + # Work Dispatch + # ------------- if socks.get(self.pull_socket) == self.zmq.POLLIN: message = self.pull_socket.recv() if message == str(CONTROL_PROTOCOL.DONE): self.ds_finished_counter += 1 - if len(self.data_buffer) == self.ds_finished_counter: - #drain any remaining messages in the buffer + if len(self.sources) == self.ds_finished_counter: + # Drain any remaining messages in the buffer log.debug("Draining Feed") + + self.state = DRAINING + self.drain() self.signal_done() else: @@ -105,7 +99,15 @@ class Aggregate(Component): try: self.append(event) - self.send_next() + + if not (self.is_full() or self.draining): + event = self.next() + + if event: + self.send(event) + else: + pass + except zp.INVALID_DATASOURCE_FRAME as exc: # Invalid message return self.signal_exception(exc) @@ -118,30 +120,25 @@ class Aggregate(Component): """ Send all messages in the buffer. """ - self.state = DRAINING while self.pending_messages() > 0: - self.send_next() + event = self.next() + if event: + self.send(event) - def send_next(self): + def send(self, event): """ Send the (chronologically) next message in the buffer. """ - if not (self.is_full() or self.draining): - return - - event = self.next() - - if event: - self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK) - self.sent_counters[event.source_id] += 1 - self.sent_count += 1 + self.feed_socket.send(self.frame(event), self.zmq.NOBLOCK) + self.sent_counters[event.source_id] += 1 + self.sent_count += 1 def is_full(self): """ Indicates whether the buffer has messages in buffer for all un-DONE, blocking sources. """ - for source_id, events in self.data_buffer.iteritems(): + for source_id, events in self.sources.iteritems(): if len(events) == 0: return False return True @@ -152,19 +149,13 @@ class Aggregate(Component): buffer. """ total = 0 - for events in self.data_buffer.itervalues(): + for events in self.sources.itervalues(): total += len(events) return total - def add_source(self, source_id): - """ - Add a data source to the buffer. - """ - self.data_buffer[source_id] = [] - def __len__(self): """ Buffer's length is same as internal map holding separate sorted arrays of events keyed by source id. """ - return len(self.data_buffer) + return len(self.sources) diff --git a/zipline/components/datasource.py b/zipline/components/datasource.py index 4cafdb5f..e40e4e93 100644 --- a/zipline/components/datasource.py +++ b/zipline/components/datasource.py @@ -2,14 +2,10 @@ Commonly used messaging components. """ -import logging - import zipline.protocol as zp from zipline.core.component import Component from zipline.protocol import COMPONENT_TYPE -LOGGER = logging.getLogger('ZiplineLogger') - class DataSource(Component): """ Abstract baseclass for data sources. Subclass and implement send_all diff --git a/zipline/components/feed.py b/zipline/components/feed.py index f51e515d..719a99a6 100644 --- a/zipline/components/feed.py +++ b/zipline/components/feed.py @@ -1,5 +1,5 @@ import logbook -from collections import Counter +from collections import defaultdict, Counter from zipline.components.aggregator import Aggregate, \ AGGREGATE_STATES, AGGREGATE_TRANSITIONS @@ -28,7 +28,7 @@ class Feed(Aggregate): self.received_count = 0 self.ds_finished_counter = 0 - self.data_buffer = {} + self.sources = defaultdict(list) # source_id -> integer count self.sent_counters = Counter() @@ -71,7 +71,7 @@ class Feed(Aggregate): Add an event to the buffer for the source specified by source_id. """ - self.data_buffer[event.source_id].append(event) + self.sources[event.source_id].append(event) self.recv_counters[event.source_id] += 1 self.received_count += 1 @@ -84,24 +84,27 @@ class Feed(Aggregate): if not(self.is_full() or self.draining): return - cur_source = None earliest_source = None earliest_event = None - #iterate over the queues of events from all sources - #(1 queue per datasource) - for events in self.data_buffer.itervalues(): - if len(events) == 0: - continue - cur_source = events - first_in_list = events[0] - if first_in_list.dt == None: - #this is a filler event, discard - events.pop(0) + # iterate over the queues of source from all sources + # (1 queue per datasource) + + for source in self.sources.itervalues(): + if len(source) == 0: continue - if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt): - earliest_event = first_in_list - earliest_source = cur_source + head = source[0] + + if head.dt == None: + #this is a filler event, discard + source.pop(0) + continue + + if (earliest_event == None) or (head.dt <= earliest_event.dt): + earliest_event = head + earliest_source = source if earliest_event != None: return earliest_source.pop(0) + else: + return False diff --git a/zipline/components/merge.py b/zipline/components/merge.py index bbb9dc6a..6dc7661e 100644 --- a/zipline/components/merge.py +++ b/zipline/components/merge.py @@ -2,7 +2,7 @@ import zipline.protocol as zp from zipline.components.aggregator import Aggregate, \ AGGREGATE_STATES, AGGREGATE_TRANSITIONS -from collections import Counter +from collections import defaultdict, Counter class Merge(Aggregate): """ @@ -19,9 +19,7 @@ class Merge(Aggregate): self.draining = False self.ds_finished_counter = 0 - # Depending on the size of this, might want to use a data - # structure with better asymptotics. - self.data_buffer = {} + self.sources = defaultdict(list) # source_id -> integer count self.sent_counters = Counter() @@ -61,7 +59,7 @@ class Merge(Aggregate): source_id. """ - self.data_buffer[event.keys()[0]].append(event) + self.sources[event.keys()[0]].append(event) self.received_count += 1 def next(self): @@ -73,8 +71,10 @@ class Merge(Aggregate): return #get the raw event from the passthrough transform. - result = self.data_buffer[zp.TRANSFORM_TYPE.PASSTHROUGH].pop(0).PASSTHROUGH - for source, events in self.data_buffer.iteritems(): + passthrough = self.sources[zp.TRANSFORM_TYPE.PASSTHROUGH] + result = passthrough.pop(0).PASSTHROUGH + + for source, events in self.sources.iteritems(): if source == zp.TRANSFORM_TYPE.PASSTHROUGH: continue if len(events) > 0: diff --git a/zipline/core/component.py b/zipline/core/component.py index 4839e147..e7a1bfd0 100644 --- a/zipline/core/component.py +++ b/zipline/core/component.py @@ -19,6 +19,7 @@ import gevent_zeromq # zmq_ctypes #import zmq_ctypes +from zipline.protocol import CONTROL_UNFRAME from zipline.utils.gpoll import _Poller as GeventPoller from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_STATE, \ COMPONENT_FAILURE, CONTROL_FRAME @@ -28,7 +29,7 @@ log = logbook.Logger('Component') from zipline.exceptions import ComponentNoInit from zipline.transitions import WorkflowMeta -# LOGBOOK - embed PID in log output +log = logbook.Logger('Base') class Component(object): @@ -39,11 +40,6 @@ class Component(object): :param addresses: a dict of name_string -> zmq port address strings. Must have the following entries - :param sync_address: socket address used for synchronizing the start of - all workers, heartbeating, and exit notification - will be used in REP/REQ sockets. Bind is always on - the REP side. - :param data_address: socket address used for data sources to stream their records. Will be used in PUSH/PULL sockets between data sources and a Feed. Bind will always @@ -82,12 +78,15 @@ class Component(object): self.zmq = None self.context = None self.addresses = None + self.waiting = None self.out_socket = None self.killed = False self.controller = None # timeout after a full minute self.heartbeat_timeout = 60 *1000 + # TODO: state_flag is deprecated, remove + # TODO: error_state is deprecated, remove self.state_flag = COMPONENT_STATE.OK self.error_state = COMPONENT_FAILURE.NOFAILURE self.on_done = None @@ -98,6 +97,7 @@ class Component(object): self.stop_tic = None self.note = None self.confirmed = False + self.devel = False # Humanhashes make this way easier to debug because they stick # in your mind unlike a 32 byte string of random hex. @@ -190,6 +190,12 @@ class Component(object): raise Exception("Unknown ZeroMQ Flavor") def _run(self): + """ + The main component loop. This is wrapped inside a + exception reporting context inside of run. + + The core logic of the all components is run here. + """ self.start_tic = time.time() self.done = False # TODO: use state flag @@ -200,11 +206,20 @@ class Component(object): self.setup_poller() self.open() - self.setup_sync() self.setup_control() + self.signal_ready() + self.lock_ready() + + self.wait_ready() + # ----------------------- + # YOU SHALL NOT PASS!!!!! + # ----------------------- + # ... until the controller signals GO + self.loop() self.shutdown() + log.info("Shutdown %r" % self) self.stop_tic = time.time() @@ -246,20 +261,8 @@ class Component(object): Loop to do work while we still have work to do. """ while self.working(): - self.confirm() self.do_work() - def confirm(self): - """ - Send a synchronization request to the host. - """ - if not self.confirmed: - # TODO: proper framing - self.sync_socket.send(self.get_id + ":RUN") - - self.receive_sync_ack() # blocking - self.confirmed = True - def runtime(self): if self.ready() and self.start_tic and self.stop_tic: return self.stop_tic - self.start_tic @@ -299,6 +302,81 @@ class Component(object): # Internal Maintenance # ---------------------- + def lock_ready(self): + """ + Unlock the component, topology is now ready to run. + """ + self.waiting = True + + def unlock_ready(self): + """ + Unlock the component, topology is still pending. + """ + self.waiting = False + + def wait_ready(self): + # Implicit side-effect of unlocking the component iff + # the GO message is received from the monitor level. + # This then unlocks the barrier and proceeds to the + # do_work state. + + # Poll on a subset of the control protocol while we exist + # in the locked quasimode. Respond to HEARTBEAT and GO + # messages. + + while self.waiting: + socks = dict(self.poll.poll(self.heartbeat_timeout)) + + msg = self.control_in.recv() + event, payload = CONTROL_UNFRAME(msg) + + # ==== + # Go + # ==== + + # A distributed lock from the controller to ensure + # synchronized start. + + if event == CONTROL_PROTOCOL.HEARTBEAT: + heartbeat_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.OK, + payload + ) + self.control_out.send(heartbeat_frame) + log.info('Prestart Heartbeat ' + self.get_id) + + elif event == CONTROL_PROTOCOL.GO: + # Side effectful call from the controller to unlock + # and begin doing work only when the entire topology + # of the system beings to come online + log.info('Unlocking ' + self.__class__.__name__) + self.unlock_ready() + + def signal_ready(self): + log.info(self.__class__.__name__ + ' is ready') + + if hasattr(self, 'control_out'): + frame = CONTROL_FRAME( + CONTROL_PROTOCOL.READY, + '' + ) + self.control_out.send(frame) + + def signal_cancel(self): + self.done = True + + # TODO: no hasattr hacks + #if not self.controller: + if hasattr(self, 'control_out'): + frame = CONTROL_FRAME( + CONTROL_PROTOCOL.SHUTDOWN, + None + ) + self.control_out.send(frame) + + # then proceeds to do shutdown(), and teardown_sockets() + # to complete the process + def signal_exception(self, exc=None, scope=None): """ This is *very* important error tracking handler. @@ -320,7 +398,7 @@ class Component(object): self._exception = exc exc_type, exc_value, exc_traceback = sys.exc_info() - trace = '\n>>>'.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) sys.stdout.write(trace) if hasattr(self, 'control_out'): @@ -342,10 +420,6 @@ class Component(object): if self.out_socket: self.out_socket.send(str(CONTROL_PROTOCOL.DONE)) - #notify host we're done - # TODO: proper framing - self.sync_socket.send(self.get_id + ":" + str(CONTROL_PROTOCOL.DONE)) - #notify controller we're done done_frame = CONTROL_FRAME( CONTROL_PROTOCOL.DONE, @@ -353,7 +427,6 @@ class Component(object): ) self.control_out.send(done_frame) - self.receive_sync_ack() #notify internal work look that we're done self.done = True # TODO: use state flag @@ -373,19 +446,6 @@ class Component(object): # ZeroMQ. Either zmq.Poller or gpoll.Poller . self.poll = self.zmq_poller() - def receive_sync_ack(self): - """ - Wait for synchronization reply from the host. - - DEPRECATED, left in for compatability for now. - """ - - socks = dict(self.sync_poller.poll(self.heartbeat_timeout)) - if self.sync_socket in socks and socks[self.sync_socket] == self.zmq.POLLIN: - message = self.sync_socket.recv() - #else: - #raise Exception("Sync ack timed out on response for {id}".format(id=self.get_id)) - def bind_data(self): return self.bind_pull_socket(self.addresses['data_address']) @@ -405,10 +465,27 @@ class Component(object): return self.connect_push_socket(self.addresses['merge_address']) def bind_result(self): - return self.bind_pub_socket(self.addresses['result_address']) + return self.bind_push_socket(self.addresses['result_address']) def connect_result(self): - return self.connect_sub_socket(self.addresses['result_address']) + return self.connect_pull_socket(self.addresses['result_address']) + + def bind_push_socket(self, addr): + push_socket = self.context.socket(self.zmq.PUSH) + push_socket.bind(addr) + self.out_socket = push_socket + self.sockets.append(push_socket) + + return push_socket + + def connect_pull_socket(self, addr): + pull_socket = self.context.socket(self.zmq.PULL) + pull_socket.connect(addr) + self.sockets.append(pull_socket) + self.poll.register(pull_socket, self.zmq.POLLIN) + + return pull_socket + def bind_pull_socket(self, addr): pull_socket = self.context.socket(self.zmq.PULL) @@ -470,24 +547,6 @@ class Component(object): self.poll.register(self.control_in, self.zmq.POLLIN) self.sockets.extend([self.control_in, self.control_out]) - def setup_sync(self): - """ - Setup the sync socket and poller. ( Connect ) - - DEPRECATED, left in for compatability for now. - """ - - #LOGGER.debug("Connecting sync client for {id}".format(id=self.get_id)) - - self.sync_socket = self.context.socket(self.zmq.REQ) - self.sync_socket.connect(self.addresses['sync_address']) - #self.sync_socket.setsockopt(self.zmq.LINGER,0) - - self.sync_poller = self.zmq_poller() - self.sync_poller.register(self.sync_socket, self.zmq.POLLIN) - - self.sockets.append(self.sync_socket) - # ----------- # FSM Actions # ----------- diff --git a/zipline/core/controlled.py b/zipline/core/controlled.py new file mode 100644 index 00000000..19073fa1 --- /dev/null +++ b/zipline/core/controlled.py @@ -0,0 +1,74 @@ +""" +Poller logic for a component which is controlled by the monitor, this is +largely universal and thus we break it out into a seperate module and +splice it into the dispatch loops for each component instance. + +Example usage:: + + def do_work(): + socks = self.poll.poll() + + # Handle control events + do_handle_control_events() + + # Handle other events + if socks.get(socket) == zmq.POLLIN: + ... +""" + +import zmq +from zipline.core.component import Component +from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, CONTROL_UNFRAME + +def do_handle_control_events(cls, poller): + assert isinstance(cls, Component) + assert cls.control_in, 'Component does not have a control_in socket' + + # If we're in devel mode drop out because the controller + # isn't guaranteed to be around anymore + if cls.devel: + return + + if poller.get(cls.control_in) == zmq.POLLIN: + msg = cls.control_in.recv() + event, payload = CONTROL_UNFRAME(msg) + + # =========== + # Heartbeat + # =========== + + # The controller will send out a single number packed in + # a CONTROL_FRAME with ``heartbeat`` event every + # (n)-seconds. The component then has n seconds to + # respond to it. If not then it will be considered as + # malfunctioning or maybe CPU bound. + + if event == CONTROL_PROTOCOL.HEARTBEAT: + # Heart outgoing + heartbeat_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.OK, + payload + ) + # Echo back the heartbeat identifier to tell the + # controller that this component is still alive and + # doing work + cls.control_out.send(heartbeat_frame) + + # ========= + # Soft Kill + # ========= + + # Try and clean up properly and send out any reports or + # data that are done during a clean shutdown. Inform the + # controller that we're done. + elif event == CONTROL_PROTOCOL.SHUTDOWN: + cls.signal_done() + cls.shutdown() + + # ========= + # Hard Kill + # ========= + + # Just exit. + elif event == CONTROL_PROTOCOL.KILL: + cls.kill() diff --git a/zipline/core/devsimulator.py b/zipline/core/devsimulator.py index 7f382056..8835cb07 100644 --- a/zipline/core/devsimulator.py +++ b/zipline/core/devsimulator.py @@ -3,9 +3,18 @@ Simulator hosts all the components necessary to execute a simulation. See :py:method"" """ +import logbook import threading from zipline.core.simulatorref import SimulatorBase +log = logbook.Logger('Dev Simulator') + +DEPRECATION_WARNING = """ +WARNING WARNING WARNING +THE DEVSIMULATOR IS DEPRECATED, IT WILL NOT BEHAVE LIKE ANY OTHER +SYSTEM USED IN TESTS OR IN PRODUCTION +""" + class AddressAllocator(object): """ Produces a iterator of 10000 sockets to allocate as needed. @@ -38,6 +47,8 @@ class Simulator(SimulatorBase): self.subthreads = [] self.running = False + log.warn(DEPRECATION_WARNING) + @property def get_id(self): return 'Simple Simulator' diff --git a/zipline/core/host.py b/zipline/core/host.py index 96ab5c05..fe97904f 100644 --- a/zipline/core/host.py +++ b/zipline/core/host.py @@ -1,26 +1,21 @@ +import os +import sys import logbook -import datetime - -from component import Component from zipline.transforms import BaseTransform from zipline.components import Feed, Merge, PassthroughTransform, \ DataSource from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_STATE -log = logbook.Logger('Host') +log = logbook.Logger('Topology') -class ComponentHost(Component): +class ComponentHost(object): """ Components that can launch multiple sub-components, synchronize their start, and then wait for all components to be finished. """ - def init(self, addresses): - assert hasattr(self, 'zmq_flavor'), \ - """ You must specify a flavor of ZeroMQ for all ComponentHost - subclasses. """ - + def __init__(self, addresses): self.addresses = addresses self.running = False @@ -32,17 +27,30 @@ class ComponentHost(Component): self._components = {} # ---------------------- - self.sync_register = {} - self.timeout = datetime.timedelta(seconds=60) + self.exception = None self.feed = Feed() self.merge = Merge() self.passthrough = PassthroughTransform() self.controller = None - #register the feed and the merge self.register_components([self.feed, self.merge, self.passthrough]) + def _run(self): + self.open() + + def run(self, catch_exceptions=True): + """ + Run the host. + """ + log.info('===== PARENT PID: %s' % os.getppid()) + + self.open() + #self.shutdown() + + def shutdown(self, ensure_clean=True): + raise NotImplementedError + def register_controller(self, controller): """ Add the given components to the registry. Establish @@ -74,33 +82,26 @@ class ComponentHost(Component): self._components[component.guid] = component self.components[component.get_id] = component - self.sync_register[component.get_id] = datetime.datetime.utcnow() if isinstance(component, DataSource): - self.feed.add_source(component.source_id) + self.feed.add_source(component.get_id) if isinstance(component, BaseTransform): self.merge.add_source(component.get_id) def unregister_component(self, component_id): del self.components[component_id] - del self.sync_register[component_id] - - def setup_sync(self): - """ - Setup the sync socket and poller. ( Bind ) - """ - #log.debug("Connecting sync server.") - - self.sync_socket = self.context.socket(self.zmq.REP) - self.sync_socket.bind(self.addresses['sync_address']) - - self.sync_poller = self.zmq_poller() - self.sync_poller.register(self.sync_socket, self.zmq.POLLIN) - - self.sockets.append(self.sync_socket) def open(self): + assert hasattr(self, 'zmq_flavor'), \ + """ You must specify a flavor of ZeroMQ for all Topology + subclasses. """ + + log.info('== Roll Call ==') + log.info('Controller') + + self.launch_controller() + for component in self.components.itervalues(): log.info(component) @@ -109,8 +110,6 @@ class ComponentHost(Component): for component in self.components.itervalues(): self.launch_component(component) - self.launch_controller() - def is_running(self): """ DEPRECATED, left in for compatability for now. @@ -122,38 +121,15 @@ class ComponentHost(Component): return True - def loop(self, lockstep=True): - - while self.is_running(): - # wait for synchronization request at start, and DONE at end. - # don't timeout. - socks = dict(self.sync_poller.poll()) - - if socks.get(self.sync_socket) == self.zmq.POLLIN: - msg = self.sync_socket.recv() - - try: - parts = msg.split(':') - sync_id, status = parts - except ValueError as exc: - self.signal_exception(exc) - - # TODO: other way around - if status == str(CONTROL_PROTOCOL.DONE): - #log.debug("Component claims done: {id}".format(id=sync_id)) - self.unregister_component(sync_id) - self.state_flag = COMPONENT_STATE.DONE - else: - self.sync_register[sync_id] = datetime.datetime.utcnow() - - #log.info("confirmed {id}".format(id=msg)) - # send synchronization reply - self.sync_socket.send('ack', self.zmq.NOBLOCK) + def ready(self): + return True # ------------------ # Simulation Control # ------------------ + # Overloaded by simulator + def launch_controller(self, controller): raise NotImplementedError diff --git a/zipline/core/monitor.py b/zipline/core/monitor.py index 0ea7d355..56061944 100644 --- a/zipline/core/monitor.py +++ b/zipline/core/monitor.py @@ -1,9 +1,12 @@ +import os import zmq +import sys import time import gevent import itertools import logbook import gevent_zeromq +from signal import SIGHUP, SIGINT from collections import OrderedDict @@ -11,14 +14,17 @@ from zipline.utils.gpoll import _Poller as GeventPoller from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \ CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \ +from zipline.utils.protocol_utils import ndict + INIT, SOURCES_READY, RUNNING, TERMINATE = CONTROL_STATES CONTROLLER_TRANSITIONS = frozenset([ (-1 , INIT), (INIT , SOURCES_READY), (SOURCES_READY , RUNNING), - (INIT , TERMINATE), - (SOURCES_READY , TERMINATE), + + (INIT , TERMINATE), # pseudo failure mode + (SOURCES_READY , TERMINATE), # pseudo failure mode (RUNNING , TERMINATE), ]) @@ -32,6 +38,18 @@ class UnknownChatter(Exception): log = logbook.Logger('Controller') +# The scalars determining the timing of the monitor behavior for +# the system. + +PARAMETERS = ndict(dict( + GENERATIONAL_PERIOD = 1, + ALLOWED_SKIPPED_HEARTBEATS = 3, + ALLOWED_INVALID_HEARTBEATS = 3, + PRESTART_HEARBEATS = 3, + SOURCES_START_HEARTBEATS = 3, + SYSTEM_TIMEOUT = 50, +)) + class Controller(object): """ A N to M messaging system for inter component communication. @@ -43,38 +61,25 @@ class Controller(object): the individual components. :func message_sender: . - Topology is the set of components we expect to show up. - States are the transitions the sytems go through. The - simplest is from RUNNING -> NOT RUNNING . - - Usage:: - - controller = Controller( - 'tcp://127.0.0.1:5000', - 'tcp://127.0.0.1:5001', - ) - - # typically you'd want to run this async to your main - # program since it blocks indefinetely. - controller.manage( - [ TOPOLOGY ] - [ STATES ] - ) - """ - debug = False - period = 1 + # Turn on debug for verbose logging of the system. + debug = True + period = PARAMETERS.GENERATIONAL_PERIOD - def __init__(self, pub_socket, route_socket): + def __init__(self, pub_socket, route_socket, devel=True): + self.devel = devel + self.nosignals = False self.context = None self.zmq = None self.zmq_poller = None self.running = False - self.polling = False + self.alive = False self.tracked = set() + self.finished = set() + self.responses = set() self.ctime = 0 @@ -89,6 +94,8 @@ class Controller(object): self.error_replay = OrderedDict() + log.warn("Running Controller in development mode, will ONLY synchronize start.") + def init_zmq(self, flavor): assert self.zmq_flavor in ['thread', 'mp', 'green'] @@ -97,6 +104,8 @@ class Controller(object): self.zmq = zmq self.context = self.zmq.Context() self.zmq_poller = self.zmq.Poller + + log.warning("USING DEVELOPMENT MODE IN MP CONTEXT NOT RECOMMENDED") return if flavor == 'thread': self.zmq = zmq @@ -114,7 +123,7 @@ class Controller(object): self.zmq_poller = self.zmq.Poller return - def manage(self, topology, states=None, context=None): + def manage(self, topology): """ Give the controller a set set of components to manage and a set of state transitions for the entire system. @@ -128,31 +137,43 @@ class Controller(object): else: self.freeform = False self.topology = frozenset(topology) - - self.polling = True - self.state = CONTROL_STATES.INIT + self.alive = True @property def state(self): + #log.info('returned %s' % self._state) return self._state @state.setter def state(self, new): - old, self._state = self._state, new + old = self._state - if (old, new) not in CONTROLLER_TRANSITIONS: - raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) + if (old, new) in CONTROLLER_TRANSITIONS: + self._state = new + log.info("State Transition : %s -> %s" % (old, self._state)) else: - log.info("State Transition : %s -> %s" %(old, new)) + raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) def run(self): self.running = True self.init_zmq(self.zmq_flavor) + self.state = CONTROL_STATES.INIT + + # Interpreter SIDE EFFECT + # ----------------------- + # The last breathe of the interpreter will assume that we've + # failed unless we specify otherwise. + if not self.devel: + sys.exitfunc = self.signal_interrupt + # We overload this if ( and only if ) the topology exits + # cleanly. This prevents failure modes where the monitor + # dies. + try: return self._poll() # use a python loop except KeyboardInterrupt: - log.debug('Shutdown event loop') + log.info('Shutdown event loop') def log_status(self): """ @@ -172,6 +193,13 @@ class Controller(object): # Publications # ------------- + def send_go(self): + go_frame = CONTROL_FRAME( + CONTROL_PROTOCOL.GO, + '' + ) + self.pub.send(go_frame) + def send_heart(self): if not self.running: return @@ -210,45 +238,55 @@ class Controller(object): assert self.route_socket assert self.pub_socket - #assert self.cancel_socket + + assert self.topology,\ + """"Must define topology to monitor, call setup_controller() on + your Zipline. """ # -- Publish -- # ============= self.pub = self.context.socket(self.zmq.PUB) self.pub.bind(self.pub_socket) - - # -- Cancel -- - # ============= - #assert isinstance(self.cancel_socket,basestring), self.cancel_socket - #self.cancel = self.context.socket(self.zmq.REP) - #self.cancel.connect(self.cancel_socket) + self.pub.setsockopt(zmq.LINGER, 0) # -- Router -- # ============= self.router = self.context.socket(self.zmq.ROUTER) self.router.bind(self.route_socket) - + self.router.setsockopt(zmq.LINGER, 0) poller = self.zmq.Poller() poller.register(self.router, self.zmq.POLLIN) #poller.register(self.cancel, self.zmq.POLLIN) - self.associated += [self.pub, self.router]# self.cancel] + self.associated += [self.pub, self.router] # TODO: actually do this self.state = CONTROL_STATES.SOURCES_READY + self.state = CONTROL_STATES.RUNNING buffer = [] + # =================== + # Heartbeat Iteration + # =================== + for i in itertools.count(0): self.log_status() + + # Reset the responses for this cycle self.responses = set() + # broadcast the heartbeat packet self.ctime = time.time() self.send_heart() - while self.polling: - # Reset the responses for this cycle + # ============== + # Hearbeat Cycle + # ============== + + # Wait the responses + while self.alive: socks = dict(poller.poll(self.period)) tic = time.time() @@ -270,28 +308,90 @@ class Controller(object): log.error('Invalid frame', rawmessage) pass - #if socks.get(self.cancel) == self.zmq.POLLIN: - # self.logging.info('[Controller] Received Cancellation') - # rawmessage = self.cancel.recv() - # self.cancel.send('') - # self.shutdown(soft=True) - # break + # ================ + # Heartbeat Stats + # ================ - self.beat() + complete = self.beat() + + # ================ + # Topology Status + # ================ + + # Is the entire topology told us its DONE + done = len(self.finished) == len(self.topology) + + # Is the entire topology shown up to the party + complete = len(self.tracked) == len(self.topology) + + if complete: + self.send_go() + + # If we're running in development stop here + # because our responsibilites are over. The + # zipline will either run to completion or die, + # monitor doesn't care anymore because its all + # threads. + + if self.devel: + log.warn("Shutting down Controller because in devel mode") + #sys.exitfunc = lambda: None + self.shutdown(soft=True) + + log.info('Heartbeat (%s, %s)' % (done, complete)) + + # ================ + # Exit Strategies + # ================ if self.zmq_flavor == 'green': gevent.sleep(0) - if self.state is CONTROL_STATES.TERMINATE: + # Will also fall out of loop when done, if using + # non-freeform topology + if done: + log.info('Entire topology exited cleanly') + self.shutdown(soft=True) + + # Noop exit func + #sys.exitfunc = lambda: None + + # Send SIGHUP to buritto + self.signal_hangup() + + if not self.alive: break - if not self.polling: - break + def signal_hangup(self): + """ + A clean exit, inform the burrito ( and arbiter ) that + we're good. The topology exited cleanly and we can prove + it. + """ + if not self.nosignals: + ppid = os.getpid() + os.kill(ppid, SIGHUP) + else: + log.warning("Would SIGHUP here, but disabled") - # After loop exits - self.terminated = True + def signal_interrupt(self): + """ + Send a SIGINT in the error mode that the monitor's + interpreter exits. If the monitor dies the system is + considered a failure. + """ + if not self.nosignals: + ppid = os.getpid() + os.kill(ppid, SIGINT) + else: + log.warning("Would SIGINT here, but disabled") def beat(self): + """ + The tracking logic of the system. It's the "stethoscope" + that inspects to the heartbeats in a generation and + infers the state of the system from the responses. + """ # These the set overloaded operations # A & B ~ set.intersection @@ -303,24 +403,44 @@ class Controller(object): # send us back a response. # * new - Components we haven't heard from yet, but sent back the # right response. + # * finished - Components we were tracking but have now + # finished, when this set goes to zero this + # triggers the end of the topology. good = self.tracked & self.responses bad = self.tracked - good new = self.responses - good + missing = self.topology - self.tracked + for component in new: self.new(component) + if self.debug: + log.info('New component %r' % component) + for component in bad: self.fail(component) + if self.debug: + log.info('Bad component %r' % component) + + if self.debug: + for component in missing: + log.info('Missing component %r' % component) + + for component in self.tracked: + if component not in self.topology: + log.info('Uninvited component %r' % component) + # -------------- # Init Handlers # -------------- def new_source(self): - if self.state is CONTROL_STATES.RUNNING: - self.state = SOURCES_READY + #if self.state is CONTROL_STATES.RUNNING: + #self.state = SOURCES_READY + pass def new_universal(self): pass @@ -331,7 +451,11 @@ class Controller(object): if self.state is CONTROL_STATES.TERMINATE: return - log.info(' Now Tracking "%s" ' % component) + if component in self.finished: + #log.info("Got heartbeat from supposedly finished component") + return + + log.info('Now Tracking "%s" ' % component) universal = self.new_universal init_handlers = { @@ -342,7 +466,7 @@ class Controller(object): init_handlers.get(component, universal)() self.tracked.add(component) else: - # Some sort of socket collision has occured, this is + # Some sort of socket collision has occurred, this is # a very bad failure mode. raise UnknownChatter(component) @@ -353,8 +477,8 @@ class Controller(object): def fail_universal(self): pass # TODO: this requires higher order functionality - #log.error('System in exception state, shutting down') - #self.shutdown(soft=True) + log.error('System in exception state, shutting down') + self.shutdown(soft=True) def fail(self, component): if self.state is CONTROL_STATES.TERMINATE: @@ -364,16 +488,22 @@ class Controller(object): fail_handlers = { } if component in self.topology or self.freeform: - log.warning('Component "%s" timed out' % component) + log.warning('Component "%s" missed heartbeat' % component) self.tracked.remove(component) - fail_handlers.get(component, universal)() + + # TODO: determine when this propogates to a true + # failure, missing one heartbeat could just mean that + # its CPU overloaded + #fail_handlers.get(component, universal)() # ------------------- # Completion Handling # ------------------- def done(self, component): - log.info('Component "%s" signaled done.' % component) + self.finished.add(component) + self.tracked.discard(component) + log.info('Component "%s" finished.' % component) # -------------- # Error Handling @@ -393,6 +523,7 @@ class Controller(object): if component in self.topology or self.freeform: self.error_replay[(component, time.time())] = failure log.error('Component in exception state: %s' % component) + log.error(str(failure)) exception_handlers.get(component, universal)() else: @@ -406,30 +537,48 @@ class Controller(object): """ Check for proper framing at the transport layer. Seperates the proper frames from anything else that might - be coming over the wire. Which shouldn't happen ... right? + be coming over the wire. """ - identity = msg[0] + + identity = msg[0] # identity of the socket id, status = CONTROL_UNFRAME(msg[1]) - # A component is telling us its alive: + # I'm alive, condemned to be a free process in the cold + # cold dark absurd Zipline universe. + if id is CONTROL_PROTOCOL.READY: + self.responses.add(identity) + return + + # The heartbeat love song between a component and the + # controller if id is CONTROL_PROTOCOL.OK: if status == str(self.ctime): + # Go to your bosom; knock there, and ask your heart what + # it doth know... self.responses.add(identity) + elif status < self.ctime: + # False face must hide what the false heart doth know. + log.warning('Delayed heartbeat received.') else: # Otherwise its something weird and we don't know # what to do so just say so, probably line noise # from ZeroMQ - log.error("Weird stuff happened: %s" % msg) + + # What's in a name? that which we call a rose... + log.error("Weird heartbeat packet happened: %s" % msg) + return # A component is telling us it failed, and how if id is CONTROL_PROTOCOL.EXCEPTION: self.exception(identity, status) + return # A component is telling us its done with work and won't # be talking to us anymore if id is CONTROL_PROTOCOL.DONE: self.done(identity) + return # ------------------- # Hooks for Endpoints @@ -474,54 +623,27 @@ class Controller(object): def do_error_replay(self): for (component, time), error in self.error_replay.iteritems(): - log.info('Error Log for -- %s --:\n%s' % - (component, error)) + log.debug('Component Log for -- %s --:\n%s' % (component, error)) - def shutdown(self, hard=False, soft=True, context=None): + def shutdown(self, hard=False, soft=True): + + assert hard or soft, """ Must specify kill hard or soft """ if self.state is CONTROL_STATES.TERMINATE: return - if not self.polling: - return + self.alive = False - self.polling = False - - assert hard or soft, """ Must specify kill hard or soft """ - - if hard: + if hard and not self.devel: self.state = CONTROL_STATES.TERMINATE - log.info('Hard Shutdown') - #for asoc in self.associated: - #asoc.close() - - if soft: + if soft and not self.devel: self.state = CONTROL_STATES.TERMINATE - log.info('Soft Shutdown') self.send_softkill() - #for asoc in self.associated: - #asoc.close() + #self.do_error_replay() - self.do_error_replay() - -if __name__ == '__main__': - - print 'Running on '\ - 'tcp://127.0.0.1:5000 '\ - 'tcp://127.0.0.1:5001 ' - - controller = Controller( - 'tcp://127.0.0.1:5000', - 'tcp://127.0.0.1:5001', - ) - controller.zmq_flavor = 'green' - - controller.manage( - 'freeform', - [] - ) - controller.run() + #self.pub.close() + #self.router.close() diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index 7021c581..8e687c49 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -36,14 +36,14 @@ Risk Report """ -import logging +import logbook import datetime import math import numpy as np import numpy.linalg as la from zipline.utils.date_utils import epoch_now -LOGGER = logging.getLogger('ZiplineLogger') +log = logbook.Logger('Risk') def advance_by_months(dt, jump_in_months): month = dt.month + jump_in_months @@ -255,7 +255,7 @@ class RiskMetrics(): cur_return = math.log(1.0 + r) + cur_return #this is a guard for a single day returning -100% except ValueError: - LOGGER.warn("{cur} return, zeroing the returns".format(cur=cur_return)) + log.warn("{cur} return, zeroing the returns".format(cur=cur_return)) cur_return = 0.0 compounded_returns.append(cur_return) diff --git a/zipline/finance/sources.py b/zipline/finance/sources.py index 346e4673..0aba2186 100644 --- a/zipline/finance/sources.py +++ b/zipline/finance/sources.py @@ -21,13 +21,12 @@ import zipline.protocol as zp class TradeDataSource(DataSource): - def init(self, source_id): - self.source_id = source_id + def init(self): self.setup_source() - @property - def get_id(self): - return 'TradeDataSource' + #@property + #def get_id(self): + # return 'TradeDataSource' def send(self, event): """ @@ -37,14 +36,14 @@ class TradeDataSource(DataSource): :rtype: None """ - event.source_id = self.source_id + event.source_id = self.get_id if event.sid in self.filter['sid']: message = zp.DATASOURCE_FRAME(event) else: blank = ndict({ "type" : zp.DATASOURCE_TYPE.TRADE, - "source_id" : self.source_id + "source_id" : self.get_id }) message = zp.DATASOURCE_FRAME(blank) @@ -56,8 +55,7 @@ class RandomEquityTrades(TradeDataSource): Generates a random stream of trades for testing. """ - def init(self, sid, source_id, count): - self.source_id = source_id + def init(self, sid, count): self.count = count self.incr = 0 self.sid = sid @@ -95,7 +93,7 @@ class SpecificEquityTrades(TradeDataSource): Generates a random stream of trades for testing. """ - def init(self, source_id, event_list): + def init(self, event_list): """ :param event_list: should be a chronologically ordered list of dictionaries in the following form:: @@ -107,7 +105,6 @@ class SpecificEquityTrades(TradeDataSource): 'volume' : integer for volume } """ - self.source_id = source_id self.event_list = event_list self.count = 0 diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index b2aacdc1..bf3a5374 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -1,12 +1,12 @@ import pytz import math -import logging +import logbook import datetime import zipline.protocol as zp from zipline.protocol import SIMULATION_STYLE -LOGGER = logging.getLogger('ZiplineLogger') +log = logbook.Logger('Transaction Simulator') class TransactionSimulator(object): @@ -39,7 +39,7 @@ class TransactionSimulator(object): log = "requested to trade zero shares of {sid}".format( sid=event.sid ) - LOGGER.debug(log) + log.debug(log) return if not self.open_orders.has_key(event.sid): diff --git a/zipline/lines.py b/zipline/lines.py index 768b1ef9..23eb8690 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -60,6 +60,8 @@ before invoking simulate. +---------------------------------+ """ +import inspect +import logbook import zipline.utils.factory as factory from zipline.components import DataSource @@ -71,6 +73,8 @@ from zipline.core.devsimulator import Simulator from zipline.core.monitor import Controller from zipline.finance.trading import SIMULATION_STYLE +log = logbook.Logger('Lines') + class SimulatedTrading(object): """ Zipline with:: @@ -113,6 +117,8 @@ class SimulatedTrading(object): self.trading_environment = config['trading_environment'] self.sim_style = config.get('simulation_style') + self.devel = config.get('devel', False) + self.leased_sockets = [] self.sim_context = None @@ -123,14 +129,15 @@ class SimulatedTrading(object): 'feed_address' : sockets[2], 'merge_address' : sockets[3], 'result_address' : sockets[4], - 'order_address' : sockets[5] + 'order_address' : sockets[5] } self.con = Controller( sockets[6], sockets[7], + devel = self.devel ) - + # TODO: Not freeform self.con.manage( 'freeform' @@ -167,7 +174,7 @@ class SimulatedTrading(object): def create_test_zipline(**config): """ :param config: A configuration object that is a dict with: - + - environment - a \ :py:class:`zipline.finance.trading.TradingEnvironment` - allocator - a :py:class:`zipline.simulator.AddressAllocator` @@ -190,7 +197,7 @@ class SimulatedTrading(object): a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading` """ assert isinstance(config, dict) - + allocator = config['allocator'] sid = config['sid'] @@ -266,12 +273,20 @@ class SimulatedTrading(object): 'allocator' : allocator, 'simulator_class' : simulator_class, 'simulation_style' : simulation_style, - 'log_socket' : log_socket + 'log_socket' : log_socket, + 'devel' : config.get('devel', False) }) #------------------- zipline.add_source(trade_source) + # Save us from needless debugging + inside_test = 'nose' in inspect.stack()[-1][1] + if inside_test and not config.get('devel', False): + assert False, """ + You need to run the SimulatedTrading inside a test with devel=True + """ + return zipline def add_source(self, source): @@ -285,7 +300,7 @@ class SimulatedTrading(object): self.sim.register_components([source]) # ``id`` is name of source_id, ``get_id`` is the class name - self.sources[source.source_id] = source + self.sources[source.get_id] = source def add_transform(self, transform): assert isinstance(transform, BaseTransform) @@ -324,12 +339,76 @@ class SimulatedTrading(object): return leased - def simulate(self, blocking=False): + @property + def components(self): + """ + Return the component instances inside of this topology + """ + + base = set(self.sim.components.values()) + transforms = set(self.transforms.values()) + sources = set(self.sources.values()) + + return base | transforms | sources + + @property + def topology(self): + """ + Returns the Component names in the topology of the + backtest. + """ + + # A complete topology is the union of three classes of + # components added individually to the simulation client + # at various places. + # + # base : ['FEED', 'MERGE', 'TRADING_CLIENT', 'PASSTHROUGH'] + # transforms : ['vwap__01', ... ] + # sources : ['MongoTradeHistory', ... ] + + base = set(self.sim.components.keys()) + transforms = set(self.transforms.keys()) + sources = set(self.sources.keys()) + + return base | transforms | sources + + def setup_controller(self): + """ + Prepare the controller tro manage the topology specified + by this line. + """ + self.con.manage(self.topology) + + def simulate(self, blocking=True): + self.setup_controller() + self.started = True self.sim_context = self.sim.simulate() - if blocking: - self.sim_context.join() + # If we're in development mode then flag all the + # components in the topology as devel so as to indicate + # that they won't poll on the control channels for + # anything other than the synchronized start. + if self.devel: + for component in self.components: + component.devel = True + + # If we're using a threaded simulator block on the pool + # of thread since we're only ever in a test and we don't + # generally monitor the state of the system as a hold at + # the supervisory layer + + # TODO: better way of identifying concurrency substrate + if self.sim.zmq_flavor == 'thread': + log.debug('Blocking') + for thread in self.sim.subthreads: + #log.debug('Waiting on %r' % thread) + log.debug('Waiting on %r' % thread) + thread.join() + log.debug('Yielded on %r' % thread) + else: + for process in self.sim.subprocesses: + process.join() @property def is_success(self): diff --git a/zipline/optimize/factory.py b/zipline/optimize/factory.py index d9776a9c..153f1d57 100644 --- a/zipline/optimize/factory.py +++ b/zipline/optimize/factory.py @@ -65,7 +65,7 @@ def create_updown_trade_source(sid, trade_count, trading_environment, base_price trading_environment.period_end = cur - source = SpecificEquityTrades("updown_" + str(sid), events) + source = SpecificEquityTrades(events) return source @@ -128,6 +128,7 @@ def create_predictable_zipline(config, offset=0, simulate=True): config['trade_source'] = source config['environment'] = trading_environment config['simulation_style'] = SIMULATION_STYLE.FIXED_SLIPPAGE + config['devel'] = True zipline = SimulatedTrading.create_test_zipline(**config) diff --git a/zipline/protocol.py b/zipline/protocol.py index 20d35f1c..35cef68b 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -141,11 +141,12 @@ CONTROL_PROTOCOL = Enum( 'HEARTBEAT' , # 0 - req 'SHUTDOWN' , # 1 - req 'KILL' , # 2 - req + 'GO' , # - req 'OK' , # 3 - rep 'DONE' , # 4 - rep 'EXCEPTION' , # 5 - rep - 'SIGNAL' , # 6 - rep + 'READY' , # 6 - rep ) def CONTROL_FRAME(event, payload): @@ -305,7 +306,7 @@ def FEED_FRAME(event): - source_id - type """ - assert isinstance(event, ndict) + assert isinstance(event, ndict), 'unknown type %s' % str(event) source_id = event.source_id ds_type = event.type PACK_DATE(event) diff --git a/zipline/transforms/base.py b/zipline/transforms/base.py index 2c8ab856..6e406efc 100644 --- a/zipline/transforms/base.py +++ b/zipline/transforms/base.py @@ -1,12 +1,9 @@ -import logging from zipline.core.component import Component import zipline.protocol as zp from zipline.protocol import CONTROL_PROTOCOL, COMPONENT_TYPE, \ CONTROL_FRAME, CONTROL_UNFRAME -LOGGER = logging.getLogger('ZiplineLogger') - class BaseTransform(Component): """ Top level execution entry point for the transform diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index b979849f..6101c3b2 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -171,7 +171,7 @@ def create_returns_from_list(returns, trading_calendar): def create_random_trade_source(sid, trade_count, trading_environment): # create the source - source = RandomEquityTrades(sid, "rand-"+str(sid), trade_count) + source = RandomEquityTrades(sid, trade_count) # make the period_end of trading_environment match cur = trading_environment.first_open @@ -244,5 +244,5 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ #history. trading_environment.period_end = trade_history[-1].dt - source = SpecificEquityTrades("flat", trade_history) + source = SpecificEquityTrades(trade_history) return source