From 7c8a44ecd7fe90f456b2373acbfe5196020d6f47 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Wed, 3 Feb 2016 16:16:06 -0500 Subject: [PATCH] ENH: allow users to specify odo kwargs in from_blaze --- zipline/pipeline/loaders/blaze/core.py | 58 +++++++++++--------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index c167120f..7b8b71a8 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -193,25 +193,7 @@ is_invalid_deltas_node = complement(flip(isinstance, valid_deltas_node_types)) get__name__ = op.attrgetter('__name__') -class _ExprRepr(object): - """Box for repring expressions with the str of the expression. - - Parameters - ---------- - expr : Expr - The expression to box for repring. - """ - __slots__ = 'expr', - - def __init__(self, expr): - self.expr = expr - - def __repr__(self): - return str(self.expr) - __str__ = __repr__ - - -class ExprData(namedtuple('ExprData', 'expr deltas resources')): +class ExprData(namedtuple('ExprData', 'expr deltas odo_kwargs')): """A pair of expressions and data resources. The expresions will be computed using the resources as the starting scope. @@ -221,20 +203,25 @@ class ExprData(namedtuple('ExprData', 'expr deltas resources')): The baseline values. deltas : Expr, optional The deltas for the data. - resources : resource or dict of resources, optional - The resources to compute the exprs against. + odo_kwargs : dict, optional + The keyword arguments to forward to the odo calls internally. """ - def __new__(cls, expr, deltas=None, resources=None): - return super(ExprData, cls).__new__(cls, expr, deltas, resources) + def __new__(cls, expr, deltas=None, odo_kwargs=None): + return super(ExprData, cls).__new__( + cls, + expr, + deltas, + odo_kwargs or {}, + ) def __repr__(self): # If the expressions have _resources() then the repr will - # drive computation so we box them. + # drive computation so we take the str here. cls = type(self) return super(ExprData, cls).__repr__(cls( - _ExprRepr(self.expr), - _ExprRepr(self.deltas), - self.resources, + str(self.expr), + str(self.deltas), + self.odo_kwargs, )) @@ -484,6 +471,7 @@ def from_blaze(expr, deltas='auto', loader=None, resources=None, + odo_kwargs=None, no_deltas_rule=no_deltas_rules.warn): """Create a Pipeline API object from a blaze expression. @@ -503,6 +491,8 @@ def from_blaze(expr, resources : dict or any, optional The data to execute the blaze expressions against. This is used as the scope for ``bz.compute``. + odo_kwargs : dict, optional + The keyword arguments to pass to odo when evaluating the expressions. no_deltas_rule : no_deltas_rule What should happen if ``deltas='auto'`` but no deltas can be found. 'warn' says to raise a warning but continue. @@ -595,9 +585,11 @@ def from_blaze(expr, ds = new_dataset(dataset_expr, deltas) # Register our new dataset with the loader. (loader if loader is not None else global_loader)[ds] = ExprData( - dataset_expr, - deltas, - resources, + bind_expression_to_resources(dataset_expr, resources), + bind_expression_to_resources(deltas, resources) + if deltas is not None else + None, + odo_kwargs=odo_kwargs, ) if single_column is not None: # We were passed a single column, extract and return it. @@ -838,7 +830,7 @@ class BlazeLoader(dict): except ValueError: raise AssertionError('all columns must come from the same dataset') - expr, deltas, resources = self[dataset] + expr, deltas, odo_kwargs = self[dataset] have_sids = SID_FIELD_NAME in expr.fields asset_idx = pd.Series(index=assets, data=np.arange(len(assets))) assets = list(map(int, assets)) # coerce from numpy.int64 @@ -900,7 +892,7 @@ class BlazeLoader(dict): (e[TS_FIELD_NAME] <= upper_dt) ][added_query_fields + [colname]] - def collect_expr(e, _kwargs={'d': resources} if resources else {}): + def collect_expr(e): """Execute and merge all of the per-column subqueries. Parameters @@ -921,7 +913,7 @@ class BlazeLoader(dict): return reduce( partial(pd.merge, on=added_query_fields, how='outer'), ( - odo(where(e, column), pd.DataFrame, **_kwargs) + odo(where(e, column), pd.DataFrame, **odo_kwargs) for column in columns ), ).sort(TS_FIELD_NAME) # sort for the groupby later