mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 17:05:35 +08:00
MAINT: Make load_adjusted_array return a dict.
Rather than a list that's ordered the same as the received columns. Most nontrivial loaders were constructing dicts internally and then converting back to lists, only to have the engine convert **back again** into a dict. This cuts out the middleman, and prevents bugs due to incorrect ordering of the output arrays.
This commit is contained in:
@@ -80,7 +80,7 @@ class DataFrameLoaderTestCase(TestCase):
|
||||
self.dates[dates_slice],
|
||||
self.sids[sids_slice],
|
||||
self.mask[dates_slice, sids_slice],
|
||||
)
|
||||
).values()
|
||||
|
||||
for idx, window in enumerate(adj_array.traverse(window_length=3)):
|
||||
expected = baseline.values[dates_slice, sids_slice][idx:idx + 3]
|
||||
|
||||
@@ -35,6 +35,7 @@ from pandas import (
|
||||
Timestamp,
|
||||
)
|
||||
from testfixtures import TempDirectory
|
||||
from toolz.curried.operator import getitem
|
||||
|
||||
from zipline.lib.adjustment import Float64Multiply
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
@@ -422,12 +423,13 @@ class USEquityPricingLoaderTestCase(TestCase):
|
||||
adjustment_reader,
|
||||
)
|
||||
|
||||
closes, volumes = pricing_loader.load_adjusted_array(
|
||||
results = pricing_loader.load_adjusted_array(
|
||||
columns,
|
||||
dates=query_days,
|
||||
assets=self.assets,
|
||||
mask=ones((len(query_days), len(self.assets)), dtype=bool),
|
||||
)
|
||||
closes, volumes = map(getitem(results), columns)
|
||||
|
||||
expected_baseline_closes = self.bcolz_writer.expected_values_2d(
|
||||
shifted_query_days,
|
||||
@@ -500,12 +502,13 @@ class USEquityPricingLoaderTestCase(TestCase):
|
||||
adjustment_reader,
|
||||
)
|
||||
|
||||
highs, volumes = pricing_loader.load_adjusted_array(
|
||||
results = pricing_loader.load_adjusted_array(
|
||||
columns,
|
||||
dates=query_days,
|
||||
assets=Int64Index(arange(1, 7)),
|
||||
mask=ones((len(query_days), 6), dtype=bool),
|
||||
)
|
||||
highs, volumes = map(getitem(results), columns)
|
||||
|
||||
expected_baseline_highs = self.bcolz_writer.expected_values_2d(
|
||||
shifted_query_days,
|
||||
|
||||
@@ -11,7 +11,6 @@ from six import (
|
||||
iteritems,
|
||||
with_metaclass,
|
||||
)
|
||||
from six.moves import zip_longest
|
||||
from numpy import array
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
@@ -342,12 +341,10 @@ class SimplePipelineEngine(object):
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = get_loader(term)
|
||||
loaded = tuple(loader.load_adjusted_array(
|
||||
loaded = loader.load_adjusted_array(
|
||||
to_load, mask_dates, assets, mask,
|
||||
))
|
||||
assert len(to_load) == len(loaded)
|
||||
for loaded_term, adj_array in zip_longest(to_load, loaded):
|
||||
workspace[loaded_term] = adj_array
|
||||
)
|
||||
workspace.update(loaded)
|
||||
else:
|
||||
workspace[term] = term._compute(
|
||||
self._inputs_for_term(term, workspace, graph),
|
||||
|
||||
@@ -761,19 +761,11 @@ class BlazeLoader(dict):
|
||||
raise KeyError(column)
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
return map(
|
||||
op.getitem(
|
||||
dict(concat(map(
|
||||
partial(
|
||||
self._load_dataset,
|
||||
dates,
|
||||
assets,
|
||||
mask
|
||||
),
|
||||
itervalues(groupby(getdataset, columns))
|
||||
))),
|
||||
),
|
||||
columns,
|
||||
return dict(
|
||||
concat(map(
|
||||
partial(self._load_dataset, dates, assets, mask),
|
||||
itervalues(groupby(getdataset, columns))
|
||||
))
|
||||
)
|
||||
|
||||
def _load_dataset(self, dates, assets, mask, columns):
|
||||
|
||||
@@ -56,18 +56,18 @@ class USEquityPricingLoader(PipelineLoader):
|
||||
end_date,
|
||||
assets,
|
||||
)
|
||||
|
||||
adjustments = self.adjustments_loader.load_adjustments(
|
||||
columns,
|
||||
dates,
|
||||
assets,
|
||||
)
|
||||
|
||||
return [
|
||||
adjusted_arrays = [
|
||||
adjusted_array(raw_array, mask, col_adjustments)
|
||||
for raw_array, col_adjustments in zip(raw_arrays, adjustments)
|
||||
]
|
||||
|
||||
return dict(zip(columns, adjusted_arrays))
|
||||
|
||||
|
||||
def _shift_dates(dates, start_date, end_date, shift):
|
||||
try:
|
||||
|
||||
@@ -177,10 +177,11 @@ class DataFrameLoader(PipelineLoader):
|
||||
good_dates = (date_indexer != -1)
|
||||
good_assets = (assets_indexer != -1)
|
||||
|
||||
return [adjusted_array(
|
||||
arrays = [adjusted_array(
|
||||
# Pull out requested columns/rows from our baseline data.
|
||||
data=self.baseline[ix_(date_indexer, assets_indexer)],
|
||||
# Mask out requested columns/rows that didnt match.
|
||||
mask=(good_assets & good_dates[:, None]) & mask,
|
||||
adjustments=self.format_adjustments(dates, assets),
|
||||
)]
|
||||
return dict(zip(columns, arrays))
|
||||
|
||||
@@ -71,13 +71,15 @@ class ConstantLoader(PipelineLoader):
|
||||
"""
|
||||
Load by delegating to sub-loaders.
|
||||
"""
|
||||
out = []
|
||||
out = {}
|
||||
for col in columns:
|
||||
try:
|
||||
loader = self._loaders[col]
|
||||
except KeyError:
|
||||
raise ValueError("Couldn't find loader for %s" % col)
|
||||
out.extend(loader.load_adjusted_array([col], dates, assets, mask))
|
||||
out.update(
|
||||
loader.load_adjusted_array([col], dates, assets, mask)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user