mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 01:21:18 +08:00
MAINT: Allow algo.run() to write to db
This commit is contained in:
+30
-10
@@ -489,7 +489,8 @@ class TestTransformAlgorithm(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
self.sim_params = factory.create_simulation_parameters(num_days=4)
|
||||
trading.environment.write_data(equities_identifiers=[0, 1, 133])
|
||||
trading.environment = trading.TradingEnvironment()
|
||||
trading.environment.write_data(equities_identifiers=[133])
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
@@ -499,7 +500,6 @@ class TestTransformAlgorithm(TestCase):
|
||||
self.sim_params
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
self.df_source, self.df = \
|
||||
factory.create_test_df_source(self.sim_params)
|
||||
|
||||
@@ -526,13 +526,14 @@ class TestTransformAlgorithm(TestCase):
|
||||
algo.run(self.source)
|
||||
|
||||
def test_multi_source_as_input(self):
|
||||
trading.environment.write_data(equities_identifiers=[0, 1])
|
||||
sim_params = SimulationParameters(
|
||||
self.df.index[0],
|
||||
self.df.index[-1]
|
||||
)
|
||||
algo = TestRegisterTransformAlgorithm(
|
||||
sim_params=sim_params,
|
||||
sids=[0, 1, 133]
|
||||
sids=[0, 1]
|
||||
)
|
||||
algo.run([self.source, self.df_source], overwrite_sim_params=False)
|
||||
self.assertEqual(len(algo.sources), 2)
|
||||
@@ -540,7 +541,6 @@ class TestTransformAlgorithm(TestCase):
|
||||
def test_df_as_input(self):
|
||||
algo = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
)
|
||||
algo.run(self.df)
|
||||
assert isinstance(algo.sources[0], DataFrameSource)
|
||||
@@ -553,13 +553,23 @@ class TestTransformAlgorithm(TestCase):
|
||||
assert isinstance(algo.sources[0], DataPanelSource)
|
||||
|
||||
def test_run_twice(self):
|
||||
algo = TestRegisterTransformAlgorithm(
|
||||
algo1 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
)
|
||||
|
||||
res1 = algo.run(self.df)
|
||||
res2 = algo.run(self.df)
|
||||
res1 = algo1.run(self.df)
|
||||
|
||||
# Create a new trading environment
|
||||
trading.environment = trading.TradingEnvironment()
|
||||
# Create a new trading algorithm, which will
|
||||
# use the newly instantiated environment.
|
||||
algo2 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
)
|
||||
|
||||
res2 = algo2.run(self.df)
|
||||
|
||||
np.testing.assert_array_equal(res1, res2)
|
||||
|
||||
@@ -617,6 +627,7 @@ class TestTransformAlgorithm(TestCase):
|
||||
'order_target_value']
|
||||
|
||||
for name in method_names_to_test:
|
||||
trading.environment = trading.TradingEnvironment()
|
||||
algo = TestOrderStyleForwardingAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
instant_fill=False,
|
||||
@@ -630,6 +641,7 @@ class TestTransformAlgorithm(TestCase):
|
||||
algo.run(self.df)
|
||||
|
||||
def test_minute_data(self):
|
||||
trading.environment.write_data(equities_identifiers=[0, 1])
|
||||
source = RandomWalkSource(freq='minute',
|
||||
start=pd.Timestamp('2000-1-3',
|
||||
tz='UTC'),
|
||||
@@ -646,7 +658,7 @@ class TestPositions(TestCase):
|
||||
def setUp(self):
|
||||
setup_logger(self)
|
||||
self.sim_params = factory.create_simulation_parameters(num_days=4)
|
||||
trading.environment.write_data(equities_identifiers=[0, 1, 133])
|
||||
trading.environment.write_data(equities_identifiers=[1, 133])
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
1,
|
||||
@@ -691,10 +703,11 @@ class TestPositions(TestCase):
|
||||
class TestAlgoScript(TestCase):
|
||||
def setUp(self):
|
||||
days = 251
|
||||
# Note that create_simulation_parameters creates a new TradingEnvironment
|
||||
# Note that create_simulation_parameters creates
|
||||
# a new TradingEnvironment
|
||||
self.sim_params = factory.create_simulation_parameters(num_days=days)
|
||||
setup_logger(self)
|
||||
trading.environment.write_data(equities_identifiers=[0, 1, 133])
|
||||
trading.environment.write_data(equities_identifiers=[1, 133])
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0] * days,
|
||||
@@ -755,6 +768,7 @@ class TestAlgoScript(TestCase):
|
||||
def test_fixed_slippage(self):
|
||||
# verify order -> transaction -> portfolio position.
|
||||
# --------------
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script="""
|
||||
from zipline.api import (slippage,
|
||||
@@ -809,6 +823,7 @@ def handle_data(context, data):
|
||||
def test_volshare_slippage(self):
|
||||
# verify order -> transaction -> portfolio position.
|
||||
# --------------
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script="""
|
||||
from zipline.api import *
|
||||
@@ -876,6 +891,7 @@ def handle_data(context, data):
|
||||
self.zipline_test_config['algorithm'] = test_algo
|
||||
self.zipline_test_config['trade_count'] = 200
|
||||
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
zipline = simfactory.create_test_zipline(
|
||||
**self.zipline_test_config)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
@@ -902,6 +918,7 @@ def handle_data(context, data):
|
||||
test_algo.record(foo=MagicMock())
|
||||
|
||||
def _algo_record_float_magic_should_pass(self, var_type):
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script=record_float_magic % var_type,
|
||||
sim_params=self.sim_params,
|
||||
@@ -928,6 +945,7 @@ def handle_data(context, data):
|
||||
Only test that order methods can be called without error.
|
||||
Correct filling of orders is tested in zipline.
|
||||
"""
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script=call_all_order_methods,
|
||||
sim_params=self.sim_params,
|
||||
@@ -959,6 +977,7 @@ def handle_data(context, data):
|
||||
"""
|
||||
Test that accessing portfolio in init doesn't break.
|
||||
"""
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script=access_portfolio_in_init,
|
||||
sim_params=self.sim_params,
|
||||
@@ -977,6 +996,7 @@ def handle_data(context, data):
|
||||
"""
|
||||
Test that accessing account in init doesn't break.
|
||||
"""
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
test_algo = TradingAlgorithm(
|
||||
script=access_account_in_init,
|
||||
sim_params=self.sim_params,
|
||||
|
||||
@@ -31,6 +31,7 @@ from six import (
|
||||
)
|
||||
from operator import attrgetter
|
||||
|
||||
|
||||
from zipline.errors import (
|
||||
AddTermPostInit,
|
||||
OrderDuringInitialize,
|
||||
@@ -192,7 +193,6 @@ class TradingAlgorithm(object):
|
||||
|
||||
# set the capital base
|
||||
self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
|
||||
|
||||
self.sim_params = kwargs.pop('sim_params', None)
|
||||
if self.sim_params is None:
|
||||
self.sim_params = create_simulation_parameters(
|
||||
@@ -500,6 +500,8 @@ class TradingAlgorithm(object):
|
||||
# if DataFrame provided, map columns to sids and wrap
|
||||
# in DataFrameSource
|
||||
copy_frame = source.copy()
|
||||
self.trading_environment.write_data(
|
||||
equities_identifiers=source.columns)
|
||||
copy_frame.columns = \
|
||||
self.asset_finder.map_identifier_index_to_sids(
|
||||
source.columns, source.index[0]
|
||||
@@ -510,6 +512,8 @@ class TradingAlgorithm(object):
|
||||
# If Panel provided, map items to sids and wrap
|
||||
# in DataPanelSource
|
||||
copy_panel = source.copy()
|
||||
self.trading_environment.write_data(
|
||||
equities_identifiers=source.items)
|
||||
copy_panel.items = self.asset_finder.map_identifier_index_to_sids(
|
||||
source.items, source.major_axis[0]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user