mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 02:43:12 +08:00
TST: Use fixture's trading env for FakeDataPortal or TradingAlgo
to avoid a new trading env needing to download data unnecessarily
This commit is contained in:
+15
-7
@@ -314,6 +314,7 @@ def handle_data(algo, data):
|
||||
initialize=lambda context: None,
|
||||
handle_data=lambda context, data: None,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
# Verify that api methods get resolved dynamically by patching them out
|
||||
@@ -892,7 +893,8 @@ def before_trading_start(context, data):
|
||||
def test_run_twice(self):
|
||||
algo1 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res1 = algo1.run(self.data_portal)
|
||||
@@ -901,7 +903,8 @@ def before_trading_start(context, data):
|
||||
# use the newly instantiated environment.
|
||||
algo2 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res2 = algo2.run(self.data_portal)
|
||||
@@ -1569,15 +1572,16 @@ class TestAlgoScript(WithLogger,
|
||||
|
||||
def test_noop(self):
|
||||
algo = TradingAlgorithm(initialize=initialize_noop,
|
||||
handle_data=handle_data_noop)
|
||||
handle_data=handle_data_noop,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_noop_string(self):
|
||||
algo = TradingAlgorithm(script=noop_algo)
|
||||
algo = TradingAlgorithm(script=noop_algo, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_no_handle_data(self):
|
||||
algo = TradingAlgorithm(script=no_handle_data)
|
||||
algo = TradingAlgorithm(script=no_handle_data, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_api_calls(self):
|
||||
@@ -1593,7 +1597,8 @@ class TestAlgoScript(WithLogger,
|
||||
def test_api_get_environment(self):
|
||||
platform = 'zipline'
|
||||
algo = TradingAlgorithm(script=api_get_environment_algo,
|
||||
platform=platform)
|
||||
platform=platform,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
self.assertEqual(algo.environment, platform)
|
||||
|
||||
@@ -1779,6 +1784,7 @@ def handle_data(context, data):
|
||||
test_algo = TradingAlgorithm(
|
||||
script=record_variables,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
set_algo_instance(test_algo)
|
||||
|
||||
@@ -3785,6 +3791,7 @@ class TestTradingAlgorithm(ZiplineTestCase):
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
analyze=analyze,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
with empty_trading_env() as env:
|
||||
@@ -4642,7 +4649,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase):
|
||||
self.assertEqual(expected_message, w.message)
|
||||
|
||||
|
||||
class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
def test_reject_passing_both_api_methods_and_script(self):
|
||||
script = dedent(
|
||||
@@ -4668,6 +4675,7 @@ class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
TradingAlgorithm(
|
||||
script=script,
|
||||
env=self.env,
|
||||
**{method: lambda *args, **kwargs: None}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user