mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 01:21:18 +08:00
TST: add test for blaze loader with null data in date col
MAINT: fix blaze query
This commit is contained in:
@@ -6,8 +6,9 @@ from unittest import TestCase
|
||||
|
||||
import blaze as bz
|
||||
from nose_parameterized import parameterized
|
||||
from numpy.testing import assert_array_equal
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_series_equal, assert_frame_equal
|
||||
from pandas.util.testing import assert_series_equal
|
||||
|
||||
from zipline.pipeline.common import (
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
@@ -26,7 +27,7 @@ from zipline.pipeline.loaders.events import (
|
||||
WRONG_SINGLE_COL_DATA_FORMAT_ERROR
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype
|
||||
|
||||
OTHER_FIELD = "other_field"
|
||||
|
||||
@@ -38,11 +39,6 @@ class EventDataSet(DataSet):
|
||||
previous_announcement = Column(datetime64ns_dtype)
|
||||
|
||||
|
||||
class OtherFieldEventDataSet(DataSet):
|
||||
previous_announcement = Column(datetime64ns_dtype)
|
||||
previous_other_field = Column(float64_dtype)
|
||||
|
||||
|
||||
class EventDataSetLoader(EventsLoader):
|
||||
expected_cols = frozenset([ANNOUNCEMENT_FIELD_NAME])
|
||||
|
||||
@@ -67,29 +63,6 @@ class EventDataSetLoader(EventsLoader):
|
||||
)
|
||||
|
||||
|
||||
class EventDataSetLoaderMultipleExpectedCols(EventDataSetLoader):
|
||||
expected_cols = frozenset([ANNOUNCEMENT_FIELD_NAME, OTHER_FIELD])
|
||||
event_date_col = ANNOUNCEMENT_FIELD_NAME
|
||||
|
||||
def __init__(self,
|
||||
all_dates,
|
||||
events_by_sid,
|
||||
infer_timestamps=False,
|
||||
dataset=OtherFieldEventDataSet):
|
||||
super(EventDataSetLoader, self).__init__(
|
||||
all_dates,
|
||||
events_by_sid,
|
||||
infer_timestamps=infer_timestamps,
|
||||
dataset=dataset,
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def previous_other_field_loader(self):
|
||||
return self._previous_event_date_loader(
|
||||
self.dataset.previous_announcement,
|
||||
)
|
||||
|
||||
|
||||
# Test case just for catching an error when multiple columns are in the wrong
|
||||
# data format, so no loader defined.
|
||||
class EventDataSetLoaderMultipleExpectedColsNoColumnLoaders(EventsLoader):
|
||||
@@ -144,28 +117,6 @@ class EventLoaderTestCase(TestCase):
|
||||
EventDataSetLoader
|
||||
)
|
||||
|
||||
def test_null_in_event_date_col(self):
|
||||
# Tests that getting a null date in the event date column filters the
|
||||
# entire row from the data.
|
||||
dates_with_null = pd.Series(dtx)
|
||||
dates_with_null[2] = pd.NaT
|
||||
other_col_data = pd.Series(range(0, len(dtx)))
|
||||
events_by_sid = {0: pd.DataFrame({ANNOUNCEMENT_FIELD_NAME:
|
||||
dates_with_null,
|
||||
OTHER_FIELD: other_col_data,
|
||||
TS_FIELD_NAME: dtx})}
|
||||
loader = EventDataSetLoaderMultipleExpectedCols(
|
||||
dtx,
|
||||
events_by_sid,
|
||||
)
|
||||
|
||||
expected = events_by_sid[0].drop(2, axis=0).set_index(TS_FIELD_NAME)
|
||||
# Check that index by first given date has been added
|
||||
assert_frame_equal(
|
||||
loader.events_by_sid[0],
|
||||
expected,
|
||||
)
|
||||
|
||||
@parameterized.expand([
|
||||
# DataFrame without timestamp column and infer_timestamps = True
|
||||
[pd.DataFrame({ANNOUNCEMENT_FIELD_NAME: dtx}), True],
|
||||
@@ -293,3 +244,42 @@ class BlazeEventLoaderTestCase(TestCase):
|
||||
SID_FIELD_NAME: 0})
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BlazeEventDataSetLoader(BlazeEventsLoader):
|
||||
concrete_loader = EventDataSetLoader
|
||||
_expected_fields = frozenset({ANNOUNCEMENT_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME})
|
||||
|
||||
def __init__(self,
|
||||
expr,
|
||||
dataset=EventDataSet,
|
||||
**kwargs):
|
||||
super(
|
||||
BlazeEventDataSetLoader, self
|
||||
).__init__(expr,
|
||||
dataset=dataset,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class BlazeEventLoaderNullInDateColumnTestCase(TestCase):
|
||||
def test_null_in_event_date_col(self):
|
||||
# Tests that if there is a null date in the event date column, it is
|
||||
# filtered out and does not break on loading the adjusted array.
|
||||
dates_with_null = pd.Series(dtx)
|
||||
dates_with_null[2] = pd.NaT
|
||||
events_by_sid = pd.DataFrame({SID_FIELD_NAME: 0,
|
||||
ANNOUNCEMENT_FIELD_NAME: dates_with_null,
|
||||
TS_FIELD_NAME: dtx})
|
||||
loader = BlazeEventDataSetLoader(
|
||||
bz.data(events_by_sid),
|
||||
)
|
||||
|
||||
result = loader.load_adjusted_array({
|
||||
EventDataSet.previous_announcement
|
||||
}, dtx, [0], [True])[EventDataSet.previous_announcement].data[:, 0]
|
||||
|
||||
expected = dates_with_null.copy(True)
|
||||
expected[2] = dtx[1]
|
||||
assert_array_equal(result, expected)
|
||||
|
||||
@@ -77,8 +77,9 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
)
|
||||
|
||||
expected_fields = self._expected_fields
|
||||
expr = expr[list(expected_fields)]
|
||||
self._expr = bind_expression_to_resources(
|
||||
expr[list(expected_fields)][expr[
|
||||
expr[expr[
|
||||
self.concrete_loader.event_date_col
|
||||
].notnull()],
|
||||
resources,
|
||||
|
||||
Reference in New Issue
Block a user