mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 06:28:42 +08:00
FIX: Check types of args passed to api methods on data
This commit is contained in:
@@ -107,7 +107,22 @@ from zipline.test_algorithms import (
|
||||
call_without_kwargs,
|
||||
call_with_bad_kwargs_current,
|
||||
call_with_bad_kwargs_history,
|
||||
bad_type_history_assets,
|
||||
bad_type_history_fields,
|
||||
bad_type_history_bar_count,
|
||||
bad_type_history_frequency,
|
||||
bad_type_current_assets,
|
||||
bad_type_current_fields,
|
||||
bad_type_can_trade_assets,
|
||||
bad_type_is_stale_assets,
|
||||
bad_type_history_assets_kwarg,
|
||||
bad_type_history_fields_kwarg,
|
||||
bad_type_history_bar_count_kwarg,
|
||||
bad_type_history_frequency_kwarg,
|
||||
bad_type_current_assets_kwarg,
|
||||
bad_type_current_fields_kwarg,
|
||||
no_handle_data)
|
||||
|
||||
from zipline.testing import (
|
||||
make_jagged_equity_info,
|
||||
to_utc,
|
||||
@@ -1431,6 +1446,29 @@ class TestBeforeTradingStart(TestCase):
|
||||
|
||||
class TestAlgoScript(TestCase):
|
||||
|
||||
ARG_TYPE_TEST_CASES = [
|
||||
('history__assets', (bad_type_history_assets, 'Asset, str', True)),
|
||||
('history__fields', (bad_type_history_fields, 'str', True)),
|
||||
('history__bar_count', (bad_type_history_bar_count, 'int', False)),
|
||||
('history__frequency', (bad_type_history_frequency, 'str', False)),
|
||||
('current__assets', (bad_type_current_assets, 'Asset, str', True)),
|
||||
('current__fields', (bad_type_current_fields, 'str', True)),
|
||||
('is_stale__assets', (bad_type_is_stale_assets, 'Asset', True)),
|
||||
('can_trade__assets', (bad_type_can_trade_assets, 'Asset', True)),
|
||||
('history_kwarg__assets',
|
||||
(bad_type_history_assets_kwarg, 'Asset, str', True)),
|
||||
('history_kwarg__fields',
|
||||
(bad_type_history_fields_kwarg, 'str', True)),
|
||||
('history_kwarg__bar_count',
|
||||
(bad_type_history_bar_count_kwarg, 'int', False)),
|
||||
('history_kwarg__frequency',
|
||||
(bad_type_history_frequency_kwarg, 'str', False)),
|
||||
('current_kwarg__assets',
|
||||
(bad_type_current_assets_kwarg, 'Asset, str', True)),
|
||||
('current_kwarg__fields',
|
||||
(bad_type_current_fields_kwarg, 'str', True)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_logger(cls)
|
||||
@@ -1846,6 +1884,27 @@ def handle_data(context, data):
|
||||
self.assertEqual("%s() got an unexpected keyword argument 'blahblah'"
|
||||
% name, cm.exception.args[0])
|
||||
|
||||
@parameterized.expand(ARG_TYPE_TEST_CASES)
|
||||
def test_arg_types(self, name, inputs):
|
||||
|
||||
keyword = name.split('__')[1]
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
algo = TradingAlgorithm(
|
||||
script=inputs[0],
|
||||
sim_params=self.sim_params,
|
||||
env=self.env
|
||||
)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
expected = "Expected %s argument to be of type %s%s" % (
|
||||
keyword,
|
||||
'or iterable of type ' if inputs[2] else '',
|
||||
inputs[1]
|
||||
)
|
||||
|
||||
self.assertEqual(expected, cm.exception.args[0])
|
||||
|
||||
|
||||
class TestGetDatetime(TestCase):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user