Removes code branching on sequential/merged flags in StatefulTransform.

So that the unit tests exercise the same transform logic as what
is executed a TradingAlgorithm object.
This commit is contained in:
Eddie Hebert
2013-01-30 16:23:40 -05:00
parent 7443b0c602
commit a6ce57ef4f
3 changed files with 11 additions and 59 deletions
+8 -6
View File
@@ -183,7 +183,7 @@ class TestFinanceTransforms(TestCase):
transformed = list(vwap.transform(self.source))
# Output values
tnfm_vals = [message.tnfm_value for message in transformed]
tnfm_vals = [message[vwap.get_hash()] for message in transformed]
# "Hand calculated" values.
expected = [
(10.0 * 100) / 100.0,
@@ -202,7 +202,7 @@ class TestFinanceTransforms(TestCase):
returns = Returns(1)
transformed = list(returns.transform(self.source))
tnfm_vals = [message.tnfm_value for message in transformed]
tnfm_vals = [message[returns.get_hash()] for message in transformed]
# No returns for the first event because we don't have a
# previous close.
@@ -226,7 +226,7 @@ class TestFinanceTransforms(TestCase):
returns = StatefulTransform(Returns, 2)
transformed = list(returns.transform(self.source))
tnfm_vals = [message.tnfm_value for message in transformed]
tnfm_vals = [message[returns.get_hash()] for message in transformed]
expected = [
0.0,
@@ -248,8 +248,10 @@ class TestFinanceTransforms(TestCase):
transformed = list(mavg.transform(self.source))
# Output values.
tnfm_prices = [message.tnfm_value.price for message in transformed]
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
tnfm_prices = [message[mavg.get_hash()].price
for message in transformed]
tnfm_volumes = [message[mavg.get_hash()].volume
for message in transformed]
# "Hand-calculated" values
expected_prices = [
@@ -289,7 +291,7 @@ class TestFinanceTransforms(TestCase):
transformed = list(stddev.transform(self.source))
vals = [message.tnfm_value for message in transformed]
vals = [message[stddev.get_hash()] for message in transformed]
expected = [
None,
-7
View File
@@ -39,13 +39,6 @@ def sequential_transforms(stream_in, *transforms):
Each transform application will add a new entry indexed to the transform's
hash string.
"""
assert isinstance(transforms, (list, tuple))
for tnfm in transforms:
tnfm.sequential = True
tnfm.merged = False
# Recursively apply all transforms to the stream.
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
transforms,
+3 -46
View File
@@ -85,12 +85,6 @@ class StatefulTransform(object):
# behavior if we are being fed to merged_transforms.
self.passthrough = hasattr(tnfm_class, 'PASSTHROUGH')
# Flags specifying how to append the calculated value.
# Merged is the default for ease of testing, but we use sequential
# in production.
self.sequential = False
self.merged = True
# Create an instance of our transform class.
if isinstance(tnfm_class, TransformMeta):
# Classes derived TransformMeta have their __call__
@@ -130,48 +124,11 @@ class StatefulTransform(object):
assert_sort_unframe_protocol(message)
# This flag is set by by merged_transforms to ensure
# isolation of messages.
if self.merged:
message = deepcopy(message)
tnfm_value = self.state.update(message)
# PASSTHROUGH flag means we want to keep all original
# values, plus append tnfm_id and tnfm_value. Used for
# preserving the original event fields when our output
# will be fed into a merge. Currently only Passthrough
# uses this flag.
if self.passthrough and self.merged:
out_message = message
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
yield out_message
# If the merged flag is set, we create a new message
# containing just the tnfm_id, the event's datetime, and
# the calculated tnfm_value. This is the default behavior
# for a non-passthrough transform being fed into a merge.
elif self.merged:
out_message = TransformMessage()
out_message.tnfm_id = self.namestring
out_message.tnfm_value = tnfm_value
out_message.dt = message.dt
yield out_message
# Sequential flag should be used to add a single new
# key-value pair to the event. The new key is this
# transform's namestring, and its value is the value
# returned by state.update(event). This is almost
# identical to the behavior of FORWARDER, except we
# compress the two calculated values (tnfm_id, and
# tnfm_value) into a single field. This mode is used by
# the sequential_transforms composite and is the default
# if no behavior is specified by the internal state class.
elif self.sequential:
out_message = message
out_message[self.namestring] = tnfm_value
yield out_message
out_message = message
out_message[self.namestring] = tnfm_value
yield out_message
log.info('Finished StatefulTransform [%s]' % self.get_hash())