Fixed an issue with reader array size

This commit is contained in:
fredfortier
2017-10-18 18:33:37 -04:00
parent 339fa21c35
commit 874a4bb682
2 changed files with 17 additions and 9 deletions
+1 -1
View File
@@ -17,7 +17,6 @@ from catalyst.exchange.exchange_utils import get_exchange_bundles_folder
from catalyst.utils.deprecate import deprecated
from catalyst.utils.paths import data_path
EXCHANGE_NAMES = ['bitfinex', 'bittrex', 'poloniex']
API_URL = 'http://data.enigma.co/api/v1'
@@ -198,6 +197,7 @@ def get_ffill_candles(candles, bar_count, end_dt, data_frequency,
start_dt = get_start_dt(end_dt, bar_count, data_frequency)
date = start_dt
# TODO: this works well with a small number of candles, consider using numpy as needed
while date <= end_dt:
candle = next((
candle for candle in candles if candle['last_traded'] == date
+16 -8
View File
@@ -3,6 +3,7 @@ import numpy as np
from catalyst import get_calendar
from catalyst.data.minute_bars import BcolzMinuteBarReader, \
BcolzMinuteBarWriter
from catalyst.exchange.bundle_utils import get_periods, get_periods_range
class BcolzExchangeBarWriter(BcolzMinuteBarWriter):
@@ -34,6 +35,10 @@ class BcolzExchangeBarReader(BcolzMinuteBarReader):
super(BcolzExchangeBarReader, self).__init__(*args, **kwargs)
@property
def data_frequency(self):
return self._data_frequency
def load_raw_arrays(self, fields, start_dt, end_dt, sids):
if self._data_frequency == 'minute':
@@ -47,24 +52,27 @@ class BcolzExchangeBarReader(BcolzMinuteBarReader):
start_idx = self._find_position_of_minute(start_dt)
end_idx = self._find_position_of_minute(end_dt)
num_days = (end_idx - start_idx + 1)
periods = get_periods_range(start_dt, end_dt, self.data_frequency)
num_days = len(periods)
shape = num_days, len(sids)
if len(fields) == 1 and fields[0] == 'volume':
fields.insert(0, 'close')
mask = None
data = []
for field in fields:
if field != 'volume':
out = np.full(shape, np.nan)
else:
out = np.zeros(shape, dtype=np.float64)
out = np.full(shape, np.nan)
for i, sid in enumerate(sids):
carray = self._open_minute_file(field, sid)
a = carray[start_idx:end_idx + 1]
where = a != 0
if mask is None:
mask = a != 0
out[:len(where), i][where] = (
a[where] * self._ohlc_ratio_inverse_for_sid(sid)
out[:len(mask), i][mask] = (
a[mask] * self._ohlc_ratio_inverse_for_sid(sid)
)
data.append(out)