mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 09:39:14 +08:00
pipeline through merge
This commit is contained in:
+20
-26
@@ -8,8 +8,8 @@ from zipline.gens.sort import date_sort
|
||||
from zipline.gens.merge import merge
|
||||
from zipline.gens.transform import stateful_transform
|
||||
|
||||
SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs'])
|
||||
MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs'])
|
||||
SourceBundle = namedtuple("SourceBundle", ['source', 'args', 'kwargs'])
|
||||
TransformBundle = namedtuple("TransformBundle", ['tnfm', 'args', 'kwargs'])
|
||||
|
||||
def date_sorted_sources(bundles):
|
||||
"""
|
||||
@@ -18,19 +18,19 @@ def date_sorted_sources(bundles):
|
||||
"""
|
||||
assert isinstance(bundles, (list, tuple))
|
||||
for bundle in bundles:
|
||||
assert isinstance(bundle, SortBundle)
|
||||
assert isinstance(bundle, SourceBundle)
|
||||
|
||||
# Calculate namestring hashes to pass to date_sort.
|
||||
names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
|
||||
# Pass each source its arguments.
|
||||
initialized = [bundle.source(*bundle.args, **bundle.kwargs)
|
||||
source_gens = [bundle.source(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
|
||||
# Convert the list of generators into a flat stream by pulling
|
||||
# one element at a time from each.
|
||||
stream_in = roundrobin(initialized, names)
|
||||
stream_in = roundrobin(source_gens, names)
|
||||
|
||||
# Guarantee the flat stream will be sorted by date, using source_id as
|
||||
# tie-breaker, which is fully deterministic (given deterministic string
|
||||
@@ -38,7 +38,7 @@ def date_sorted_sources(bundles):
|
||||
return date_sort(stream_in, names)
|
||||
|
||||
|
||||
def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
def merged_transforms(sorted_stream, bundles):
|
||||
"""
|
||||
A generator that takes the expected output of a date_sort, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
@@ -48,36 +48,30 @@ def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
"""
|
||||
|
||||
# We should have as many sets of args as we have transforms.
|
||||
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
|
||||
# Generate expected hashes for each transform
|
||||
namestrings = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(sorted_stream, len(tnfms))
|
||||
split = tee(sorted_stream, len(bundles))
|
||||
# Package a stream copy with each bundle
|
||||
tnfms_with_streams = zip(split, bundles)
|
||||
|
||||
# Package each transform with a stream copy and set of args. Use a list
|
||||
# so that we can re-use this for calculating hashes.
|
||||
bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs))
|
||||
|
||||
bundles = tuple(bundle_gen)
|
||||
# list comprehension to create transform generators from
|
||||
# bundles
|
||||
# Convert the copies into transform streams.
|
||||
tnfm_gens = [
|
||||
stateful_transform(
|
||||
bundle.stream,
|
||||
stream_copy,
|
||||
bundle.tnfm,
|
||||
*bundle.args,
|
||||
**bundle.kwargs
|
||||
)
|
||||
for bundle in bundles]
|
||||
|
||||
# Generate expected hashes for each transform
|
||||
hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
for stream_copy, bundle in tnfms_with_streams
|
||||
]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(*tnfm_gens)
|
||||
to_merge = roundrobin(tnfm_gens, namestrings)
|
||||
|
||||
# Pipe the stream into merge.
|
||||
merged = merge(to_merge, hashes)
|
||||
return merged_transforms
|
||||
merged = merge(to_merge, namestrings)
|
||||
# Return the merged events.
|
||||
return merged
|
||||
|
||||
+50
-31
@@ -1,38 +1,57 @@
|
||||
from zipline.gens.composites import
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
from zipline.gens.composites import SourceBundle, TransformBundle, date_sorted_sources, merged_transforms
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import MovingAverage, Passthrough
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
filter = [1,2,3,4]
|
||||
#Set up source a. One hour between events.
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = ),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
kwargs_a = {
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a)
|
||||
|
||||
sources = (SpecificEquityTrades,) * 4
|
||||
source_args = (args_a, args_b, args_c, args_d)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
|
||||
|
||||
# Generate our expected source_ids.
|
||||
zip_args = zip(source_args, source_kwargs)
|
||||
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
|
||||
for args, kwargs in zip_args]
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'sids' : [2,3],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 2),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'sids' : [3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 3),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c)
|
||||
|
||||
source_bundles = (bundle_a, bundle_b, bundle_c)
|
||||
# Pipe our sources into sort.
|
||||
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
|
||||
sort_out = date_sorted_sources(source_bundles)
|
||||
|
||||
passthrough = TransformBundle(Passthrough, (), {})
|
||||
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price', 'volume']), {})
|
||||
tnfm_bundles = (passthrough, mavg_price)
|
||||
|
||||
merge_out = merged_transforms(sort_out, tnfm_bundles)
|
||||
|
||||
for message in merge_out:
|
||||
print "Event: \n", message.event
|
||||
print "Transforms: \n", message.tnfms
|
||||
|
||||
|
||||
|
||||
+25
-14
@@ -7,7 +7,7 @@ from collections import deque
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, \
|
||||
assert_merge_protocol
|
||||
|
||||
from itertools import repeat
|
||||
|
||||
def merge(stream_in, tnfm_ids):
|
||||
"""
|
||||
@@ -17,7 +17,7 @@ def merge(stream_in, tnfm_ids):
|
||||
and merge them together into an event. We raise an error if we
|
||||
do not receive the same number of events from all sources.
|
||||
"""
|
||||
|
||||
|
||||
assert isinstance(tnfm_ids, list)
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
@@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids):
|
||||
|
||||
# Process incoming streams.
|
||||
for message in stream_in:
|
||||
assert isinstance(message, tuple), \
|
||||
"Bad message in merge: %s" %message
|
||||
assert len(message) == 2
|
||||
id, value = message
|
||||
assert isinstance(message, ndict)
|
||||
assert message.has_key('tnfm_id')
|
||||
assert message.has_key('tnfm_value')
|
||||
assert message.has_key('dt')
|
||||
|
||||
id = message.tnfm_id
|
||||
assert id in tnfm_ids, \
|
||||
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
|
||||
assert isinstance(value, ndict), "Bad message in merge: %s" %message
|
||||
|
||||
tnfms[id].append(value)
|
||||
|
||||
tnfms[id].append(message)
|
||||
|
||||
# Only pop messages when we have a pending message from
|
||||
# all datasources. Stop if all sources have signalled done.
|
||||
|
||||
while ready(tnfms) and not done(tnfms):
|
||||
message = merge_one(tnfms)
|
||||
assert_merge_protocol(tnfm_ids, message)
|
||||
yield message
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
@@ -53,11 +53,22 @@ def merge(stream_in, tnfm_ids):
|
||||
"Bad last message in merge on exit: %s" % queue
|
||||
|
||||
def merge_one(sources):
|
||||
output = ndict()
|
||||
dict_primer = zip(sources.keys(), repeat(None))
|
||||
transforms = ndict(dict_primer)
|
||||
event_fields = ndict()
|
||||
|
||||
for key, queue in sources.iteritems():
|
||||
new_xform = ndict({key: queue.popleft()})
|
||||
output.merge(new_xform)
|
||||
return output
|
||||
|
||||
# Add transform value to the transforms dict.
|
||||
message = queue.popleft()
|
||||
transforms[message.tnfm_id] = message.tnfm_value
|
||||
del message['tnfm_id']
|
||||
del message['tnfm_value']
|
||||
|
||||
# Merge any remaining fields into the event dict.
|
||||
event_fields.merge(message)
|
||||
|
||||
return ndict({'event' : event_fields, 'tnfms' : transforms})
|
||||
|
||||
|
||||
#TODO: This is replicated in sort. Probably should be one source file.
|
||||
|
||||
@@ -106,10 +106,8 @@ def SpecificEquityTrades(*args, **config):
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Add a done message to the end of the stream. For a live
|
||||
# datasource this would be handled by the containing Component.
|
||||
out = chain(filtered, [mock_done(namestring)])
|
||||
return out
|
||||
# Return the filtered event stream.
|
||||
return filtered
|
||||
|
||||
def RandomEquityTrades(*args, **config):
|
||||
# We shouldn't get any positional args.
|
||||
|
||||
+34
-15
@@ -3,6 +3,7 @@ Generator versions of transforms.
|
||||
"""
|
||||
import types
|
||||
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from collections import deque, defaultdict
|
||||
from numbers import Number
|
||||
@@ -12,6 +13,7 @@ from zipline.gens.utils import assert_sort_unframe_protocol, \
|
||||
assert_transform_protocol, hash_args
|
||||
|
||||
class Passthrough(object):
|
||||
FORWARDER = True
|
||||
"""
|
||||
Trivial class for forwarding events.
|
||||
"""
|
||||
@@ -19,10 +21,7 @@ class Passthrough(object):
|
||||
pass
|
||||
|
||||
def update(self, event):
|
||||
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
|
||||
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
|
||||
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
|
||||
return event
|
||||
pass
|
||||
|
||||
def functional_transform(stream_in, func, *args, **kwargs):
|
||||
"""
|
||||
@@ -44,9 +43,13 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and passes it to a state class. For each call to update, the state
|
||||
class must produce a message to be fed downstream.
|
||||
class must produce a message to be fed downstream. Any transform class
|
||||
with the FORWARDER class variable set to true will forward all fields
|
||||
in the original message. Otherwise only dt, tnfm_id, and tnfm_value
|
||||
are forwarded.
|
||||
"""
|
||||
|
||||
forward_all_fields = tnfm_class.__dict__.get('FORWARDER', False)
|
||||
|
||||
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
|
||||
"Stateful transform requires a class."
|
||||
assert tnfm_class.__dict__.has_key('update'), \
|
||||
@@ -58,11 +61,31 @@ def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
|
||||
# Generate the string associated with this generator's output.
|
||||
namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
for message in stream_in:
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
out_value = state.update(message)
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
message_copy = deepcopy(message)
|
||||
|
||||
# Same shared pointer issue here as above.
|
||||
tnfm_value = state.update(deepcopy(message_copy))
|
||||
|
||||
# If we want to keep all original values, just append tnfm_id
|
||||
# and tnfm_value.
|
||||
if forward_all_fields:
|
||||
out_message = message_copy
|
||||
out_message.tnfm_id = namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# Otherwise send tnfm_id, tnfm_value, and the message date.
|
||||
else:
|
||||
out_message = ndict()
|
||||
out_message.tnfm_id = namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
out_message.dt = message_copy.dt
|
||||
yield out_message
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
@@ -70,6 +93,7 @@ class MovingAverage(object):
|
||||
Upon receipt of each message we update the
|
||||
corresponding window and return the calculated average.
|
||||
"""
|
||||
FORWARDER = False
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
self.delta = delta
|
||||
@@ -93,16 +117,11 @@ class MovingAverage(object):
|
||||
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
|
||||
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
|
||||
|
||||
output = ndict({'sid': event.sid, 'dt': event.dt})
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
averages = window.get_averages()
|
||||
|
||||
# Return the calculated averages along with
|
||||
output.merge(averages)
|
||||
return output
|
||||
return window.get_averages()
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
|
||||
+14
-5
@@ -1,6 +1,7 @@
|
||||
import pytz
|
||||
import numbers
|
||||
|
||||
from collections import OrderedDict
|
||||
from hashlib import md5
|
||||
from datetime import datetime
|
||||
from itertools import izip_longest
|
||||
@@ -16,8 +17,16 @@ def mock_raw_event(sid, dt):
|
||||
}
|
||||
return event
|
||||
|
||||
def mock_done(source_id):
|
||||
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
|
||||
def mock_done(id):
|
||||
return ndict({
|
||||
'dt' : "DONE",
|
||||
"source_id" : id,
|
||||
'tnfm_id' : id,
|
||||
'tnfm_value': None,
|
||||
'type' : 0
|
||||
})
|
||||
|
||||
done_message = mock_done
|
||||
|
||||
def alternate(g1, g2):
|
||||
"""Specialized version of roundrobin for just 2 generators."""
|
||||
@@ -36,14 +45,14 @@ def roundrobin(sources, namestrings):
|
||||
mapping = OrderedDict(zip(namestrings, sources))
|
||||
|
||||
# While our generators have not been exhausted, pull elements
|
||||
while mapping != []:
|
||||
for namestring, source in mapping:
|
||||
while mapping.keys() != []:
|
||||
for namestring, source in mapping.iteritems():
|
||||
try:
|
||||
message = source.next()
|
||||
yield message
|
||||
except StopIteration:
|
||||
yield done_message(namestring)
|
||||
del mapping(namestring)
|
||||
del mapping[namestring]
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user