From 11ba76874c41b25e97d28befac6bd05950271ee5 Mon Sep 17 00:00:00 2001 From: llllllllll Date: Fri, 9 Oct 2015 18:36:41 -0400 Subject: [PATCH] ENH: allow columns from different datasets --- zipline/pipeline/loaders/blaze.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/zipline/pipeline/loaders/blaze.py b/zipline/pipeline/loaders/blaze.py index 6bc6c4a5..0262cbfd 100644 --- a/zipline/pipeline/loaders/blaze.py +++ b/zipline/pipeline/loaders/blaze.py @@ -126,7 +126,6 @@ from __future__ import division, absolute_import from abc import ABCMeta, abstractproperty from collections import namedtuple, defaultdict from itertools import count -from operator import attrgetter import warnings from weakref import WeakKeyDictionary @@ -148,10 +147,12 @@ from toolz import ( compose, concat, flip, + groupby, identity, memoize, ) -from six import with_metaclass, PY2 +import toolz.curried.operator as op +from six import with_metaclass, PY2, iteritems from ..data.dataset import DataSet, Column @@ -172,7 +173,7 @@ traversable_nodes = ( bz.expr.Label, ) is_invalid_deltas_node = complement(flip(isinstance, valid_deltas_node_types)) -getname = attrgetter('__name__') +getname = op.attrgetter('__name__') class _ExprRepr(object): @@ -553,8 +554,8 @@ def from_blaze(expr, return ds -getdataset = attrgetter('dataset') -dataset_name = attrgetter('name') +getdataset = op.attrgetter('dataset') +dataset_name = op.attrgetter('name') def inline_novel_deltas(base, deltas, dates): @@ -712,6 +713,17 @@ class BlazeLoader(dict): return cls() def load_adjusted_array(self, columns, dates, assets, mask): + return map( + op.getitem( + dict(concat( + self._load_dataset(cs, dates, assets, mask) + for _, cs in iteritems(groupby(getdataset, columns)) + )), + ), + columns, + ) + + def _load_dataset(self, columns, dates, assets, mask): try: dataset, = set(map(getdataset, columns)) except ValueError: @@ -826,7 +838,7 @@ class BlazeLoader(dict): for column_idx, column in enumerate(columns): column_name = column.name - yield adjusted_array( + yield column, adjusted_array( column_view( dense_output[column_name].values.astype(column.dtype), ),