diff --git a/tests/test_feed.py b/tests/test_feed.py deleted file mode 100644 index 21e8afb3..00000000 --- a/tests/test_feed.py +++ /dev/null @@ -1,233 +0,0 @@ -from unittest2 import TestCase -from itertools import cycle, chain -from datetime import datetime, timedelta -from collections import deque - -from zipline import ndict -from zipline.gens.sort import \ - date_sort, \ - ready, \ - done, \ - queue_is_ready,\ - queue_is_done -from zipline.gens.utils import hash_args, alternate -from zipline.gens.tradegens import date_gen, SpecificEquityTrades -from zipline.gens.composites import date_sorted_sources - -import zipline.protocol as zp - -class HelperTestCase(TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_individual_queue_logic(self): - queue = deque() - # Empty queues are neither done nor ready. - assert not queue_is_ready(queue) - assert not queue_is_done(queue) - - queue.append(to_dt('foo')) - assert queue_is_ready(queue) - assert not queue_is_done(queue) - - - queue.appendleft(to_dt('DONE')) - assert queue_is_ready(queue) - - # Checking done when we have a message after done will trip an assert. - self.assertRaises(AssertionError, queue_is_done, queue) - - queue.pop() - assert queue_is_ready(queue) - assert queue_is_done(queue) - - def test_pop_logic(self): - sources = {} - ids = ['a', 'b', 'c'] - for id in ids: - sources[id] = deque() - - assert not ready(sources) - assert not done(sources) - - # All sources must have a message to be ready/done - sources['a'].append(to_dt("datetime")) - assert not ready(sources) - assert not done(sources) - sources['a'].pop() - - for id in ids: - sources[id].append(to_dt("datetime")) - - assert ready(sources) - assert not done(sources) - - for id in ids: - sources[id].appendleft(to_dt("DONE")) - - # ["DONE", message] will trip an assert in queue_is_done. - assert ready(sources) - self.assertRaises(AssertionError, done, sources) - - for id in ids: - sources[id].pop() - - assert ready(sources) - assert done(sources) - -class DateSortTestCase(TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def run_date_sort(self, events, expected, source_ids): - """ - Take a list of events, their source_ids, and an expected sorting. - Assert that date_sort's output agrees with expected. - """ - sort_gen = date_sort(events, source_ids) - l = list(sort_gen) - assert l == expected - - def test_single_source(self): - source_ids = ['a'] - # 100 events, increasing by a minute at a time. - type = zp.DATASOURCE_TYPE.TRADE - dates = list(date_gen(count = 100)) - dates.append("DONE") - - # [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)] - event_args = zip(cycle(source_ids), iter(dates), cycle([type])) - - # Turn event_args into proper events. - events = [mock_data_unframe(*args) for args in event_args] - - # We don't expected Feed to yield the last event. - expected = events[:-1] - - event_gen = (e for e in events) - - self.run_date_sort(event_gen, expected, source_ids) - - def test_multi_source(self): - source_ids = ['a', 'b'] - type = zp.DATASOURCE_TYPE.TRADE - - # Set up source 'a'. Outputs 20 events with 2 minute deltas. - delta_a = timedelta(minutes = 2) - dates_a = list(date_gen(delta = delta_a, count = 20)) - dates_a.append("DONE") - - events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type])) - events_a = [mock_data_unframe(*args) for args in events_a_args] - - # Set up source 'b'. Outputs 10 events with 1 minute deltas. - delta_b = timedelta(minutes = 1) - dates_b = list(date_gen(delta = delta_b, count = 10)) - dates_b.append("DONE") - - events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type])) - events_b = [mock_data_unframe(*args) for args in events_b_args] - - # The expected output is all non-DONE events in both a and b, - # sorted first by dt and then by source_id. - non_dones = events_a[:-1] + events_b[:-1] - expected = sorted(non_dones, compare_by_dt_source_id) - - # Alternating between a and b. - interleaved = alternate(iter(events_a), iter(events_b)) - self.run_date_sort(interleaved, expected, source_ids) - - # All of a, then all of b. - - sequential = chain(iter(events_a), iter(events_b)) - self.run_date_sort(sequential, expected, source_ids) - - def test_sorted_sources(self): - - filter = [1,2] - #Set up source a. One hour between events. - args_a = tuple() - kwargs_a = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(hours = 1), - 'filter' : filter - } - #Set up source b. One day between events. - args_b = tuple() - kwargs_b = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(days = 1), - 'filter' : filter - } - #Set up source c. One minute between events. - args_c = tuple() - kwargs_c = {'sids' : [1,2,3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), - 'filter' : filter - } - # Set up source d. This should produce no events because the - # internal sids don't match the filter. - args_d = tuple() - kwargs_d = {'sids' : [3,4], - 'start' : datetime(2012,6,6,0), - 'delta' : timedelta(minutes = 1), - 'filter' : filter - } - - sources = (SpecificEquityTrades,) * 4 - source_args = (args_a, args_b, args_c, args_d) - source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d) - - # Generate our expected source_ids. - zip_args = zip(source_args, source_kwargs) - expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs) - for args, kwargs in zip_args] - - # Pipe our sources into sort. - sort_out = date_sorted_sources(sources, source_args, source_kwargs) - - # Read all the values from sort and assert that they arrive in - # the correct sorting with the expected hash values. - to_list = list(sort_out) - copy = to_list[:] - for e in to_list: - # All events should match one of our expected source_ids. - assert e.source_id in expected_ids - # But none of them should match source_d. - assert e.source_id != hash_args(*args_d, **kwargs_d) - - expected = sorted(copy, compare_by_dt_source_id) - assert to_list == expected - -def mock_data_unframe(source_id, dt, type): - event = ndict() - event.source_id = source_id - event.dt = dt - event.type = type - return event - -def to_dt(val): - return ndict({'dt': val}) - -def compare_by_dt_source_id(x,y): - if x.dt < y.dt: - return -1 - elif x.dt > y.dt: - return 1 - - elif x.source_id < y.source_id: - return -1 - elif x.source_id > y.source_id: - return 1 - - else: - return 0 diff --git a/tests/test_sorting.py b/tests/test_sorting.py new file mode 100644 index 00000000..966dec3f --- /dev/null +++ b/tests/test_sorting.py @@ -0,0 +1,257 @@ +import pytz + +from unittest2 import TestCase +from itertools import cycle, chain, izip, izip_longest +from datetime import datetime, timedelta +from collections import deque + +from zipline import ndict +from zipline.gens.sort import \ + date_sort, \ + ready, \ + done, \ + queue_is_ready,\ + queue_is_done +from zipline.gens.utils import hash_args, alternate, done_message +from zipline.gens.tradegens import date_gen, SpecificEquityTrades +from zipline.gens.composites import date_sorted_sources + +import zipline.protocol as zp + +class HelperTestCase(TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_individual_queue_logic(self): + queue = deque() + # Empty queues are neither done nor ready. + assert not queue_is_ready(queue) + assert not queue_is_done(queue) + + queue.append(to_dt('foo')) + assert queue_is_ready(queue) + assert not queue_is_done(queue) + + + queue.appendleft(to_dt('DONE')) + assert queue_is_ready(queue) + + # Checking done when we have a message after done will trip an assert. + self.assertRaises(AssertionError, queue_is_done, queue) + + queue.pop() + assert queue_is_ready(queue) + assert queue_is_done(queue) + + def test_pop_logic(self): + sources = {} + ids = ['a', 'b', 'c'] + for id in ids: + sources[id] = deque() + + assert not ready(sources) + assert not done(sources) + + # All sources must have a message to be ready/done + sources['a'].append(to_dt("datetime")) + assert not ready(sources) + assert not done(sources) + sources['a'].pop() + + for id in ids: + sources[id].append(to_dt("datetime")) + + assert ready(sources) + assert not done(sources) + + for id in ids: + sources[id].appendleft(to_dt("DONE")) + + # ["DONE", message] will trip an assert in queue_is_done. + assert ready(sources) + self.assertRaises(AssertionError, done, sources) + + for id in ids: + sources[id].pop() + + assert ready(sources) + assert done(sources) + +class DateSortTestCase(TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def run_date_sort(self, event_stream, expected, source_ids): + """ + Take a list of events, their source_ids, and an expected sorting. + Assert that date_sort's output agrees with expected. + """ + sort_out = date_sort(event_stream, source_ids) + for m1, m2 in izip_longest(sort_out, expected): + assert m1 == m2 + + def test_single_source(self): + + # Just using the built-in defaults. See + # zipline/gens/tradegens.py + source = SpecificEquityTrades() + expected = list(source) + source.rewind() + # The raw source doesn't handle done messaging, so we need to + # append a done message for sort to work properly. + with_done = chain(source, [done_message(source.get_hash())]) + self.run_date_sort(with_done, expected, [source.get_hash()]) + + def test_multi_source(self): + + filter = [2,3] + args_a = tuple() + kwargs_a = { + 'count' : 100, + 'sids' : [1,2,3], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 6), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + args_b = tuple() + kwargs_b = { + 'count' : 100, + 'sids' : [2,3,4], + 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 5), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + all_events = list(chain(source_a, source_b)) + + # The expected output is all events, sorted by dt with + # source_id as a tiebreaker. + expected = sorted(all_events, comp) + source_ids = [source_a.get_hash(), source_b.get_hash()] + + # Generating the events list consumes the sources. Rewind them + # for testing. + source_a.rewind() + source_b.rewind() + + # Append a done message to each source. + with_done_a = chain(source_a, [done_message(source_a.get_hash())]) + with_done_b = chain(source_b, [done_message(source_b.get_hash())]) + + interleaved = alternate(with_done_a, with_done_b) + + # Test sort with alternating messages from source_a and + # source_b. + self.run_date_sort(interleaved, expected, source_ids) + + source_a.rewind() + source_b.rewind() + with_done_a = chain(source_a, [done_message(source_a.get_hash())]) + with_done_b = chain(source_b, [done_message(source_b.get_hash())]) + + sequential = chain(with_done_a, with_done_b) + + # Test sort with all messages from a, followed by all messages + # from b. + + self.run_date_sort(sequential, expected, source_ids) + + + def test_sort_composite(self): + + filter = [1,2] + + #Set up source a. One hour between events. + args_a = tuple() + kwargs_a = { + 'count' : 100, + 'sids' : [1], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(hours = 1), + 'filter' : filter + } + source_a = SpecificEquityTrades(*args_a, **kwargs_a) + + #Set up source b. One day between events. + args_b = tuple() + kwargs_b = { + 'count' : 50, + 'sids' : [2], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(days = 1), + 'filter' : filter + } + source_b = SpecificEquityTrades(*args_b, **kwargs_b) + + #Set up source c. One minute between events. + args_c = tuple() + kwargs_c = { + 'count' : 150, + 'sids' : [1,2], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + source_c = SpecificEquityTrades(*args_c, **kwargs_c) + # Set up source d. This should produce no events because the + # internal sids don't match the filter. + args_d = tuple() + kwargs_d = { + 'count' : 50, + 'sids' : [3], + 'start' : datetime(2012,6,6,0), + 'delta' : timedelta(minutes = 1), + 'filter' : filter + } + source_d = SpecificEquityTrades(*args_d, **kwargs_d) + sources = [source_a, source_b, source_c, source_d] + hashes = [source.get_hash() for source in sources] + + sort_out = date_sorted_sources(*sources) + + # Read all the values from sort and assert that they arrive in + # the correct sorting with the expected hash values. + to_list = list(sort_out) + copy = to_list[:] + + # We should have 300 events (100 from a, 150 from b, 50 from c) + assert len(to_list) == 300 + + for e in to_list: + # All events should match one of our expected source_ids. + assert e.source_id in hashes + # But none of them should match source_d. + assert e.source_id != source_d.get_hash() + + # The events should be sorted by dt, with source_id as tiebreaker. + expected = sorted(copy, comp) + + assert to_list == expected + +def compare_by_dt_source_id(x,y): + if x.dt < y.dt: + return -1 + elif x.dt > y.dt: + return 1 + + elif x.source_id < y.source_id: + return -1 + elif x.source_id > y.source_id: + return 1 + + else: + return 0 + +#Alias for ease of use +comp = compare_by_dt_source_id diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index 2716ee32..4b5cd5ac 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -13,8 +13,9 @@ TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs']) def date_sorted_sources(*sources): """ - Takes an iterable of SortBundles, generating namestrings and initialized datasources - for each before piping them into a date_sort. + Takes an iterable of SortBundles, generating namestrings and + initialized datasources for each before piping them into a + date_sort. """ for source in sources: @@ -28,21 +29,21 @@ def date_sorted_sources(*sources): # one element at a time from each. stream_in = roundrobin(sources, names) - # Guarantee the flat stream will be sorted by date, using source_id as - # tie-breaker, which is fully deterministic (given deterministic string - # representation for all args/kwargs) + # Guarantee the flat stream will be sorted by date, using + # source_id as tie-breaker, which is fully deterministic (given + # deterministic string representation for all args/kwargs) return date_sort(stream_in, names) def merged_transforms(sorted_stream, *transforms): """ - A generator that takes the expected output of a date_sort, pipes it - through a given set of transforms, and runs the results throught a - merge to output a unified stream. tnfms should be a list of - pointers to generator functions. tnfm_args should be a list of - tuples, representing the arguments to be passed to each transform. - tnfm_kwargs should be a list of dictionaries representing keyword - arguments to each transform. + A generator that takes the expected output of a date_sort, pipes + it through a given set of transforms, and runs the results + through a merge to output a unified stream. tnfms should be a + list of pointers to generator functions. tnfm_args should be a + list of tuples, representing the arguments to be passed to each + transform. tnfm_kwargs should be a list of dictionaries + representing keyword arguments to each transform. """ for transform in transforms: assert isinstance(transform, StatefulTransform) @@ -62,15 +63,35 @@ def merged_transforms(sorted_stream, *transforms): # Roundrobin the outputs of our transforms to create a single flat # stream. to_merge = roundrobin(tnfm_gens, namestrings) - # Pipe the stream into merge. merged = merge(to_merge, namestrings) # Return the merged events. return merged -def zipline(sources, transforms, endpoint): - assert isinstance(sources, (list, tuple)) +def sequential_transforms(stream_in, *transforms): + """ + Apply each transform in transforms sequentially to each event in stream_in. + Each transform application will add a new entry indexed to the transform's + hash string. + """ + assert isinstance(transforms, (list, tuple)) + for tnfm in transforms: + tnfm.forward_all = False + tnfm.update_in_place = False + tnfm.append_value = True + + # Recursively apply all transforms to the stream. + stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream), + transforms, + stream_in) + return stream_out + + + + + + diff --git a/zipline/gens/examples.py b/zipline/gens/examples.py index a6a95f59..f3a0dd0b 100644 --- a/zipline/gens/examples.py +++ b/zipline/gens/examples.py @@ -1,14 +1,16 @@ import pytz +import time from time import sleep from pprint import pprint as pp from datetime import datetime, timedelta +from itertools import izip from zipline.utils.factory import create_trading_environment from zipline.test_algorithms import TestAlgorithm from zipline.gens.composites import SourceBundle, TransformBundle, \ - date_sorted_sources, merged_transforms + date_sorted_sources, merged_transforms, sequential_transforms from zipline.gens.tradegens import SpecificEquityTrades from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform from zipline.gens.tradesimulation import TradeSimulationClient as tsc @@ -18,43 +20,81 @@ import zipline.protocol as zp if __name__ == "__main__": filter = [2,3] - #Set up source a. One minute between events. + #Set up source a. Six minutes between events. args_a = tuple() kwargs_a = { - 'count' : 325, + 'count' : 1000, 'sids' : [1,2,3], 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), - 'delta' : timedelta(hours = 6), + 'delta' : timedelta(minutes = 6), 'filter' : filter } source_a = SpecificEquityTrades(*args_a, **kwargs_a) + source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a) - #Set up source b. Two minutes between events. + #Set up source b. Five minutes between events. args_b = tuple() kwargs_b = { - 'count' : 7500, + 'count' : 1000, 'sids' : [2,3,4], 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), 'delta' : timedelta(minutes = 5), 'filter' : filter } source_b = SpecificEquityTrades(*args_b, **kwargs_b) - - #Set up source c. Three minutes between events. + source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b) sorted = date_sorted_sources(source_a, source_b) + sorted_prime = date_sorted_sources( + source_a_prime, + source_b_prime + ) passthrough = StatefulTransform(Passthrough) - mavg_price = StatefulTransform(MovingAverage, timedelta(minutes = 20), ['price']) + mavg_price = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) + + passthrough_prime = StatefulTransform(Passthrough) + mavg_price_prime = StatefulTransform( + MovingAverage, + timedelta(minutes = 20), + ['price'] + ) merged = merged_transforms(sorted, passthrough, mavg_price) + start = time.time() + for message in merged: + assert 1 + 1 == 2 + stop = time.time() + merge_time = stop - start + print "Merge time: %s" % str(merge_time) + + sequential = sequential_transforms( + sorted_prime, + passthrough_prime, + mavg_price_prime + ) - algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) - environment = create_trading_environment(year = 2012) - style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE + start = time.time() + for message in sequential: + assert 1 + 1 == 2 + stop = time.time() + seq_time = stop - start + print "Sequential time: %s" % str(seq_time) + print "Merge/Seq: %s" % (str(merge_time/seq_time)) - trading_client = tsc(algo, environment, style) + +# merged = merged_transforms(sorted, passthrough, mavg_price) - for message in trading_client.simulate(merged): - pp(message) + # algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3]) +# environment = create_trading_environment(year = 2012) +# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE + +# trading_client = tsc(algo, environment, style) + +# for message in trading_client.simulate(merged): +# pp(message) diff --git a/zipline/gens/merge.py b/zipline/gens/merge.py index f6434918..c4afb1b4 100644 --- a/zipline/gens/merge.py +++ b/zipline/gens/merge.py @@ -11,8 +11,8 @@ from itertools import repeat def merge(stream_in, tnfm_ids): """ - A generator that takes a generator and a list of source_ids. We - maintain an internal queue for each id in source_ids. Once we + A generator that takes a generator and a list of transform ids. We + maintain an internal queue for each id in tnfm_ids. Once we have a message from every queue, we pop an event from each queue and merge them together into an event. We raise an error if we do not receive the same number of events from all sources. @@ -54,9 +54,8 @@ def merge(stream_in, tnfm_ids): yield done_message('Merge') def merge_one(sources): - dict_primer = zip(sources.keys(), repeat(None)) - event_fields = ndict() + event_fields = ndict() for key, queue in sources.iteritems(): # Add transform value to the transforms dict. diff --git a/zipline/gens/sort.py b/zipline/gens/sort.py index f6ff7a5e..3ff5ee3f 100644 --- a/zipline/gens/sort.py +++ b/zipline/gens/sort.py @@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids): have messages pending from all sources, we pull the earliest message and yield it. """ - assert isinstance(source_ids, (list, tuple)) # Set up an internal queue for each expected source. @@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids): message = pop_oldest(sources) assert_sort_protocol(message) yield message - + # We should have only a done message left in each queue. for queue in sources.itervalues(): assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue diff --git a/zipline/gens/transform.py b/zipline/gens/transform.py index 2883733a..36e15689 100644 --- a/zipline/gens/transform.py +++ b/zipline/gens/transform.py @@ -56,9 +56,12 @@ class StatefulTransform(object): self.forward_all = tnfm_class.__dict__.get('FORWARDER', False) self.update_in_place = tnfm_class.__dict__.get('UPDATER', False) + self.append_value = tnfm_class.__dict__.get('APPENDER', False) - # You can't be both a forwarded and an updater. - assert not all([self.forward_all, self.update_in_place]) + # You only one special behavior mode can be set. + assert sum(map(int, [self.forward_all, + self.update_in_place, + self.append_value])) <= 1 # Create an instance of our transform class. self.state = tnfm_class(*args, **kwargs) @@ -75,11 +78,15 @@ class StatefulTransform(object): def _gen(self, stream_in): # IMPORTANT: Messages may contain pointers that are shared with # other streams, so we only manipulate copies. + for message in stream_in: + # allow upstream generators to yield None to avoid # blocking. if message == None: continue + + #TODO: refactor this to avoid unnecessary copying. assert_sort_unframe_protocol(message) message_copy = deepcopy(message) @@ -87,22 +94,43 @@ class StatefulTransform(object): # Same shared pointer issue here as above. tnfm_value = self.state.update(deepcopy(message_copy)) - # If we want to keep all original values, plus append tnfm_id - # and tnfm_value. Used for Passthrough. + # FORWARDER flag means we want to keep all original + # values, plus append tnfm_id and tnfm_value. Used for + # preserving the original event fields when our output + # will be fed into a merge. if self.forward_all: out_message = message_copy out_message.tnfm_id = self.namestring out_message.tnfm_value = tnfm_value yield out_message - # Our expectation is that the transform simply updated the - # message it was passed. Useful for chaining together - # multiple transforms, e.g. TransactionSimulator/PerformanceTracker. + # UPDATER flag should be used for transforms that + # side-effectfully modify the event they are passed. + # Updated messages are passed along exactly as they are + # returned to use by our state class. Useful for chaining + # specific transforms that won't be fed to a merge. (See + # the implementation of TradeSimulationClient for example + # usage of this flag with PerformanceTracker and + # TransactionSimulator. elif self.update_in_place: yield tnfm_value + + # APPENDER flag should be used to add a single new + # key-value pair to the event. The new key is this + # transform's namestring, and it's value is the value + # returned by state.update(event). This is almost + # identical to the behavior of FORWARDER, except we + # compress the two calculated values (tnfm_id, and + # tnfm_value) into a single field. + elif self.append_value: + out_message = message_copy + out_message[self.namestring] = tnfm_value + yield out_message - # Otherwise send tnfm_id, tnfm_value, and the message - # date. Useful for transforms being piped to a merge. + # If no flags are set, we create a new message containing + # just the tnfm_id, the event's datetime, and the + # calculated tnfm_value. This is the default behavior for + # a transform being fed into a merge. else: out_message = ndict() out_message.tnfm_id = self.namestring diff --git a/zipline/gens/utils.py b/zipline/gens/utils.py index 071ce5fc..83372753 100644 --- a/zipline/gens/utils.py +++ b/zipline/gens/utils.py @@ -66,6 +66,14 @@ def hash_args(*args, **kwargs): hasher.update(combined) return hasher.hexdigest() +def sum_true(bool_iterable): + """ + Takes an iterable of boolean values and returns the number of + those values that are True. + """ + return sum(map(int, bool_iterable)) + + def assert_datasource_protocol(event): """Assert that an event meets the protocol for datasource outputs."""