From caebdf7cfcd5ba3f7c89c70d330647e7e605fca4 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Mon, 20 Jun 2016 13:35:07 -0400 Subject: [PATCH] MAINT: shuffle the complex expression checks --- tests/pipeline/test_blaze.py | 48 +++++++++++----- zipline/pipeline/loaders/blaze/core.py | 80 ++++++++++++++++++-------- 2 files changed, 91 insertions(+), 37 deletions(-) diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index 29885a30..d54f9476 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -680,49 +680,71 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): ) def test_complex_expr(self): - expr = bz.data(self.df, dshape=self.dshape) + expr = bz.data(self.df, dshape=self.dshape, name='expr') # put an Add in the table expr_with_add = bz.transform(expr, value=expr.value + 1) - # Test that we can have complex expressions with no deltas + # test that we can have complex expressions with no metadata from_blaze( expr_with_add, deltas=None, + checkpoints=None, loader=self.garbage_loader, missing_values=self.missing_values, no_checkpoints_rule='ignore', ) - with self.assertRaises(TypeError): + with self.assertRaises(TypeError) as e: + # test that we cannot create a single column from a non field from_blaze( expr.value + 1, # put an Add in the column deltas=None, + checkpoints=None, loader=self.garbage_loader, missing_values=self.missing_values, no_checkpoints_rule='ignore', ) + assert_equal( + str(e.exception), + "expression 'expr.value + 1' was array-like but not a simple field" + " of some larger table", + ) deltas = bz.data( pd.DataFrame(columns=self.df.columns), dshape=self.dshape, + name='deltas', + ) + checkpoints = bz.data( + pd.DataFrame(columns=self.df.columns), + dshape=self.dshape, + name='checkpoints', ) - with self.assertRaises(TypeError): - from_blaze( - expr_with_add, - deltas=deltas, - loader=self.garbage_loader, - missing_values=self.missing_values, - no_checkpoints_rule='ignore', - ) - with self.assertRaises(TypeError): + # test that we can have complex expressions with explicit metadata + from_blaze( + expr_with_add, + deltas=deltas, + checkpoints=checkpoints, + loader=self.garbage_loader, + missing_values=self.missing_values, + ) + + with self.assertRaises(TypeError) as e: + # test that we cannot create a single column from a non field + # even with explicit metadata from_blaze( expr.value + 1, deltas=deltas, + checkpoints=checkpoints, loader=self.garbage_loader, missing_values=self.missing_values, - no_checkpoints_rule='ignore', ) + assert_equal( + str(e.exception), + "expression 'expr.value + 1' was array-like but not a simple field" + " of some larger table", + ) def _test_id(self, df, dshape, expected, finder, add): expr = bz.data(df, name='expr', dshape=dshape) diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 47cf5967..6a9ec792 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -482,7 +482,27 @@ def _get_metadata(field, expr, metadata_expr, no_metadata_rule): return None -def _ensure_timestamp_field(dataset_expr, deltas): +def _ad_as_ts(expr): + """Duplicate the asof_date column as the timestamp column. + + Parameters + ---------- + expr : Expr or None + The expression to change the columns of. + + Returns + ------- + transformed : Expr or None + The transformed expression or None if ``expr`` is None. + """ + return ( + None + if expr is None else + bz.transform(expr, **{TS_FIELD_NAME: expr[AD_FIELD_NAME]}) + ) + + +def _ensure_timestamp_field(dataset_expr, deltas, checkpoints): """Verify that the baseline and deltas expressions have a timestamp field. If there is not a ``TS_FIELD_NAME`` on either of the expressions, it will @@ -495,6 +515,8 @@ def _ensure_timestamp_field(dataset_expr, deltas): The baseline expression. deltas : Expr or None The deltas expression if any was provided. + checkpoints : Expr or None + The checkpoints expression if any was provided. Returns ------- @@ -507,15 +529,12 @@ def _ensure_timestamp_field(dataset_expr, deltas): dataset_expr, **{TS_FIELD_NAME: dataset_expr[AD_FIELD_NAME]} ) - if deltas is not None: - deltas = bz.transform( - deltas, - **{TS_FIELD_NAME: deltas[AD_FIELD_NAME]} - ) + deltas = _ad_as_ts(deltas) + checkpoints = _ad_as_ts(checkpoints) else: _check_datetime_field(TS_FIELD_NAME, measure) - return dataset_expr, deltas + return dataset_expr, deltas, checkpoints @expect_element( @@ -580,6 +599,22 @@ def from_blaze(expr, is passed, a ``BoundColumn`` on the dataset that would be constructed from passing the parent is returned. """ + if 'auto' in {deltas, checkpoints}: + invalid_nodes = tuple(filter(is_invalid_deltas_node, expr._subterms())) + if invalid_nodes: + raise TypeError( + 'expression with auto %s may only contain (%s) nodes,' + " found: %s" % ( + ' or '.join( + ['deltas'] if deltas is not None else [] + + ['checkpoints'] if checkpoints is not None else [], + ), + ', '.join(map(get__name__, valid_deltas_node_types)), + ', '.join( + set(map(compose(get__name__, type), invalid_nodes)), + ), + ), + ) deltas = _get_metadata( 'deltas', expr, @@ -592,22 +627,6 @@ def from_blaze(expr, checkpoints, no_checkpoints_rule, ) - if 'auto' in {deltas, checkpoints}: - invalid_nodes = tuple(filter(is_invalid_deltas_node, expr._subterms())) - if invalid_nodes: - raise TypeError( - 'expression with %s may only contain (%s) nodes,' - " found: %s" % ( - ' or '.join( - ['deltas'] if deltas is not None else [] + - ['checkpoints'] if checkpoints is not None else [], - ), - ', '.join(map(get__name__, valid_deltas_node_types)), - ', '.join( - set(map(compose(get__name__, type), invalid_nodes)), - ), - ), - ) # Check if this is a single column out of a dataset. if bz.ndim(expr) != 1: @@ -653,7 +672,11 @@ def from_blaze(expr, ), ) _check_datetime_field(AD_FIELD_NAME, measure) - dataset_expr, deltas = _ensure_timestamp_field(dataset_expr, deltas) + dataset_expr, deltas, checkpoints = _ensure_timestamp_field( + dataset_expr, + deltas, + checkpoints, + ) if deltas is not None and (sorted(deltas.dshape.measure.fields) != sorted(measure.fields)): @@ -663,6 +686,15 @@ def from_blaze(expr, deltas.dshape.measure, ), ) + if (checkpoints is not None and + (sorted(checkpoints.dshape.measure.fields) != + sorted(measure.fields))): + raise TypeError( + 'baseline measure != checkpoints measure:\n%s != %s' % ( + measure, + deltas.dshape.measure, + ), + ) # Ensure that we have a data resource to execute the query against. _check_resources('expr', dataset_expr, resources)