mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 13:38:52 +08:00
ENH: Adds truncate method to BcolzMinuteBarWriter (#1499)
This commit is contained in:
@@ -33,6 +33,7 @@ from pandas import (
|
||||
Timedelta,
|
||||
NaT,
|
||||
date_range,
|
||||
isnull,
|
||||
)
|
||||
|
||||
from zipline.data.bar_reader import NoDataOnDate
|
||||
@@ -943,3 +944,151 @@ class BcolzMinuteBarTestCase(WithTradingCalendars,
|
||||
# Read the attributes
|
||||
for k, v in attrs.items():
|
||||
self.assertEqual(self.reader.get_sid_attr(sid, k), v)
|
||||
|
||||
def test_truncate_between_data_points(self):
|
||||
|
||||
tds = self.market_opens.index
|
||||
days = tds[tds.slice_indexer(
|
||||
start=self.test_calendar_start + 1,
|
||||
end=self.test_calendar_start + 3
|
||||
)]
|
||||
minutes = DatetimeIndex([
|
||||
self.market_opens[days[0]] + timedelta(minutes=60),
|
||||
self.market_opens[days[1]] + timedelta(minutes=120),
|
||||
])
|
||||
sid = 1
|
||||
data = DataFrame(
|
||||
data={
|
||||
'open': [10.0, 11.0],
|
||||
'high': [20.0, 21.0],
|
||||
'low': [30.0, 31.0],
|
||||
'close': [40.0, 41.0],
|
||||
'volume': [50.0, 51.0]
|
||||
},
|
||||
index=minutes)
|
||||
self.writer.write_sid(sid, data)
|
||||
|
||||
# Truncate to first day with data.
|
||||
self.writer.truncate(days[0])
|
||||
|
||||
self.assertEqual(self.writer.last_date_in_output_for_sid(sid), days[0])
|
||||
|
||||
minute = minutes[0]
|
||||
|
||||
open_price = self.reader.get_value(sid, minute, 'open')
|
||||
|
||||
self.assertEquals(10.0, open_price)
|
||||
|
||||
high_price = self.reader.get_value(sid, minute, 'high')
|
||||
|
||||
self.assertEquals(20.0, high_price)
|
||||
|
||||
low_price = self.reader.get_value(sid, minute, 'low')
|
||||
|
||||
self.assertEquals(30.0, low_price)
|
||||
|
||||
close_price = self.reader.get_value(sid, minute, 'close')
|
||||
|
||||
self.assertEquals(40.0, close_price)
|
||||
|
||||
volume_price = self.reader.get_value(sid, minute, 'volume')
|
||||
|
||||
self.assertEquals(50.0, volume_price)
|
||||
|
||||
minute = minutes[1]
|
||||
|
||||
open_price = self.reader.get_value(sid, minute, 'open')
|
||||
|
||||
self.assertTrue(isnull(open_price))
|
||||
|
||||
high_price = self.reader.get_value(sid, minute, 'high')
|
||||
|
||||
self.assertTrue(isnull(high_price))
|
||||
|
||||
low_price = self.reader.get_value(sid, minute, 'low')
|
||||
|
||||
self.assertTrue(isnull(low_price))
|
||||
|
||||
close_price = self.reader.get_value(sid, minute, 'close')
|
||||
|
||||
self.assertTrue(isnull(close_price))
|
||||
|
||||
volume_price = self.reader.get_value(sid, minute, 'volume')
|
||||
|
||||
self.assertEqual(0.0, volume_price)
|
||||
|
||||
def test_truncate_all_data_points(self):
|
||||
|
||||
tds = self.market_opens.index
|
||||
days = tds[tds.slice_indexer(
|
||||
start=self.test_calendar_start + 1,
|
||||
end=self.test_calendar_start + 3
|
||||
)]
|
||||
minutes = DatetimeIndex([
|
||||
self.market_opens[days[0]] + timedelta(minutes=60),
|
||||
self.market_opens[days[1]] + timedelta(minutes=120),
|
||||
])
|
||||
sid = 1
|
||||
data = DataFrame(
|
||||
data={
|
||||
'open': [10.0, 11.0],
|
||||
'high': [20.0, 21.0],
|
||||
'low': [30.0, 31.0],
|
||||
'close': [40.0, 41.0],
|
||||
'volume': [50.0, 51.0]
|
||||
},
|
||||
index=minutes)
|
||||
self.writer.write_sid(sid, data)
|
||||
|
||||
# Truncate to first day in the calendar, a day before the first
|
||||
# day with minute data.
|
||||
self.writer.truncate(self.test_calendar_start)
|
||||
|
||||
self.assertEqual(
|
||||
self.writer.last_date_in_output_for_sid(sid),
|
||||
self.test_calendar_start,
|
||||
)
|
||||
|
||||
minute = minutes[0]
|
||||
|
||||
open_price = self.reader.get_value(sid, minute, 'open')
|
||||
|
||||
self.assertTrue(isnull(open_price))
|
||||
|
||||
high_price = self.reader.get_value(sid, minute, 'high')
|
||||
|
||||
self.assertTrue(isnull(high_price))
|
||||
|
||||
low_price = self.reader.get_value(sid, minute, 'low')
|
||||
|
||||
self.assertTrue(isnull(low_price))
|
||||
|
||||
close_price = self.reader.get_value(sid, minute, 'close')
|
||||
|
||||
self.assertTrue(isnull(close_price))
|
||||
|
||||
volume_price = self.reader.get_value(sid, minute, 'volume')
|
||||
|
||||
self.assertEquals(0.0, volume_price)
|
||||
|
||||
minute = minutes[1]
|
||||
|
||||
open_price = self.reader.get_value(sid, minute, 'open')
|
||||
|
||||
self.assertTrue(isnull(open_price))
|
||||
|
||||
high_price = self.reader.get_value(sid, minute, 'high')
|
||||
|
||||
self.assertTrue(isnull(high_price))
|
||||
|
||||
low_price = self.reader.get_value(sid, minute, 'low')
|
||||
|
||||
self.assertTrue(isnull(low_price))
|
||||
|
||||
close_price = self.reader.get_value(sid, minute, 'close')
|
||||
|
||||
self.assertTrue(isnull(close_price))
|
||||
|
||||
volume_price = self.reader.get_value(sid, minute, 'volume')
|
||||
|
||||
self.assertEqual(0.0, volume_price)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from glob import glob
|
||||
from os.path import join
|
||||
from textwrap import dedent
|
||||
|
||||
@@ -20,6 +22,7 @@ from lru import LRU
|
||||
import bcolz
|
||||
from bcolz import ctable
|
||||
from intervaltree import IntervalTree
|
||||
import logbook
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from toolz import keymap, valmap
|
||||
@@ -37,6 +40,9 @@ from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.cli import maybe_show_progress
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
|
||||
logger = logbook.Logger('MinuteBars')
|
||||
|
||||
US_EQUITIES_MINUTES_PER_DAY = 390
|
||||
FUTURES_MINUTES_PER_DAY = 1440
|
||||
|
||||
@@ -739,6 +745,56 @@ class BcolzMinuteBarWriter(object):
|
||||
])
|
||||
table.flush()
|
||||
|
||||
def data_len_for_day(self, day):
|
||||
"""
|
||||
Return the number of data points up to and including the
|
||||
provided day.
|
||||
"""
|
||||
day_ix = self._session_labels.get_loc(day)
|
||||
# Add one to the 0-indexed day_ix to get the number of days.
|
||||
num_days = day_ix + 1
|
||||
return num_days * self._minutes_per_day
|
||||
|
||||
def truncate(self, date):
|
||||
"""Truncate data beyond this date in all ctables."""
|
||||
truncate_slice_end = self.data_len_for_day(date)
|
||||
|
||||
glob_path = os.path.join(self._rootdir, "*", "*", "*.bcolz")
|
||||
sid_paths = glob(glob_path)
|
||||
|
||||
for sid_path in sid_paths:
|
||||
file_name = os.path.basename(sid_path)
|
||||
|
||||
try:
|
||||
table = bcolz.open(rootdir=sid_path)
|
||||
except IOError:
|
||||
continue
|
||||
if table.len <= truncate_slice_end:
|
||||
logger.info("{0} not past truncate date={1}.", file_name, date)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Truncting {0} back at end_date={1}", file_name, date.date()
|
||||
)
|
||||
|
||||
new_table = table[:truncate_slice_end]
|
||||
tmp_path = sid_path + '.bak'
|
||||
shutil.move(sid_path, tmp_path)
|
||||
try:
|
||||
bcolz.ctable(new_table, rootdir=sid_path)
|
||||
try:
|
||||
shutil.rmtree(tmp_path)
|
||||
except Exception as err:
|
||||
logger.info(
|
||||
"Could not delete tmp_path={0}, err={1}", tmp_path, err
|
||||
)
|
||||
except Exception as err:
|
||||
# On any ctable write error, restore the original table.
|
||||
logger.warn(
|
||||
"Could not write {0}, err={1}", file_name, err
|
||||
)
|
||||
shutil.move(tmp_path, sid_path)
|
||||
|
||||
|
||||
class BcolzMinuteBarReader(MinuteBarReader):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user