mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 05:16:54 +08:00
BUG: Refactored batch_transform unittests and fixed some bugs.
This commit is contained in:
@@ -340,27 +340,19 @@ class TestBatchTransform(TestCase):
|
||||
)
|
||||
|
||||
self.assertTrue(all(
|
||||
field['arbitrary'].values.flatten() == ['test'] * 8),
|
||||
field['arbitrary'].values.flatten() ==
|
||||
['test'] * algo.window_length),
|
||||
'arbitrary dataframe should contain only "test"'
|
||||
)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
algo.history_return_price_decorator]:
|
||||
np.testing.assert_array_equal(
|
||||
range(2, 8),
|
||||
test_history[2].values.flatten()
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
range(2, 8),
|
||||
test_history[3].values.flatten()
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
range(4, 12),
|
||||
test_history[4].values.flatten()
|
||||
)
|
||||
for i in range(3, 6):
|
||||
np.testing.assert_array_equal(
|
||||
range(i - algo.window_length + 1, i + 1),
|
||||
test_history[i].values.flatten()
|
||||
)
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm(1, kwarg='str')
|
||||
@@ -371,8 +363,8 @@ class TestBatchTransform(TestCase):
|
||||
expected_item = ((1, ), {'kwarg': 'str'})
|
||||
self.assertEqual(
|
||||
algo.history_return_args,
|
||||
[None, None, expected_item, expected_item,
|
||||
expected_item, expected_item])
|
||||
[None, None, None, expected_item, expected_item,
|
||||
expected_item])
|
||||
|
||||
|
||||
class TestBatchTransformMarketAware(TestCase):
|
||||
|
||||
@@ -214,7 +214,6 @@ class TimeoutAlgorithm(TradingAlgorithm):
|
||||
time.sleep(100)
|
||||
pass
|
||||
|
||||
from datetime import timedelta
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import BatchTransform, batch_transform
|
||||
from zipline.transforms import MovingAverage
|
||||
@@ -237,6 +236,7 @@ class TestRegisterTransformAlgorithm(TradingAlgorithm):
|
||||
|
||||
class ReturnPriceBatchTransform(BatchTransform):
|
||||
def get_value(self, data):
|
||||
assert data.shape[1] == self.window_length
|
||||
return data.price
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def return_data(data, *args, **kwargs):
|
||||
|
||||
class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.refresh_period = kwargs.pop('refresh_period', 2)
|
||||
self.refresh_period = kwargs.pop('refresh_period', 1)
|
||||
self.window_length = kwargs.pop('window_length', 3)
|
||||
|
||||
self.args = args
|
||||
|
||||
+17
-16
@@ -345,7 +345,7 @@ class BatchTransform(EventWindow):
|
||||
self.last_dt = None
|
||||
|
||||
self.updated = False
|
||||
self.data = None
|
||||
self.cached = None
|
||||
|
||||
self.field_names = None
|
||||
|
||||
@@ -373,20 +373,22 @@ class BatchTransform(EventWindow):
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_dt:
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
def _extract_field_names(self, event):
|
||||
# extract field names from sids (price, volume etc), make sure
|
||||
# every sid has the same fields.
|
||||
sid_keys = [set(sid.keys()) for sid in event.data.itervalues()]
|
||||
assert sid_keys[0] == set.intersection(*sid_keys),\
|
||||
"Each sid must have the same keys."
|
||||
if self.field_names is None:
|
||||
unwanted_fields = set(['portfolio', 'sid', 'dt', 'type',
|
||||
'datetime'])
|
||||
self.field_names = sid_keys[0] - unwanted_fields
|
||||
|
||||
unwanted_fields = set(['portfolio', 'sid', 'dt', 'type',
|
||||
'datetime'])
|
||||
return sid_keys[0] - unwanted_fields
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_dt:
|
||||
self.field_names = self._extract_field_names(event)
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
# update trading day counters
|
||||
if self.last_dt.day != event.dt.day:
|
||||
@@ -398,13 +400,11 @@ class BatchTransform(EventWindow):
|
||||
self.trading_days_total >= self.window_length and
|
||||
self.trading_days_since_update >= self.refresh_period
|
||||
):
|
||||
|
||||
# Create datapanel of running event window.
|
||||
self.data = self.get_data()
|
||||
# Setting updated to True will cause get_transform_value()
|
||||
# to call the user-defined batch-transform with the most
|
||||
# recent datapanel
|
||||
self.updated = True
|
||||
self.full = True
|
||||
self.trading_days_since_update = 0
|
||||
else:
|
||||
self.updated = False
|
||||
@@ -427,7 +427,8 @@ class BatchTransform(EventWindow):
|
||||
fields = {}
|
||||
|
||||
for field_name in self.field_names:
|
||||
sids = self.ticks[0].data.keys()
|
||||
# Extract all used sids
|
||||
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
|
||||
|
||||
values_per_sid = {}
|
||||
|
||||
@@ -471,11 +472,11 @@ class BatchTransform(EventWindow):
|
||||
has actually been updated. Otherwise, the previously, cached
|
||||
value will be returned.
|
||||
"""
|
||||
if self.data is None:
|
||||
if not self.full:
|
||||
return None
|
||||
|
||||
if self.updated:
|
||||
self.cached = self.compute_transform_value(self.data,
|
||||
self.cached = self.compute_transform_value(self.get_data(),
|
||||
*args, **kwargs)
|
||||
|
||||
return self.cached
|
||||
|
||||
@@ -271,9 +271,9 @@ def create_test_df_source():
|
||||
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
|
||||
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
|
||||
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
|
||||
x = np.arange(2., len(index) * 2 + 2).reshape((-1, 2))
|
||||
x = np.arange(0, len(index))
|
||||
|
||||
df = pd.DataFrame(x, index=index, columns=[0, 1])
|
||||
df = pd.DataFrame(x, index=index, columns=[0])
|
||||
|
||||
return DataFrameSource(df), df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user