From a6ce57ef4f92364fdfa23bbfcdfe4687ddddf8e6 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Wed, 30 Jan 2013 16:23:40 -0500 Subject: [PATCH] Removes code branching on sequential/merged flags in StatefulTransform. So that the unit tests exercise the same transform logic as what is executed a TradingAlgorithm object. --- tests/test_transforms.py | 14 ++++++----- zipline/gens/composites.py | 7 ------ zipline/transforms/utils.py | 49 +++---------------------------------- 3 files changed, 11 insertions(+), 59 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 233331c6..225f1b9d 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -183,7 +183,7 @@ class TestFinanceTransforms(TestCase): transformed = list(vwap.transform(self.source)) # Output values - tnfm_vals = [message.tnfm_value for message in transformed] + tnfm_vals = [message[vwap.get_hash()] for message in transformed] # "Hand calculated" values. expected = [ (10.0 * 100) / 100.0, @@ -202,7 +202,7 @@ class TestFinanceTransforms(TestCase): returns = Returns(1) transformed = list(returns.transform(self.source)) - tnfm_vals = [message.tnfm_value for message in transformed] + tnfm_vals = [message[returns.get_hash()] for message in transformed] # No returns for the first event because we don't have a # previous close. @@ -226,7 +226,7 @@ class TestFinanceTransforms(TestCase): returns = StatefulTransform(Returns, 2) transformed = list(returns.transform(self.source)) - tnfm_vals = [message.tnfm_value for message in transformed] + tnfm_vals = [message[returns.get_hash()] for message in transformed] expected = [ 0.0, @@ -248,8 +248,10 @@ class TestFinanceTransforms(TestCase): transformed = list(mavg.transform(self.source)) # Output values. - tnfm_prices = [message.tnfm_value.price for message in transformed] - tnfm_volumes = [message.tnfm_value.volume for message in transformed] + tnfm_prices = [message[mavg.get_hash()].price + for message in transformed] + tnfm_volumes = [message[mavg.get_hash()].volume + for message in transformed] # "Hand-calculated" values expected_prices = [ @@ -289,7 +291,7 @@ class TestFinanceTransforms(TestCase): transformed = list(stddev.transform(self.source)) - vals = [message.tnfm_value for message in transformed] + vals = [message[stddev.get_hash()] for message in transformed] expected = [ None, diff --git a/zipline/gens/composites.py b/zipline/gens/composites.py index dba86587..3be5a605 100644 --- a/zipline/gens/composites.py +++ b/zipline/gens/composites.py @@ -39,13 +39,6 @@ def sequential_transforms(stream_in, *transforms): Each transform application will add a new entry indexed to the transform's hash string. """ - - assert isinstance(transforms, (list, tuple)) - - for tnfm in transforms: - tnfm.sequential = True - tnfm.merged = False - # Recursively apply all transforms to the stream. stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream), transforms, diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index abf2223a..7abe4158 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -85,12 +85,6 @@ class StatefulTransform(object): # behavior if we are being fed to merged_transforms. self.passthrough = hasattr(tnfm_class, 'PASSTHROUGH') - # Flags specifying how to append the calculated value. - # Merged is the default for ease of testing, but we use sequential - # in production. - self.sequential = False - self.merged = True - # Create an instance of our transform class. if isinstance(tnfm_class, TransformMeta): # Classes derived TransformMeta have their __call__ @@ -130,48 +124,11 @@ class StatefulTransform(object): assert_sort_unframe_protocol(message) - # This flag is set by by merged_transforms to ensure - # isolation of messages. - if self.merged: - message = deepcopy(message) - tnfm_value = self.state.update(message) - # PASSTHROUGH flag means we want to keep all original - # values, plus append tnfm_id and tnfm_value. Used for - # preserving the original event fields when our output - # will be fed into a merge. Currently only Passthrough - # uses this flag. - if self.passthrough and self.merged: - out_message = message - out_message.tnfm_id = self.namestring - out_message.tnfm_value = tnfm_value - yield out_message - - # If the merged flag is set, we create a new message - # containing just the tnfm_id, the event's datetime, and - # the calculated tnfm_value. This is the default behavior - # for a non-passthrough transform being fed into a merge. - elif self.merged: - out_message = TransformMessage() - out_message.tnfm_id = self.namestring - out_message.tnfm_value = tnfm_value - out_message.dt = message.dt - yield out_message - - # Sequential flag should be used to add a single new - # key-value pair to the event. The new key is this - # transform's namestring, and its value is the value - # returned by state.update(event). This is almost - # identical to the behavior of FORWARDER, except we - # compress the two calculated values (tnfm_id, and - # tnfm_value) into a single field. This mode is used by - # the sequential_transforms composite and is the default - # if no behavior is specified by the internal state class. - elif self.sequential: - out_message = message - out_message[self.namestring] = tnfm_value - yield out_message + out_message = message + out_message[self.namestring] = tnfm_value + yield out_message log.info('Finished StatefulTransform [%s]' % self.get_hash())