mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 00:53:42 +08:00
file moves and style changes in sorted/merged layers
This commit is contained in:
+37
-12
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
from itertools import tee
|
||||
from itertools import tee, starmap
|
||||
from collections import namedtuple
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import roundrobin, hash_args
|
||||
@@ -7,6 +8,9 @@ from zipline.gens.sort import date_sort
|
||||
from zipline.gens.merge import merge
|
||||
from zipline.gens.transform import stateful_transform
|
||||
|
||||
SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs'])
|
||||
MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs'])
|
||||
|
||||
def date_sorted_sources(sources, source_args, source_kwargs):
|
||||
"""
|
||||
Takes a list of generator functions, a list of tuples of positional arguments,
|
||||
@@ -15,17 +19,31 @@ def date_sorted_sources(sources, source_args, source_kwargs):
|
||||
"""
|
||||
assert len(sources) == len(source_args) == len(source_kwargs)
|
||||
# Package up sources and arguments.
|
||||
arg_bundles = zip(sources, source_args, source_kwargs)
|
||||
|
||||
# Create a generator of SortBundle objects to be turned into
|
||||
# namestrings and generator objects.
|
||||
bundle_gen = starmap(SortBundle, zip(sources, source_args, source_kwargs))
|
||||
|
||||
# Load the results of the generator into a tuple so that the
|
||||
# results can be used twice (once in namestring comprehension,
|
||||
# once in the generator comprehension for intialized sources.
|
||||
bundles = tuple(bundle_gen)
|
||||
|
||||
# Calculate namestring hashes to pass to date_sort.
|
||||
namestrings = [source.__name__ + hash_args(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles]
|
||||
names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
# Pass each source its arguments.
|
||||
initialized = tuple(source(*args, **kwargs)
|
||||
for source, args, kwargs in arg_bundles)
|
||||
initialized = [bundle.source(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
|
||||
# Convert the list of generators into a flat stream by pulling
|
||||
# one element at a time from each.
|
||||
stream_in = roundrobin(*initialized)
|
||||
return date_sort(stream_in, namestrings)
|
||||
|
||||
# Guarantee the flat stream will be sorted by date, using source_id as
|
||||
# tie-breaker, which is fully deterministic (given deterministic string
|
||||
# representation for all args/kwargs)
|
||||
return date_sort(stream_in, names)
|
||||
|
||||
|
||||
def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
@@ -47,16 +65,23 @@ def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
|
||||
# Package each transform with a stream copy and set of args. Use a list
|
||||
# so that we can re-use this for calculating hashes.
|
||||
bundles = zip(split, tnfms, tnfm_args, tnfm_kwargs)
|
||||
bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs))
|
||||
|
||||
bundles = tuple(bundle_gen)
|
||||
# list comprehension to create transform generators from
|
||||
# bundles
|
||||
tnfm_gens = [stateful_transform(stream, tnfm, *args, **kwargs)
|
||||
for stream, tnfm, args, kwargs in bundles]
|
||||
tnfm_gens = [
|
||||
stateful_transform(
|
||||
bundle.stream,
|
||||
bundle.tnfm,
|
||||
*bundle.args,
|
||||
**bundle.kwargs
|
||||
)
|
||||
for bundle in bundles]
|
||||
|
||||
# Generate expected hashes for each transform
|
||||
hashes = [tnfm.__name__ + hash_args(*args, **kwargs)
|
||||
for _, tnfm, args, kwargs in bundles]
|
||||
hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(*tnfm_gens)
|
||||
|
||||
@@ -15,7 +15,7 @@ def date_sort(stream_in, source_ids):
|
||||
message and yield it.
|
||||
"""
|
||||
|
||||
assert isinstance(source_ids, list)
|
||||
assert isinstance(source_ids, (list, tuple))
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
sources = {}
|
||||
|
||||
@@ -9,9 +9,9 @@ from datetime import datetime, timedelta
|
||||
from zipline.utils.factory import create_trade
|
||||
from zipline.gens.utils import hash_args, mock_done
|
||||
|
||||
def date_gen(start=datetime(2012, 6, 6, 0),
|
||||
delta=timedelta(minutes = 1),
|
||||
count=100):
|
||||
def date_gen(start = datetime(2012, 6, 6, 0),
|
||||
delta = timedelta(minutes = 1),
|
||||
count = 100):
|
||||
"""
|
||||
Utility to generate a stream of dates.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user