mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 22:40:20 +08:00
Fixed batch_transform window length not being market aware.
Added accompanying unittest. Minor refactoring of unittests and factory.
This commit is contained in:
committed by
Eddie Hebert
parent
ec6ad7182c
commit
126e9fdf26
@@ -15,6 +15,7 @@
|
||||
|
||||
import pytz
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from unittest2 import TestCase
|
||||
@@ -57,8 +58,7 @@ class NoopEventWindow(EventWindow):
|
||||
self.removed.append(event)
|
||||
|
||||
|
||||
class EventWindowTestCase(TestCase):
|
||||
|
||||
class TestEventWindow(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
|
||||
@@ -154,7 +154,7 @@ class EventWindowTestCase(TestCase):
|
||||
setup_logger(self)
|
||||
|
||||
|
||||
class FinanceTransformsTestCase(TestCase):
|
||||
class TestFinanceTransforms(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.trading_environment = factory.create_trading_environment()
|
||||
@@ -309,7 +309,7 @@ class FinanceTransformsTestCase(TestCase):
|
||||
############################################################
|
||||
# Test BatchTransform
|
||||
|
||||
class BatchTransformTestCase(TestCase):
|
||||
class TestBatchTransform(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
self.source, self.df = factory.create_test_df_source()
|
||||
@@ -324,6 +324,9 @@ class BatchTransformTestCase(TestCase):
|
||||
self.assertEqual(algo.history_return_price_decorator[:2],
|
||||
[None, None],
|
||||
"First two iterations should return None")
|
||||
self.assertEqual(algo.history_return_price_market_aware[:2],
|
||||
[None, None],
|
||||
"First two iterations should return None")
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
@@ -354,3 +357,27 @@ class BatchTransformTestCase(TestCase):
|
||||
algo.history_return_args,
|
||||
[None, None, expected_item, expected_item,
|
||||
expected_item, expected_item])
|
||||
|
||||
|
||||
class TestBatchTransformMarketAware(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = pd.datetime(1994, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
|
||||
self.data = factory.load_from_yahoo(stocks=['AAPL'],
|
||||
indexes={},
|
||||
start=start, end=end)
|
||||
|
||||
def test_event_window(self):
|
||||
days = 50
|
||||
algo = BatchTransformAlgorithm(days=days, refresh_period=days)
|
||||
algo.run(self.data)
|
||||
|
||||
self.assertEqual(algo.history_return_price_market_aware[:days],
|
||||
[None] * days,
|
||||
"First {days} iterations should return None"
|
||||
.format(days=days))
|
||||
self.assertFalse(algo.history_return_price_market_aware[days + 1]
|
||||
is None,
|
||||
"Window is contains too many Nones.")
|
||||
|
||||
@@ -117,7 +117,7 @@ if __name__ == '__main__':
|
||||
pairtrade = Pairtrade()
|
||||
results = pairtrade.run(data)
|
||||
data['spreads'] = np.nan
|
||||
data.spreads[70:] = pairtrade.spreads
|
||||
data.spreads[pairtrade.window_length:] = pairtrade.spreads
|
||||
|
||||
ax1 = plt.subplot(211)
|
||||
data[['PEP', 'KO']].plot(ax=ax1)
|
||||
|
||||
@@ -246,33 +246,41 @@ def return_args_batch_decorator(data, *args, **kwargs):
|
||||
|
||||
class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.history_return_price_class = []
|
||||
self.history_return_price_decorator = []
|
||||
self.history_return_args = []
|
||||
|
||||
self.days = 3
|
||||
self.refresh_period = kwargs.pop('refresh_period', 2)
|
||||
self.days = kwargs.pop('days', 3)
|
||||
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.history_return_price_class = []
|
||||
self.history_return_price_decorator = []
|
||||
self.history_return_args = []
|
||||
self.history_return_price_market_aware = []
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
refresh_period=self.refresh_period,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_price_decorator = return_price_batch_decorator(
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
refresh_period=self.refresh_period,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_args_batch = return_args_batch_decorator(
|
||||
market_aware=False,
|
||||
refresh_period=2,
|
||||
refresh_period=self.refresh_period,
|
||||
delta=timedelta(days=self.days)
|
||||
)
|
||||
|
||||
self.return_price_market_aware = ReturnPriceBatchTransform(
|
||||
market_aware=True,
|
||||
refresh_period=self.refresh_period,
|
||||
days=self.days
|
||||
)
|
||||
|
||||
self.set_slippage(FixedSlippage())
|
||||
|
||||
def handle_data(self, data):
|
||||
@@ -283,6 +291,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.history_return_args.append(
|
||||
self.return_args_batch.handle_data(
|
||||
data, *self.args, **self.kwargs))
|
||||
self.history_return_price_market_aware.append(
|
||||
self.return_price_market_aware.handle_data(data))
|
||||
|
||||
|
||||
class SetPortfolioAlgorithm(TradingAlgorithm):
|
||||
|
||||
@@ -360,9 +360,10 @@ class BatchTransform(EventWindow):
|
||||
|
||||
self.refresh_period = refresh_period
|
||||
self.days = days
|
||||
self.trading_days_since_update = 0
|
||||
|
||||
self.full = False
|
||||
self.last_refresh = None
|
||||
self.last_dt = None
|
||||
|
||||
self.updated = False
|
||||
self.data = None
|
||||
@@ -392,12 +393,15 @@ class BatchTransform(EventWindow):
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_refresh:
|
||||
self.last_refresh = event.dt
|
||||
if not self.last_dt:
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
age = event.dt - self.last_refresh
|
||||
if age.days >= self.refresh_period:
|
||||
if self.last_dt.day != event.dt.day:
|
||||
self.last_dt = event.dt
|
||||
self.trading_days_since_update += 1
|
||||
|
||||
if self.trading_days_since_update >= self.refresh_period:
|
||||
# Create a pandas.Panel (i.e. 3d DataFrame) from the
|
||||
# events in the current window.
|
||||
#
|
||||
@@ -432,7 +436,7 @@ class BatchTransform(EventWindow):
|
||||
self.data = pd.Panel.from_dict(fields, orient='items')
|
||||
|
||||
self.updated = True
|
||||
self.last_refresh = event.dt
|
||||
self.trading_days_since_update = 0
|
||||
else:
|
||||
self.updated = False
|
||||
|
||||
|
||||
@@ -274,7 +274,8 @@ def create_test_df_source():
|
||||
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
|
||||
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
|
||||
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
|
||||
x = np.arange(2., 14.).reshape((6, 2))
|
||||
x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2))
|
||||
|
||||
df = pd.DataFrame(x, index=index, columns=[0, 1])
|
||||
|
||||
return DataFrameSource(df), df
|
||||
|
||||
Reference in New Issue
Block a user