mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 03:19:03 +08:00
Merge pull request #1116 from quantopian/ensure-matched-write-col-input
BUG: Ensure matched input length to minute writer.
This commit is contained in:
@@ -42,6 +42,7 @@ from zipline.data.minute_bars import (
|
||||
BcolzMinuteBarReader,
|
||||
BcolzMinuteOverlappingData,
|
||||
US_EQUITIES_MINUTES_PER_DAY,
|
||||
BcolzMinuteWriterColumnMismatch
|
||||
)
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
|
||||
@@ -594,6 +595,20 @@ class BcolzMinuteBarTestCase(TestCase):
|
||||
|
||||
self.assertEquals(51.0, volume_price)
|
||||
|
||||
def test_write_cols_mismatch_length(self):
|
||||
dts = date_range(self.market_opens[self.test_calendar_start],
|
||||
periods=2, freq='min').asi8.astype('datetime64[s]')
|
||||
sid = 1
|
||||
cols = {
|
||||
'open': array([10.0, 11.0, 12.0]),
|
||||
'high': array([20.0, 21.0]),
|
||||
'low': array([30.0, 31.0, 33.0, 34.0]),
|
||||
'close': array([40.0, 41.0]),
|
||||
'volume': array([50.0, 51.0, 52.0])
|
||||
}
|
||||
with self.assertRaises(BcolzMinuteWriterColumnMismatch):
|
||||
self.writer.write_cols(sid, dts, cols)
|
||||
|
||||
def test_unadjusted_minutes(self):
|
||||
"""
|
||||
Test unadjusted minutes.
|
||||
|
||||
@@ -43,6 +43,10 @@ class BcolzMinuteOverlappingData(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BcolzMinuteWriterColumnMismatch(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _calc_minute_index(market_opens, minutes_per_day):
|
||||
minutes = np.zeros(len(market_opens) * minutes_per_day,
|
||||
dtype='datetime64[ns]')
|
||||
@@ -212,6 +216,8 @@ class BcolzMinuteBarWriter(object):
|
||||
The datetimes which correspond to each position are written in the metadata
|
||||
as integer nanoseconds since the epoch into the `minute_index` key.
|
||||
"""
|
||||
COL_NAMES = ('open', 'high', 'low', 'close', 'volume')
|
||||
|
||||
def __init__(self,
|
||||
first_trading_day,
|
||||
rootdir,
|
||||
@@ -465,7 +471,9 @@ class BcolzMinuteBarWriter(object):
|
||||
'volume': df.volume.values,
|
||||
}
|
||||
dts = df.index.values
|
||||
self.write_cols(sid, dts, cols)
|
||||
# Call internal method, since DataFrame has already ensured matching
|
||||
# index and value lengths.
|
||||
self._write_cols(sid, dts, cols)
|
||||
|
||||
def write_cols(self, sid, dts, cols):
|
||||
"""
|
||||
@@ -474,6 +482,33 @@ class BcolzMinuteBarWriter(object):
|
||||
If the length of the bcolz ctable is not exactly to the date before
|
||||
the first day provided, fill the ctable with 0s up to that date.
|
||||
Writes in blocks of the size of the days times minutes per day.
|
||||
Parameters:
|
||||
-----------
|
||||
sid : int
|
||||
The asset identifier for the data being written.
|
||||
dts : datetime64 array
|
||||
The dts corresponding to values in cols.
|
||||
cols : dict of str -> np.array
|
||||
dict of market data with the following characteristics.
|
||||
keys are ('open', 'high', 'low', 'close', 'volume')
|
||||
open : float64
|
||||
high : float64
|
||||
low : float64
|
||||
close : float64
|
||||
volume : float64|int64
|
||||
"""
|
||||
if not all(len(dts) == len(cols[name]) for name in self.COL_NAMES):
|
||||
raise BcolzMinuteWriterColumnMismatch(
|
||||
"Length of dts={0} should match cols: {1}".format(
|
||||
len(dts),
|
||||
" ".join("{0}={1}".format(name, len(cols[name]))
|
||||
for name in self.COL_NAMES)))
|
||||
self._write_cols(sid, dts, cols)
|
||||
|
||||
def _write_cols(self, sid, dts, cols):
|
||||
"""
|
||||
Internal method for `write_cols` and `write`.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
sid : int
|
||||
|
||||
Reference in New Issue
Block a user