BUG: Ensure matched input length to minute writer.

When the dts and length of cols are mismatched the writer behaves in
unintended ways. e.g. in a case where a consumer passed dts which had
minutes with no trades removed, but regular (market minute for day)
sized arrays for the data with `0`'s on minutes without trades, the non
trade minutes from cols are written to slots in the output where a trade
is intended.

Protect against this misuse by checking that all lengths are equal when
using the `write_cols` method.

Make a separate `_write_cols` method for use by both `write_cols` and
`write`, since the `write` method which takes a DataFrame has the
matched input length enforced by the DataFrame.
This commit is contained in:
Eddie Hebert
2016-04-07 13:15:51 -04:00
parent 005299e7e4
commit 0a3a2f3653
2 changed files with 51 additions and 1 deletions
+15
View File
@@ -42,6 +42,7 @@ from zipline.data.minute_bars import (
BcolzMinuteBarReader,
BcolzMinuteOverlappingData,
US_EQUITIES_MINUTES_PER_DAY,
BcolzMinuteWriterColumnMismatch
)
from zipline.finance.trading import TradingEnvironment
@@ -594,6 +595,20 @@ class BcolzMinuteBarTestCase(TestCase):
self.assertEquals(51.0, volume_price)
def test_write_cols_mismatch_length(self):
dts = date_range(self.market_opens[self.test_calendar_start],
periods=2, freq='min').asi8.astype('datetime64[s]')
sid = 1
cols = {
'open': array([10.0, 11.0, 12.0]),
'high': array([20.0, 21.0]),
'low': array([30.0, 31.0, 33.0, 34.0]),
'close': array([40.0, 41.0]),
'volume': array([50.0, 51.0, 52.0])
}
with self.assertRaises(BcolzMinuteWriterColumnMismatch):
self.writer.write_cols(sid, dts, cols)
def test_unadjusted_minutes(self):
"""
Test unadjusted minutes.
+36 -1
View File
@@ -43,6 +43,10 @@ class BcolzMinuteOverlappingData(Exception):
pass
class BcolzMinuteWriterColumnMismatch(Exception):
pass
def _calc_minute_index(market_opens, minutes_per_day):
minutes = np.zeros(len(market_opens) * minutes_per_day,
dtype='datetime64[ns]')
@@ -212,6 +216,8 @@ class BcolzMinuteBarWriter(object):
The datetimes which correspond to each position are written in the metadata
as integer nanoseconds since the epoch into the `minute_index` key.
"""
COL_NAMES = ('open', 'high', 'low', 'close', 'volume')
def __init__(self,
first_trading_day,
rootdir,
@@ -465,7 +471,9 @@ class BcolzMinuteBarWriter(object):
'volume': df.volume.values,
}
dts = df.index.values
self.write_cols(sid, dts, cols)
# Call internal method, since DataFrame has already ensured matching
# index and value lengths.
self._write_cols(sid, dts, cols)
def write_cols(self, sid, dts, cols):
"""
@@ -474,6 +482,33 @@ class BcolzMinuteBarWriter(object):
If the length of the bcolz ctable is not exactly to the date before
the first day provided, fill the ctable with 0s up to that date.
Writes in blocks of the size of the days times minutes per day.
Parameters:
-----------
sid : int
The asset identifier for the data being written.
dts : datetime64 array
The dts corresponding to values in cols.
cols : dict of str -> np.array
dict of market data with the following characteristics.
keys are ('open', 'high', 'low', 'close', 'volume')
open : float64
high : float64
low : float64
close : float64
volume : float64|int64
"""
if not all(len(dts) == len(cols[name]) for name in self.COL_NAMES):
raise BcolzMinuteWriterColumnMismatch(
"Length of dts={0} should match cols: {1}".format(
len(dts),
" ".join("{0}={1}".format(name, len(cols[name]))
for name in self.COL_NAMES)))
self._write_cols(sid, dts, cols)
def _write_cols(self, sid, dts, cols):
"""
Internal method for `write_cols` and `write`.
Parameters:
-----------
sid : int