From 0a3a2f365328c48eca7ce6b656f78fa64ec94e5c Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Thu, 7 Apr 2016 13:15:51 -0400 Subject: [PATCH] BUG: Ensure matched input length to minute writer. When the dts and length of cols are mismatched the writer behaves in unintended ways. e.g. in a case where a consumer passed dts which had minutes with no trades removed, but regular (market minute for day) sized arrays for the data with `0`'s on minutes without trades, the non trade minutes from cols are written to slots in the output where a trade is intended. Protect against this misuse by checking that all lengths are equal when using the `write_cols` method. Make a separate `_write_cols` method for use by both `write_cols` and `write`, since the `write` method which takes a DataFrame has the matched input length enforced by the DataFrame. --- tests/data/test_minute_bars.py | 15 ++++++++++++++ zipline/data/minute_bars.py | 37 +++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/data/test_minute_bars.py b/tests/data/test_minute_bars.py index e13c2847..e3c1a954 100644 --- a/tests/data/test_minute_bars.py +++ b/tests/data/test_minute_bars.py @@ -42,6 +42,7 @@ from zipline.data.minute_bars import ( BcolzMinuteBarReader, BcolzMinuteOverlappingData, US_EQUITIES_MINUTES_PER_DAY, + BcolzMinuteWriterColumnMismatch ) from zipline.finance.trading import TradingEnvironment @@ -594,6 +595,20 @@ class BcolzMinuteBarTestCase(TestCase): self.assertEquals(51.0, volume_price) + def test_write_cols_mismatch_length(self): + dts = date_range(self.market_opens[self.test_calendar_start], + periods=2, freq='min').asi8.astype('datetime64[s]') + sid = 1 + cols = { + 'open': array([10.0, 11.0, 12.0]), + 'high': array([20.0, 21.0]), + 'low': array([30.0, 31.0, 33.0, 34.0]), + 'close': array([40.0, 41.0]), + 'volume': array([50.0, 51.0, 52.0]) + } + with self.assertRaises(BcolzMinuteWriterColumnMismatch): + self.writer.write_cols(sid, dts, cols) + def test_unadjusted_minutes(self): """ Test unadjusted minutes. diff --git a/zipline/data/minute_bars.py b/zipline/data/minute_bars.py index 993b61a7..446b69c5 100644 --- a/zipline/data/minute_bars.py +++ b/zipline/data/minute_bars.py @@ -43,6 +43,10 @@ class BcolzMinuteOverlappingData(Exception): pass +class BcolzMinuteWriterColumnMismatch(Exception): + pass + + def _calc_minute_index(market_opens, minutes_per_day): minutes = np.zeros(len(market_opens) * minutes_per_day, dtype='datetime64[ns]') @@ -212,6 +216,8 @@ class BcolzMinuteBarWriter(object): The datetimes which correspond to each position are written in the metadata as integer nanoseconds since the epoch into the `minute_index` key. """ + COL_NAMES = ('open', 'high', 'low', 'close', 'volume') + def __init__(self, first_trading_day, rootdir, @@ -465,7 +471,9 @@ class BcolzMinuteBarWriter(object): 'volume': df.volume.values, } dts = df.index.values - self.write_cols(sid, dts, cols) + # Call internal method, since DataFrame has already ensured matching + # index and value lengths. + self._write_cols(sid, dts, cols) def write_cols(self, sid, dts, cols): """ @@ -474,6 +482,33 @@ class BcolzMinuteBarWriter(object): If the length of the bcolz ctable is not exactly to the date before the first day provided, fill the ctable with 0s up to that date. Writes in blocks of the size of the days times minutes per day. + Parameters: + ----------- + sid : int + The asset identifier for the data being written. + dts : datetime64 array + The dts corresponding to values in cols. + cols : dict of str -> np.array + dict of market data with the following characteristics. + keys are ('open', 'high', 'low', 'close', 'volume') + open : float64 + high : float64 + low : float64 + close : float64 + volume : float64|int64 + """ + if not all(len(dts) == len(cols[name]) for name in self.COL_NAMES): + raise BcolzMinuteWriterColumnMismatch( + "Length of dts={0} should match cols: {1}".format( + len(dts), + " ".join("{0}={1}".format(name, len(cols[name])) + for name in self.COL_NAMES))) + self._write_cols(sid, dts, cols) + + def _write_cols(self, sid, dts, cols): + """ + Internal method for `write_cols` and `write`. + Parameters: ----------- sid : int