mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 08:51:44 +08:00
TST: Move batch transform tests to their own file.
From @twiecki's rolling batch transform work.
This commit is contained in:
@@ -0,0 +1,114 @@
|
||||
from collections import deque
|
||||
|
||||
import pytz
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from unittest import TestCase
|
||||
|
||||
from zipline.utils.test_utils import setup_logger
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
|
||||
from zipline.test_algorithms import BatchTransformAlgorithm
|
||||
|
||||
|
||||
class TestBatchTransform(TestCase):
|
||||
def setUp(self):
|
||||
self.sim_params = factory.create_simulation_parameters(
|
||||
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
|
||||
end=datetime(1990, 1, 8, tzinfo=pytz.utc)
|
||||
)
|
||||
setup_logger(self)
|
||||
self.source, self.df = \
|
||||
factory.create_test_df_source(self.sim_params)
|
||||
|
||||
def test_event_window(self):
|
||||
algo = BatchTransformAlgorithm(sim_params=self.sim_params)
|
||||
algo.run(self.source)
|
||||
wl = algo.window_length
|
||||
# The following assertion depend on window length of 3
|
||||
self.assertEqual(wl, 3)
|
||||
self.assertEqual(algo.history_return_price_class[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_class,))
|
||||
self.assertEqual(algo.history_return_price_decorator[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_decorator,))
|
||||
|
||||
# After three Nones, the next value should be a data frame
|
||||
self.assertTrue(isinstance(
|
||||
algo.history_return_price_class[wl],
|
||||
pd.DataFrame)
|
||||
)
|
||||
|
||||
# Test whether arbitrary fields can be added to datapanel
|
||||
field = algo.history_return_arbitrary_fields[-1]
|
||||
self.assertTrue(
|
||||
'arbitrary' in field.items,
|
||||
'datapanel should contain column arbitrary'
|
||||
)
|
||||
|
||||
self.assertTrue(all(
|
||||
field['arbitrary'].values.flatten() ==
|
||||
[123] * algo.window_length),
|
||||
'arbitrary dataframe should contain only "test"'
|
||||
)
|
||||
|
||||
for data in algo.history_return_sid_filter[wl:]:
|
||||
self.assertIn(0, data.columns)
|
||||
self.assertNotIn(1, data.columns)
|
||||
|
||||
for data in algo.history_return_field_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertNotIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_field_no_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_ticks[wl:]:
|
||||
self.assertTrue(isinstance(data, deque))
|
||||
|
||||
for data in algo.history_return_not_full:
|
||||
self.assertIsNot(data, None)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
algo.history_return_price_decorator]:
|
||||
# starting at window length, the window should contain
|
||||
# consecutive (of window length) numbers up till the end.
|
||||
for i in range(algo.window_length, len(test_history)):
|
||||
np.testing.assert_array_equal(
|
||||
range(i - algo.window_length + 1, i + 1),
|
||||
test_history[i].values.flatten()
|
||||
)
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm(1,
|
||||
kwarg='str', sim_params=self.sim_params)
|
||||
self.assertEqual(algo.args, (1,))
|
||||
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
|
||||
|
||||
algo.run(self.source)
|
||||
expected_item = ((1, ), {'kwarg': 'str'})
|
||||
self.assertEqual(
|
||||
algo.history_return_args,
|
||||
[
|
||||
# 1990-01-01 - market holiday, no event
|
||||
# 1990-01-02 - window not full
|
||||
None,
|
||||
# 1990-01-03 - window not full
|
||||
None,
|
||||
# 1990-01-04 - window not full, 3rd event
|
||||
None,
|
||||
# 1990-01-05 - window now full
|
||||
expected_item,
|
||||
# 1990-01-08 - window now full
|
||||
expected_item
|
||||
])
|
||||
@@ -12,12 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import pytz
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from unittest import TestCase
|
||||
@@ -33,8 +29,6 @@ from zipline.transforms import MovingStandardDev
|
||||
from zipline.transforms import Returns
|
||||
import zipline.utils.factory as factory
|
||||
|
||||
from zipline.test_algorithms import BatchTransformAlgorithm
|
||||
|
||||
|
||||
def to_dt(msg):
|
||||
return Event({'dt': msg})
|
||||
@@ -276,106 +270,3 @@ class TestFinanceTransforms(TestCase):
|
||||
self.assertIsNone(v2)
|
||||
continue
|
||||
self.assertEquals(round(v1, 5), round(v2, 5))
|
||||
|
||||
|
||||
############################################################
|
||||
# Test BatchTransform
|
||||
|
||||
class TestBatchTransform(TestCase):
|
||||
def setUp(self):
|
||||
self.sim_params = factory.create_simulation_parameters(
|
||||
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
|
||||
end=datetime(1990, 1, 8, tzinfo=pytz.utc)
|
||||
)
|
||||
setup_logger(self)
|
||||
self.source, self.df = \
|
||||
factory.create_test_df_source(self.sim_params)
|
||||
|
||||
def test_event_window(self):
|
||||
algo = BatchTransformAlgorithm(sim_params=self.sim_params)
|
||||
algo.run(self.source)
|
||||
wl = algo.window_length
|
||||
# The following assertion depend on window length of 3
|
||||
self.assertEqual(wl, 3)
|
||||
self.assertEqual(algo.history_return_price_class[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_class,))
|
||||
self.assertEqual(algo.history_return_price_decorator[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_decorator,))
|
||||
|
||||
# After three Nones, the next value should be a data frame
|
||||
self.assertTrue(isinstance(
|
||||
algo.history_return_price_class[wl],
|
||||
pd.DataFrame)
|
||||
)
|
||||
|
||||
# Test whether arbitrary fields can be added to datapanel
|
||||
field = algo.history_return_arbitrary_fields[-1]
|
||||
self.assertTrue(
|
||||
'arbitrary' in field.items,
|
||||
'datapanel should contain column arbitrary'
|
||||
)
|
||||
|
||||
self.assertTrue(all(
|
||||
field['arbitrary'].values.flatten() ==
|
||||
[123] * algo.window_length),
|
||||
'arbitrary dataframe should contain only "test"'
|
||||
)
|
||||
|
||||
for data in algo.history_return_sid_filter[wl:]:
|
||||
self.assertIn(0, data.columns)
|
||||
self.assertNotIn(1, data.columns)
|
||||
|
||||
for data in algo.history_return_field_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertNotIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_field_no_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_ticks[wl:]:
|
||||
self.assertTrue(isinstance(data, deque))
|
||||
|
||||
for data in algo.history_return_not_full:
|
||||
self.assertIsNot(data, None)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
algo.history_return_price_decorator]:
|
||||
# starting at window length, the window should contain
|
||||
# consecutive (of window length) numbers up till the end.
|
||||
for i in range(algo.window_length, len(test_history)):
|
||||
np.testing.assert_array_equal(
|
||||
range(i - algo.window_length + 1, i + 1),
|
||||
test_history[i].values.flatten()
|
||||
)
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm(1,
|
||||
kwarg='str', sim_params=self.sim_params)
|
||||
self.assertEqual(algo.args, (1,))
|
||||
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
|
||||
|
||||
algo.run(self.source)
|
||||
expected_item = ((1, ), {'kwarg': 'str'})
|
||||
self.assertEqual(
|
||||
algo.history_return_args,
|
||||
[
|
||||
# 1990-01-01 - market holiday, no event
|
||||
# 1990-01-02 - window not full
|
||||
None,
|
||||
# 1990-01-03 - window not full
|
||||
None,
|
||||
# 1990-01-04 - window not full, 3rd event
|
||||
None,
|
||||
# 1990-01-05 - window now full
|
||||
expected_item,
|
||||
# 1990-01-08 - window now full
|
||||
expected_item
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user