From a4ea33218d97c0712e6a09e529f9ab0722f3152a Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Fri, 26 Apr 2013 23:56:35 -0400 Subject: [PATCH] TST: Move batch transform tests to their own file. From @twiecki's rolling batch transform work. --- tests/test_batchtransform.py | 114 +++++++++++++++++++++++++++++++++++ tests/test_transforms.py | 109 --------------------------------- 2 files changed, 114 insertions(+), 109 deletions(-) create mode 100644 tests/test_batchtransform.py diff --git a/tests/test_batchtransform.py b/tests/test_batchtransform.py new file mode 100644 index 00000000..35e8e2cc --- /dev/null +++ b/tests/test_batchtransform.py @@ -0,0 +1,114 @@ +from collections import deque + +import pytz +import numpy as np +import pandas as pd + +from datetime import datetime +from unittest import TestCase + +from zipline.utils.test_utils import setup_logger + +import zipline.utils.factory as factory + +from zipline.test_algorithms import BatchTransformAlgorithm + + +class TestBatchTransform(TestCase): + def setUp(self): + self.sim_params = factory.create_simulation_parameters( + start=datetime(1990, 1, 1, tzinfo=pytz.utc), + end=datetime(1990, 1, 8, tzinfo=pytz.utc) + ) + setup_logger(self) + self.source, self.df = \ + factory.create_test_df_source(self.sim_params) + + def test_event_window(self): + algo = BatchTransformAlgorithm(sim_params=self.sim_params) + algo.run(self.source) + wl = algo.window_length + # The following assertion depend on window length of 3 + self.assertEqual(wl, 3) + self.assertEqual(algo.history_return_price_class[:wl], + [None] * wl, + "First three iterations should return None." + "\n" + + "i.e. no returned values until window is full'" + + "%s" % (algo.history_return_price_class,)) + self.assertEqual(algo.history_return_price_decorator[:wl], + [None] * wl, + "First three iterations should return None." + "\n" + + "i.e. no returned values until window is full'" + + "%s" % (algo.history_return_price_decorator,)) + + # After three Nones, the next value should be a data frame + self.assertTrue(isinstance( + algo.history_return_price_class[wl], + pd.DataFrame) + ) + + # Test whether arbitrary fields can be added to datapanel + field = algo.history_return_arbitrary_fields[-1] + self.assertTrue( + 'arbitrary' in field.items, + 'datapanel should contain column arbitrary' + ) + + self.assertTrue(all( + field['arbitrary'].values.flatten() == + [123] * algo.window_length), + 'arbitrary dataframe should contain only "test"' + ) + + for data in algo.history_return_sid_filter[wl:]: + self.assertIn(0, data.columns) + self.assertNotIn(1, data.columns) + + for data in algo.history_return_field_filter[wl:]: + self.assertIn('price', data.items) + self.assertNotIn('ignore', data.items) + + for data in algo.history_return_field_no_filter[wl:]: + self.assertIn('price', data.items) + self.assertIn('ignore', data.items) + + for data in algo.history_return_ticks[wl:]: + self.assertTrue(isinstance(data, deque)) + + for data in algo.history_return_not_full: + self.assertIsNot(data, None) + + # test overloaded class + for test_history in [algo.history_return_price_class, + algo.history_return_price_decorator]: + # starting at window length, the window should contain + # consecutive (of window length) numbers up till the end. + for i in range(algo.window_length, len(test_history)): + np.testing.assert_array_equal( + range(i - algo.window_length + 1, i + 1), + test_history[i].values.flatten() + ) + + def test_passing_of_args(self): + algo = BatchTransformAlgorithm(1, + kwarg='str', sim_params=self.sim_params) + self.assertEqual(algo.args, (1,)) + self.assertEqual(algo.kwargs, {'kwarg': 'str'}) + + algo.run(self.source) + expected_item = ((1, ), {'kwarg': 'str'}) + self.assertEqual( + algo.history_return_args, + [ + # 1990-01-01 - market holiday, no event + # 1990-01-02 - window not full + None, + # 1990-01-03 - window not full + None, + # 1990-01-04 - window not full, 3rd event + None, + # 1990-01-05 - window now full + expected_item, + # 1990-01-08 - window now full + expected_item + ]) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3fd4948a..c08eb27a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -12,12 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import deque - import pytz import numpy as np -import pandas as pd from datetime import timedelta, datetime from unittest import TestCase @@ -33,8 +29,6 @@ from zipline.transforms import MovingStandardDev from zipline.transforms import Returns import zipline.utils.factory as factory -from zipline.test_algorithms import BatchTransformAlgorithm - def to_dt(msg): return Event({'dt': msg}) @@ -276,106 +270,3 @@ class TestFinanceTransforms(TestCase): self.assertIsNone(v2) continue self.assertEquals(round(v1, 5), round(v2, 5)) - - -############################################################ -# Test BatchTransform - -class TestBatchTransform(TestCase): - def setUp(self): - self.sim_params = factory.create_simulation_parameters( - start=datetime(1990, 1, 1, tzinfo=pytz.utc), - end=datetime(1990, 1, 8, tzinfo=pytz.utc) - ) - setup_logger(self) - self.source, self.df = \ - factory.create_test_df_source(self.sim_params) - - def test_event_window(self): - algo = BatchTransformAlgorithm(sim_params=self.sim_params) - algo.run(self.source) - wl = algo.window_length - # The following assertion depend on window length of 3 - self.assertEqual(wl, 3) - self.assertEqual(algo.history_return_price_class[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + - "i.e. no returned values until window is full'" + - "%s" % (algo.history_return_price_class,)) - self.assertEqual(algo.history_return_price_decorator[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + - "i.e. no returned values until window is full'" + - "%s" % (algo.history_return_price_decorator,)) - - # After three Nones, the next value should be a data frame - self.assertTrue(isinstance( - algo.history_return_price_class[wl], - pd.DataFrame) - ) - - # Test whether arbitrary fields can be added to datapanel - field = algo.history_return_arbitrary_fields[-1] - self.assertTrue( - 'arbitrary' in field.items, - 'datapanel should contain column arbitrary' - ) - - self.assertTrue(all( - field['arbitrary'].values.flatten() == - [123] * algo.window_length), - 'arbitrary dataframe should contain only "test"' - ) - - for data in algo.history_return_sid_filter[wl:]: - self.assertIn(0, data.columns) - self.assertNotIn(1, data.columns) - - for data in algo.history_return_field_filter[wl:]: - self.assertIn('price', data.items) - self.assertNotIn('ignore', data.items) - - for data in algo.history_return_field_no_filter[wl:]: - self.assertIn('price', data.items) - self.assertIn('ignore', data.items) - - for data in algo.history_return_ticks[wl:]: - self.assertTrue(isinstance(data, deque)) - - for data in algo.history_return_not_full: - self.assertIsNot(data, None) - - # test overloaded class - for test_history in [algo.history_return_price_class, - algo.history_return_price_decorator]: - # starting at window length, the window should contain - # consecutive (of window length) numbers up till the end. - for i in range(algo.window_length, len(test_history)): - np.testing.assert_array_equal( - range(i - algo.window_length + 1, i + 1), - test_history[i].values.flatten() - ) - - def test_passing_of_args(self): - algo = BatchTransformAlgorithm(1, - kwarg='str', sim_params=self.sim_params) - self.assertEqual(algo.args, (1,)) - self.assertEqual(algo.kwargs, {'kwarg': 'str'}) - - algo.run(self.source) - expected_item = ((1, ), {'kwarg': 'str'}) - self.assertEqual( - algo.history_return_args, - [ - # 1990-01-01 - market holiday, no event - # 1990-01-02 - window not full - None, - # 1990-01-03 - window not full - None, - # 1990-01-04 - window not full, 3rd event - None, - # 1990-01-05 - window now full - expected_item, - # 1990-01-08 - window now full - expected_item - ])