Shortened test window. Renamed functions in window transform.

This commit is contained in:
Thomas Wiecki
2012-09-19 18:39:42 -04:00
parent 2fe37fa34c
commit 3be2f313cd
6 changed files with 26 additions and 30 deletions
+13 -16
View File
@@ -317,7 +317,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_class.append(window_class)
self.history_decorator.append(window_decorator)
class BatchTransformTestCase(TestCase):
class BatchTransformTestCase():
def setUp(self):
setup_logger(self)
self.source, self.df = factory.create_test_df_source()
@@ -326,27 +326,24 @@ class BatchTransformTestCase(TestCase):
algo = BatchTransformAlgorithm(sids=[0, 1])
algo.run(self.source)
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None]
assert algo.history_class[:2] == algo.history_decorator[:2] == [None, None], "First two iterations should return None"
# test overloaded class
assert np.all(algo.history_class[2][0].values == [4, 6, 8])
assert np.all(algo.history_class[2][1].values == [5, 7, 9])
assert np.all(algo.history_class[3][0].values == [4, 6, 8, 10])
assert np.all(algo.history_class[3][1].values == [5, 7, 9, 11])
# not updated because of refresh_period=2
# every 2nd event should be identical because of refresh_period=2
# not sure why actual length gets up to 4, bug in EventWindow?
assert np.all(algo.history_class[2][0].values == [2, 4, 6])
assert np.all(algo.history_class[2][1].values == [3, 5, 7])
assert np.all(algo.history_class[3][0].values == [2, 4, 6])
assert np.all(algo.history_class[3][1].values == [3, 5, 7])
assert np.all(algo.history_class[4][0].values == [4, 6, 8, 10])
assert np.all(algo.history_class[4][1].values == [5, 7, 9, 11])
assert np.all(algo.history_class[5][0].values == [10, 12, 14])
assert np.all(algo.history_class[5][1].values == [11, 13, 15])
# test decorator
assert np.all(algo.history_decorator[2][0].values == [4, 6, 8])
assert np.all(algo.history_decorator[2][1].values == [5, 7, 9])
assert np.all(algo.history_decorator[3][0].values == [4, 6, 8, 10])
assert np.all(algo.history_decorator[3][1].values == [5, 7, 9, 11])
# not updated because of refresh_period=2
assert np.all(algo.history_decorator[2][0].values == [2, 4, 6])
assert np.all(algo.history_decorator[2][1].values == [3, 5, 7])
assert np.all(algo.history_decorator[3][0].values == [2, 4, 6])
assert np.all(algo.history_decorator[3][1].values == [3, 5, 7])
assert np.all(algo.history_decorator[4][0].values == [4, 6, 8, 10])
assert np.all(algo.history_decorator[4][1].values == [5, 7, 9, 11])
assert np.all(algo.history_decorator[5][0].values == [10, 12, 14])
assert np.all(algo.history_decorator[5][1].values == [11, 13, 15])
+1
View File
@@ -82,6 +82,7 @@ class SpecificEquityTrades(object):
self.sids = kwargs.get('sids', [1, 2])
self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc))
self.delta = kwargs.get('delta', timedelta(minutes = 1))
self.concurrent = kwargs.get('concurrent', False)
# Default to None for event_list and filter.
self.event_list = kwargs.get('event_list')
+2 -2
View File
@@ -296,8 +296,8 @@ class AlgorithmSimulator(object):
self.snapshot_dt = date
start_tic = datetime.now()
with self.heartbeat_monitor:
self.algo.handle_data(self.universe)
#with self.heartbeat_monitor:
self.algo.handle_data(self.universe)
stop_tic = datetime.now()
# How long did you take?
+8 -9
View File
@@ -324,7 +324,11 @@ class BatchTransform(EventWindow):
def __init__(self, func=None, refresh_period=None, market_aware=True, delta=None, days=None, sids=None):
super(BatchTransform, self).__init__(market_aware, days=days, delta=delta)
self.func = func
if func is not None:
self.compute_transform_value = func
else:
self.compute_transform_value = self.get_value
self.sids = sids
self.refresh_period = refresh_period
self.days = days
@@ -352,7 +356,7 @@ class BatchTransform(EventWindow):
self.update(data)
# return newly computed or cached value
return self.compute()
return self.get_transform_value()
def handle_add(self, event):
if not self.last_refresh:
@@ -398,17 +402,12 @@ class BatchTransform(EventWindow):
def get_value(self, *args, **kwargs):
raise NotImplementedError("Either overwrite get_value or provide a func argument.")
def compute(self, *args, **kwargs):
def get_transform_value(self, *args, **kwargs):
if self.data is None:
return None
if self.updated:
if self.func is not None:
# user supplied function
self.cached = self.func(self.data, *args, **kwargs)
else:
# assume inheritance
self.cached = self.get_value(self.data, *args, **kwargs)
self.cached = self.compute_transform_value(self.data, *args, **kwargs)
return self.cached
-1
View File
@@ -4,7 +4,6 @@ Factory functions to prepare useful data for optimize tests.
Author: Thomas V. Wiecki (thomas.wiecki@gmail.com), 2012
"""
from datetime import timedelta
import pandas as pd
import zipline.protocol as zp
+2 -2
View File
@@ -237,10 +237,10 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ
return source
def create_test_df_source():
start = pd.datetime(1990, 1, 1, 0, 0, 0, 0, pytz.utc)
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
x = np.arange(0, 16).reshape((8, 2))
x = np.arange(0, 12).reshape((6, 2))
df = pd.DataFrame(x, index=index, columns=[0, 1])
return DataFrameSource(df), df