MAINT: Allow algo.run() to write to db

This commit is contained in:
Stewart Douglas
2015-08-14 14:39:41 -04:00
committed by jfkirk
parent 1ef2274d11
commit 7be2cf8652
2 changed files with 35 additions and 11 deletions
+30 -10
View File
@@ -489,7 +489,8 @@ class TestTransformAlgorithm(TestCase):
def setUp(self):
setup_logger(self)
self.sim_params = factory.create_simulation_parameters(num_days=4)
trading.environment.write_data(equities_identifiers=[0, 1, 133])
trading.environment = trading.TradingEnvironment()
trading.environment.write_data(equities_identifiers=[133])
trade_history = factory.create_trade_history(
133,
@@ -499,7 +500,6 @@ class TestTransformAlgorithm(TestCase):
self.sim_params
)
self.source = SpecificEquityTrades(event_list=trade_history)
self.df_source, self.df = \
factory.create_test_df_source(self.sim_params)
@@ -526,13 +526,14 @@ class TestTransformAlgorithm(TestCase):
algo.run(self.source)
def test_multi_source_as_input(self):
trading.environment.write_data(equities_identifiers=[0, 1])
sim_params = SimulationParameters(
self.df.index[0],
self.df.index[-1]
)
algo = TestRegisterTransformAlgorithm(
sim_params=sim_params,
sids=[0, 1, 133]
sids=[0, 1]
)
algo.run([self.source, self.df_source], overwrite_sim_params=False)
self.assertEqual(len(algo.sources), 2)
@@ -540,7 +541,6 @@ class TestTransformAlgorithm(TestCase):
def test_df_as_input(self):
algo = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
)
algo.run(self.df)
assert isinstance(algo.sources[0], DataFrameSource)
@@ -553,13 +553,23 @@ class TestTransformAlgorithm(TestCase):
assert isinstance(algo.sources[0], DataPanelSource)
def test_run_twice(self):
algo = TestRegisterTransformAlgorithm(
algo1 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
)
res1 = algo.run(self.df)
res2 = algo.run(self.df)
res1 = algo1.run(self.df)
# Create a new trading environment
trading.environment = trading.TradingEnvironment()
# Create a new trading algorithm, which will
# use the newly instantiated environment.
algo2 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
)
res2 = algo2.run(self.df)
np.testing.assert_array_equal(res1, res2)
@@ -617,6 +627,7 @@ class TestTransformAlgorithm(TestCase):
'order_target_value']
for name in method_names_to_test:
trading.environment = trading.TradingEnvironment()
algo = TestOrderStyleForwardingAlgorithm(
sim_params=self.sim_params,
instant_fill=False,
@@ -630,6 +641,7 @@ class TestTransformAlgorithm(TestCase):
algo.run(self.df)
def test_minute_data(self):
trading.environment.write_data(equities_identifiers=[0, 1])
source = RandomWalkSource(freq='minute',
start=pd.Timestamp('2000-1-3',
tz='UTC'),
@@ -646,7 +658,7 @@ class TestPositions(TestCase):
def setUp(self):
setup_logger(self)
self.sim_params = factory.create_simulation_parameters(num_days=4)
trading.environment.write_data(equities_identifiers=[0, 1, 133])
trading.environment.write_data(equities_identifiers=[1, 133])
trade_history = factory.create_trade_history(
1,
@@ -691,10 +703,11 @@ class TestPositions(TestCase):
class TestAlgoScript(TestCase):
def setUp(self):
days = 251
# Note that create_simulation_parameters creates a new TradingEnvironment
# Note that create_simulation_parameters creates
# a new TradingEnvironment
self.sim_params = factory.create_simulation_parameters(num_days=days)
setup_logger(self)
trading.environment.write_data(equities_identifiers=[0, 1, 133])
trading.environment.write_data(equities_identifiers=[1, 133])
trade_history = factory.create_trade_history(
133,
[10.0] * days,
@@ -755,6 +768,7 @@ class TestAlgoScript(TestCase):
def test_fixed_slippage(self):
# verify order -> transaction -> portfolio position.
# --------------
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script="""
from zipline.api import (slippage,
@@ -809,6 +823,7 @@ def handle_data(context, data):
def test_volshare_slippage(self):
# verify order -> transaction -> portfolio position.
# --------------
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script="""
from zipline.api import *
@@ -876,6 +891,7 @@ def handle_data(context, data):
self.zipline_test_config['algorithm'] = test_algo
self.zipline_test_config['trade_count'] = 200
trading.environment.write_data(equities_identifiers=[0])
zipline = simfactory.create_test_zipline(
**self.zipline_test_config)
output, _ = drain_zipline(self, zipline)
@@ -902,6 +918,7 @@ def handle_data(context, data):
test_algo.record(foo=MagicMock())
def _algo_record_float_magic_should_pass(self, var_type):
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script=record_float_magic % var_type,
sim_params=self.sim_params,
@@ -928,6 +945,7 @@ def handle_data(context, data):
Only test that order methods can be called without error.
Correct filling of orders is tested in zipline.
"""
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script=call_all_order_methods,
sim_params=self.sim_params,
@@ -959,6 +977,7 @@ def handle_data(context, data):
"""
Test that accessing portfolio in init doesn't break.
"""
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script=access_portfolio_in_init,
sim_params=self.sim_params,
@@ -977,6 +996,7 @@ def handle_data(context, data):
"""
Test that accessing account in init doesn't break.
"""
trading.environment.write_data(equities_identifiers=[0])
test_algo = TradingAlgorithm(
script=access_account_in_init,
sim_params=self.sim_params,
+5 -1
View File
@@ -31,6 +31,7 @@ from six import (
)
from operator import attrgetter
from zipline.errors import (
AddTermPostInit,
OrderDuringInitialize,
@@ -192,7 +193,6 @@ class TradingAlgorithm(object):
# set the capital base
self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE)
self.sim_params = kwargs.pop('sim_params', None)
if self.sim_params is None:
self.sim_params = create_simulation_parameters(
@@ -500,6 +500,8 @@ class TradingAlgorithm(object):
# if DataFrame provided, map columns to sids and wrap
# in DataFrameSource
copy_frame = source.copy()
self.trading_environment.write_data(
equities_identifiers=source.columns)
copy_frame.columns = \
self.asset_finder.map_identifier_index_to_sids(
source.columns, source.index[0]
@@ -510,6 +512,8 @@ class TradingAlgorithm(object):
# If Panel provided, map items to sids and wrap
# in DataPanelSource
copy_panel = source.copy()
self.trading_environment.write_data(
equities_identifiers=source.items)
copy_panel.items = self.asset_finder.map_identifier_index_to_sids(
source.items, source.major_axis[0]
)