diff --git a/tests/pipeline/test_events.py b/tests/pipeline/test_events.py index 50ba8a2f..0a9c8b1d 100644 --- a/tests/pipeline/test_events.py +++ b/tests/pipeline/test_events.py @@ -6,8 +6,9 @@ from unittest import TestCase import blaze as bz from nose_parameterized import parameterized +from numpy.testing import assert_array_equal import pandas as pd -from pandas.util.testing import assert_series_equal, assert_frame_equal +from pandas.util.testing import assert_series_equal from zipline.pipeline.common import ( ANNOUNCEMENT_FIELD_NAME, @@ -26,7 +27,7 @@ from zipline.pipeline.loaders.events import ( WRONG_SINGLE_COL_DATA_FORMAT_ERROR ) from zipline.utils.memoize import lazyval -from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype +from zipline.utils.numpy_utils import datetime64ns_dtype OTHER_FIELD = "other_field" @@ -38,11 +39,6 @@ class EventDataSet(DataSet): previous_announcement = Column(datetime64ns_dtype) -class OtherFieldEventDataSet(DataSet): - previous_announcement = Column(datetime64ns_dtype) - previous_other_field = Column(float64_dtype) - - class EventDataSetLoader(EventsLoader): expected_cols = frozenset([ANNOUNCEMENT_FIELD_NAME]) @@ -67,29 +63,6 @@ class EventDataSetLoader(EventsLoader): ) -class EventDataSetLoaderMultipleExpectedCols(EventDataSetLoader): - expected_cols = frozenset([ANNOUNCEMENT_FIELD_NAME, OTHER_FIELD]) - event_date_col = ANNOUNCEMENT_FIELD_NAME - - def __init__(self, - all_dates, - events_by_sid, - infer_timestamps=False, - dataset=OtherFieldEventDataSet): - super(EventDataSetLoader, self).__init__( - all_dates, - events_by_sid, - infer_timestamps=infer_timestamps, - dataset=dataset, - ) - - @lazyval - def previous_other_field_loader(self): - return self._previous_event_date_loader( - self.dataset.previous_announcement, - ) - - # Test case just for catching an error when multiple columns are in the wrong # data format, so no loader defined. class EventDataSetLoaderMultipleExpectedColsNoColumnLoaders(EventsLoader): @@ -144,28 +117,6 @@ class EventLoaderTestCase(TestCase): EventDataSetLoader ) - def test_null_in_event_date_col(self): - # Tests that getting a null date in the event date column filters the - # entire row from the data. - dates_with_null = pd.Series(dtx) - dates_with_null[2] = pd.NaT - other_col_data = pd.Series(range(0, len(dtx))) - events_by_sid = {0: pd.DataFrame({ANNOUNCEMENT_FIELD_NAME: - dates_with_null, - OTHER_FIELD: other_col_data, - TS_FIELD_NAME: dtx})} - loader = EventDataSetLoaderMultipleExpectedCols( - dtx, - events_by_sid, - ) - - expected = events_by_sid[0].drop(2, axis=0).set_index(TS_FIELD_NAME) - # Check that index by first given date has been added - assert_frame_equal( - loader.events_by_sid[0], - expected, - ) - @parameterized.expand([ # DataFrame without timestamp column and infer_timestamps = True [pd.DataFrame({ANNOUNCEMENT_FIELD_NAME: dtx}), True], @@ -293,3 +244,42 @@ class BlazeEventLoaderTestCase(TestCase): SID_FIELD_NAME: 0}) ) ) + + +class BlazeEventDataSetLoader(BlazeEventsLoader): + concrete_loader = EventDataSetLoader + _expected_fields = frozenset({ANNOUNCEMENT_FIELD_NAME, + TS_FIELD_NAME, + SID_FIELD_NAME}) + + def __init__(self, + expr, + dataset=EventDataSet, + **kwargs): + super( + BlazeEventDataSetLoader, self + ).__init__(expr, + dataset=dataset, + **kwargs) + + +class BlazeEventLoaderNullInDateColumnTestCase(TestCase): + def test_null_in_event_date_col(self): + # Tests that if there is a null date in the event date column, it is + # filtered out and does not break on loading the adjusted array. + dates_with_null = pd.Series(dtx) + dates_with_null[2] = pd.NaT + events_by_sid = pd.DataFrame({SID_FIELD_NAME: 0, + ANNOUNCEMENT_FIELD_NAME: dates_with_null, + TS_FIELD_NAME: dtx}) + loader = BlazeEventDataSetLoader( + bz.data(events_by_sid), + ) + + result = loader.load_adjusted_array({ + EventDataSet.previous_announcement + }, dtx, [0], [True])[EventDataSet.previous_announcement].data[:, 0] + + expected = dates_with_null.copy(True) + expected[2] = dtx[1] + assert_array_equal(result, expected) diff --git a/zipline/pipeline/loaders/blaze/events.py b/zipline/pipeline/loaders/blaze/events.py index e951cac0..34d7f37d 100644 --- a/zipline/pipeline/loaders/blaze/events.py +++ b/zipline/pipeline/loaders/blaze/events.py @@ -77,8 +77,9 @@ class BlazeEventsLoader(PipelineLoader): ) expected_fields = self._expected_fields + expr = expr[list(expected_fields)] self._expr = bind_expression_to_resources( - expr[list(expected_fields)][expr[ + expr[expr[ self.concrete_loader.event_date_col ].notnull()], resources,