MAINT: Make load_adjusted_array return a dict.

Rather than a list that's ordered the same as the received columns.
Most nontrivial loaders were constructing dicts internally and then
converting back to lists, only to have the engine convert **back again**
into a dict.  This cuts out the middleman, and prevents bugs due to
incorrect ordering of the output arrays.
This commit is contained in:
Scott Sanderson
2015-11-03 10:56:04 -05:00
parent f98349e24d
commit 8cd4f7d100
7 changed files with 23 additions and 28 deletions
+1 -1
View File
@@ -80,7 +80,7 @@ class DataFrameLoaderTestCase(TestCase):
self.dates[dates_slice],
self.sids[sids_slice],
self.mask[dates_slice, sids_slice],
)
).values()
for idx, window in enumerate(adj_array.traverse(window_length=3)):
expected = baseline.values[dates_slice, sids_slice][idx:idx + 3]
@@ -35,6 +35,7 @@ from pandas import (
Timestamp,
)
from testfixtures import TempDirectory
from toolz.curried.operator import getitem
from zipline.lib.adjustment import Float64Multiply
from zipline.pipeline.loaders.synthetic import (
@@ -422,12 +423,13 @@ class USEquityPricingLoaderTestCase(TestCase):
adjustment_reader,
)
closes, volumes = pricing_loader.load_adjusted_array(
results = pricing_loader.load_adjusted_array(
columns,
dates=query_days,
assets=self.assets,
mask=ones((len(query_days), len(self.assets)), dtype=bool),
)
closes, volumes = map(getitem(results), columns)
expected_baseline_closes = self.bcolz_writer.expected_values_2d(
shifted_query_days,
@@ -500,12 +502,13 @@ class USEquityPricingLoaderTestCase(TestCase):
adjustment_reader,
)
highs, volumes = pricing_loader.load_adjusted_array(
results = pricing_loader.load_adjusted_array(
columns,
dates=query_days,
assets=Int64Index(arange(1, 7)),
mask=ones((len(query_days), 6), dtype=bool),
)
highs, volumes = map(getitem(results), columns)
expected_baseline_highs = self.bcolz_writer.expected_values_2d(
shifted_query_days,
+3 -6
View File
@@ -11,7 +11,6 @@ from six import (
iteritems,
with_metaclass,
)
from six.moves import zip_longest
from numpy import array
from pandas import (
DataFrame,
@@ -342,12 +341,10 @@ class SimplePipelineEngine(object):
key=lambda t: t.dataset
)
loader = get_loader(term)
loaded = tuple(loader.load_adjusted_array(
loaded = loader.load_adjusted_array(
to_load, mask_dates, assets, mask,
))
assert len(to_load) == len(loaded)
for loaded_term, adj_array in zip_longest(to_load, loaded):
workspace[loaded_term] = adj_array
)
workspace.update(loaded)
else:
workspace[term] = term._compute(
self._inputs_for_term(term, workspace, graph),
+5 -13
View File
@@ -761,19 +761,11 @@ class BlazeLoader(dict):
raise KeyError(column)
def load_adjusted_array(self, columns, dates, assets, mask):
return map(
op.getitem(
dict(concat(map(
partial(
self._load_dataset,
dates,
assets,
mask
),
itervalues(groupby(getdataset, columns))
))),
),
columns,
return dict(
concat(map(
partial(self._load_dataset, dates, assets, mask),
itervalues(groupby(getdataset, columns))
))
)
def _load_dataset(self, dates, assets, mask, columns):
@@ -56,18 +56,18 @@ class USEquityPricingLoader(PipelineLoader):
end_date,
assets,
)
adjustments = self.adjustments_loader.load_adjustments(
columns,
dates,
assets,
)
return [
adjusted_arrays = [
adjusted_array(raw_array, mask, col_adjustments)
for raw_array, col_adjustments in zip(raw_arrays, adjustments)
]
return dict(zip(columns, adjusted_arrays))
def _shift_dates(dates, start_date, end_date, shift):
try:
+2 -1
View File
@@ -177,10 +177,11 @@ class DataFrameLoader(PipelineLoader):
good_dates = (date_indexer != -1)
good_assets = (assets_indexer != -1)
return [adjusted_array(
arrays = [adjusted_array(
# Pull out requested columns/rows from our baseline data.
data=self.baseline[ix_(date_indexer, assets_indexer)],
# Mask out requested columns/rows that didnt match.
mask=(good_assets & good_dates[:, None]) & mask,
adjustments=self.format_adjustments(dates, assets),
)]
return dict(zip(columns, arrays))
+4 -2
View File
@@ -71,13 +71,15 @@ class ConstantLoader(PipelineLoader):
"""
Load by delegating to sub-loaders.
"""
out = []
out = {}
for col in columns:
try:
loader = self._loaders[col]
except KeyError:
raise ValueError("Couldn't find loader for %s" % col)
out.extend(loader.load_adjusted_array([col], dates, assets, mask))
out.update(
loader.load_adjusted_array([col], dates, assets, mask)
)
return out