diff --git a/tests/test_components.py b/tests/test_components.py index 02b05a69..a9b765d0 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -4,17 +4,14 @@ from datetime import datetime, timedelta from unittest2 import TestCase from collections import defaultdict +from zipline.gens.composite import date_sorted_sources from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator -from zipline.lines import SimulatedTrading from zipline.utils.test_utils import ( - drain_zipline, - check, setup_logger, teardown_logger, - launch_component, create_monitor, launch_monitor ) @@ -28,7 +25,7 @@ from zipline.protocol import ( ) from zipline.gens.tradegens import SpecificEquityTrades -from zipline.gens.utils import hash_args +from zipline.gens.sort import date_sort from zipline.gens.zmqgen import gen_from_poller import logbook @@ -53,10 +50,14 @@ class ComponentTestCase(TestCase): setup_logger(self) def tearDown(self): - self.ctx.term() + #self.ctx.term() teardown_logger(self) - def test_specific_equity_source(self): + def test_source(self): + monitor = create_monitor(allocator) + socket_uri = allocator.lease(1)[0] + count = 100 + filter = [1,2,3,4] #Set up source a. One minute between events. args_a = tuple() @@ -65,42 +66,113 @@ class ComponentTestCase(TestCase): 'start' : datetime(2012,6,6,0,tzinfo=pytz.utc), 'delta' : timedelta(minutes = 1), 'filter' : filter, - 'count' : 100 + 'count' : count } - c_id = SpecificEquityTrades.__name__ + hash_args(args_a, kwargs_a) - mon = create_monitor(allocator) - - out_socket_args = ComponentSocketArgs( - style=zmq.PUSH, - uri=allocator.lease(1)[0], - bind=True + comp_a = Component( + SpecificEquityTrades, + args_a, + kwargs_a, + monitor, + socket_uri, + DATASOURCE_FRAME, + DATASOURCE_UNFRAME ) - c = Component( - SpecificEquityTrades, - args_a, - kwargs_a, - c_id, - out_socket_args, - DATASOURCE_FRAME, - mon - ) + launch_monitor(monitor) - mon.manage(set([c.get_id])) - mon_proc = launch_monitor(mon) + for event in comp_a: + log.info(event) - # launch in a process - proc = launch_component(c) - pull_socket = self.ctx.socket(zmq.PULL) - pull_socket.connect(out_socket_args.uri) - poller = zmq.Poller() - poller.register(pull_socket, zmq.POLLIN) - unframe = DATASOURCE_UNFRAME - for msg in gen_from_poller(poller, pull_socket, unframe): - # assert things about the messages. - log.info(msg) + def test_sort(self): + monitor = create_monitor(allocator) + poller = zmq.Poller() + socket_uris = allocator.lease(3) + count = 100 - pull_socket.close() - log.info("DONE!") + filter = [1,2,3,4] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0,tzinfo=pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + + + comp_a = Component( + SpecificEquityTrades, + args_a, + kwargs_a, + monitor, + socket_uris[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME + ) + + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'sids' : [2], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + + + comp_b = Component( + SpecificEquityTrades, + args_b, + kwargs_b, + monitor, + socket_uris[1], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME + ) + + #Set up source c. Three minutes between events. + args_c = tuple() + kwargs_c = { + 'sids' : [3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 1), + 'filter' : filter, + 'count' : count + } + + comp_c = Component( + SpecificEquityTrades, + args_c, + kwargs_c, + monitor, + socket_uris[2], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME + ) + + names = [ + comp_a.get_id, + comp_b.get_id, + comp_c.get_id + ] + + monitor.manage(set(names)) + launch_monitor(monitor) + + sorted_out = date_sorted_sources([comp_a, comp_b, comp_c]) + + prev = None + sort_count = 0 + for msg in sorted_out: + if prev: + self.assertTrue(msg.dt >= prev.dt, \ + "Messages should be in date ascending order") + prev = msg + sort_count += 1 + + self.assertEqual(count*3, sort_count) diff --git a/zipline/core/component.py b/zipline/core/component.py index f9dc63c1..287373b4 100644 --- a/zipline/core/component.py +++ b/zipline/core/component.py @@ -10,8 +10,11 @@ import socket import logbook import traceback import humanhash +import multiprocessing from setproctitle import setproctitle from collections import namedtuple +from zipline.gens.utils import hash_args + # pyzmq import zmq @@ -36,7 +39,7 @@ class KillSignal(Exception): def __init__(self): pass -ComponentSocketArgs = namedtuple('ComponentSocket',['uri','style','bind']) +ComponentSocketArgs = namedtuple('ComponentSocketArgs',['uri','style','bind']) class Component(object): @@ -49,33 +52,27 @@ class Component(object): gen_args, gen_kwargs, component_id, - out_socket_args, - frame, monitor, - in_socket_args=None, - unframe=None + socket_uri, + frame, + unframe ): assert component_id, \ "Every component needs a unique and invariant identifier" assert isinstance(component_id, basestring), \ "Components must have string IDs" - assert isinstance(out_socket_args, ComponentSocketArgs), \ - "out_socket_args args must be ComponentSocketArgs" - - if in_socket_args: - assert isinstance(in_socket_args, ComponentSocketArgs), \ - "in_socket_args args must be ComponentSocketArgs" # ----------------- # Generator # ----------------- - self.component_id = component_id self.gen_args = gen_args self.gen_kwargs = gen_kwargs self.gen_func = gen_func self.generator = None self.frame = frame + self.component_id = self.gen_func.__name__ \ + + hash_args(gen_args, gen_kwargs) # lock for waiting on monitor "GO" self.waiting = None @@ -83,14 +80,27 @@ class Component(object): # ----------------- # ZMQ properties # ----------------- - self.in_socket_args = in_socket_args - self.out_socket_args = out_socket_args + self.in_socket_args = ComponentSocketArgs( + uri = socket_uri, + style = zmq.PULL, + bind = False + ) + self.out_socket_args = ComponentSocketArgs( + uri = socket_uri, + style = zmq.PUSH, + bind = True + ) self.zmq = None self.context = None self.out_socket = None self.in_socket = None - self.monitor = monitor + self.monitor = monitor self.unframe = unframe + self.prefix = "" + + # register two components with the monitor + monitor.add_to_topology(self.component_id) + monitor.add_to_topology("FORK-"+self.component_id) # TODO: state_flag is deprecated, remove self.state_flag = COMPONENT_STATE.OK @@ -109,7 +119,7 @@ class Component(object): # ------------ - def _run(self): + def _run_out(self): """ The main component loop. This is wrapped inside a exception reporting context inside of run. @@ -118,13 +128,12 @@ class Component(object): """ # The process title so you can watch it in top, ps. setproctitle(self.gen_func.__name__) + self.prefix = "FORK-" log.info("Start %r" % self) log.info("Pid %s" % os.getpid()) log.info("Group %s" % os.getpgrp()) - self.sockets = [] - self.open() self.signal_ready() @@ -138,17 +147,36 @@ class Component(object): for event in self.generator: self.heartbeat() + event.source_id = self.get_id msg = self.frame(event) self.out_socket.send(msg) self.signal_done() - def run(self, catch_exceptions=True): + def _run_in(self): + self.open(send=False) + self.signal_ready() + self.lock_ready() + self.wait_ready() + # ----------------------- + # YOU SHALL NOT PASS!!!!! + # ----------------------- + # ... until the monitor signals GO + + # return the generator + for event in gen_from_poller(self.poll, self.in_socket, self.unframe): + event.source_id = self.get_id + yield event + + self.signal_done() + + def run_safe(self, func): """ - Run the component. + Run a function that is assumed to include wait_ready and + heartbeat. Used to wrap fork_generator and consume_gen. """ try: - self._run() + return func() except Exception as exc: if not isinstance(exc, KillSignal): self.signal_exception(exc) @@ -160,6 +188,23 @@ class Component(object): log.info("Exiting %r" % self) + def _launch(self): + # first, start the generator in its own process. Once + # Monitor says "go", Events from the generator will be + # FRAME'd and PUSH'd to self.socket_uri. + proc = multiprocessing.Process( + target=self.run_safe, + args=(self._run_out,) + ) + proc.start() + + # Start the poller-generator, which will PULL messages + # from self.sockiet_uri, UNFRAME'd them, and yield them. + return self.run_safe(self._run_in) + + def __iter__(self): + return self._launch() + # ---------------------------- # Cleanup & Modes of Failure # ---------------------------- @@ -420,8 +465,9 @@ class Component(object): # notify internal work loop that we're done self.done = True # TODO: use state flag - msg = zmq.Message(str(CONTROL_PROTOCOL.DONE)) - self.out_socket.send(msg) + if self.out_socket: + msg = zmq.Message(str(CONTROL_PROTOCOL.DONE)) + self.out_socket.send(msg) # notify monitor we're done @@ -437,40 +483,32 @@ class Component(object): # after the Monitor accepts our prior heartbeat, but just # before the next one is sent. So, we hang around for one # last heartbeat, and wait an unusually long time. - self.heartbeat(timeout=5000) + # TODO: decided if this is really necessary. + # self.heartbeat(timeout=5000) # ----------- # Messaging # ----------- - def open(self): + def open(self, send=True): """ Open the connections needed to start doing work. Perform any setup that must be done within process. """ - + self.sockets = [] self.zmq = zmq self.context = self.zmq.Context() self.poll = self.zmq.Poller() self.setup_control() - if self.in_socket_args: - self.in_socket = self.open_socket(self.in_socket_args) - poller_gen = gen_from_poller( - self.poller, - self.in_socket, - self.unframe - ) - self.generator = self.gen_func( - poller_gen, - *self.gen_args, - **self.gen_kwargs - ) - else: + if send: self.generator = self.gen_func(*self.gen_args, **self.gen_kwargs) - - self.out_socket = self.open_socket(self.out_socket_args) + self.out_socket = self.open_socket(self.out_socket_args) + self.sockets.extend([self.out_socket]) + else: + self.in_socket = self.open_socket(self.in_socket_args) + self.sockets.extend([self.in_socket]) def open_socket(self, sock_args): if sock_args.bind: @@ -577,7 +615,7 @@ class Component(object): The time invariant name for this component. Must be unique within this zipline. """ - return self.component_id + return self.prefix + self.component_id def debug(self): """ diff --git a/zipline/core/monitor.py b/zipline/core/monitor.py index 183de45b..83b022c5 100644 --- a/zipline/core/monitor.py +++ b/zipline/core/monitor.py @@ -105,6 +105,9 @@ class Monitor(object): self.missed_beats = Counter() + # start with an empty topology + self.topology = set([]) + self.send_sighup = send_sighup if self.send_sighup: log.info("Request to send sighup/sigint") @@ -116,6 +119,17 @@ class Monitor(object): self.zmq_poller = self.zmq.Poller return + def add_to_topology(self, component_id): + add = set([component_id]) + self.topology.update(add) + + def freeze_topology(self): + if isinstance(self.topology, frozenset): + return + # we've been incrementally adding components. + # time to freeze. + self.manage(self.topology) + def manage(self, topology): """ Give the controller a set set of components to manage and @@ -147,6 +161,7 @@ class Monitor(object): raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new)) def run(self): + self.freeze_topology() self.running = True self.init_zmq() setproctitle('Monitor') diff --git a/zipline/gens/examples.py b/zipline/gens/examples.py index afc8ebe8..93f77d47 100644 --- a/zipline/gens/examples.py +++ b/zipline/gens/examples.py @@ -1,7 +1,4 @@ import pytz -from time import sleep - -from pprint import pprint as pp from datetime import datetime, timedelta from zipline.utils.factory import create_trading_environment @@ -26,7 +23,7 @@ if __name__ == "__main__": 'delta' : timedelta(minutes = 1), 'filter' : filter } - source_a = SpecificEquityTrades(*args_a, **kwargs_a) + bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a) #Set up source b. Two minutes between events. args_b = tuple() @@ -36,9 +33,10 @@ if __name__ == "__main__": 'delta' : timedelta(minutes = 1), 'filter' : filter } - source_b = SpecificEquityTrades(*args_a, **kwargs_a) - + bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b) + #Set up source c. Three minutes between events. + sort_out = date_sorted_sources(source_a, source_b) # passthrough = TransformBundle(Passthrough, (), {}) @@ -58,6 +56,4 @@ if __name__ == "__main__": # for message in client_out: # pp(message) - - - + diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 2c49ea19..a8a691b5 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -7,7 +7,7 @@ from itertools import chain, cycle, ifilter, izip from datetime import datetime, timedelta from zipline.utils.factory import create_trade -from zipline.gens.utils import hash_args, mock_done +from zipline.gens.utils import hash_args def date_gen(start = datetime(2006, 6, 6, 12), delta = timedelta(minutes = 1), @@ -54,9 +54,9 @@ class SpecificEquityTrades(object): Yields all events in event_list that match the given sid_filter. If no event_list is specified, generates an internal stream of events to filter. Returns all events if filter is None. - + Configuration options: - + count : integer representing number of trades sids : list of values representing simulated internal sids start : start date @@ -67,22 +67,22 @@ class SpecificEquityTrades(object): def __init__(self, *args, **kwargs): # We shouldn't get any positional arguments. assert len(args) == 0 - + # Unpack config dictionary with default values. self.count = kwargs.get('count', 500) self.sids = kwargs.get('sids', [1, 2]) self.start = kwargs.get('start', datetime(2012, 6, 6, 0)) self.delta = kwargs.get('delta', timedelta(minutes = 1)) - + # Default to None for event_list and filter. self.event_list = kwargs.get('event_list') self.filter = kwargs.get('filter') - + # Hash_value for downstream sorting. self.arg_string = hash_args(*args, **kwargs) - + self.generator = self.create_fresh_generator() - + def __iter__(self): return self.generator @@ -94,22 +94,22 @@ class SpecificEquityTrades(object): def get_hash(self): return self.__class__.__name__ + "-" + self.arg_string - + def create_fresh_generator(self): - + if self.event_list: unfiltered = (event for event in self.event_list) # Set up iterators for each expected field. else: - dates = date_gen(count=self.count, - start=self.start, + dates = date_gen(count=self.count, + start=self.start, delta=self.delta ) prices = mock_prices(self.count) volumes = mock_volumes(self.count) sids = cycle(self.sids) - + # Combine the iterators into a single iterator of arguments arg_gen = izip(sids, prices, volumes, dates) @@ -137,7 +137,7 @@ def RandomEquityTrades(object): def __init__(self): # We shouldn't get any positional args. assert args == () - + self.count = config.get('count', 500) self.sids = config.get('sids', [1,2]) self.filter = config.get('filter') diff --git a/zipline/gens/zmqgen.py b/zipline/gens/zmqgen.py index e51e3bab..f9d5f919 100644 --- a/zipline/gens/zmqgen.py +++ b/zipline/gens/zmqgen.py @@ -13,6 +13,11 @@ def gen_from_pull_socket(socket_uri, context, unframe): return gen_from_poller(poller, pull_socket, unframe) + +# this generator needs to know about the source_ids coming in via +# the poller, and need to yield DONE messages for each +# source_id. + def gen_from_poller(poller, in_socket, unframe): while True: