diff --git a/tests/data/test_minute_bars.py b/tests/data/test_minute_bars.py index e13c2847..e3c1a954 100644 --- a/tests/data/test_minute_bars.py +++ b/tests/data/test_minute_bars.py @@ -42,6 +42,7 @@ from zipline.data.minute_bars import ( BcolzMinuteBarReader, BcolzMinuteOverlappingData, US_EQUITIES_MINUTES_PER_DAY, + BcolzMinuteWriterColumnMismatch ) from zipline.finance.trading import TradingEnvironment @@ -594,6 +595,20 @@ class BcolzMinuteBarTestCase(TestCase): self.assertEquals(51.0, volume_price) + def test_write_cols_mismatch_length(self): + dts = date_range(self.market_opens[self.test_calendar_start], + periods=2, freq='min').asi8.astype('datetime64[s]') + sid = 1 + cols = { + 'open': array([10.0, 11.0, 12.0]), + 'high': array([20.0, 21.0]), + 'low': array([30.0, 31.0, 33.0, 34.0]), + 'close': array([40.0, 41.0]), + 'volume': array([50.0, 51.0, 52.0]) + } + with self.assertRaises(BcolzMinuteWriterColumnMismatch): + self.writer.write_cols(sid, dts, cols) + def test_unadjusted_minutes(self): """ Test unadjusted minutes. diff --git a/zipline/data/minute_bars.py b/zipline/data/minute_bars.py index 993b61a7..446b69c5 100644 --- a/zipline/data/minute_bars.py +++ b/zipline/data/minute_bars.py @@ -43,6 +43,10 @@ class BcolzMinuteOverlappingData(Exception): pass +class BcolzMinuteWriterColumnMismatch(Exception): + pass + + def _calc_minute_index(market_opens, minutes_per_day): minutes = np.zeros(len(market_opens) * minutes_per_day, dtype='datetime64[ns]') @@ -212,6 +216,8 @@ class BcolzMinuteBarWriter(object): The datetimes which correspond to each position are written in the metadata as integer nanoseconds since the epoch into the `minute_index` key. """ + COL_NAMES = ('open', 'high', 'low', 'close', 'volume') + def __init__(self, first_trading_day, rootdir, @@ -465,7 +471,9 @@ class BcolzMinuteBarWriter(object): 'volume': df.volume.values, } dts = df.index.values - self.write_cols(sid, dts, cols) + # Call internal method, since DataFrame has already ensured matching + # index and value lengths. + self._write_cols(sid, dts, cols) def write_cols(self, sid, dts, cols): """ @@ -474,6 +482,33 @@ class BcolzMinuteBarWriter(object): If the length of the bcolz ctable is not exactly to the date before the first day provided, fill the ctable with 0s up to that date. Writes in blocks of the size of the days times minutes per day. + Parameters: + ----------- + sid : int + The asset identifier for the data being written. + dts : datetime64 array + The dts corresponding to values in cols. + cols : dict of str -> np.array + dict of market data with the following characteristics. + keys are ('open', 'high', 'low', 'close', 'volume') + open : float64 + high : float64 + low : float64 + close : float64 + volume : float64|int64 + """ + if not all(len(dts) == len(cols[name]) for name in self.COL_NAMES): + raise BcolzMinuteWriterColumnMismatch( + "Length of dts={0} should match cols: {1}".format( + len(dts), + " ".join("{0}={1}".format(name, len(cols[name])) + for name in self.COL_NAMES))) + self._write_cols(sid, dts, cols) + + def _write_cols(self, sid, dts, cols): + """ + Internal method for `write_cols` and `write`. + Parameters: ----------- sid : int