mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 20:54:47 +08:00
Removes code branching on sequential/merged flags in StatefulTransform.
So that the unit tests exercise the same transform logic as what is executed a TradingAlgorithm object.
This commit is contained in:
@@ -183,7 +183,7 @@ class TestFinanceTransforms(TestCase):
|
||||
transformed = list(vwap.transform(self.source))
|
||||
|
||||
# Output values
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
tnfm_vals = [message[vwap.get_hash()] for message in transformed]
|
||||
# "Hand calculated" values.
|
||||
expected = [
|
||||
(10.0 * 100) / 100.0,
|
||||
@@ -202,7 +202,7 @@ class TestFinanceTransforms(TestCase):
|
||||
returns = Returns(1)
|
||||
|
||||
transformed = list(returns.transform(self.source))
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
tnfm_vals = [message[returns.get_hash()] for message in transformed]
|
||||
|
||||
# No returns for the first event because we don't have a
|
||||
# previous close.
|
||||
@@ -226,7 +226,7 @@ class TestFinanceTransforms(TestCase):
|
||||
returns = StatefulTransform(Returns, 2)
|
||||
|
||||
transformed = list(returns.transform(self.source))
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
tnfm_vals = [message[returns.get_hash()] for message in transformed]
|
||||
|
||||
expected = [
|
||||
0.0,
|
||||
@@ -248,8 +248,10 @@ class TestFinanceTransforms(TestCase):
|
||||
|
||||
transformed = list(mavg.transform(self.source))
|
||||
# Output values.
|
||||
tnfm_prices = [message.tnfm_value.price for message in transformed]
|
||||
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
|
||||
tnfm_prices = [message[mavg.get_hash()].price
|
||||
for message in transformed]
|
||||
tnfm_volumes = [message[mavg.get_hash()].volume
|
||||
for message in transformed]
|
||||
|
||||
# "Hand-calculated" values
|
||||
expected_prices = [
|
||||
@@ -289,7 +291,7 @@ class TestFinanceTransforms(TestCase):
|
||||
|
||||
transformed = list(stddev.transform(self.source))
|
||||
|
||||
vals = [message.tnfm_value for message in transformed]
|
||||
vals = [message[stddev.get_hash()] for message in transformed]
|
||||
|
||||
expected = [
|
||||
None,
|
||||
|
||||
@@ -39,13 +39,6 @@ def sequential_transforms(stream_in, *transforms):
|
||||
Each transform application will add a new entry indexed to the transform's
|
||||
hash string.
|
||||
"""
|
||||
|
||||
assert isinstance(transforms, (list, tuple))
|
||||
|
||||
for tnfm in transforms:
|
||||
tnfm.sequential = True
|
||||
tnfm.merged = False
|
||||
|
||||
# Recursively apply all transforms to the stream.
|
||||
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
|
||||
transforms,
|
||||
|
||||
@@ -85,12 +85,6 @@ class StatefulTransform(object):
|
||||
# behavior if we are being fed to merged_transforms.
|
||||
self.passthrough = hasattr(tnfm_class, 'PASSTHROUGH')
|
||||
|
||||
# Flags specifying how to append the calculated value.
|
||||
# Merged is the default for ease of testing, but we use sequential
|
||||
# in production.
|
||||
self.sequential = False
|
||||
self.merged = True
|
||||
|
||||
# Create an instance of our transform class.
|
||||
if isinstance(tnfm_class, TransformMeta):
|
||||
# Classes derived TransformMeta have their __call__
|
||||
@@ -130,48 +124,11 @@ class StatefulTransform(object):
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
|
||||
# This flag is set by by merged_transforms to ensure
|
||||
# isolation of messages.
|
||||
if self.merged:
|
||||
message = deepcopy(message)
|
||||
|
||||
tnfm_value = self.state.update(message)
|
||||
|
||||
# PASSTHROUGH flag means we want to keep all original
|
||||
# values, plus append tnfm_id and tnfm_value. Used for
|
||||
# preserving the original event fields when our output
|
||||
# will be fed into a merge. Currently only Passthrough
|
||||
# uses this flag.
|
||||
if self.passthrough and self.merged:
|
||||
out_message = message
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# If the merged flag is set, we create a new message
|
||||
# containing just the tnfm_id, the event's datetime, and
|
||||
# the calculated tnfm_value. This is the default behavior
|
||||
# for a non-passthrough transform being fed into a merge.
|
||||
elif self.merged:
|
||||
out_message = TransformMessage()
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
out_message.dt = message.dt
|
||||
yield out_message
|
||||
|
||||
# Sequential flag should be used to add a single new
|
||||
# key-value pair to the event. The new key is this
|
||||
# transform's namestring, and its value is the value
|
||||
# returned by state.update(event). This is almost
|
||||
# identical to the behavior of FORWARDER, except we
|
||||
# compress the two calculated values (tnfm_id, and
|
||||
# tnfm_value) into a single field. This mode is used by
|
||||
# the sequential_transforms composite and is the default
|
||||
# if no behavior is specified by the internal state class.
|
||||
elif self.sequential:
|
||||
out_message = message
|
||||
out_message[self.namestring] = tnfm_value
|
||||
yield out_message
|
||||
out_message = message
|
||||
out_message[self.namestring] = tnfm_value
|
||||
yield out_message
|
||||
|
||||
log.info('Finished StatefulTransform [%s]' % self.get_hash())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user