Fixed batch_transform window length not being market aware.

Added accompanying unittest.

Minor refactoring of unittests and factory.
This commit is contained in:
Thomas Wiecki
2012-11-06 12:04:01 -05:00
committed by Eddie Hebert
parent ec6ad7182c
commit 126e9fdf26
5 changed files with 62 additions and 20 deletions
+31 -4
View File
@@ -15,6 +15,7 @@
import pytz
import numpy as np
import pandas as pd
from datetime import timedelta, datetime
from unittest2 import TestCase
@@ -57,8 +58,7 @@ class NoopEventWindow(EventWindow):
self.removed.append(event)
class EventWindowTestCase(TestCase):
class TestEventWindow(TestCase):
def setUp(self):
setup_logger(self)
@@ -154,7 +154,7 @@ class EventWindowTestCase(TestCase):
setup_logger(self)
class FinanceTransformsTestCase(TestCase):
class TestFinanceTransforms(TestCase):
def setUp(self):
self.trading_environment = factory.create_trading_environment()
@@ -309,7 +309,7 @@ class FinanceTransformsTestCase(TestCase):
############################################################
# Test BatchTransform
class BatchTransformTestCase(TestCase):
class TestBatchTransform(TestCase):
def setUp(self):
setup_logger(self)
self.source, self.df = factory.create_test_df_source()
@@ -324,6 +324,9 @@ class BatchTransformTestCase(TestCase):
self.assertEqual(algo.history_return_price_decorator[:2],
[None, None],
"First two iterations should return None")
self.assertEqual(algo.history_return_price_market_aware[:2],
[None, None],
"First two iterations should return None")
# test overloaded class
for test_history in [algo.history_return_price_class,
@@ -354,3 +357,27 @@ class BatchTransformTestCase(TestCase):
algo.history_return_args,
[None, None, expected_item, expected_item,
expected_item, expected_item])
class TestBatchTransformMarketAware(TestCase):
def setUp(self):
setup_logger(self)
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1994, 1, 1, 0, 0, 0, 0, pytz.utc)
self.data = factory.load_from_yahoo(stocks=['AAPL'],
indexes={},
start=start, end=end)
def test_event_window(self):
days = 50
algo = BatchTransformAlgorithm(days=days, refresh_period=days)
algo.run(self.data)
self.assertEqual(algo.history_return_price_market_aware[:days],
[None] * days,
"First {days} iterations should return None"
.format(days=days))
self.assertFalse(algo.history_return_price_market_aware[days + 1]
is None,
"Window is contains too many Nones.")
+1 -1
View File
@@ -117,7 +117,7 @@ if __name__ == '__main__':
pairtrade = Pairtrade()
results = pairtrade.run(data)
data['spreads'] = np.nan
data.spreads[70:] = pairtrade.spreads
data.spreads[pairtrade.window_length:] = pairtrade.spreads
ax1 = plt.subplot(211)
data[['PEP', 'KO']].plot(ax=ax1)
+18 -8
View File
@@ -246,33 +246,41 @@ def return_args_batch_decorator(data, *args, **kwargs):
class BatchTransformAlgorithm(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.history_return_price_class = []
self.history_return_price_decorator = []
self.history_return_args = []
self.days = 3
self.refresh_period = kwargs.pop('refresh_period', 2)
self.days = kwargs.pop('days', 3)
self.args = args
self.kwargs = kwargs
self.history_return_price_class = []
self.history_return_price_decorator = []
self.history_return_args = []
self.history_return_price_market_aware = []
self.return_price_class = ReturnPriceBatchTransform(
market_aware=False,
refresh_period=2,
refresh_period=self.refresh_period,
delta=timedelta(days=self.days)
)
self.return_price_decorator = return_price_batch_decorator(
market_aware=False,
refresh_period=2,
refresh_period=self.refresh_period,
delta=timedelta(days=self.days)
)
self.return_args_batch = return_args_batch_decorator(
market_aware=False,
refresh_period=2,
refresh_period=self.refresh_period,
delta=timedelta(days=self.days)
)
self.return_price_market_aware = ReturnPriceBatchTransform(
market_aware=True,
refresh_period=self.refresh_period,
days=self.days
)
self.set_slippage(FixedSlippage())
def handle_data(self, data):
@@ -283,6 +291,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_return_args.append(
self.return_args_batch.handle_data(
data, *self.args, **self.kwargs))
self.history_return_price_market_aware.append(
self.return_price_market_aware.handle_data(data))
class SetPortfolioAlgorithm(TradingAlgorithm):
+10 -6
View File
@@ -360,9 +360,10 @@ class BatchTransform(EventWindow):
self.refresh_period = refresh_period
self.days = days
self.trading_days_since_update = 0
self.full = False
self.last_refresh = None
self.last_dt = None
self.updated = False
self.data = None
@@ -392,12 +393,15 @@ class BatchTransform(EventWindow):
return self.get_transform_value(*args, **kwargs)
def handle_add(self, event):
if not self.last_refresh:
self.last_refresh = event.dt
if not self.last_dt:
self.last_dt = event.dt
return
age = event.dt - self.last_refresh
if age.days >= self.refresh_period:
if self.last_dt.day != event.dt.day:
self.last_dt = event.dt
self.trading_days_since_update += 1
if self.trading_days_since_update >= self.refresh_period:
# Create a pandas.Panel (i.e. 3d DataFrame) from the
# events in the current window.
#
@@ -432,7 +436,7 @@ class BatchTransform(EventWindow):
self.data = pd.Panel.from_dict(fields, orient='items')
self.updated = True
self.last_refresh = event.dt
self.trading_days_since_update = 0
else:
self.updated = False
+2 -1
View File
@@ -274,7 +274,8 @@ def create_test_df_source():
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
x = np.arange(2., 14.).reshape((6, 2))
x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2))
df = pd.DataFrame(x, index=index, columns=[0, 1])
return DataFrameSource(df), df