BUG: Refactored batch_transform unittests and fixed some bugs.

This commit is contained in:
Thomas Wiecki
2012-12-06 12:36:47 -05:00
parent c69858f8b9
commit 5f6839beea
4 changed files with 30 additions and 37 deletions
+9 -17
View File
@@ -340,27 +340,19 @@ class TestBatchTransform(TestCase):
)
self.assertTrue(all(
field['arbitrary'].values.flatten() == ['test'] * 8),
field['arbitrary'].values.flatten() ==
['test'] * algo.window_length),
'arbitrary dataframe should contain only "test"'
)
# test overloaded class
for test_history in [algo.history_return_price_class,
algo.history_return_price_decorator]:
np.testing.assert_array_equal(
range(2, 8),
test_history[2].values.flatten()
)
np.testing.assert_array_equal(
range(2, 8),
test_history[3].values.flatten()
)
np.testing.assert_array_equal(
range(4, 12),
test_history[4].values.flatten()
)
for i in range(3, 6):
np.testing.assert_array_equal(
range(i - algo.window_length + 1, i + 1),
test_history[i].values.flatten()
)
def test_passing_of_args(self):
algo = BatchTransformAlgorithm(1, kwarg='str')
@@ -371,8 +363,8 @@ class TestBatchTransform(TestCase):
expected_item = ((1, ), {'kwarg': 'str'})
self.assertEqual(
algo.history_return_args,
[None, None, expected_item, expected_item,
expected_item, expected_item])
[None, None, None, expected_item, expected_item,
expected_item])
class TestBatchTransformMarketAware(TestCase):
+2 -2
View File
@@ -214,7 +214,6 @@ class TimeoutAlgorithm(TradingAlgorithm):
time.sleep(100)
pass
from datetime import timedelta
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import BatchTransform, batch_transform
from zipline.transforms import MovingAverage
@@ -237,6 +236,7 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
class ReturnPriceBatchTransform(BatchTransform):
def get_value(self, data):
assert data.shape[1] == self.window_length
return data.price
@@ -257,7 +257,7 @@ def return_data(data, *args, **kwargs):
class BatchTransformAlgorithm(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.refresh_period = kwargs.pop('refresh_period', 2)
self.refresh_period = kwargs.pop('refresh_period', 1)
self.window_length = kwargs.pop('window_length', 3)
self.args = args
+17 -16
View File
@@ -345,7 +345,7 @@ class BatchTransform(EventWindow):
self.last_dt = None
self.updated = False
self.data = None
self.cached = None
self.field_names = None
@@ -373,20 +373,22 @@ class BatchTransform(EventWindow):
# return newly computed or cached value
return self.get_transform_value(*args, **kwargs)
def handle_add(self, event):
if not self.last_dt:
self.last_dt = event.dt
return
def _extract_field_names(self, event):
# extract field names from sids (price, volume etc), make sure
# every sid has the same fields.
sid_keys = [set(sid.keys()) for sid in event.data.itervalues()]
assert sid_keys[0] == set.intersection(*sid_keys),\
"Each sid must have the same keys."
if self.field_names is None:
unwanted_fields = set(['portfolio', 'sid', 'dt', 'type',
'datetime'])
self.field_names = sid_keys[0] - unwanted_fields
unwanted_fields = set(['portfolio', 'sid', 'dt', 'type',
'datetime'])
return sid_keys[0] - unwanted_fields
def handle_add(self, event):
if not self.last_dt:
self.field_names = self._extract_field_names(event)
self.last_dt = event.dt
return
# update trading day counters
if self.last_dt.day != event.dt.day:
@@ -398,13 +400,11 @@ class BatchTransform(EventWindow):
self.trading_days_total >= self.window_length and
self.trading_days_since_update >= self.refresh_period
):
# Create datapanel of running event window.
self.data = self.get_data()
# Setting updated to True will cause get_transform_value()
# to call the user-defined batch-transform with the most
# recent datapanel
self.updated = True
self.full = True
self.trading_days_since_update = 0
else:
self.updated = False
@@ -427,7 +427,8 @@ class BatchTransform(EventWindow):
fields = {}
for field_name in self.field_names:
sids = self.ticks[0].data.keys()
# Extract all used sids
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
values_per_sid = {}
@@ -471,11 +472,11 @@ class BatchTransform(EventWindow):
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if self.data is None:
if not self.full:
return None
if self.updated:
self.cached = self.compute_transform_value(self.data,
self.cached = self.compute_transform_value(self.get_data(),
*args, **kwargs)
return self.cached
+2 -2
View File
@@ -271,9 +271,9 @@ def create_test_df_source():
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2))
x = np.arange(0, len(index))
df = pd.DataFrame(x, index=index, columns=[0, 1])
df = pd.DataFrame(x, index=index, columns=[0])
return DataFrameSource(df), df