From 4a582e8952efeef6ed6dfb18517bdb619102c522 Mon Sep 17 00:00:00 2001 From: fawce Date: Sat, 4 Aug 2012 12:58:07 -0400 Subject: [PATCH] modified zmq_gen method to yield None when there is no waiting message. This prevents blocking in the next() method of a component. But it requires generators wrapping the component to handle None. Also modified component's receiver creation to be triggered on the first call to next, rather than iter. This change means that the zmq context and socket for the component's receiver should always be created in the same process as the consumer of the generator. Chaining together component wrapped generators will result in the send process of the last component actually instantiating the receive socket of the prior component. In this way, the components are actually communicating directly via zmq. Component's send method now calls the wait_ready(), which waits for the monitor's GO message, inside the generator loop. This guarantees that the generator's next method is called before the send loop blocks on the monitor. As a result, components will call __init__ and next() without blocking, mimicking the behavior of plain generators. --- tests/test_components.py | 186 ++++++++++++++++++++++++++++++++++++-- zipline/core/component.py | 55 ++++++----- zipline/gens/merge.py | 10 +- zipline/gens/transform.py | 14 ++- zipline/gens/utils.py | 4 +- 5 files changed, 228 insertions(+), 41 deletions(-) diff --git a/tests/test_components.py b/tests/test_components.py index 896f4afc..54049723 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -1,13 +1,20 @@ import zmq import pytz +from pprint import pformat as pf from datetime import datetime, timedelta from unittest2 import TestCase from collections import defaultdict -from zipline.gens.composites import date_sorted_sources +from zipline.gens.composites import date_sorted_sources, merged_transforms from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator +from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform +from zipline.gens.tradesimulation import TradeSimulationClient as tsc + +from zipline.utils.factory import create_trading_environment +from zipline.test_algorithms import TestAlgorithm + from zipline.utils.test_utils import ( setup_logger, @@ -19,7 +26,12 @@ from zipline.utils.test_utils import ( from zipline.core import Component from zipline.protocol import ( DATASOURCE_FRAME, - DATASOURCE_UNFRAME + DATASOURCE_UNFRAME, + FEED_FRAME, + FEED_UNFRAME, + MERGE_FRAME, + MERGE_UNFRAME, + SIMULATION_STYLE ) from zipline.gens.tradegens import SpecificEquityTrades @@ -65,9 +77,7 @@ class ComponentTestCase(TestCase): } trade_gen = SpecificEquityTrades(*args_a, **kwargs_a) - monitor.add_to_topology(trade_gen.get_hash()) - launch_monitor(monitor) comp_a = Component( trade_gen, @@ -77,9 +87,14 @@ class ComponentTestCase(TestCase): DATASOURCE_UNFRAME ) + launch_monitor(monitor) + for event in comp_a: log.info(event) + # wait for the sending process to exit + comp_a.proc.join() + def test_sort(self): monitor = create_monitor(allocator) @@ -97,7 +112,6 @@ class ComponentTestCase(TestCase): 'count' : count } trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a) - monitor.add_to_topology(trade_gen_a.get_hash()) #Set up source b. Two minutes between events. args_b = tuple() @@ -109,7 +123,6 @@ class ComponentTestCase(TestCase): 'count' : count } trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b) - monitor.add_to_topology(trade_gen_b.get_hash()) #Set up source c. Three minutes between events. args_c = tuple() @@ -122,9 +135,7 @@ class ComponentTestCase(TestCase): } trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c) - monitor.add_to_topology(trade_gen_c.get_hash()) - launch_monitor(monitor) comp_a = Component( trade_gen_a, @@ -154,6 +165,8 @@ class ComponentTestCase(TestCase): sorted_out = date_sorted_sources(*sources) + launch_monitor(monitor) + prev = None sort_count = 0 for msg in sorted_out: @@ -164,3 +177,160 @@ class ComponentTestCase(TestCase): sort_count += 1 self.assertEqual(count*3, sort_count) + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + comp_c.proc.join() + + + def test_full(self): + monitor = create_monitor(allocator) + + filter = [2,3] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'count' : 325, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(hours = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : 7500, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + # ------------------------ + # Run sources in dedicated processes + comp_a = Component( + source_a, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + source_a.get_hash() + ) + + comp_b = Component( + source_b, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + source_b.get_hash() + ) + + # Date sort the sources, and run the sort in a dedicated + # process + sources = [comp_a, comp_b] + + sorted_out = date_sorted_sources(*sources) + + #launch_monitor(monitor) + #import nose.tools; nose.tools.set_trace() + #for feed_msg in sorted_out: + # log.info(pf(feed_msg)) + + #return + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME, + "sort" + ) + + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) + + merged_gen = merged_transforms(sorted, passthrough, mavg_price) + + merged = Component( + merged_gen, + monitor, + allocator.lease(1)[0], + MERGE_FRAME, + MERGE_UNFRAME, + "merge" + ) + + algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) + environment = create_trading_environment(year = 2012) + style = SIMULATION_STYLE.FIXED_SLIPPAGE + + trading_client = tsc(algo, environment, style) + + launch_monitor(monitor) + for message in trading_client.simulate(merged): + log.info(pf(message)) + + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + sorted.proc.join() + merged.proc.join() + return + + + + def test_compound(self): + monitor = create_monitor(allocator) + + filter = [2,3] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'count' : 325, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(hours = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : 7500, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + sorted_out = date_sorted_sources(source_a, source_b) + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME + ) + + launch_monitor(monitor) + + for event in sorted: + log.info(event) + + + sorted.proc.join() diff --git a/zipline/core/component.py b/zipline/core/component.py index 313e76e5..c9ce56e6 100644 --- a/zipline/core/component.py +++ b/zipline/core/component.py @@ -51,7 +51,8 @@ class Component(object): monitor, socket_uri, frame, - unframe + unframe, + component_id ): # ----------------- @@ -59,7 +60,7 @@ class Component(object): # ----------------- self.generator = generator self.frame = frame - self.component_id = self.generator.get_hash() + self.component_id = component_id # lock for waiting on monitor "GO" self.waiting = None @@ -99,15 +100,16 @@ class Component(object): # first, start the generator in its own process. Once # Monitor says "go", Events from the generator will be # FRAME'd and PUSH'd to self.socket_uri. - proc = multiprocessing.Process( - target=self.loop_send - ) - proc.start() + monitor.add_to_topology(self.component_id) - # ------------ - # Message Receiver/Generator - # ------------ - self.recv_gen = self.create_recv_gen() + self.proc = multiprocessing.Process( + target=self.loop_send + ) + self.proc.start() + + # Placeholder for receive generator, which will be + # created in __iter__ + self.recv_gen = None # ------------ @@ -123,8 +125,8 @@ class Component(object): """ try: # The process title so you can watch it in top, ps. - setproctitle(self.generator.__class__.__name__) self.prefix = "FORK-" + setproctitle(self.get_id) log.info("Start %r" % self) log.info("Pid %s" % os.getpid()) @@ -134,14 +136,15 @@ class Component(object): self.signal_ready() self.lock_ready() - self.wait_ready() - - # ----------------------- - # YOU SHALL NOT PASS!!!!! - # ----------------------- - # ... until the monitor signals GO + msg = None for event in self.generator: + + if hasattr(event, 'dt') and event.dt == 'DONE': + continue + + self.wait_ready() + self.heartbeat() msg = self.frame(event) self.out_socket.send(msg) @@ -163,9 +166,6 @@ class Component(object): def create_recv_gen(self): try: - self.open(send=False) - self.signal_ready() - self.lock_ready() # return the generator return self.loop_recv() except Exception as exc: @@ -175,8 +175,12 @@ class Component(object): def loop_recv(self): try: + self.open(send=False) + self.signal_ready() + self.lock_ready() + # we block on ready here until monitor sends the GO - self.wait_ready() + # self.wait_ready() for event in self.gen_from_poller(self.poll, self.in_socket, self.unframe): yield event @@ -189,7 +193,10 @@ class Component(object): def gen_from_poller(self, poller, in_socket, unframe): while True: - socks = dict(poller.poll(0)) + # Since we will yield None to avoid blocking, we need + # to have a small delay to give the poller a chance + # to receive a message from upstream. + socks = dict(poller.poll(100)) self.heartbeat() if socks.get(in_socket) == zmq.POLLIN: message = in_socket.recv() @@ -198,6 +205,8 @@ class Component(object): else: event = unframe(message) yield event + else: + yield def handle_exception(self, exc, re_raise=False): if isinstance(exc, KillSignal): @@ -215,6 +224,8 @@ class Component(object): return self def next(self): + if not self.recv_gen: + self.recv_gen = self.create_recv_gen() return self.recv_gen.next() # ---------------------------- diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 32035492..f6434918 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -17,9 +17,9 @@ def merge(stream_in, tnfm_ids): and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. """ - + assert isinstance(tnfm_ids, list) - + # Set up an internal queue for each expected source. tnfms = {} for id in tnfm_ids: @@ -36,7 +36,7 @@ def merge(stream_in, tnfm_ids): id = message.tnfm_id assert id in tnfm_ids, \ "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) - + tnfms[id].append(message) # Only pop messages when we have a pending message from @@ -58,13 +58,13 @@ def merge_one(sources): event_fields = ndict() for key, queue in sources.iteritems(): - + # Add transform value to the transforms dict. message = queue.popleft() event_fields[message.tnfm_id] = message.tnfm_value del message['tnfm_id'] del message['tnfm_value'] - + # Merge any remaining fields into the event dict. event_fields.merge(message) return event_fields diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 564284b5..2883733a 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -53,16 +53,16 @@ class StatefulTransform(object): "Stateful transform requires a class." assert tnfm_class.__dict__.has_key('update'), \ "Stateful transform requires the class to have an update method" - + self.forward_all = tnfm_class.__dict__.get('FORWARDER', False) self.update_in_place = tnfm_class.__dict__.get('UPDATER', False) # You can't be both a forwarded and an updater. assert not all([self.forward_all, self.update_in_place]) - + # Create an instance of our transform class. self.state = tnfm_class(*args, **kwargs) - + # Create the string associated with this generator's output. self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) @@ -76,7 +76,11 @@ class StatefulTransform(object): # IMPORTANT: Messages may contain pointers that are shared with # other streams, so we only manipulate copies. for message in stream_in: - + # allow upstream generators to yield None to avoid + # blocking. + if message == None: + continue + assert_sort_unframe_protocol(message) message_copy = deepcopy(message) @@ -90,7 +94,7 @@ class StatefulTransform(object): out_message.tnfm_id = self.namestring out_message.tnfm_value = tnfm_value yield out_message - + # Our expectation is that the transform simply updated the # message it was passed. Useful for chaining together # multiple transforms, e.g. TransactionSimulator/PerformanceTracker. diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index f979e063..071ce5fc 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -49,7 +49,9 @@ def roundrobin(sources, namestrings): for namestring, source in mapping.iteritems(): try: message = source.next() - yield message + # allow sources to yield None to avoid blocking. + if message: + yield message except StopIteration: yield done_message(namestring) del mapping[namestring]