Files
catalyst/tests/test_batchtransform.py
T
Eddie Hebert a4ea33218d TST: Move batch transform tests to their own file.
From @twiecki's rolling batch transform work.
2013-04-26 23:56:35 -04:00

115 lines
4.2 KiB
Python

from collections import deque
import pytz
import numpy as np
import pandas as pd
from datetime import datetime
from unittest import TestCase
from zipline.utils.test_utils import setup_logger
import zipline.utils.factory as factory
from zipline.test_algorithms import BatchTransformAlgorithm
class TestBatchTransform(TestCase):
def setUp(self):
self.sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 1, 8, tzinfo=pytz.utc)
)
setup_logger(self)
self.source, self.df = \
factory.create_test_df_source(self.sim_params)
def test_event_window(self):
algo = BatchTransformAlgorithm(sim_params=self.sim_params)
algo.run(self.source)
wl = algo.window_length
# The following assertion depend on window length of 3
self.assertEqual(wl, 3)
self.assertEqual(algo.history_return_price_class[:wl],
[None] * wl,
"First three iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_class,))
self.assertEqual(algo.history_return_price_decorator[:wl],
[None] * wl,
"First three iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_decorator,))
# After three Nones, the next value should be a data frame
self.assertTrue(isinstance(
algo.history_return_price_class[wl],
pd.DataFrame)
)
# Test whether arbitrary fields can be added to datapanel
field = algo.history_return_arbitrary_fields[-1]
self.assertTrue(
'arbitrary' in field.items,
'datapanel should contain column arbitrary'
)
self.assertTrue(all(
field['arbitrary'].values.flatten() ==
[123] * algo.window_length),
'arbitrary dataframe should contain only "test"'
)
for data in algo.history_return_sid_filter[wl:]:
self.assertIn(0, data.columns)
self.assertNotIn(1, data.columns)
for data in algo.history_return_field_filter[wl:]:
self.assertIn('price', data.items)
self.assertNotIn('ignore', data.items)
for data in algo.history_return_field_no_filter[wl:]:
self.assertIn('price', data.items)
self.assertIn('ignore', data.items)
for data in algo.history_return_ticks[wl:]:
self.assertTrue(isinstance(data, deque))
for data in algo.history_return_not_full:
self.assertIsNot(data, None)
# test overloaded class
for test_history in [algo.history_return_price_class,
algo.history_return_price_decorator]:
# starting at window length, the window should contain
# consecutive (of window length) numbers up till the end.
for i in range(algo.window_length, len(test_history)):
np.testing.assert_array_equal(
range(i - algo.window_length + 1, i + 1),
test_history[i].values.flatten()
)
def test_passing_of_args(self):
algo = BatchTransformAlgorithm(1,
kwarg='str', sim_params=self.sim_params)
self.assertEqual(algo.args, (1,))
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
algo.run(self.source)
expected_item = ((1, ), {'kwarg': 'str'})
self.assertEqual(
algo.history_return_args,
[
# 1990-01-01 - market holiday, no event
# 1990-01-02 - window not full
None,
# 1990-01-03 - window not full
None,
# 1990-01-04 - window not full, 3rd event
None,
# 1990-01-05 - window now full
expected_item,
# 1990-01-08 - window now full
expected_item
])