mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 05:22:49 +08:00
Merge pull request #980 from quantopian/odo-kwargs
ENH: allow users to specify odo kwargs in from_blaze
This commit is contained in:
@@ -193,25 +193,7 @@ is_invalid_deltas_node = complement(flip(isinstance, valid_deltas_node_types))
|
||||
get__name__ = op.attrgetter('__name__')
|
||||
|
||||
|
||||
class _ExprRepr(object):
|
||||
"""Box for repring expressions with the str of the expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The expression to box for repring.
|
||||
"""
|
||||
__slots__ = 'expr',
|
||||
|
||||
def __init__(self, expr):
|
||||
self.expr = expr
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.expr)
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class ExprData(namedtuple('ExprData', 'expr deltas resources')):
|
||||
class ExprData(namedtuple('ExprData', 'expr deltas odo_kwargs')):
|
||||
"""A pair of expressions and data resources. The expresions will be
|
||||
computed using the resources as the starting scope.
|
||||
|
||||
@@ -221,20 +203,25 @@ class ExprData(namedtuple('ExprData', 'expr deltas resources')):
|
||||
The baseline values.
|
||||
deltas : Expr, optional
|
||||
The deltas for the data.
|
||||
resources : resource or dict of resources, optional
|
||||
The resources to compute the exprs against.
|
||||
odo_kwargs : dict, optional
|
||||
The keyword arguments to forward to the odo calls internally.
|
||||
"""
|
||||
def __new__(cls, expr, deltas=None, resources=None):
|
||||
return super(ExprData, cls).__new__(cls, expr, deltas, resources)
|
||||
def __new__(cls, expr, deltas=None, odo_kwargs=None):
|
||||
return super(ExprData, cls).__new__(
|
||||
cls,
|
||||
expr,
|
||||
deltas,
|
||||
odo_kwargs or {},
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
# If the expressions have _resources() then the repr will
|
||||
# drive computation so we box them.
|
||||
# drive computation so we take the str here.
|
||||
cls = type(self)
|
||||
return super(ExprData, cls).__repr__(cls(
|
||||
_ExprRepr(self.expr),
|
||||
_ExprRepr(self.deltas),
|
||||
self.resources,
|
||||
str(self.expr),
|
||||
str(self.deltas),
|
||||
self.odo_kwargs,
|
||||
))
|
||||
|
||||
|
||||
@@ -484,6 +471,7 @@ def from_blaze(expr,
|
||||
deltas='auto',
|
||||
loader=None,
|
||||
resources=None,
|
||||
odo_kwargs=None,
|
||||
no_deltas_rule=no_deltas_rules.warn):
|
||||
"""Create a Pipeline API object from a blaze expression.
|
||||
|
||||
@@ -503,6 +491,8 @@ def from_blaze(expr,
|
||||
resources : dict or any, optional
|
||||
The data to execute the blaze expressions against. This is used as the
|
||||
scope for ``bz.compute``.
|
||||
odo_kwargs : dict, optional
|
||||
The keyword arguments to pass to odo when evaluating the expressions.
|
||||
no_deltas_rule : no_deltas_rule
|
||||
What should happen if ``deltas='auto'`` but no deltas can be found.
|
||||
'warn' says to raise a warning but continue.
|
||||
@@ -595,9 +585,11 @@ def from_blaze(expr,
|
||||
ds = new_dataset(dataset_expr, deltas)
|
||||
# Register our new dataset with the loader.
|
||||
(loader if loader is not None else global_loader)[ds] = ExprData(
|
||||
dataset_expr,
|
||||
deltas,
|
||||
resources,
|
||||
bind_expression_to_resources(dataset_expr, resources),
|
||||
bind_expression_to_resources(deltas, resources)
|
||||
if deltas is not None else
|
||||
None,
|
||||
odo_kwargs=odo_kwargs,
|
||||
)
|
||||
if single_column is not None:
|
||||
# We were passed a single column, extract and return it.
|
||||
@@ -838,7 +830,7 @@ class BlazeLoader(dict):
|
||||
except ValueError:
|
||||
raise AssertionError('all columns must come from the same dataset')
|
||||
|
||||
expr, deltas, resources = self[dataset]
|
||||
expr, deltas, odo_kwargs = self[dataset]
|
||||
have_sids = SID_FIELD_NAME in expr.fields
|
||||
asset_idx = pd.Series(index=assets, data=np.arange(len(assets)))
|
||||
assets = list(map(int, assets)) # coerce from numpy.int64
|
||||
@@ -900,7 +892,7 @@ class BlazeLoader(dict):
|
||||
(e[TS_FIELD_NAME] <= upper_dt)
|
||||
][added_query_fields + [colname]]
|
||||
|
||||
def collect_expr(e, _kwargs={'d': resources} if resources else {}):
|
||||
def collect_expr(e):
|
||||
"""Execute and merge all of the per-column subqueries.
|
||||
|
||||
Parameters
|
||||
@@ -921,7 +913,7 @@ class BlazeLoader(dict):
|
||||
return reduce(
|
||||
partial(pd.merge, on=added_query_fields, how='outer'),
|
||||
(
|
||||
odo(where(e, column), pd.DataFrame, **_kwargs)
|
||||
odo(where(e, column), pd.DataFrame, **odo_kwargs)
|
||||
for column in columns
|
||||
),
|
||||
).sort(TS_FIELD_NAME) # sort for the groupby later
|
||||
|
||||
Reference in New Issue
Block a user