diff --git a/tests/test_components.py b/tests/test_components.py index 896f4afc..54049723 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -1,13 +1,20 @@ import zmq import pytz +from pprint import pformat as pf from datetime import datetime, timedelta from unittest2 import TestCase from collections import defaultdict -from zipline.gens.composites import date_sorted_sources +from zipline.gens.composites import date_sorted_sources, merged_transforms from zipline.finance.trading import SIMULATION_STYLE from zipline.core.devsimulator import AddressAllocator +from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform +from zipline.gens.tradesimulation import TradeSimulationClient as tsc + +from zipline.utils.factory import create_trading_environment +from zipline.test_algorithms import TestAlgorithm + from zipline.utils.test_utils import ( setup_logger, @@ -19,7 +26,12 @@ from zipline.utils.test_utils import ( from zipline.core import Component from zipline.protocol import ( DATASOURCE_FRAME, - DATASOURCE_UNFRAME + DATASOURCE_UNFRAME, + FEED_FRAME, + FEED_UNFRAME, + MERGE_FRAME, + MERGE_UNFRAME, + SIMULATION_STYLE ) from zipline.gens.tradegens import SpecificEquityTrades @@ -65,9 +77,7 @@ class ComponentTestCase(TestCase): } trade_gen = SpecificEquityTrades(*args_a, **kwargs_a) - monitor.add_to_topology(trade_gen.get_hash()) - launch_monitor(monitor) comp_a = Component( trade_gen, @@ -77,9 +87,14 @@ class ComponentTestCase(TestCase): DATASOURCE_UNFRAME ) + launch_monitor(monitor) + for event in comp_a: log.info(event) + # wait for the sending process to exit + comp_a.proc.join() + def test_sort(self): monitor = create_monitor(allocator) @@ -97,7 +112,6 @@ class ComponentTestCase(TestCase): 'count' : count } trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a) - monitor.add_to_topology(trade_gen_a.get_hash()) #Set up source b. Two minutes between events. args_b = tuple() @@ -109,7 +123,6 @@ class ComponentTestCase(TestCase): 'count' : count } trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b) - monitor.add_to_topology(trade_gen_b.get_hash()) #Set up source c. Three minutes between events. args_c = tuple() @@ -122,9 +135,7 @@ class ComponentTestCase(TestCase): } trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c) - monitor.add_to_topology(trade_gen_c.get_hash()) - launch_monitor(monitor) comp_a = Component( trade_gen_a, @@ -154,6 +165,8 @@ class ComponentTestCase(TestCase): sorted_out = date_sorted_sources(*sources) + launch_monitor(monitor) + prev = None sort_count = 0 for msg in sorted_out: @@ -164,3 +177,160 @@ class ComponentTestCase(TestCase): sort_count += 1 self.assertEqual(count*3, sort_count) + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + comp_c.proc.join() + + + def test_full(self): + monitor = create_monitor(allocator) + + filter = [2,3] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'count' : 325, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(hours = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : 7500, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + # ------------------------ + # Run sources in dedicated processes + comp_a = Component( + source_a, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + source_a.get_hash() + ) + + comp_b = Component( + source_b, + monitor, + allocator.lease(1)[0], + DATASOURCE_FRAME, + DATASOURCE_UNFRAME, + source_b.get_hash() + ) + + # Date sort the sources, and run the sort in a dedicated + # process + sources = [comp_a, comp_b] + + sorted_out = date_sorted_sources(*sources) + + #launch_monitor(monitor) + #import nose.tools; nose.tools.set_trace() + #for feed_msg in sorted_out: + # log.info(pf(feed_msg)) + + #return + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME, + "sort" + ) + + + passthrough = StatefulTransform(Passthrough) + mavg_price = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) + + merged_gen = merged_transforms(sorted, passthrough, mavg_price) + + merged = Component( + merged_gen, + monitor, + allocator.lease(1)[0], + MERGE_FRAME, + MERGE_UNFRAME, + "merge" + ) + + algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) + environment = create_trading_environment(year = 2012) + style = SIMULATION_STYLE.FIXED_SLIPPAGE + + trading_client = tsc(algo, environment, style) + + launch_monitor(monitor) + for message in trading_client.simulate(merged): + log.info(pf(message)) + + + # wait for processes to finish + comp_a.proc.join() + comp_b.proc.join() + sorted.proc.join() + merged.proc.join() + return + + + + def test_compound(self): + monitor = create_monitor(allocator) + + filter = [2,3] + #Set up source a. One minute between events. + args_a = tuple() + kwargs_a = { + 'count' : 325, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(hours = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. Two minutes between events. + args_b = tuple() + kwargs_b = { + 'count' : 7500, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + sorted_out = date_sorted_sources(source_a, source_b) + + sorted = Component( + sorted_out, + monitor, + allocator.lease(1)[0], + FEED_FRAME, + FEED_UNFRAME + ) + + launch_monitor(monitor) + + for event in sorted: + log.info(event) + + + sorted.proc.join() diff --git a/zipline/core/component.py b/zipline/core/component.py index 313e76e5..c9ce56e6 100644 --- a/zipline/core/component.py +++ b/zipline/core/component.py @@ -51,7 +51,8 @@ class Component(object): monitor, socket_uri, frame, - unframe + unframe, + component_id ): # ----------------- @@ -59,7 +60,7 @@ class Component(object): # ----------------- self.generator = generator self.frame = frame - self.component_id = self.generator.get_hash() + self.component_id = component_id # lock for waiting on monitor "GO" self.waiting = None @@ -99,15 +100,16 @@ class Component(object): # first, start the generator in its own process. Once # Monitor says "go", Events from the generator will be # FRAME'd and PUSH'd to self.socket_uri. - proc = multiprocessing.Process( - target=self.loop_send - ) - proc.start() + monitor.add_to_topology(self.component_id) - # ------------ - # Message Receiver/Generator - # ------------ - self.recv_gen = self.create_recv_gen() + self.proc = multiprocessing.Process( + target=self.loop_send + ) + self.proc.start() + + # Placeholder for receive generator, which will be + # created in __iter__ + self.recv_gen = None # ------------ @@ -123,8 +125,8 @@ class Component(object): """ try: # The process title so you can watch it in top, ps. - setproctitle(self.generator.__class__.__name__) self.prefix = "FORK-" + setproctitle(self.get_id) log.info("Start %r" % self) log.info("Pid %s" % os.getpid()) @@ -134,14 +136,15 @@ class Component(object): self.signal_ready() self.lock_ready() - self.wait_ready() - - # ----------------------- - # YOU SHALL NOT PASS!!!!! - # ----------------------- - # ... until the monitor signals GO + msg = None for event in self.generator: + + if hasattr(event, 'dt') and event.dt == 'DONE': + continue + + self.wait_ready() + self.heartbeat() msg = self.frame(event) self.out_socket.send(msg) @@ -163,9 +166,6 @@ class Component(object): def create_recv_gen(self): try: - self.open(send=False) - self.signal_ready() - self.lock_ready() # return the generator return self.loop_recv() except Exception as exc: @@ -175,8 +175,12 @@ class Component(object): def loop_recv(self): try: + self.open(send=False) + self.signal_ready() + self.lock_ready() + # we block on ready here until monitor sends the GO - self.wait_ready() + # self.wait_ready() for event in self.gen_from_poller(self.poll, self.in_socket, self.unframe): yield event @@ -189,7 +193,10 @@ class Component(object): def gen_from_poller(self, poller, in_socket, unframe): while True: - socks = dict(poller.poll(0)) + # Since we will yield None to avoid blocking, we need + # to have a small delay to give the poller a chance + # to receive a message from upstream. + socks = dict(poller.poll(100)) self.heartbeat() if socks.get(in_socket) == zmq.POLLIN: message = in_socket.recv() @@ -198,6 +205,8 @@ class Component(object): else: event = unframe(message) yield event + else: + yield def handle_exception(self, exc, re_raise=False): if isinstance(exc, KillSignal): @@ -215,6 +224,8 @@ class Component(object): return self def next(self): + if not self.recv_gen: + self.recv_gen = self.create_recv_gen() return self.recv_gen.next() # ---------------------------- diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index 32035492..f6434918 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -17,9 +17,9 @@ def merge(stream_in, tnfm_ids): and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. """ - + assert isinstance(tnfm_ids, list) - + # Set up an internal queue for each expected source. tnfms = {} for id in tnfm_ids: @@ -36,7 +36,7 @@ def merge(stream_in, tnfm_ids): id = message.tnfm_id assert id in tnfm_ids, \ "Message from unexpected tnfm: %s, %s" % (id, tnfm_ids) - + tnfms[id].append(message) # Only pop messages when we have a pending message from @@ -58,13 +58,13 @@ def merge_one(sources): event_fields = ndict() for key, queue in sources.iteritems(): - + # Add transform value to the transforms dict. message = queue.popleft() event_fields[message.tnfm_id] = message.tnfm_value del message['tnfm_id'] del message['tnfm_value'] - + # Merge any remaining fields into the event dict. event_fields.merge(message) return event_fields diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 564284b5..2883733a 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -53,16 +53,16 @@ class StatefulTransform(object): "Stateful transform requires a class." assert tnfm_class.__dict__.has_key('update'), \ "Stateful transform requires the class to have an update method" - + self.forward_all = tnfm_class.__dict__.get('FORWARDER', False) self.update_in_place = tnfm_class.__dict__.get('UPDATER', False) # You can't be both a forwarded and an updater. assert not all([self.forward_all, self.update_in_place]) - + # Create an instance of our transform class. self.state = tnfm_class(*args, **kwargs) - + # Create the string associated with this generator's output. self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs) @@ -76,7 +76,11 @@ class StatefulTransform(object): # IMPORTANT: Messages may contain pointers that are shared with # other streams, so we only manipulate copies. for message in stream_in: - + # allow upstream generators to yield None to avoid + # blocking. + if message == None: + continue + assert_sort_unframe_protocol(message) message_copy = deepcopy(message) @@ -90,7 +94,7 @@ class StatefulTransform(object): out_message.tnfm_id = self.namestring out_message.tnfm_value = tnfm_value yield out_message - + # Our expectation is that the transform simply updated the # message it was passed. Useful for chaining together # multiple transforms, e.g. TransactionSimulator/PerformanceTracker. diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index f979e063..071ce5fc 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -49,7 +49,9 @@ def roundrobin(sources, namestrings): for namestring, source in mapping.iteritems(): try: message = source.next() - yield message + # allow sources to yield None to avoid blocking. + if message: + yield message except StopIteration: yield done_message(namestring) del mapping[namestring]