mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 22:18:58 +08:00
TST: tests removing of expired data and removes ffill in DataPanelSource
This commit is contained in:
@@ -59,6 +59,7 @@ from zipline.test_algorithms import (
|
||||
TestTargetAlgorithm,
|
||||
TestTargetPercentAlgorithm,
|
||||
TestTargetValueAlgorithm,
|
||||
TestRemoveDataAlgo,
|
||||
SetLongOnlyAlgorithm,
|
||||
SetAssetDateBoundsAlgorithm,
|
||||
SetMaxPositionSizeAlgorithm,
|
||||
@@ -1965,3 +1966,42 @@ class TestTradingAlgorithm(TestCase):
|
||||
analyze=analyze)
|
||||
results = algo.run(self.panel)
|
||||
self.assertIs(results, self.perf_ref)
|
||||
|
||||
|
||||
class TestRemoveData(TestCase):
|
||||
"""
|
||||
tests if futures data is removed after expiry
|
||||
"""
|
||||
def setUp(self):
|
||||
dt = pd.Timestamp('2015-01-01', tz='UTC')
|
||||
metadata = {0: {'symbol': 'X',
|
||||
'expiration_date': dt + timedelta(days=5),
|
||||
'end_date': dt + timedelta(days=5)},
|
||||
1: {'symbol': 'Y',
|
||||
'expiration_date': dt + timedelta(days=7),
|
||||
'end_date': dt + timedelta(days=7)}}
|
||||
|
||||
env = TradingEnvironment()
|
||||
env.write_data(futures_data=metadata)
|
||||
|
||||
index_x = pd.date_range(dt, periods=5)
|
||||
data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
|
||||
[5, 100]],
|
||||
index=index_x, columns=['price', 'volume'])
|
||||
index_y = index_x.shift(2)
|
||||
data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
|
||||
[10, 100]],
|
||||
index=index_y, columns=['price', 'volume'])
|
||||
|
||||
pan = pd.Panel({0: data_x, 1: data_y})
|
||||
self.source = DataPanelSource(pan)
|
||||
self.algo = TestRemoveDataAlgo(env=env)
|
||||
|
||||
def test_remove_data(self):
|
||||
self.algo.run(self.source)
|
||||
|
||||
expected_length = [1, 2, 2, 2, 2, 1]
|
||||
# initially only data for X should be sent and on the last day only
|
||||
# data for Y should be sent since X is expired
|
||||
for i, length in enumerate(self.algo.data):
|
||||
self.assertEqual(expected_length[i], length, i)
|
||||
|
||||
@@ -123,9 +123,12 @@ class TestDataFrameSource(TestCase):
|
||||
self.assertEqual(5, event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual(4, event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual(5, event.sid)
|
||||
self.assertFalse(np.isnan(event.price))
|
||||
try:
|
||||
x = False
|
||||
event = next(source)
|
||||
except StopIteration:
|
||||
x = True
|
||||
self.assertTrue(x)
|
||||
|
||||
|
||||
class TestRandomWalkSource(TestCase):
|
||||
|
||||
@@ -26,6 +26,7 @@ from zipline.protocol import (
|
||||
SIDData,
|
||||
DATASOURCE_TYPE
|
||||
)
|
||||
from zipline.errors import SidNotFound
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
@@ -100,6 +101,16 @@ class AlgorithmSimulator(object):
|
||||
self.simulation_dt = date
|
||||
self.on_dt_changed(date)
|
||||
|
||||
# removing expired futures
|
||||
for sid in self.current_data.keys():
|
||||
try:
|
||||
if self.env.asset_finder.retrieve_asset(sid).end_date \
|
||||
< self.simulation_dt:
|
||||
del self.current_data[sid]
|
||||
except (AttributeError, TypeError, ValueError,
|
||||
SidNotFound):
|
||||
continue
|
||||
|
||||
# If we're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
# and don't send a snapshot to handle_data.
|
||||
|
||||
@@ -114,7 +114,6 @@ class DataPanelSource(DataSource):
|
||||
# TODO is ffilling correct/necessary?
|
||||
# forward fill with volumes of 0
|
||||
self.data = data.fillna(value={'volume': 0})
|
||||
self.data = self.data.fillna(method='ffill')
|
||||
# Unpack config dictionary with default values.
|
||||
self.start = kwargs.get('start', self.data.major_axis[0])
|
||||
self.end = kwargs.get('end', self.data.major_axis[-1])
|
||||
@@ -153,8 +152,7 @@ class DataPanelSource(DataSource):
|
||||
df = self.data.major_xs(dt)
|
||||
for sid, series in df.iteritems():
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(series['price']) and \
|
||||
sid not in self.started_sids:
|
||||
if np.isnan(series['price']):
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
|
||||
|
||||
@@ -937,6 +937,16 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
|
||||
style=style)
|
||||
|
||||
|
||||
class TestRemoveDataAlgo(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.data = np.zeros(6)
|
||||
self.i = 0
|
||||
|
||||
def handle_data(self, data):
|
||||
self.data[self.i] = len(data)
|
||||
self.i += 1
|
||||
|
||||
|
||||
##############################
|
||||
# Quantopian style algorithms
|
||||
|
||||
|
||||
Reference in New Issue
Block a user