TST: tests removing of expired data and removes ffill in DataPanelSource

This commit is contained in:
warren-oneill
2015-11-16 16:38:11 +01:00
parent abbda17239
commit 987d6d4e48
5 changed files with 68 additions and 6 deletions
+40
View File
@@ -59,6 +59,7 @@ from zipline.test_algorithms import (
TestTargetAlgorithm,
TestTargetPercentAlgorithm,
TestTargetValueAlgorithm,
TestRemoveDataAlgo,
SetLongOnlyAlgorithm,
SetAssetDateBoundsAlgorithm,
SetMaxPositionSizeAlgorithm,
@@ -1965,3 +1966,42 @@ class TestTradingAlgorithm(TestCase):
analyze=analyze)
results = algo.run(self.panel)
self.assertIs(results, self.perf_ref)
class TestRemoveData(TestCase):
"""
tests if futures data is removed after expiry
"""
def setUp(self):
dt = pd.Timestamp('2015-01-01', tz='UTC')
metadata = {0: {'symbol': 'X',
'expiration_date': dt + timedelta(days=5),
'end_date': dt + timedelta(days=5)},
1: {'symbol': 'Y',
'expiration_date': dt + timedelta(days=7),
'end_date': dt + timedelta(days=7)}}
env = TradingEnvironment()
env.write_data(futures_data=metadata)
index_x = pd.date_range(dt, periods=5)
data_x = pd.DataFrame([[1, 100], [2, 100], [3, 100], [4, 100],
[5, 100]],
index=index_x, columns=['price', 'volume'])
index_y = index_x.shift(2)
data_y = pd.DataFrame([[6, 100], [7, 100], [8, 100], [9, 100],
[10, 100]],
index=index_y, columns=['price', 'volume'])
pan = pd.Panel({0: data_x, 1: data_y})
self.source = DataPanelSource(pan)
self.algo = TestRemoveDataAlgo(env=env)
def test_remove_data(self):
self.algo.run(self.source)
expected_length = [1, 2, 2, 2, 2, 1]
# initially only data for X should be sent and on the last day only
# data for Y should be sent since X is expired
for i, length in enumerate(self.algo.data):
self.assertEqual(expected_length[i], length, i)
+6 -3
View File
@@ -123,9 +123,12 @@ class TestDataFrameSource(TestCase):
self.assertEqual(5, event.sid)
event = next(source)
self.assertEqual(4, event.sid)
event = next(source)
self.assertEqual(5, event.sid)
self.assertFalse(np.isnan(event.price))
try:
x = False
event = next(source)
except StopIteration:
x = True
self.assertTrue(x)
class TestRandomWalkSource(TestCase):
+11
View File
@@ -26,6 +26,7 @@ from zipline.protocol import (
SIDData,
DATASOURCE_TYPE
)
from zipline.errors import SidNotFound
log = Logger('Trade Simulation')
@@ -100,6 +101,16 @@ class AlgorithmSimulator(object):
self.simulation_dt = date
self.on_dt_changed(date)
# removing expired futures
for sid in self.current_data.keys():
try:
if self.env.asset_finder.retrieve_asset(sid).end_date \
< self.simulation_dt:
del self.current_data[sid]
except (AttributeError, TypeError, ValueError,
SidNotFound):
continue
# If we're still in the warmup period. Use the event to
# update our universe, but don't yield any perf messages,
# and don't send a snapshot to handle_data.
+1 -3
View File
@@ -114,7 +114,6 @@ class DataPanelSource(DataSource):
# TODO is ffilling correct/necessary?
# forward fill with volumes of 0
self.data = data.fillna(value={'volume': 0})
self.data = self.data.fillna(method='ffill')
# Unpack config dictionary with default values.
self.start = kwargs.get('start', self.data.major_axis[0])
self.end = kwargs.get('end', self.data.major_axis[-1])
@@ -153,8 +152,7 @@ class DataPanelSource(DataSource):
df = self.data.major_xs(dt)
for sid, series in df.iteritems():
# Skip SIDs that can not be forward filled
if np.isnan(series['price']) and \
sid not in self.started_sids:
if np.isnan(series['price']):
continue
self.started_sids.add(sid)
+10
View File
@@ -937,6 +937,16 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
style=style)
class TestRemoveDataAlgo(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.data = np.zeros(6)
self.i = 0
def handle_data(self, data):
self.data[self.i] = len(data)
self.i += 1
##############################
# Quantopian style algorithms