From dc964a7e7da7e975b22fabfe962d2210704c1d6b Mon Sep 17 00:00:00 2001 From: jfkirk Date: Mon, 24 Aug 2015 16:04:10 -0400 Subject: [PATCH] MAINT: Removes the ability to reference a global TradingEnvironment This commit removes the ability to reference a shared TradingEnvironment through the zipline.finance.trading module. In place, the classes that require a TradingEnvironment, or its child AssetFinder, contain their own references to those objects. This commit also adds serialization utilities that allow for the pickling/unpickling of objects without unintentionally their TradingEnvironments or AssetFinders. --- tests/history_cases.py | 41 +- .../test_us_equity_pricing_loader.py | 11 +- tests/risk/test_risk_cumulative.py | 12 +- tests/risk/test_risk_period.py | 45 +- tests/serialization_cases.py | 34 +- tests/test_algorithm.py | 451 ++++++++++++------ tests/test_algorithm_gen.py | 32 +- tests/test_assets.py | 76 +-- tests/test_batchtransform.py | 69 ++- tests/test_events_through_risk.py | 320 ++++++------- tests/test_examples.py | 4 - tests/test_exception_handling.py | 18 +- tests/test_finance.py | 27 +- tests/test_history.py | 62 ++- tests/test_perf_tracking.py | 311 ++++++------ tests/test_pickle_serialization.py | 11 +- tests/test_rolling_panel.py | 21 +- tests/test_security_list.py | 146 +++--- tests/test_serialization.py | 2 +- tests/test_sources.py | 12 +- tests/test_transforms.py | 5 +- tests/test_transforms_talib.py | 12 +- tests/utils/test_events.py | 57 ++- zipline/algorithm.py | 73 ++- zipline/assets/assets.py | 8 +- zipline/errors.py | 12 + zipline/finance/performance/period.py | 25 +- .../finance/performance/position_tracker.py | 62 ++- zipline/finance/performance/tracker.py | 53 +- zipline/finance/risk/cumulative.py | 29 +- zipline/finance/risk/period.py | 36 +- zipline/finance/risk/report.py | 8 +- zipline/finance/risk/risk.py | 9 +- zipline/finance/trading.py | 120 +---- zipline/gens/tradesimulation.py | 10 +- zipline/history/history.py | 71 ++- zipline/history/history_container.py | 58 +-- zipline/protocol.py | 5 +- zipline/sources/test_source.py | 15 +- zipline/transforms/batch_transform.py | 16 +- zipline/utils/events.py | 107 ++--- zipline/utils/factory.py | 83 ++-- zipline/utils/security_list.py | 15 +- zipline/utils/serialization_utils.py | 60 +++ zipline/utils/simfactory.py | 3 +- 45 files changed, 1484 insertions(+), 1173 deletions(-) diff --git a/tests/history_cases.py b/tests/history_cases.py index ef11dcd9..1db3560f 100644 --- a/tests/history_cases.py +++ b/tests/history_cases.py @@ -10,18 +10,19 @@ from zipline.history.history import HistorySpec from zipline.protocol import BarData from zipline.utils.test_utils import to_utc +_cases_env = TradingEnvironment() + def mixed_frequency_expected_index(count, frequency): """ Helper for enumerating expected indices for test_mixed_frequency. """ - env = TradingEnvironment.instance() minute = MIXED_FREQUENCY_MINUTES[count] if frequency == '1d': - return [env.previous_open_and_close(minute)[1], minute] + return [_cases_env.previous_open_and_close(minute)[1], minute] elif frequency == '1m': - return [env.previous_market_minute(minute), minute] + return [_cases_env.previous_market_minute(minute), minute] def mixed_frequency_expected_data(count, frequency): @@ -41,32 +42,36 @@ def mixed_frequency_expected_data(count, frequency): return [count - 1, count] -MIXED_FREQUENCY_MINUTES = TradingEnvironment.instance().market_minute_window( +MIXED_FREQUENCY_MINUTES = _cases_env.market_minute_window( to_utc('2013-07-03 9:31AM'), 600, ) ONE_MINUTE_PRICE_ONLY_SPECS = [ - HistorySpec(1, '1m', 'price', True, data_frequency='minute'), + HistorySpec(1, '1m', 'price', True, _cases_env, data_frequency='minute'), ] DAILY_OPEN_CLOSE_SPECS = [ - HistorySpec(3, '1d', 'open_price', False, data_frequency='minute'), - HistorySpec(3, '1d', 'close_price', False, data_frequency='minute'), + HistorySpec(3, '1d', 'open_price', False, _cases_env, + data_frequency='minute'), + HistorySpec(3, '1d', 'close_price', False, _cases_env, + data_frequency='minute'), ] ILLIQUID_PRICES_SPECS = [ - HistorySpec(3, '1m', 'price', False, data_frequency='minute'), - HistorySpec(5, '1m', 'price', True, data_frequency='minute'), + HistorySpec(3, '1m', 'price', False, _cases_env, data_frequency='minute'), + HistorySpec(5, '1m', 'price', True, _cases_env, data_frequency='minute'), ] MIXED_FREQUENCY_SPECS = [ - HistorySpec(1, '1m', 'price', False, data_frequency='minute'), - HistorySpec(2, '1m', 'price', False, data_frequency='minute'), - HistorySpec(2, '1d', 'price', False, data_frequency='minute'), + HistorySpec(1, '1m', 'price', False, _cases_env, data_frequency='minute'), + HistorySpec(2, '1m', 'price', False, _cases_env, data_frequency='minute'), + HistorySpec(2, '1d', 'price', False, _cases_env, data_frequency='minute'), ] MIXED_FIELDS_SPECS = [ - HistorySpec(3, '1m', 'price', True, data_frequency='minute'), - HistorySpec(3, '1m', 'open_price', True, data_frequency='minute'), - HistorySpec(3, '1m', 'close_price', True, data_frequency='minute'), - HistorySpec(3, '1m', 'high', True, data_frequency='minute'), - HistorySpec(3, '1m', 'low', True, data_frequency='minute'), - HistorySpec(3, '1m', 'volume', True, data_frequency='minute'), + HistorySpec(3, '1m', 'price', True, _cases_env, data_frequency='minute'), + HistorySpec(3, '1m', 'open_price', True, _cases_env, + data_frequency='minute'), + HistorySpec(3, '1m', 'close_price', True, _cases_env, + data_frequency='minute'), + HistorySpec(3, '1m', 'high', True, _cases_env, data_frequency='minute'), + HistorySpec(3, '1m', 'low', True, _cases_env, data_frequency='minute'), + HistorySpec(3, '1m', 'volume', True, _cases_env, data_frequency='minute'), ] diff --git a/tests/modelling/test_us_equity_pricing_loader.py b/tests/modelling/test_us_equity_pricing_loader.py index 0aa4f929..f9f0de00 100644 --- a/tests/modelling/test_us_equity_pricing_loader.py +++ b/tests/modelling/test_us_equity_pricing_loader.py @@ -96,13 +96,16 @@ TEST_QUERY_ASSETS = EQUITY_INFO.index class BcolzDailyBarTestCase(TestCase): - def setUp(self): - all_trading_days = TradingEnvironment.instance().trading_days - self.trading_days = all_trading_days[ + @classmethod + def setUpClass(cls): + all_trading_days = TradingEnvironment().trading_days + cls.trading_days = all_trading_days[ all_trading_days.get_loc(TEST_CALENDAR_START): all_trading_days.get_loc(TEST_CALENDAR_STOP) + 1 ] + def setUp(self): + self.asset_info = EQUITY_INFO self.writer = SyntheticDailyBarWriter( self.asset_info, @@ -401,7 +404,7 @@ class USEquityPricingLoaderTestCase(TestCase): writer.write(SPLITS, MERGERS, DIVIDENDS) cls.assets = TEST_QUERY_ASSETS - all_days = TradingEnvironment.instance().trading_days + all_days = TradingEnvironment().trading_days cls.calendar_days = all_days[ all_days.slice_indexer(TEST_CALENDAR_START, TEST_CALENDAR_STOP) ] diff --git a/tests/risk/test_risk_cumulative.py b/tests/risk/test_risk_cumulative.py index 856883d1..578c72c6 100644 --- a/tests/risk/test_risk_cumulative.py +++ b/tests/risk/test_risk_cumulative.py @@ -21,7 +21,7 @@ import pytz import zipline.finance.risk as risk from zipline.utils import factory -from zipline.finance.trading import SimulationParameters +from zipline.finance.trading import SimulationParameters, TradingEnvironment from . import answer_key ANSWER_KEY = answer_key.ANSWER_KEY @@ -29,6 +29,10 @@ ANSWER_KEY = answer_key.ANSWER_KEY class TestRisk(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + def setUp(self): start_date = datetime.datetime( year=2006, @@ -42,7 +46,8 @@ class TestRisk(unittest.TestCase): self.sim_params = SimulationParameters( period_start=start_date, - period_end=end_date + period_end=end_date, + env=self.env, ) self.algo_returns_06 = factory.create_returns_from_list( @@ -51,7 +56,8 @@ class TestRisk(unittest.TestCase): ) self.cumulative_metrics_06 = risk.RiskMetricsCumulative( - self.sim_params) + self.sim_params, env=self.env + ) for dt, returns in answer_key.RETURNS_DATA.iterrows(): self.cumulative_metrics_06.update(dt, diff --git a/tests/risk/test_risk_period.py b/tests/risk/test_risk_period.py index f3f18735..83c73430 100644 --- a/tests/risk/test_risk_period.py +++ b/tests/risk/test_risk_period.py @@ -21,7 +21,7 @@ import pytz import zipline.finance.risk as risk from zipline.utils import factory -from zipline.finance.trading import SimulationParameters +from zipline.finance.trading import SimulationParameters, TradingEnvironment from . import answer_key from . answer_key import AnswerKey @@ -33,6 +33,10 @@ RETURNS = ANSWER_KEY.RETURNS class TestRisk(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + def setUp(self): start_date = datetime.datetime( @@ -47,7 +51,8 @@ class TestRisk(unittest.TestCase): self.sim_params = SimulationParameters( period_start=start_date, - period_end=end_date + period_end=end_date, + env=self.env, ) self.algo_returns_06 = factory.create_returns_from_list( @@ -61,7 +66,8 @@ class TestRisk(unittest.TestCase): self.metrics_06 = risk.RiskReport( self.algo_returns_06, self.sim_params, - benchmark_returns=self.benchmark_returns_06 + benchmark_returns=self.benchmark_returns_06, + env=self.env, ) start_08 = datetime.datetime( @@ -80,7 +86,8 @@ class TestRisk(unittest.TestCase): ) self.sim_params08 = SimulationParameters( period_start=start_08, - period_end=end_08 + period_end=end_08, + env=self.env, ) def tearDown(self): @@ -97,9 +104,13 @@ class TestRisk(unittest.TestCase): returns = factory.create_returns_from_list( [1.0, -0.5, 0.8, .17, 1.0, -0.1, -0.45], self.sim_params) # 200, 100, 180, 210.6, 421.2, 379.8, 208.494 - metrics = risk.RiskMetricsPeriod(returns.index[0], - returns.index[-1], - returns) + metrics = risk.RiskMetricsPeriod( + returns.index[0], + returns.index[-1], + returns, + env=self.env, + benchmark_returns=self.env.benchmark_returns, + ) self.assertEqual(metrics.max_drawdown, 0.505) def test_benchmark_returns_06(self): @@ -123,7 +134,7 @@ class TestRisk(unittest.TestCase): def test_trading_days_06(self): returns = factory.create_returns_from_range(self.sim_params) - metrics = risk.RiskReport(returns, self.sim_params) + metrics = risk.RiskReport(returns, self.sim_params, env=self.env) self.assertEqual([x.num_trading_days for x in metrics.year_periods], [251]) self.assertEqual([x.num_trading_days for x in metrics.month_periods], @@ -347,7 +358,7 @@ class TestRisk(unittest.TestCase): def test_benchmark_returns_08(self): returns = factory.create_returns_from_range(self.sim_params08) - metrics = risk.RiskReport(returns, self.sim_params08) + metrics = risk.RiskReport(returns, self.sim_params08, env=self.env) self.assertEqual([round(x.benchmark_period_returns, 3) for x in metrics.month_periods], @@ -393,7 +404,7 @@ class TestRisk(unittest.TestCase): def test_trading_days_08(self): returns = factory.create_returns_from_range(self.sim_params08) - metrics = risk.RiskReport(returns, self.sim_params08) + metrics = risk.RiskReport(returns, self.sim_params08, env=self.env) self.assertEqual([x.num_trading_days for x in metrics.year_periods], [253]) @@ -402,7 +413,7 @@ class TestRisk(unittest.TestCase): def test_benchmark_volatility_08(self): returns = factory.create_returns_from_range(self.sim_params08) - metrics = risk.RiskReport(returns, self.sim_params08) + metrics = risk.RiskReport(returns, self.sim_params08, env=self.env) self.assertEqual([round(x.benchmark_volatility, 3) for x in metrics.month_periods], @@ -450,7 +461,7 @@ class TestRisk(unittest.TestCase): def test_treasury_returns_06(self): returns = factory.create_returns_from_range(self.sim_params) - metrics = risk.RiskReport(returns, self.sim_params) + metrics = risk.RiskReport(returns, self.sim_params, env=self.env) self.assertEqual([round(x.treasury_period_return, 4) for x in metrics.month_periods], [0.0037, @@ -513,22 +524,24 @@ class TestRisk(unittest.TestCase): end = start + datetime.timedelta(days=total_days) sim_params90s = SimulationParameters( period_start=start, - period_end=end + period_end=end, + env=self.env, ) returns = factory.create_returns_from_range(sim_params90s) returns = returns[:-10] # truncate the returns series to end mid-month - metrics = risk.RiskReport(returns, sim_params90s) + metrics = risk.RiskReport(returns, sim_params90s, env=self.env) total_months = 60 self.check_metrics(metrics, total_months, start) def check_year_range(self, start_date, years): sim_params = SimulationParameters( period_start=start_date, - period_end=start_date.replace(year=(start_date.year + years)) + period_end=start_date.replace(year=(start_date.year + years)), + env=self.env, ) returns = factory.create_returns_from_range(sim_params) - metrics = risk.RiskReport(returns, self.sim_params) + metrics = risk.RiskReport(returns, self.sim_params, env=self.env) total_months = years * 12 self.check_metrics(metrics, total_months, start_date) diff --git a/tests/serialization_cases.py b/tests/serialization_cases.py index af71b065..2a6bd094 100644 --- a/tests/serialization_cases.py +++ b/tests/serialization_cases.py @@ -23,7 +23,7 @@ from zipline.protocol import Account from zipline.protocol import Portfolio from zipline.protocol import Position as ProtocolPosition -from zipline.finance.trading import SimulationParameters +from zipline.finance.trading import SimulationParameters, TradingEnvironment from zipline.utils import factory @@ -41,17 +41,19 @@ def stringify_cases(cases, func=None): results.append(new_case) return results - +cases_env = TradingEnvironment() sim_params_daily = SimulationParameters( datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC), datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC), 10000, - emission_rate='daily') + emission_rate='daily', + env=cases_env) sim_params_minute = SimulationParameters( datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC), datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC), 10000, - emission_rate='minute') + emission_rate='minute', + env=cases_env) returns = factory.create_returns_from_list( [1.0], sim_params_daily) @@ -65,14 +67,17 @@ def object_serialization_cases(skip_daily=False): (PerTrade, (), {}, 'dict'), (PerDollar, (), {}, 'dict'), (PerformancePeriod, - (10000,), {'position_tracker': PositionTracker()}, 'to_dict'), + (10000, cases_env.asset_finder), + {'position_tracker': PositionTracker(cases_env.asset_finder)}, + 'to_dict'), (Position, (8554,), {}, 'dict'), - (PositionTracker, (), {}, 'dict'), - (PerformanceTracker, (sim_params_minute,), {}, 'to_dict'), - (RiskMetricsCumulative, (sim_params_minute,), {}, 'to_dict'), + (PositionTracker, (cases_env.asset_finder,), {}, 'dict'), + (PerformanceTracker, (sim_params_minute, cases_env), {}, 'to_dict'), + (RiskMetricsCumulative, (sim_params_minute, cases_env), {}, 'to_dict'), (RiskMetricsPeriod, - (returns.index[0], returns.index[0], returns), {}, 'to_dict'), - (RiskReport, (returns, sim_params_minute), {}, 'to_dict'), + (returns.index[0], returns.index[0], returns, cases_env), + {}, 'to_dict'), + (RiskReport, (returns, sim_params_minute, cases_env), {}, 'to_dict'), (FixedSlippage, (), {}, 'dict'), (Transaction, (8554, 10, datetime.datetime(2013, 6, 19), 100, "0000"), {}, @@ -85,9 +90,12 @@ def object_serialization_cases(skip_daily=False): if not skip_daily: cases.extend([ - (PerformanceTracker, (sim_params_daily,), {}, 'to_dict'), - (RiskMetricsCumulative, (sim_params_daily,), {}, 'to_dict'), - (RiskReport, (returns, sim_params_daily), {}, 'to_dict'), + (PerformanceTracker, + (sim_params_daily, cases_env), {}, 'to_dict'), + (RiskMetricsCumulative, + (sim_params_daily, cases_env), {}, 'to_dict'), + (RiskReport, + (returns, sim_params_daily, cases_env), {}, 'to_dict'), ]) return stringify_cases(cases) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index e27b1a6c..505d7cb7 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -93,7 +93,6 @@ from zipline.sources import (SpecificEquityTrades, from zipline.assets import Equity from zipline.finance.execution import LimitOrder -from zipline.finance import trading from zipline.finance.trading import SimulationParameters from zipline.utils.api_support import set_algo_instance from zipline.utils.events import DateRuleFactory, TimeRuleFactory @@ -107,24 +106,31 @@ _multiprocess_can_split_ = False class TestRecordAlgorithm(TestCase): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[133]) + def setUp(self): - self.sim_params = factory.create_simulation_parameters(num_days=4) - trading.environment.write_data(equities_identifiers=[133]) + self.sim_params = factory.create_simulation_parameters(num_days=4, + env=self.env) trade_history = factory.create_trade_history( 133, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - self.sim_params + self.sim_params, + self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) self.df_source, self.df = \ - factory.create_test_df_source(self.sim_params) + factory.create_test_df_source(self.sim_params, self.env) def test_record_incr(self): - algo = RecordAlgorithm( - sim_params=self.sim_params) + algo = RecordAlgorithm(sim_params=self.sim_params, env=self.env) output = algo.run(self.source) np.testing.assert_array_equal(output['incr'].values, @@ -138,19 +144,68 @@ class TestRecordAlgorithm(TestCase): class TestMiscellaneousAPI(TestCase): + + @classmethod + def setUpClass(cls): + cls.sids = [1, 2] + cls.env = TradingEnvironment() + + metadata = {3: {'symbol': 'PLAY', + 'asset_type': 'equity', + 'start_date': '2002-01-01', + 'end_date': '2004-01-01'}, + 4: {'symbol': 'PLAY', + 'asset_type': 'equity', + 'start_date': '2005-01-01', + 'end_date': '2006-01-01'}} + + futures_metadata = { + 5: { + 'symbol': 'CLG06', + 'root_symbol': 'CL', + 'asset_type': 'future', + 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), + 'notice_date': pd.Timestamp('2005-12-20', tz='UTC'), + 'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')}, + 6: { + 'root_symbol': 'CL', + 'symbol': 'CLK06', + 'asset_type': 'future', + 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), + 'notice_date': pd.Timestamp('2006-03-20', tz='UTC'), + 'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')}, + 7: { + 'symbol': 'CLQ06', + 'root_symbol': 'CL', + 'asset_type': 'future', + 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), + 'notice_date': pd.Timestamp('2006-06-20', tz='UTC'), + 'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')}, + 8: { + 'symbol': 'CLX06', + 'root_symbol': 'CL', + 'asset_type': 'future', + 'start_date': pd.Timestamp('2006-02-01', tz='UTC'), + 'notice_date': pd.Timestamp('2006-09-20', tz='UTC'), + 'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')} + } + cls.env.write_data(equities_identifiers=cls.sids, + equities_data=metadata, + futures_data=futures_metadata) + def setUp(self): setup_logger(self) - sids = [1, 2] self.sim_params = factory.create_simulation_parameters( num_days=2, data_frequency='minute', emission_rate='minute', + env=self.env, ) - trading.environment.write_data(equities_identifiers=sids) self.source = factory.create_minutely_trade_source( - sids, + self.sids, sim_params=self.sim_params, concurrent=True, + env=self.env, ) def tearDown(self): @@ -195,7 +250,8 @@ class TestMiscellaneousAPI(TestCase): algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data, - sim_params=self.sim_params) + sim_params=self.sim_params, + env=self.env) algo.run(self.source) def test_get_open_orders(self): @@ -245,7 +301,8 @@ class TestMiscellaneousAPI(TestCase): algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data, - sim_params=self.sim_params) + sim_params=self.sim_params, + env=self.env) algo.run(self.source) def test_schedule_function(self): @@ -281,6 +338,7 @@ class TestMiscellaneousAPI(TestCase): initialize=initialize, handle_data=handle_data, sim_params=self.sim_params, + env=self.env, ) algo.run(self.source) @@ -296,7 +354,10 @@ class TestMiscellaneousAPI(TestCase): self.sim_params.data_frequency = mode algo = TradingAlgorithm( - initialize=nop, handle_data=nop, sim_params=self.sim_params, + initialize=nop, + handle_data=nop, + sim_params=self.sim_params, + env=self.env, ) # Schedule something for NOT Always. @@ -324,16 +385,7 @@ class TestMiscellaneousAPI(TestCase): def test_asset_lookup(self): - trading.environment = trading.TradingEnvironment() - metadata = {0: {'symbol': 'PLAY', - 'asset_type': 'equity', - 'start_date': '2002-01-01', - 'end_date': '2004-01-01'}, - 1: {'symbol': 'PLAY', - 'asset_type': 'equity', - 'start_date': '2005-01-01', - 'end_date': '2006-01-01'}} - algo = TradingAlgorithm(asset_metadata=metadata) + algo = TradingAlgorithm(env=self.env) # Test before either PLAY existed algo.sim_params.period_end = pd.Timestamp('2001-12-01', tz='UTC') @@ -345,63 +397,30 @@ class TestMiscellaneousAPI(TestCase): # Test when first PLAY exists algo.sim_params.period_end = pd.Timestamp('2002-12-01', tz='UTC') list_result = algo.symbols('PLAY') - self.assertEqual(0, list_result[0]) + self.assertEqual(3, list_result[0]) # Test after first PLAY ends algo.sim_params.period_end = pd.Timestamp('2004-12-01', tz='UTC') - self.assertEqual(0, algo.symbol('PLAY')) + self.assertEqual(3, algo.symbol('PLAY')) # Test after second PLAY begins algo.sim_params.period_end = pd.Timestamp('2005-12-01', tz='UTC') - self.assertEqual(1, algo.symbol('PLAY')) + self.assertEqual(4, algo.symbol('PLAY')) # Test after second PLAY ends algo.sim_params.period_end = pd.Timestamp('2006-12-01', tz='UTC') - self.assertEqual(1, algo.symbol('PLAY')) + self.assertEqual(4, algo.symbol('PLAY')) list_result = algo.symbols('PLAY') - self.assertEqual(1, list_result[0]) + self.assertEqual(4, list_result[0]) # Test lookup SID - self.assertIsInstance(algo.sid(0), Equity) - self.assertIsInstance(algo.sid(1), Equity) + self.assertIsInstance(algo.sid(3), Equity) + self.assertIsInstance(algo.sid(4), Equity) def test_future_chain(self): """ Tests the future_chain API function. """ - trading.environment = trading.TradingEnvironment() - metadata = { - 0: { - 'symbol': 'CLG06', - 'root_symbol': 'CL', - 'asset_type': 'future', - 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), - 'notice_date': pd.Timestamp('2005-12-20', tz='UTC'), - 'expiration_date': pd.Timestamp('2006-01-20', tz='UTC')}, - 1: { - 'root_symbol': 'CL', - 'symbol': 'CLK06', - 'asset_type': 'future', - 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), - 'notice_date': pd.Timestamp('2006-03-20', tz='UTC'), - 'expiration_date': pd.Timestamp('2006-04-20', tz='UTC')}, - 2: { - 'symbol': 'CLQ06', - 'root_symbol': 'CL', - 'asset_type': 'future', - 'start_date': pd.Timestamp('2005-12-01', tz='UTC'), - 'notice_date': pd.Timestamp('2006-06-20', tz='UTC'), - 'expiration_date': pd.Timestamp('2006-07-20', tz='UTC')}, - 3: { - 'symbol': 'CLX06', - 'root_symbol': 'CL', - 'asset_type': 'future', - 'start_date': pd.Timestamp('2006-02-01', tz='UTC'), - 'notice_date': pd.Timestamp('2006-09-20', tz='UTC'), - 'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')} - } - - trading.environment.write_data(futures_data=metadata) - algo = TradingAlgorithm() + algo = TradingAlgorithm(env=self.env) algo.datetime = pd.Timestamp('2006-12-01', tz='UTC') # Check that the fields of the FutureChain object are set correctly @@ -485,25 +504,37 @@ class TestMiscellaneousAPI(TestCase): class TestTransformAlgorithm(TestCase): + + @classmethod + def setUpClass(cls): + futures_metadata = {3: {'asset_type': 'future', + 'contract_multiplier': 10}} + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[0, 1, 133], + futures_data=futures_metadata) + def setUp(self): setup_logger(self) - self.sim_params = factory.create_simulation_parameters(num_days=4) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[133]) + self.sim_params = factory.create_simulation_parameters(num_days=4, + env=self.env) trade_history = factory.create_trade_history( 133, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - self.sim_params + self.sim_params, + self.env + ) + self.source = SpecificEquityTrades( + event_list=trade_history, + env=self.env, ) - self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ - factory.create_test_df_source(self.sim_params) + factory.create_test_df_source(self.sim_params, self.env) self.panel_source, self.panel = \ - factory.create_test_panel_source(self.sim_params) + factory.create_test_panel_source(self.sim_params, self.env) def tearDown(self): teardown_logger(self) @@ -511,6 +542,7 @@ class TestTransformAlgorithm(TestCase): def test_source_as_input(self): algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, + env=self.env, sids=[133] ) algo.run(self.source) @@ -520,19 +552,21 @@ class TestTransformAlgorithm(TestCase): def test_invalid_order_parameters(self): algo = InvalidOrderAlgorithm( sids=[133], - sim_params=self.sim_params + sim_params=self.sim_params, + env=self.env, ) algo.run(self.source) def test_multi_source_as_input(self): - trading.environment.write_data(equities_identifiers=[0, 1]) sim_params = SimulationParameters( self.df.index[0], - self.df.index[-1] + self.df.index[-1], + env=self.env, ) algo = TestRegisterTransformAlgorithm( sim_params=sim_params, - sids=[0, 1] + sids=[0, 1], + env=self.env, ) algo.run([self.source, self.df_source], overwrite_sim_params=False) self.assertEqual(len(algo.sources), 2) @@ -540,6 +574,7 @@ class TestTransformAlgorithm(TestCase): def test_df_as_input(self): algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, + env=self.env, ) algo.run(self.df) assert isinstance(algo.sources[0], DataFrameSource) @@ -547,6 +582,7 @@ class TestTransformAlgorithm(TestCase): def test_panel_as_input(self): algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, + env=self.env, sids=[0, 1]) algo.run(self.panel) assert isinstance(algo.sources[0], DataPanelSource) @@ -559,8 +595,6 @@ class TestTransformAlgorithm(TestCase): res1 = algo1.run(self.df) - # Create a new trading environment - trading.environment = trading.TradingEnvironment() # Create a new trading algorithm, which will # use the newly instantiated environment. algo2 = TestRegisterTransformAlgorithm( @@ -576,12 +610,14 @@ class TestTransformAlgorithm(TestCase): self.sim_params.data_frequency = 'daily' algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, + env=self.env, ) self.assertEqual(algo.sim_params.data_frequency, 'daily') self.sim_params.data_frequency = 'minute' algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, + env=self.env, ) self.assertEqual(algo.sim_params.data_frequency, 'minute') @@ -596,6 +632,7 @@ class TestTransformAlgorithm(TestCase): def test_order_methods(self, algo_class): algo = algo_class( sim_params=self.sim_params, + env=self.env, ) algo.run(self.df) @@ -607,12 +644,9 @@ class TestTransformAlgorithm(TestCase): (TestTargetValueAlgorithm,), ]) def test_order_methods_for_future(self, algo_class): - # Use sid not already in test database. - metadata = {3: {'asset_type': 'future', - 'contract_multiplier': 10}} algo = algo_class( sim_params=self.sim_params, - asset_metadata=metadata + env=self.env, ) algo.run(self.df) @@ -626,7 +660,8 @@ class TestTransformAlgorithm(TestCase): 'order_target_value'] for name in method_names_to_test: - trading.environment = trading.TradingEnvironment() + # Don't supply an env so the TradingAlgorithm builds a new one for + # each method algo = TestOrderStyleForwardingAlgorithm( sim_params=self.sim_params, instant_fill=False, @@ -636,11 +671,11 @@ class TestTransformAlgorithm(TestCase): def test_order_instant(self): algo = TestOrderInstantAlgorithm(sim_params=self.sim_params, + env=self.env, instant_fill=True) algo.run(self.df) def test_minute_data(self): - trading.environment.write_data(equities_identifiers=[0, 1]) source = RandomWalkSource(freq='minute', start=pd.Timestamp('2000-1-3', tz='UTC'), @@ -648,6 +683,7 @@ class TestTransformAlgorithm(TestCase): tz='UTC')) self.sim_params.data_frequency = 'minute' algo = TestOrderInstantAlgorithm(sim_params=self.sim_params, + env=self.env, instant_fill=True) algo.run(source) @@ -656,26 +692,33 @@ class TestPositions(TestCase): def setUp(self): setup_logger(self) - self.sim_params = factory.create_simulation_parameters(num_days=4) - trading.environment.write_data(equities_identifiers=[1, 133]) + self.env = TradingEnvironment() + self.sim_params = factory.create_simulation_parameters(num_days=4, + env=self.env) + self.env.write_data(equities_identifiers=[1, 133]) trade_history = factory.create_trade_history( 1, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - self.sim_params + self.sim_params, + self.env + ) + self.source = SpecificEquityTrades( + event_list=trade_history, + env=self.env, ) - self.source = SpecificEquityTrades(event_list=trade_history) self.df_source, self.df = \ - factory.create_test_df_source(self.sim_params) + factory.create_test_df_source(self.sim_params, self.env) def tearDown(self): teardown_logger(self) def test_empty_portfolio(self): - algo = EmptyPositionsAlgorithm(sim_params=self.sim_params) + algo = EmptyPositionsAlgorithm(sim_params=self.sim_params, + env=self.env) daily_stats = algo.run(self.df) expected_position_count = [ @@ -691,7 +734,9 @@ class TestPositions(TestCase): def test_noop_orders(self): - algo = AmbitiousStopLimitAlgorithm(sid=1) + algo = AmbitiousStopLimitAlgorithm(sid=1, + sim_params=self.sim_params, + env=self.env) daily_stats = algo.run(self.source) # Verify that possitions are empty for all dates. @@ -700,25 +745,39 @@ class TestPositions(TestCase): class TestAlgoScript(TestCase): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data( + equities_identifiers=[0, 1, 133] + ) + def setUp(self): days = 251 # Note that create_simulation_parameters creates # a new TradingEnvironment - self.sim_params = factory.create_simulation_parameters(num_days=days) + self.sim_params = factory.create_simulation_parameters(num_days=days, + env=self.env) + setup_logger(self) - trading.environment.write_data(equities_identifiers=[1, 133]) trade_history = factory.create_trade_history( 133, [10.0] * days, [100] * days, timedelta(days=1), - self.sim_params + self.sim_params, + self.env + ) + + self.source = SpecificEquityTrades( + sids=[133], + event_list=trade_history, + env=self.env, ) - self.source = SpecificEquityTrades(sids=[133], - event_list=trade_history) self.df_source, self.df = \ - factory.create_test_df_source(self.sim_params) + factory.create_test_df_source(self.sim_params, self.env) self.zipline_test_config = { 'sid': 0, @@ -767,7 +826,6 @@ class TestAlgoScript(TestCase): def test_fixed_slippage(self): # verify order -> transaction -> portfolio position. # -------------- - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=""" from zipline.api import (slippage, @@ -792,6 +850,7 @@ def handle_data(context, data): context.incr += 1""", sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -822,7 +881,6 @@ def handle_data(context, data): def test_volshare_slippage(self): # verify order -> transaction -> portfolio position. # -------------- - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=""" from zipline.api import * @@ -848,6 +906,7 @@ def handle_data(context, data): context.incr += 1 """, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -884,13 +943,13 @@ def handle_data(context, data): test_algo = TradingAlgorithm( script=record_variables, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) self.zipline_test_config['algorithm'] = test_algo self.zipline_test_config['trade_count'] = 200 - trading.environment.write_data(equities_identifiers=[0]) zipline = simfactory.create_test_zipline( **self.zipline_test_config) output, _ = drain_zipline(self, zipline) @@ -917,10 +976,10 @@ def handle_data(context, data): test_algo.record(foo=MagicMock()) def _algo_record_float_magic_should_pass(self, var_type): - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=record_float_magic % var_type, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -944,10 +1003,10 @@ def handle_data(context, data): Only test that order methods can be called without error. Correct filling of orders is tested in zipline. """ - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=call_all_order_methods, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -968,6 +1027,7 @@ def handle_data(context, data): test_algo = TradingAlgorithm( script=call_order_in_init, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) test_algo.run(self.source) @@ -976,10 +1036,10 @@ def handle_data(context, data): """ Test that accessing portfolio in init doesn't break. """ - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=access_portfolio_in_init, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -995,10 +1055,10 @@ def handle_data(context, data): """ Test that accessing account in init doesn't break. """ - trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=access_account_in_init, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -1015,8 +1075,6 @@ class TestHistory(TestCase): def setUp(self): setup_logger(self) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[0, 1]) def tearDown(self): teardown_logger(self) @@ -1025,9 +1083,12 @@ class TestHistory(TestCase): def setUpClass(cls): cls._start = pd.Timestamp('1991-01-01', tz='UTC') cls._end = pd.Timestamp('1991-01-15', tz='UTC') + cls.env = TradingEnvironment() cls.sim_params = factory.create_simulation_parameters( data_frequency='minute', + env=cls.env ) + cls.env.write_data(equities_identifiers=[0, 1]) @property def source(self): @@ -1047,6 +1108,7 @@ def handle_data(context, data): algo = TradingAlgorithm( script=history_algo, sim_params=self.sim_params, + env=self.env, ) output = algo.run(self.source) self.assertIsNot(output, None) @@ -1059,6 +1121,7 @@ def handle_data(context, data): initialize=lambda _: None, handle_data=handle_data, sim_params=self.sim_params, + env=self.env, ) algo.run(self.source) @@ -1073,6 +1136,7 @@ def handle_data(context, data): initialize=lambda _: None, handle_data=handle_data, sim_params=self.sim_params, + env=self.env, ) algo.run(self.source) @@ -1082,10 +1146,13 @@ def handle_data(context, data): class TestGetDatetime(TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[0, 1]) + def setUp(self): setup_logger(self) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[0, 1]) def tearDown(self): teardown_logger(self) @@ -1128,11 +1195,13 @@ class TestGetDatetime(TestCase): end=end, ) sim_params = factory.create_simulation_parameters( - data_frequency='minute' + data_frequency='minute', + env=self.env, ) algo = TradingAlgorithm( script=algo, sim_params=sim_params, + env=self.env, ) algo.run(source) self.assertFalse(algo.first_bar) @@ -1140,19 +1209,28 @@ class TestGetDatetime(TestCase): class TestTradingControls(TestCase): + @classmethod + def setUpClass(cls): + cls.sid = 133 + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[cls.sid]) + def setUp(self): - self.sim_params = factory.create_simulation_parameters(num_days=4) - self.sid = 133 - trading.environment.write_data(equities_identifiers=[self.sid]) + self.sim_params = factory.create_simulation_parameters(num_days=4, + env=self.env) self.trade_history = factory.create_trade_history( self.sid, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - self.sim_params + self.sim_params, + self.env ) - self.source = SpecificEquityTrades(event_list=self.trade_history) + self.source = SpecificEquityTrades( + event_list=self.trade_history, + env=self.env, + ) def _check_algo(self, algo, @@ -1184,7 +1262,9 @@ class TestTradingControls(TestCase): algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, - max_notional=500.0) + max_notional=500.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data) # Buy three shares four times. Should bail on the fourth before it's @@ -1195,7 +1275,9 @@ class TestTradingControls(TestCase): algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, - max_notional=500.0) + max_notional=500.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 3) # Buy two shares four times. Should bail due to max_notional on the @@ -1206,7 +1288,9 @@ class TestTradingControls(TestCase): algo = SetMaxPositionSizeAlgorithm(sid=self.sid, max_shares=10, - max_notional=61.0) + max_notional=61.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 2) # Set the trading control to a different sid, then BUY ALL THE THINGS!. @@ -1216,7 +1300,9 @@ class TestTradingControls(TestCase): algo.order_count += 1 algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1, max_shares=10, - max_notional=61.0) + max_notional=61.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data) # Set the trading control sid to None, then BUY ALL THE THINGS!. Should @@ -1224,14 +1310,19 @@ class TestTradingControls(TestCase): def handle_data(algo, data): algo.order(algo.sid(self.sid), 10000) algo.order_count += 1 - algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0) + algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 0) def test_set_do_not_order_list(self): # set the restricted list to be the sid, and fail. algo = SetDoNotOrderListAlgorithm( sid=self.sid, - restricted_list=[self.sid]) + restricted_list=[self.sid], + sim_params=self.sim_params, + env=self.env, + ) def handle_data(algo, data): algo.order(algo.sid(self.sid), 100) @@ -1242,7 +1333,10 @@ class TestTradingControls(TestCase): # set the restricted list to exclude the sid, and succeed algo = SetDoNotOrderListAlgorithm( sid=self.sid, - restricted_list=[134, 135, 136]) + restricted_list=[134, 135, 136], + sim_params=self.sim_params, + env=self.env, + ) def handle_data(algo, data): algo.order(algo.sid(self.sid), 100) @@ -1258,7 +1352,9 @@ class TestTradingControls(TestCase): algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=10, - max_notional=500.0) + max_notional=500.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data) # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt @@ -1269,7 +1365,9 @@ class TestTradingControls(TestCase): algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=3, - max_notional=500.0) + max_notional=500.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 3) # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt @@ -1280,7 +1378,9 @@ class TestTradingControls(TestCase): algo = SetMaxOrderSizeAlgorithm(sid=self.sid, max_shares=10, - max_notional=40.0) + max_notional=40.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 3) # Set the trading control to a different sid, then BUY ALL THE THINGS!. @@ -1290,7 +1390,9 @@ class TestTradingControls(TestCase): algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1, max_shares=1, - max_notional=1.0) + max_notional=1.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data) # Set the trading control sid to None, then BUY ALL THE THINGS!. @@ -1300,7 +1402,9 @@ class TestTradingControls(TestCase): algo.order(algo.sid(self.sid), 10000) algo.order_count += 1 algo = SetMaxOrderSizeAlgorithm(max_shares=1, - max_notional=1.0) + max_notional=1.0, + sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 0) def test_set_max_order_count(self): @@ -1312,26 +1416,31 @@ class TestTradingControls(TestCase): [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(hours=6), - self.sim_params + self.sim_params, + self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) def handle_data(algo, data): for i in range(5): algo.order(algo.sid(self.sid), 1) algo.order_count += 1 - algo = SetMaxOrderCountAlgorithm(3) + algo = SetMaxOrderCountAlgorithm(3, sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 3) # Second call to handle_data is the same day as the first, so the last # order of the second call should fail. - algo = SetMaxOrderCountAlgorithm(9) + algo = SetMaxOrderCountAlgorithm(9, sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data, 9) # Only ten orders are placed per day, so this should pass even though # in total more than 20 orders are placed. - algo = SetMaxOrderCountAlgorithm(10) + algo = SetMaxOrderCountAlgorithm(10, sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data, order_count=20) def test_long_only(self): @@ -1339,7 +1448,7 @@ class TestTradingControls(TestCase): def handle_data(algo, data): algo.order(algo.sid(self.sid), -1) algo.order_count += 1 - algo = SetLongOnlyAlgorithm() + algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env) self.check_algo_fails(algo, handle_data, 0) # Buy on even days, sell on odd days. Never takes a short position, so @@ -1350,7 +1459,7 @@ class TestTradingControls(TestCase): else: algo.order(algo.sid(self.sid), -1) algo.order_count += 1 - algo = SetLongOnlyAlgorithm() + algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env) self.check_algo_succeeds(algo, handle_data) # Buy on first three days, then sell off holdings. Should succeed. @@ -1358,7 +1467,7 @@ class TestTradingControls(TestCase): amounts = [1, 1, 1, -3] algo.order(algo.sid(self.sid), amounts[algo.order_count]) algo.order_count += 1 - algo = SetLongOnlyAlgorithm() + algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env) self.check_algo_succeeds(algo, handle_data) # Buy on first three days, then sell off holdings plus an extra share. @@ -1367,7 +1476,7 @@ class TestTradingControls(TestCase): amounts = [1, 1, 1, -4] algo.order(algo.sid(self.sid), amounts[algo.order_count]) algo.order_count += 1 - algo = SetLongOnlyAlgorithm() + algo = SetLongOnlyAlgorithm(sim_params=self.sim_params, env=self.env) self.check_algo_fails(algo, handle_data, 3) def test_register_post_init(self): @@ -1387,59 +1496,80 @@ class TestTradingControls(TestCase): algo.set_long_only() algo = TradingAlgorithm(initialize=initialize, - handle_data=handle_data) + handle_data=handle_data, + sim_params=self.sim_params, + env=self.env) algo.run(self.source) self.source.rewind() def test_asset_date_bounds(self): # Run the algorithm with a sid that ends far in the future - df_source, _ = factory.create_test_df_source(self.sim_params) + temp_env = TradingEnvironment() + df_source, _ = factory.create_test_df_source(self.sim_params, temp_env) metadata = {0: {'start_date': '1990-01-01', 'end_date': '2020-01-01'}} algo = SetAssetDateBoundsAlgorithm( asset_metadata=metadata, - sim_params=self.sim_params,) + sim_params=self.sim_params, + env=temp_env, + ) algo.run(df_source) # Run the algorithm with a sid that has already ended - trading.environment = trading.TradingEnvironment() - df_source, _ = factory.create_test_df_source(self.sim_params) + temp_env = TradingEnvironment() + df_source, _ = factory.create_test_df_source(self.sim_params, temp_env) metadata = {0: {'start_date': '1989-01-01', 'end_date': '1990-01-01'}} algo = SetAssetDateBoundsAlgorithm( asset_metadata=metadata, - sim_params=self.sim_params,) + sim_params=self.sim_params, + env=temp_env, + ) with self.assertRaises(TradingControlViolation): algo.run(df_source) # Run the algorithm with a sid that has not started - trading.environment = trading.TradingEnvironment() - df_source, _ = factory.create_test_df_source(self.sim_params) + temp_env = TradingEnvironment() + df_source, _ = factory.create_test_df_source(self.sim_params, temp_env) metadata = {0: {'start_date': '2020-01-01', 'end_date': '2021-01-01'}} algo = SetAssetDateBoundsAlgorithm( asset_metadata=metadata, - sim_params=self.sim_params,) + sim_params=self.sim_params, + env=temp_env, + ) with self.assertRaises(TradingControlViolation): algo.run(df_source) class TestAccountControls(TestCase): + @classmethod + def setUpClass(cls): + cls.sidint = 133 + cls.env = TradingEnvironment() + cls.env.write_data( + equities_identifiers=[cls.sidint] + ) + def setUp(self): - self.sim_params = factory.create_simulation_parameters(num_days=4) - self.sidint = 133 - trading.environment.write_data(equities_identifiers=[self.sidint]) + self.sim_params = factory.create_simulation_parameters( + num_days=4, env=self.env + ) self.trade_history = factory.create_trade_history( self.sidint, [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - self.sim_params + self.sim_params, + self.env, ) - self.source = SpecificEquityTrades(event_list=self.trade_history) + self.source = SpecificEquityTrades( + event_list=self.trade_history, + env=self.env, + ) def _check_algo(self, algo, @@ -1466,22 +1596,24 @@ class TestAccountControls(TestCase): def handle_data(algo, data): algo.order(algo.sid(self.sidint), 1) - algo = SetMaxLeverageAlgorithm(0) + algo = SetMaxLeverageAlgorithm(0, sim_params=self.sim_params, + env=self.env) self.check_algo_fails(algo, handle_data) # Set max leverage to 1 so buying one share passes def handle_data(algo, data): algo.order(algo.sid(self.sidint), 1) - algo = SetMaxLeverageAlgorithm(1) + algo = SetMaxLeverageAlgorithm(1, sim_params=self.sim_params, + env=self.env) self.check_algo_succeeds(algo, handle_data) class TestClosePosAlgo(TestCase): def setUp(self): - trading.environment = trading.TradingEnvironment() - self.days = TradingEnvironment().trading_days + self.env = TradingEnvironment() + self.days = self.env.trading_days self.index = [self.days[0], self.days[1], self.days[2]] self.panel = pd.Panel({1: pd.DataFrame({ 'price': [1, 2, 4], 'volume': [1e9, 0, 0], @@ -1502,9 +1634,10 @@ class TestClosePosAlgo(TestCase): metadata = {1: {'symbol': 'TEST', 'asset_type': 'equity', 'end_date': self.days[3]}} + self.env.write_data(equities_data=metadata) self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, instant_fill=True, commission=PerShare(0), - asset_metadata=metadata) + env=self.env) self.data = DataPanelSource(self.panel) # Check results @@ -1518,9 +1651,10 @@ class TestClosePosAlgo(TestCase): metadata = {1: {'symbol': 'TEST', 'asset_type': 'future', }} - trading.environment.write_data(futures_data=metadata) + self.env.write_data(futures_data=metadata) self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, - instant_fill=True, commission=PerShare(0),) + instant_fill=True, commission=PerShare(0), + env=self.env) self.data = DataPanelSource(self.panel) # Check results @@ -1535,9 +1669,10 @@ class TestClosePosAlgo(TestCase): 'asset_type': 'future', 'notice_date': self.days[3], 'expiration_date': self.days[4]}} - trading.environment.write_data(futures_data=metadata) + self.env.write_data(futures_data=metadata) self.algo = TestAlgorithm(sid=1, amount=1, order_count=1, - instant_fill=True, commission=PerShare(0),) + instant_fill=True, commission=PerShare(0), + env=self.env) self.data = DataPanelSource(self.no_close_panel) # Check results diff --git a/tests/test_algorithm_gen.py b/tests/test_algorithm_gen.py index 1aa5babd..bdd1fa58 100644 --- a/tests/test_algorithm_gen.py +++ b/tests/test_algorithm_gen.py @@ -135,18 +135,17 @@ class AlgorithmGeneratorTestCase(TestCase): Ensure the pipeline of generators are in sync, at least as far as their current dates. """ - # Ensure we are pointing to the TradingEnvironment for this class - trading.environment = AlgorithmGeneratorTestCase.env sim_params = factory.create_simulation_parameters( start=datetime(2011, 7, 30, tzinfo=pytz.utc), - end=datetime(2012, 7, 30, tzinfo=pytz.utc) + end=datetime(2012, 7, 30, tzinfo=pytz.utc), + env=self.env, ) - algo = TestAlgo(self, sim_params=sim_params, - env=AlgorithmGeneratorTestCase.env) + algo = TestAlgo(self, sim_params=sim_params, env=self.env) trade_source = factory.create_daily_trade_source( [8229], - sim_params + sim_params, + env=self.env, ) algo.set_sources([trade_source]) @@ -168,10 +167,10 @@ class AlgorithmGeneratorTestCase(TestCase): sim_params = SimulationParameters( period_start=datetime(2012, 7, 30, tzinfo=pytz.utc), period_end=datetime(2012, 7, 30, tzinfo=pytz.utc), - data_frequency='minute' + data_frequency='minute', + env=self.env, ) - algo = TestAlgo(self, sim_params=sim_params, - env=AlgorithmGeneratorTestCase.env) + algo = TestAlgo(self, sim_params=sim_params, env=self.env) midnight_custom_source = [Event({ 'custom_field': 42.0, @@ -214,13 +213,14 @@ class AlgorithmGeneratorTestCase(TestCase): sim_params = factory.create_simulation_parameters( start=datetime(2008, 1, 1, tzinfo=pytz.utc), - end=datetime(2008, 1, 5, tzinfo=pytz.utc) + end=datetime(2008, 1, 5, tzinfo=pytz.utc), + env=self.env, ) - algo = TestAlgo(self, sim_params=sim_params, - env=AlgorithmGeneratorTestCase.env) + algo = TestAlgo(self, sim_params=sim_params, env=self.env) trade_source = factory.create_daily_trade_source( [8229], - sim_params + sim_params, + env=self.env, ) algo.set_sources([trade_source]) @@ -238,8 +238,8 @@ class AlgorithmGeneratorTestCase(TestCase): See https://github.com/quantopian/zipline/issues/241 """ sim_params = create_simulation_parameters(num_days=1, - data_frequency='minute') - algo = TestAlgo(self, sim_params=sim_params, - env=AlgorithmGeneratorTestCase.env) + data_frequency='minute', + env=self.env) + algo = TestAlgo(self, sim_params=sim_params, env=self.env) algo.run(source=[], overwrite_sim_params=False) self.assertEqual(algo.datetime, sim_params.last_close) diff --git a/tests/test_assets.py b/tests/test_assets.py index bad5d914..8d7e8d8a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -40,8 +40,7 @@ from zipline.errors import ( SidAssignmentError, RootSymbolNotFound, ) -from zipline.finance import trading -from zipline.finance.trading import with_environment +from zipline.finance.trading import TradingEnvironment from zipline.utils.test_utils import ( all_subindices, make_rotating_asset_info, @@ -87,9 +86,9 @@ def build_lookup_generic_cases(): }, ], index='sid') - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_df=frame) - finder = AssetFinder(trading.environment.engine) + env = TradingEnvironment() + env.write_data(equities_df=frame) + finder = env.asset_finder dupe_0, dupe_1, unique = assets = [ finder.retrieve_asset(i) for i in range(3) @@ -281,7 +280,7 @@ class TestFuture(TestCase): class AssetFinderTestCase(TestCase): def setUp(self): - trading.environment = trading.TradingEnvironment() + self.env = TradingEnvironment() def test_lookup_symbol_fuzzy(self): as_of = pd.Timestamp('2013-01-01', tz='UTC') @@ -299,8 +298,8 @@ class AssetFinderTestCase(TestCase): for i in range(3) ] ) - trading.environment.write_data(equities_df=frame) - finder = AssetFinder(trading.environment.engine, fuzzy_char='@') + self.env.write_data(equities_df=frame) + finder = AssetFinder(self.env.engine, fuzzy_char='@') asset_0, asset_1, asset_2 = ( finder.retrieve_asset(i) for i in range(3) ) @@ -344,8 +343,8 @@ class AssetFinderTestCase(TestCase): for i, date in enumerate(dates) ] ) - trading.environment.write_data(equities_df=df) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(equities_df=df) + finder = AssetFinder(self.env.engine) for _ in range(2): # Run checks twice to test for caching bugs. with self.assertRaises(SymbolNotFound): finder.lookup_symbol_resolve_multiple('non_existing', dates[0]) @@ -411,8 +410,8 @@ class AssetFinderTestCase(TestCase): }, ] ) - trading.environment.write_data(equities_df=data) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(equities_df=data) + finder = AssetFinder(self.env.engine) results, missing = finder.lookup_generic( ['real', 1, 'fake', 'real_but_old', 'real_but_in_the_future'], pd.Timestamp('2013-02-01', tz='UTC'), @@ -436,8 +435,8 @@ class AssetFinderTestCase(TestCase): 'end_date': '2015-01-01', 'symbol': "PLAY", 'foo_data': "FOO"}} - trading.environment.write_data(equities_data=data) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(equities_data=data) + finder = AssetFinder(self.env.engine) # Test proper insertion equity = finder.retrieve_asset(0) self.assertIsInstance(equity, Equity) @@ -454,8 +453,8 @@ class AssetFinderTestCase(TestCase): # Test dict consumption dict_to_consume = {0: {'symbol': 'PLAY'}, 1: {'symbol': 'MSFT'}} - trading.environment.write_data(equities_data=dict_to_consume) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(equities_data=dict_to_consume) + finder = AssetFinder(self.env.engine) equity = finder.retrieve_asset(0) self.assertIsInstance(equity, Equity) @@ -467,9 +466,9 @@ class AssetFinderTestCase(TestCase): df['exchange'][0] = "NASDAQ" df['asset_name'][1] = "Microsoft" df['exchange'][1] = "NYSE" - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_df=df) - finder = AssetFinder(trading.environment.engine) + self.env = TradingEnvironment() + self.env.write_data(equities_df=df) + finder = AssetFinder(self.env.engine) self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange) self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name) @@ -483,9 +482,9 @@ class AssetFinderTestCase(TestCase): future_asset = Future(200, symbol="TESTFUT", end_date=fut_end) # Consume the Assets - trading.environment.write_data(equities_identifiers=[equity_asset], - futures_identifiers=[future_asset]) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(equities_identifiers=[equity_asset], + futures_identifiers=[future_asset]) + finder = AssetFinder(self.env.engine) # Test equality with newly built Assets self.assertEqual(equity_asset, finder.retrieve_asset(1)) @@ -501,11 +500,11 @@ class AssetFinderTestCase(TestCase): today = normalize_date(pd.Timestamp('2015-07-09', tz='UTC')) # Write data with sid assignment - trading.environment.write_data(equities_identifiers=metadata, - allow_sid_assignment=True) + self.env.write_data(equities_identifiers=metadata, + allow_sid_assignment=True) # Verify that Assets were built and different sids were assigned - finder = AssetFinder(trading.environment.engine) + finder = AssetFinder(self.env.engine) play = finder.lookup_symbol('PLAY', today) msft = finder.lookup_symbol('MSFT', today) self.assertEqual('PLAY', play.symbol) @@ -519,8 +518,8 @@ class AssetFinderTestCase(TestCase): # Write data without sid assignment, asserting failure with self.assertRaises(SidAssignmentError): - trading.environment.write_data(equities_identifiers=metadata, - allow_sid_assignment=False) + self.env.write_data(equities_identifiers=metadata, + allow_sid_assignment=False) def test_security_dates_warning(self): @@ -577,8 +576,8 @@ class AssetFinderTestCase(TestCase): }, } - trading.environment.write_data(futures_data=metadata) - finder = AssetFinder(trading.environment.engine) + self.env.write_data(futures_data=metadata) + finder = AssetFinder(self.env.engine) dt = pd.Timestamp('2015-05-14', tz='UTC') last_year = pd.Timestamp('2014-01-01', tz='UTC') first_day = pd.Timestamp('2015-01-01', tz='UTC') @@ -609,7 +608,7 @@ class AssetFinderTestCase(TestCase): def test_map_identifier_index_to_sids(self): # Build an empty finder and some Assets dt = pd.Timestamp('2014-01-01', tz='UTC') - finder = AssetFinder(trading.environment.engine) + finder = AssetFinder(self.env.engine) asset1 = Equity(1, symbol="AAPL") asset2 = Equity(2, symbol="GOOG") asset200 = Future(200, symbol="CLK15") @@ -627,9 +626,9 @@ class AssetFinderTestCase(TestCase): post_map = finder.map_identifier_index_to_sids(pre_map, dt) self.assertListEqual([201, 2, 200, 1], post_map) - @with_environment() - def test_compute_lifetimes(self, env=None): + def test_compute_lifetimes(self): num_assets = 4 + env = TradingEnvironment() trading_day = env.trading_day first_start = pd.Timestamp('2015-04-01', tz='UTC') @@ -641,8 +640,8 @@ class AssetFinderTestCase(TestCase): asset_lifetime=5 ) - trading.environment.write_data(equities_df=frame) - finder = AssetFinder(trading.environment.engine) + env.write_data(equities_df=frame) + finder = env.asset_finder all_dates = pd.date_range( start=first_start, @@ -676,7 +675,8 @@ class AssetFinderTestCase(TestCase): class TestFutureChain(TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): metadata = { 0: { 'symbol': 'CLG06', @@ -708,9 +708,9 @@ class TestFutureChain(TestCase): 'expiration_date': pd.Timestamp('2006-10-20', tz='UTC')} } - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(futures_data=metadata) - self.asset_finder = AssetFinder(trading.environment.engine) + env = TradingEnvironment() + env.write_data(futures_data=metadata) + cls.asset_finder = env.asset_finder def test_len(self): """ Test the __len__ method of FutureChain. diff --git a/tests/test_batchtransform.py b/tests/test_batchtransform.py index 806cfbe8..1ec1a864 100644 --- a/tests/test_batchtransform.py +++ b/tests/test_batchtransform.py @@ -30,10 +30,9 @@ import zipline.utils.factory as factory from zipline.transforms import batch_transform from zipline.test_algorithms import (BatchTransformAlgorithm, - BatchTransformAlgorithmMinute, - ReturnPriceBatchTransform) + BatchTransformAlgorithmMinute) -from zipline.finance import trading +from zipline.finance.trading import TradingEnvironment from zipline.algorithm import TradingAlgorithm from zipline.utils.tradingcalendar import trading_days from copy import deepcopy @@ -107,16 +106,19 @@ class DifferentSidSource(DataSource): class TestChangeOfSids(TestCase): def setUp(self): self.sids = range(90) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=self.sids) + self.env = TradingEnvironment() + self.env.write_data(equities_identifiers=self.sids) + self.sim_params = factory.create_simulation_parameters( start=datetime(1990, 1, 1, tzinfo=pytz.utc), - end=datetime(1990, 1, 8, tzinfo=pytz.utc) + end=datetime(1990, 1, 8, tzinfo=pytz.utc), + env=self.env, ) def test_all_sids_passed(self): algo = BatchTransformAlgorithmSetSid( sim_params=self.sim_params, + env=self.env, ) source = DifferentSidSource() algo.run(source) @@ -131,26 +133,32 @@ class TestChangeOfSids(TestCase): class TestBatchTransformMinutely(TestCase): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[0]) + def setUp(self): setup_logger(self) start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) self.sim_params = factory.create_simulation_parameters( - start=start, - end=end, + start=start, end=end, env=self.env, ) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[0]) self.sim_params.emission_rate = 'daily' self.sim_params.data_frequency = 'minute' self.source, self.df = \ - factory.create_test_df_source(bars='minute') + factory.create_test_df_source(sim_params=self.sim_params, + env=self.env, + bars='minute') def tearDown(self): teardown_logger(self) def test_core(self): - algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params) + algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params, + env=self.env) algo.run(self.source) wl = int(algo.window_length * 6.5 * 60) for bt in algo.history[wl:]: @@ -158,7 +166,9 @@ class TestBatchTransformMinutely(TestCase): def test_window_length(self): algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params, - window_length=1, refresh_period=0) + env=self.env, + window_length=1, + refresh_period=0) algo.run(self.source) wl = int(algo.window_length * 6.5 * 60) np.testing.assert_array_equal(algo.history[:(wl - 1)], @@ -171,24 +181,25 @@ class TestBatchTransform(TestCase): @classmethod def setUpClass(cls): - cls.env = trading.TradingEnvironment() + cls.env = TradingEnvironment() cls.env.write_data(equities_identifiers=[0]) def setUp(self): setup_logger(self) self.sim_params = factory.create_simulation_parameters( start=datetime(1990, 1, 1, tzinfo=pytz.utc), - end=datetime(1990, 1, 8, tzinfo=pytz.utc) + end=datetime(1990, 1, 8, tzinfo=pytz.utc), + env=self.env ) - trading.environment = TestBatchTransform.env self.source, self.df = \ - factory.create_test_df_source(self.sim_params) + factory.create_test_df_source(self.sim_params, self.env) def tearDown(self): teardown_logger(self) def test_core_functionality(self): - algo = BatchTransformAlgorithm(sim_params=self.sim_params) + algo = BatchTransformAlgorithm(sim_params=self.sim_params, + env=self.env) algo.run(self.source) wl = algo.window_length # The following assertion depend on window length of 3 @@ -257,7 +268,8 @@ class TestBatchTransform(TestCase): def test_passing_of_args(self): algo = BatchTransformAlgorithm(1, kwarg='str', - sim_params=self.sim_params) + sim_params=self.sim_params, + env=self.env) algo.run(self.source) self.assertEqual(algo.args, (1,)) self.assertEqual(algo.kwargs, {'kwarg': 'str'}) @@ -278,22 +290,3 @@ class TestBatchTransform(TestCase): # 1990-01-08 - window now full expected_item ]) - - -def run_batchtransform(window_length=10): - sim_params = factory.create_simulation_parameters( - start=datetime(1990, 1, 1, tzinfo=pytz.utc), - end=datetime(1995, 1, 8, tzinfo=pytz.utc) - ) - source, df = factory.create_test_df_source(sim_params) - - return_price_class = ReturnPriceBatchTransform( - refresh_period=1, - window_length=window_length, - clean_nans=False - ) - - for raw_event in source: - raw_event['datetime'] = raw_event.dt - event = {0: raw_event} - return_price_class.handle_data(event) diff --git a/tests/test_events_through_risk.py b/tests/test_events_through_risk.py index b67b4209..f98410cc 100644 --- a/tests/test_events_through_risk.py +++ b/tests/test_events_through_risk.py @@ -19,8 +19,7 @@ import pytz import numpy as np -from zipline.finance.trading import SimulationParameters -from zipline.finance import trading +from zipline.finance.trading import SimulationParameters, TradingEnvironment from zipline.algorithm import TradingAlgorithm from zipline.protocol import ( Event, @@ -43,10 +42,12 @@ class BuyAndHoldAlgorithm(TradingAlgorithm): class TestEventsThroughRisk(unittest.TestCase): - def test_daily_buy_and_hold(self): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[1]) - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[1]) + def test_daily_buy_and_hold(self): start_date = datetime.datetime( year=2006, @@ -70,8 +71,7 @@ class TestEventsThroughRisk(unittest.TestCase): emission_rate='daily' ) - algo = BuyAndHoldAlgorithm( - sim_params=sim_params) + algo = BuyAndHoldAlgorithm(sim_params=sim_params, env=self.env) first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc) second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc) @@ -169,178 +169,176 @@ class TestEventsThroughRisk(unittest.TestCase): err_msg="Mismatch at %s" % (current_dt,)) def test_minute_buy_and_hold(self): - with trading.TradingEnvironment(): - start_date = datetime.datetime( - year=2006, - month=1, - day=3, - hour=0, - minute=0, - tzinfo=pytz.utc) - end_date = datetime.datetime( - year=2006, - month=1, - day=5, - hour=0, - minute=0, - tzinfo=pytz.utc) - sim_params = SimulationParameters( - period_start=start_date, - period_end=end_date, - emission_rate='daily', - data_frequency='minute') + start_date = datetime.datetime( + year=2006, + month=1, + day=3, + hour=0, + minute=0, + tzinfo=pytz.utc) + end_date = datetime.datetime( + year=2006, + month=1, + day=5, + hour=0, + minute=0, + tzinfo=pytz.utc) - algo = BuyAndHoldAlgorithm( - identifiers=[1], - sim_params=sim_params) + sim_params = SimulationParameters( + period_start=start_date, + period_end=end_date, + emission_rate='daily', + data_frequency='minute', + env=self.env) - first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc) - first_open, first_close = \ - trading.environment.get_open_and_close(first_date) + algo = BuyAndHoldAlgorithm( + sim_params=sim_params, + env=self.env) - second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc) - second_open, second_close = \ - trading.environment.get_open_and_close(second_date) + first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc) + first_open, first_close = self.env.get_open_and_close(first_date) - third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc) - third_open, third_close = \ - trading.environment.get_open_and_close(third_date) + second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc) + second_open, second_close = self.env.get_open_and_close(second_date) - benchmark_data = [ - Event({ - 'returns': 0.1, - 'dt': first_close, - 'source_id': 'test-benchmark-source', - 'type': DATASOURCE_TYPE.BENCHMARK - }), - Event({ - 'returns': 0.2, - 'dt': second_close, - 'source_id': 'test-benchmark-source', - 'type': DATASOURCE_TYPE.BENCHMARK - }), - Event({ - 'returns': 0.4, - 'dt': third_close, - 'source_id': 'test-benchmark-source', - 'type': DATASOURCE_TYPE.BENCHMARK - }), - ] + third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc) + third_open, third_close = self.env.get_open_and_close(third_date) - trade_bar_data = [ - Event({ - 'open_price': 10, - 'close_price': 15, - 'price': 15, - 'volume': 1000, - 'sid': 1, - 'dt': first_open, - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - Event({ - 'open_price': 10, - 'close_price': 15, - 'price': 15, - 'volume': 1000, - 'sid': 1, - 'dt': first_open + datetime.timedelta(minutes=10), - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - Event({ - 'open_price': 15, - 'close_price': 20, - 'price': 20, - 'volume': 2000, - 'sid': 1, - 'dt': second_open, - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - Event({ - 'open_price': 15, - 'close_price': 20, - 'price': 20, - 'volume': 2000, - 'sid': 1, - 'dt': second_open + datetime.timedelta(minutes=10), - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - Event({ - 'open_price': 20, - 'close_price': 15, - 'price': 15, - 'volume': 1000, - 'sid': 1, - 'dt': third_open, - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - Event({ - 'open_price': 20, - 'close_price': 15, - 'price': 15, - 'volume': 1000, - 'sid': 1, - 'dt': third_open + datetime.timedelta(minutes=10), - 'source_id': 'test-trade-source', - 'type': DATASOURCE_TYPE.TRADE - }), - ] + benchmark_data = [ + Event({ + 'returns': 0.1, + 'dt': first_close, + 'source_id': 'test-benchmark-source', + 'type': DATASOURCE_TYPE.BENCHMARK + }), + Event({ + 'returns': 0.2, + 'dt': second_close, + 'source_id': 'test-benchmark-source', + 'type': DATASOURCE_TYPE.BENCHMARK + }), + Event({ + 'returns': 0.4, + 'dt': third_close, + 'source_id': 'test-benchmark-source', + 'type': DATASOURCE_TYPE.BENCHMARK + }), + ] - algo.benchmark_return_source = benchmark_data - algo.set_sources(list([trade_bar_data])) - gen = algo._create_generator(sim_params) + trade_bar_data = [ + Event({ + 'open_price': 10, + 'close_price': 15, + 'price': 15, + 'volume': 1000, + 'sid': 1, + 'dt': first_open, + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + Event({ + 'open_price': 10, + 'close_price': 15, + 'price': 15, + 'volume': 1000, + 'sid': 1, + 'dt': first_open + datetime.timedelta(minutes=10), + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + Event({ + 'open_price': 15, + 'close_price': 20, + 'price': 20, + 'volume': 2000, + 'sid': 1, + 'dt': second_open, + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + Event({ + 'open_price': 15, + 'close_price': 20, + 'price': 20, + 'volume': 2000, + 'sid': 1, + 'dt': second_open + datetime.timedelta(minutes=10), + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + Event({ + 'open_price': 20, + 'close_price': 15, + 'price': 15, + 'volume': 1000, + 'sid': 1, + 'dt': third_open, + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + Event({ + 'open_price': 20, + 'close_price': 15, + 'price': 15, + 'volume': 1000, + 'sid': 1, + 'dt': third_open + datetime.timedelta(minutes=10), + 'source_id': 'test-trade-source', + 'type': DATASOURCE_TYPE.TRADE + }), + ] - crm = algo.perf_tracker.cumulative_risk_metrics - dt_loc = crm.cont_index.get_loc(algo.datetime) + algo.benchmark_return_source = benchmark_data + algo.set_sources(list([trade_bar_data])) + gen = algo._create_generator(sim_params) - first_msg = next(gen) + crm = algo.perf_tracker.cumulative_risk_metrics + dt_loc = crm.cont_index.get_loc(algo.datetime) - self.assertIsNotNone(first_msg, - "There should be a message emitted.") + first_msg = next(gen) - # Protects against bug where the positions appeared to be - # a day late, because benchmarks were triggering - # calculations before the events for the day were - # processed. - self.assertEqual(1, len(algo.portfolio.positions), "There should " - "be one position after the first day.") + self.assertIsNotNone(first_msg, + "There should be a message emitted.") - self.assertEquals( - 0, - crm.algorithm_volatility[dt_loc], - "On the first day algorithm volatility does not exist.") + # Protects against bug where the positions appeared to be + # a day late, because benchmarks were triggering + # calculations before the events for the day were + # processed. + self.assertEqual(1, len(algo.portfolio.positions), "There should " + "be one position after the first day.") - second_msg = next(gen) + self.assertEquals( + 0, + crm.algorithm_volatility[dt_loc], + "On the first day algorithm volatility does not exist.") - self.assertIsNotNone(second_msg, "There should be a message " - "emitted.") + second_msg = next(gen) - self.assertEqual(1, len(algo.portfolio.positions), - "Number of positions should stay the same.") + self.assertIsNotNone(second_msg, "There should be a message " + "emitted.") - # TODO: Hand derive. Current value is just a canary to - # detect changes. - np.testing.assert_almost_equal( - 0.050022510129558301, - crm.algorithm_returns[-1], - decimal=6) + self.assertEqual(1, len(algo.portfolio.positions), + "Number of positions should stay the same.") - third_msg = next(gen) + # TODO: Hand derive. Current value is just a canary to + # detect changes. + np.testing.assert_almost_equal( + 0.050022510129558301, + crm.algorithm_returns[-1], + decimal=6) - self.assertEqual(1, len(algo.portfolio.positions), - "Number of positions should stay the same.") + third_msg = next(gen) - self.assertIsNotNone(third_msg, "There should be a message " - "emitted.") + self.assertEqual(1, len(algo.portfolio.positions), + "Number of positions should stay the same.") - # TODO: Hand derive. Current value is just a canary to - # detect changes. - np.testing.assert_almost_equal( - -0.047639464532418657, - crm.algorithm_returns[-1], - decimal=6) + self.assertIsNotNone(third_msg, "There should be a message " + "emitted.") + + # TODO: Hand derive. Current value is just a canary to + # detect changes. + np.testing.assert_almost_equal( + -0.047639464532418657, + crm.algorithm_returns[-1], + decimal=6) diff --git a/tests/test_examples.py b/tests/test_examples.py index fa43d52f..5c93547b 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -31,8 +31,6 @@ from zipline.utils import parse_args, run_pipeline # Otherwise the next line sometimes complains about being run too late. _multiprocess_can_split_ = False -from zipline.finance import trading - matplotlib.use('Agg') @@ -47,8 +45,6 @@ class ExamplesTests(TestCase): @parameterized.expand(((os.path.basename(f).replace('.', '_'), f) for f in glob.glob(os.path.join(example_dir(), '*.py')))) def test_example(self, name, example): - # Create a new trading environment for each test. - trading.environment = trading.TradingEnvironment() imp.load_source('__main__', os.path.basename(example), open(example)) # Test algorithm as if scripts/run_algo.py is being used. diff --git a/tests/test_exception_handling.py b/tests/test_exception_handling.py index 505253f3..848455da 100644 --- a/tests/test_exception_handling.py +++ b/tests/test_exception_handling.py @@ -24,6 +24,7 @@ from zipline.test_algorithms import ( SetPortfolioAlgorithm, ) from zipline.finance.slippage import FixedSlippage +from zipline.finance.trading import TradingEnvironment from zipline.utils.test_utils import ( @@ -39,6 +40,11 @@ EXTENDED_TIMEOUT = 90 class ExceptionTestCase(TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[133]) + def setUp(self): self.zipline_test_config = { 'sid': 133, @@ -65,7 +71,8 @@ class ExceptionTestCase(TestCase): ExceptionAlgorithm( 'handle_data', self.zipline_test_config['sid'], - sim_params=factory.create_simulation_parameters() + sim_params=factory.create_simulation_parameters(), + env=self.env ) zipline = simfactory.create_test_zipline( @@ -75,8 +82,7 @@ class ExceptionTestCase(TestCase): with self.assertRaises(Exception) as ctx: output, _ = drain_zipline(self, zipline) - self.assertEqual(str(ctx.exception), - 'Algo exception in handle_data') + self.assertEqual(str(ctx.exception), 'Algo exception in handle_data') def test_zerodivision_exception_in_handle_data(self): @@ -85,7 +91,8 @@ class ExceptionTestCase(TestCase): self.zipline_test_config['algorithm'] = \ DivByZeroAlgorithm( self.zipline_test_config['sid'], - sim_params=factory.create_simulation_parameters() + sim_params=factory.create_simulation_parameters(), + env=self.env ) zipline = simfactory.create_test_zipline( @@ -105,7 +112,8 @@ class ExceptionTestCase(TestCase): self.zipline_test_config['algorithm'] = \ SetPortfolioAlgorithm( self.zipline_test_config['sid'], - sim_params=factory.create_simulation_parameters() + sim_params=factory.create_simulation_parameters(), + env=self.env ) zipline = simfactory.create_test_zipline( diff --git a/tests/test_finance.py b/tests/test_finance.py index b9f831db..5f4b6b25 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -39,7 +39,6 @@ import zipline.utils.simfactory as simfactory from zipline.finance.blotter import Blotter from zipline.gens.composites import date_sorted_sources -from zipline.finance import trading from zipline.finance.trading import TradingEnvironment from zipline.finance.execution import MarketOrder, LimitOrder from zipline.finance.trading import SimulationParameters @@ -59,9 +58,12 @@ _multiprocess_can_split_ = False class FinanceTestCase(TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[1, 133]) + def setUp(self): - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[1, 133]) self.zipline_test_config = { 'sid': 133, } @@ -76,7 +78,8 @@ class FinanceTestCase(TestCase): sim_params = factory.create_simulation_parameters() trade_source = factory.create_daily_trade_source( [133], - sim_params + sim_params, + env=self.env, ) prev = None for trade in trade_source: @@ -94,7 +97,6 @@ class FinanceTestCase(TestCase): # No transactions can be filled on the first trade, so # we have one extra trade to ensure all orders are filled. self.zipline_test_config['trade_count'] = 101 - trading.environment = trading.TradingEnvironment() full_zipline = simfactory.create_test_zipline( **self.zipline_test_config) assert_single_position(self, full_zipline) @@ -231,7 +233,8 @@ class FinanceTestCase(TestCase): price, volume, trade_interval, - sim_params + sim_params, + env=self.env, ) if alternate: @@ -265,7 +268,7 @@ class FinanceTestCase(TestCase): self.assertEqual(order.sid, sid) self.assertEqual(order.amount, order_amount * alternator ** i) - tracker = PerformanceTracker(sim_params) + tracker = PerformanceTracker(sim_params, env=self.env) benchmark_returns = [ Event({'dt': dt, @@ -273,7 +276,7 @@ class FinanceTestCase(TestCase): 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) - for dt, ret in trading.environment.benchmark_returns.iteritems() + for dt, ret in self.env.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] @@ -412,6 +415,7 @@ class TradingEnvironmentTestCase(TestCase): period_start=datetime(2008, 1, 1, tzinfo=pytz.utc), period_end=datetime(2008, 12, 31, tzinfo=pytz.utc), capital_base=100000, + env=self.env, ) self.assertTrue(env.last_close.month == 12) @@ -428,10 +432,11 @@ class TradingEnvironmentTestCase(TestCase): # 20 21 22 23 24 25 26 # 27 28 29 30 31 - env = SimulationParameters( + params = SimulationParameters( period_start=datetime(2007, 12, 31, tzinfo=pytz.utc), period_end=datetime(2008, 1, 7, tzinfo=pytz.utc), capital_base=100000, + env=self.env, ) expected_trading_days = ( @@ -447,9 +452,9 @@ class TradingEnvironmentTestCase(TestCase): ) num_expected_trading_days = 5 - self.assertEquals(num_expected_trading_days, env.days_in_period) + self.assertEquals(num_expected_trading_days, params.days_in_period) np.testing.assert_array_equal(expected_trading_days, - env.trading_days.tolist()) + params.trading_days.tolist()) @timed(DEFAULT_TIMEOUT) def test_market_minute_window(self): diff --git a/tests/test_history.py b/tests/test_history.py index 2b77ac00..d3641fde 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -32,7 +32,6 @@ from zipline.finance import trading from zipline.finance.trading import ( SimulationParameters, TradingEnvironment, - with_environment, ) from zipline.errors import IncompatibleHistoryFrequency @@ -133,29 +132,30 @@ def convert_cases(cases): INDEX_TEST_CASES = convert_cases(INDEX_TEST_CASES_RAW) -def get_index_at_dt(case_input): +def get_index_at_dt(case_input, env): history_spec = history.HistorySpec( case_input['bar_count'], case_input['frequency'], None, False, + env=env, data_frequency='minute', ) - return history.index_at_dt(history_spec, case_input['algo_dt']) + return history.index_at_dt(history_spec, case_input['algo_dt'], env=env) class TestHistoryIndex(TestCase): @classmethod def setUpClass(cls): - cls.environment = TradingEnvironment.instance() + cls.environment = TradingEnvironment() @parameterized.expand( [(name, case['input'], case['expected']) for name, case in INDEX_TEST_CASES.items()] ) def test_index_at_dt(self, name, case_input, expected): - history_index = get_index_at_dt(case_input) + history_index = get_index_at_dt(case_input, self.environment) history_series = pd.Series(index=history_index) expected_series = pd.Series(index=expected) @@ -167,7 +167,7 @@ class TestHistoryContainer(TestCase): @classmethod def setUpClass(cls): - cls.env = TradingEnvironment.instance() + cls.env = TradingEnvironment() def bar_data_dt(self, bar_data, require_unique=True): """ @@ -205,6 +205,7 @@ class TestHistoryContainer(TestCase): container = HistoryContainer( {spec.key_str: spec for spec in specs}, sids, dt, 'minute', + env=self.env, ) for update_count, update in enumerate(updates): @@ -232,14 +233,16 @@ class TestHistoryContainer(TestCase): frequency='1m', field='price', ffill=True, - data_frequency='minute' + data_frequency='minute', + env=self.env, ) no_fill_spec = history.HistorySpec( bar_count=3, frequency='1m', field='price', ffill=False, - data_frequency='minute' + data_frequency='minute', + env=self.env, ) specs = {spec.key_str: spec, no_fill_spec.key_str: no_fill_spec} @@ -248,7 +251,7 @@ class TestHistoryContainer(TestCase): '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC') container = HistoryContainer( - specs, initial_sids, initial_dt, 'minute' + specs, initial_sids, initial_dt, 'minute', env=self.env, ) bar_data = BarData() @@ -282,7 +285,8 @@ class TestHistoryContainer(TestCase): frequency='1d', field='price', ffill=True, - data_frequency='minute' + data_frequency='minute', + env=self.env, ) specs = {spec.key_str: spec} initial_sids = [1, ] @@ -290,7 +294,7 @@ class TestHistoryContainer(TestCase): '2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC') container = HistoryContainer( - specs, initial_sids, initial_dt, 'minute' + specs, initial_sids, initial_dt, 'minute', env=self.env, ) bar_data = BarData() @@ -440,9 +444,10 @@ def handle_data(context, data): end = pd.Timestamp('2006-03-30', tz='UTC') sim_params = factory.create_simulation_parameters( - start=start, end=end, data_frequency='daily') + start=start, end=end, data_frequency='daily', env=self.env, + ) - _, df = factory.create_test_df_source(sim_params) + _, df = factory.create_test_df_source(sim_params, self.env) df = df.astype(np.float64) source = DataFrameSource(df) @@ -1039,14 +1044,15 @@ def handle_data(context, data): period_end=end, capital_base=float("1.0e5"), data_frequency='minute', - emission_rate='daily' + emission_rate='daily', + env=self.env, ) test_algo = TradingAlgorithm( script=algo_text, data_frequency='minute', sim_params=sim_params, - env=TestHistoryAlgo.env, + env=self.env, ) test_algo.test_case = self @@ -1089,14 +1095,15 @@ def handle_data(context, data): period_end=end, capital_base=float("1.0e5"), data_frequency='minute', - emission_rate='daily' + emission_rate='daily', + env=self.env, ) test_algo = TradingAlgorithm( script=algo_text, data_frequency='minute', sim_params=sim_params, - env=TestHistoryAlgo.env, + env=self.env, ) test_algo.test_case = self @@ -1107,6 +1114,11 @@ def handle_data(context, data): class TestHistoryContainerResize(TestCase): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + @parameterized.expand( (freq, field, data_frequency, construct_digest) for freq in ('1m', '1d') @@ -1127,6 +1139,7 @@ class TestHistoryContainerResize(TestCase): field=field, ffill=True, data_frequency=data_frequency, + env=self.env, ) specs = {spec.key_str: spec} initial_sids = [1] @@ -1138,7 +1151,7 @@ class TestHistoryContainerResize(TestCase): ) container = HistoryContainer( - specs, initial_sids, initial_dt, data_frequency, + specs, initial_sids, initial_dt, data_frequency, env=self.env, ) if construct_digest: @@ -1156,6 +1169,7 @@ class TestHistoryContainerResize(TestCase): field=field, ffill=True, data_frequency=data_frequency, + env=self.env, ), history.HistorySpec( bar_count=bar_count + 2, @@ -1163,6 +1177,7 @@ class TestHistoryContainerResize(TestCase): field=field, ffill=True, data_frequency=data_frequency, + env=self.env, ), ) @@ -1192,6 +1207,7 @@ class TestHistoryContainerResize(TestCase): field=first, ffill=True, data_frequency=data_frequency, + env=self.env, ) specs = {spec.key_str: spec} initial_sids = [1] @@ -1203,7 +1219,7 @@ class TestHistoryContainerResize(TestCase): ) container = HistoryContainer( - specs, initial_sids, initial_dt, data_frequency, + specs, initial_sids, initial_dt, data_frequency, env=self.env ) if bar_count > 1: @@ -1220,6 +1236,7 @@ class TestHistoryContainerResize(TestCase): field=second, ffill=True, data_frequency=data_frequency, + env=self.env, ) container.ensure_spec(new_spec, initial_dt, bar_data) @@ -1252,6 +1269,7 @@ class TestHistoryContainerResize(TestCase): field=field, ffill=True, data_frequency=data_frequency, + env=self.env, ) specs = {spec.key_str: spec} initial_sids = [1] @@ -1263,7 +1281,7 @@ class TestHistoryContainerResize(TestCase): ) container = HistoryContainer( - specs, initial_sids, initial_dt, data_frequency, + specs, initial_sids, initial_dt, data_frequency, env=self.env, ) if bar_count > 1: @@ -1280,6 +1298,7 @@ class TestHistoryContainerResize(TestCase): field=field, ffill=True, data_frequency=data_frequency, + env=self.env, ) container.ensure_spec(new_spec, initial_dt, bar_data) @@ -1292,8 +1311,7 @@ class TestHistoryContainerResize(TestCase): self.assert_history(container, new_spec, initial_dt) - @with_environment() - def assert_history(self, container, spec, dt, env=None): + def assert_history(self, container, spec, dt): hst = container.get_history(spec, dt) self.assertEqual(len(hst), spec.bar_count) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index f95f4421..8f5fad91 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -15,7 +15,6 @@ from __future__ import division -import pickle import collections from datetime import ( datetime, @@ -40,12 +39,14 @@ from zipline.finance.slippage import Transaction, create_transaction import zipline.utils.math_utils as zp_math from zipline.gens.composites import date_sorted_sources -from zipline.finance import trading from zipline.finance.trading import SimulationParameters from zipline.finance.blotter import Order from zipline.finance.commission import PerShare, PerTrade, PerDollar -from zipline.finance.trading import with_environment +from zipline.finance.trading import TradingEnvironment from zipline.utils.factory import create_random_simulation_parameters +from zipline.utils.serialization_utils import ( + load_with_persistent_ids, dump_with_persistent_ids +) import zipline.protocol as zp from zipline.protocol import Event, DATASOURCE_TYPE from zipline.sources.data_frame_source import DataPanelSource @@ -128,8 +129,7 @@ def create_txn(trade_event, price, amount): return create_transaction(trade_event, mock_order, price, amount) -@with_environment() -def benchmark_events_in_range(sim_params, env=None): +def benchmark_events_in_range(sim_params, env): return [ Event({'dt': dt, 'returns': ret, @@ -174,7 +174,7 @@ def calculate_results(host, txns = txns or [] splits = splits or [] - perf_tracker = perf.PerformanceTracker(host.sim_params) + perf_tracker = perf.PerformanceTracker(host.sim_params, host.env) if dividend_events is not None: dividend_frame = pd.DataFrame( @@ -246,9 +246,9 @@ def check_perf_tracker_serialization(perf_tracker): 'total_days', ] - p_string = pickle.dumps(perf_tracker) + p_string = dump_with_persistent_ids(perf_tracker) - test = pickle.loads(p_string) + test = load_with_persistent_ids(p_string, env=perf_tracker.env) for k in scalar_keys: nt.assert_equal(getattr(test, k), getattr(perf_tracker, k), k) @@ -259,13 +259,15 @@ def check_perf_tracker_serialization(perf_tracker): class TestSplitPerformance(unittest.TestCase): def setUp(self): + self.env = TradingEnvironment() + self.env.write_data(equities_identifiers=[1]) self.sim_params, self.dt, self.end_dt = \ create_random_simulation_parameters() - trading.environment.write_data(equities_identifiers=[1]) # start with $10,000 self.sim_params.capital_base = 10e3 - self.benchmark_events = benchmark_events_in_range(self.sim_params) + self.benchmark_events = benchmark_events_in_range(self.sim_params, + self.env) def test_split_long_position(self): events = factory.create_trade_history( @@ -273,7 +275,8 @@ class TestSplitPerformance(unittest.TestCase): [20, 20], [100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) # set up a long position in sid 1 @@ -359,17 +362,20 @@ class TestSplitPerformance(unittest.TestCase): class TestCommissionEvents(unittest.TestCase): def setUp(self): + self.env = TradingEnvironment() + self.env.write_data( + equities_identifiers=[0, 1, 133] + ) self.sim_params, self.dt, self.end_dt = \ create_random_simulation_parameters() - trading.environment.write_data(equities_identifiers=[0, 1, 133]) - logger.info("sim_params: %s, dt: %s, end_dt: %s" % (self.sim_params, self.dt, self.end_dt)) self.sim_params.capital_base = 10e3 - self.benchmark_events = benchmark_events_in_range(self.sim_params) + self.benchmark_events = benchmark_events_in_range(self.sim_params, + self.env) def test_commission_event(self): events = factory.create_trade_history( @@ -377,7 +383,8 @@ class TestCommissionEvents(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) # Test commission models and validate result @@ -454,7 +461,8 @@ class TestCommissionEvents(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) # Buy and sell the same sid so that we have a zero position by the @@ -484,7 +492,8 @@ class TestCommissionEvents(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) # Add a cash adjustment at the time of event[3]. @@ -500,21 +509,26 @@ class TestCommissionEvents(unittest.TestCase): class TestDividendPerformance(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[1, 2]) + def setUp(self): self.sim_params, self.dt, self.end_dt = \ create_random_simulation_parameters() - trading.environment.write_data(equities_identifiers=[1, 2]) self.sim_params.capital_base = 10e3 - self.benchmark_events = benchmark_events_in_range(self.sim_params) + self.benchmark_events = benchmark_events_in_range(self.sim_params, + self.env) def test_market_hours_calculations(self): # DST in US/Eastern began on Sunday March 14, 2010 before = datetime(2010, 3, 12, 14, 31, tzinfo=pytz.utc) after = factory.get_next_trading_dt( before, - timedelta(days=1) + timedelta(days=1), + self.env, ) self.assertEqual(after.hour, 13) @@ -525,7 +539,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( 1, @@ -576,7 +591,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params) + self.sim_params, + env=self.env) ) dividend = factory.create_stock_dividend( @@ -626,7 +642,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( @@ -667,7 +684,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( @@ -708,7 +726,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10, 10], [100, 100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( @@ -749,13 +768,14 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) pay_date = self.sim_params.first_open # find pay date that is much later. for i in range(30): - pay_date = factory.get_next_trading_dt(pay_date, oneday) + pay_date = factory.get_next_trading_dt(pay_date, oneday, self.env) dividend = factory.create_dividend( 1, 10.00, @@ -795,7 +815,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( @@ -836,7 +857,8 @@ class TestDividendPerformance(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( @@ -865,15 +887,15 @@ class TestDividendPerformance(unittest.TestCase): [event['cumulative_perf']['capital_used'] for event in results] self.assertEqual(cumulative_cash_flows, [0, 0, 0, 0, 0]) - @with_environment() - def test_no_dividend_at_simulation_end(self, env=None): + def test_no_dividend_at_simulation_end(self): # post some trades in the market events = factory.create_trade_history( 1, [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - self.sim_params + self.sim_params, + env=self.env ) dividend = factory.create_dividend( 1, @@ -886,12 +908,12 @@ class TestDividendPerformance(unittest.TestCase): events[-2].dt, # pay date, when the algorithm receives the dividend. # This pays out on the day after the last event - env.next_trading_day(events[-1].dt) + self.env.next_trading_day(events[-1].dt) ) # Set the last day to be the last event self.sim_params.period_end = events[-1].dt - self.sim_params._update_internal() + self.sim_params.update_internal_from_env(self.env) # Simulate a transaction being filled prior to the ex_date. txns = [create_txn(events[0], 10.0, 100)] @@ -929,18 +951,29 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance): self.end_dt = datetime(2004, 11, 25, tzinfo=pytz.utc) self.sim_params = SimulationParameters( self.dt, - self.end_dt) - self.benchmark_events = benchmark_events_in_range(self.sim_params) + self.end_dt, + env=self.env) + + self.sim_params.capital_base = 10e3 + + self.benchmark_events = benchmark_events_in_range(self.sim_params, + self.env) class TestPositionPerformance(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[1, 2]) + def setUp(self): self.sim_params, self.dt, self.end_dt = \ create_random_simulation_parameters() - trading.environment.write_data(equities_identifiers=[1, 2]) - self.benchmark_events = benchmark_events_in_range(self.sim_params) + self.finder = self.env.asset_finder + self.benchmark_events = benchmark_events_in_range(self.sim_params, + self.env) def test_long_short_positions(self): """ @@ -956,7 +989,8 @@ class TestPositionPerformance(unittest.TestCase): [10, 10, 10, 9], [100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) trades_2 = factory.create_trade_history( @@ -964,13 +998,14 @@ class TestPositionPerformance(unittest.TestCase): [10, 10, 10, 11], [100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) txn1 = create_txn(trades_1[1], 10.0, 100) txn2 = create_txn(trades_2[1], 10.0, -100) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt pt.execute_transaction(txn1) pp.handle_execution(txn1) @@ -1046,12 +1081,13 @@ class TestPositionPerformance(unittest.TestCase): [10, 10, 10, 11], [100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) txn = create_txn(trades[1], 10.0, 1000) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1125,12 +1161,13 @@ class TestPositionPerformance(unittest.TestCase): [10, 10, 10, 11], [100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) txn = create_txn(trades[1], 10.0, 100) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1228,14 +1265,15 @@ single short-sale transaction""" [10, 10, 10, 11, 10, 9], [100, 100, 100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) trades_1 = trades[:-2] txn = create_txn(trades[1], 10.0, -100) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt pt.execute_transaction(txn) @@ -1352,8 +1390,8 @@ single short-sale transaction""" ) # now run a performance period encompassing the entire trade sample. - ptTotal = perf.PositionTracker() - ppTotal = perf.PerformancePeriod(1000.0) + ptTotal = perf.PositionTracker(self.env.asset_finder) + ppTotal = perf.PerformancePeriod(1000.0, self.env.asset_finder) ppTotal.position_tracker = pt for trade in trades_1: @@ -1447,7 +1485,8 @@ trade after cover""" [10, 10, 10, 11, 9, 8, 7, 8, 9, 10], [100, 100, 100, 100, 100, 100, 100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + env=self.env ) short_txn = create_txn( @@ -1457,8 +1496,8 @@ trade after cover""" ) cover_txn = create_txn(trades[6], 7.0, 100) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt pt.execute_transaction(short_txn) @@ -1551,13 +1590,14 @@ shares in position" [10, 11, 11, 12], [100, 100, 100, 100], onesec, - self.sim_params + self.sim_params, + self.env ) trades = factory.create_trade_history(*history_args) transactions = factory.create_txn_history(*history_args) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt average_cost = 0 @@ -1623,8 +1663,8 @@ shares in position" self.assertEqual(pp.pnl, -800, "this period goes from +400 to -400") - pt3 = perf.PositionTracker() - pp3 = perf.PerformancePeriod(1000.0) + pt3 = perf.PositionTracker(self.env.asset_finder) + pp3 = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp3.position_tracker = pt3 average_cost = 0 @@ -1666,15 +1706,16 @@ shares in position" [10, 9, 11, 8, 9, 12, 13, 14], [200, -100, -100, 100, -300, 100, 500, 400], onesec, - self.sim_params + self.sim_params, + self.env ) cost_bases = [10, 10, 0, 8, 9, 9, 13, 13.5] trades = factory.create_trade_history(*history_args) transactions = factory.create_txn_history(*history_args) - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(1000.0) + pt = perf.PositionTracker(self.env.asset_finder) + pp = perf.PerformancePeriod(1000.0, self.env.asset_finder) pp.position_tracker = pt for txn, cb in zip(transactions, cost_bases): @@ -1692,9 +1733,10 @@ shares in position" class TestPerformanceTracker(unittest.TestCase): - def setUp(self): - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[133, 134]) + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=[1, 2, 133, 134]) NumDaysToDelete = collections.namedtuple( 'NumDaysToDelete', ('start', 'middle', 'end')) @@ -1733,8 +1775,6 @@ class TestPerformanceTracker(unittest.TestCase): # 12 13 14 15 16 17 18 # 19 20 21 22 23 24 25 # 26 27 28 29 30 31 - trading.environment = trading.TradingEnvironment() - trading.environment.write_data(equities_identifiers=[133, 134]) start_dt = datetime(year=2008, month=10, day=9, @@ -1753,10 +1793,11 @@ class TestPerformanceTracker(unittest.TestCase): sim_params = SimulationParameters( period_start=start_dt, - period_end=end_dt + period_end=end_dt, + env=self.env, ) - benchmark_events = benchmark_events_in_range(sim_params) + benchmark_events = benchmark_events_in_range(sim_params, self.env) trade_history = factory.create_trade_history( sid, @@ -1764,7 +1805,8 @@ class TestPerformanceTracker(unittest.TestCase): volume, trade_time_increment, sim_params, - source_id="factory1" + source_id="factory1", + env=self.env ) sid2 = 134 @@ -1776,7 +1818,8 @@ class TestPerformanceTracker(unittest.TestCase): volume, trade_time_increment, sim_params, - source_id="factory2" + source_id="factory2", + env=self.env ) # 'middle' start of 3 depends on number of days == 7 middle = 3 @@ -1796,10 +1839,6 @@ class TestPerformanceTracker(unittest.TestCase): del trade_history[-days_to_delete.end:] del trade_history2[-days_to_delete.end:] - sim_params.first_open = \ - sim_params.calculate_first_open() - sim_params.last_close = \ - sim_params.calculate_last_close() sim_params.capital_base = 1000.0 sim_params.frame_index = [ 'sid', @@ -1808,7 +1847,7 @@ class TestPerformanceTracker(unittest.TestCase): 'price', 'changed'] perf_tracker = perf.PerformanceTracker( - sim_params + sim_params, self.env ) events = date_sorted_sources(trade_history, trade_history2) @@ -1887,23 +1926,21 @@ class TestPerformanceTracker(unittest.TestCase): else: yield event - @with_environment() - def test_minute_tracker(self, env=None): + def test_minute_tracker(self): """ Tests minute performance tracking.""" - start_dt = env.exchange_dt_in_utc(datetime(2013, 3, 1, 9, 31)) - end_dt = env.exchange_dt_in_utc(datetime(2013, 3, 1, 16, 0)) - - sim_params = SimulationParameters( - period_start=start_dt, - period_end=end_dt, - emission_rate='minute' - ) - tracker = perf.PerformanceTracker(sim_params) + start_dt = self.env.exchange_dt_in_utc(datetime(2013, 3, 1, 9, 31)) + end_dt = self.env.exchange_dt_in_utc(datetime(2013, 3, 1, 16, 0)) foosid = 1 barsid = 2 - env.write_data(equities_identifiers=[foosid, barsid]) + sim_params = SimulationParameters( + period_start=start_dt, + period_end=end_dt, + emission_rate='minute', + env=self.env, + ) + tracker = perf.PerformanceTracker(sim_params, env=self.env) foo_event_1 = factory.create_trade(foosid, 10.0, 20, start_dt) order_event_1 = Order(sid=foo_event_1.sid, @@ -1996,10 +2033,8 @@ class TestPerformanceTracker(unittest.TestCase): check_perf_tracker_serialization(tracker) - @with_environment() - def test_close_position_event(self, env=None): - env.write_data(equities_identifiers=[1, 2]) - pt = perf.PositionTracker() + def test_close_position_event(self): + pt = perf.PositionTracker(asset_finder=self.env.asset_finder) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(120.0), last_sale_date=dt, last_sale_price=3.4) @@ -2037,11 +2072,12 @@ class TestPerformanceTracker(unittest.TestCase): [10, 10, 10, 10, 10], [100, 100, 100, 100, 100], oneday, - sim_params + sim_params, + env=self.env ) # Create a tracker and a dividend - perf_tracker = perf.PerformanceTracker(sim_params) + perf_tracker = perf.PerformanceTracker(sim_params, env=self.env) dividend = factory.create_dividend( 1, 10.00, @@ -2081,11 +2117,12 @@ class TestPerformanceTracker(unittest.TestCase): sim_params = SimulationParameters( period_start=start_dt, - period_end=end_dt + period_end=end_dt, + env=self.env, ) perf_tracker = perf.PerformanceTracker( - sim_params + sim_params, env=self.env ) check_perf_tracker_serialization(perf_tracker) @@ -2099,16 +2136,26 @@ class TestPosition(unittest.TestCase): pos = perf.Position(10, amount=np.float64(120.0), last_sale_date=dt, last_sale_price=3.4) - p_string = pickle.dumps(pos) + p_string = dump_with_persistent_ids(pos) - test = pickle.loads(p_string) + test = load_with_persistent_ids(p_string, env=None) nt.assert_dict_equal(test.__dict__, pos.__dict__) class TestPositionTracker(unittest.TestCase): - def setUp(self): - trading.environment = trading.TradingEnvironment() + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + + equities_metadata = {1: {'asset_type': 'equity'}, + 2: {'asset_type': 'equity'}} + futures_metadata = {3: {'asset_type': 'future', + 'contract_multiplier': 1000}, + 4: {'asset_type': 'future', + 'contract_multiplier': 1000}} + cls.env.write_data(equities_data=equities_metadata, + futures_data=futures_metadata) def test_empty_positions(self): """ @@ -2117,7 +2164,7 @@ class TestPositionTracker(unittest.TestCase): Originally this bug was due to np.dot([], []) returning np.bool_(False) """ - pt = perf.PositionTracker() + pt = perf.PositionTracker(self.env.asset_finder) stats = [ 'calculate_positions_value', @@ -2137,41 +2184,28 @@ class TestPositionTracker(unittest.TestCase): self.assertEquals(val, 0) self.assertNotIsInstance(val, (bool, np.bool_)) - def test_update_last_sale(self, env=None): - equities_metadata = {1: {'asset_type': 'equity'}} - futures_metadata = {2: {'asset_type': 'future', - 'contract_multiplier': 1000}} - trading.environment.write_data(equities_data=equities_metadata, - futures_data=futures_metadata) - pt = perf.PositionTracker() + def test_update_last_sale(self): + pt = perf.PositionTracker(self.env.asset_finder) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(100.0), last_sale_date=dt, last_sale_price=10) - pos2 = perf.Position(2, amount=np.float64(100.0), + pos3 = perf.Position(3, amount=np.float64(100.0), last_sale_date=dt, last_sale_price=10) - pt.update_positions({1: pos1, 2: pos2}) + pt.update_positions({1: pos1, 3: pos3}) event1 = Event({'sid': 1, 'price': 11, 'dt': dt}) - event2 = Event({'sid': 2, + event3 = Event({'sid': 3, 'price': 11, 'dt': dt}) # Check cash-adjustment return value self.assertEqual(0, pt.update_last_sale(event1)) - self.assertEqual(100000, pt.update_last_sale(event2)) + self.assertEqual(100000, pt.update_last_sale(event3)) - def test_position_values_and_exposures(self, env=None): - equities_metadata = {1: {'asset_type': 'equity'}, - 2: {'asset_type': 'equity'}} - futures_metadata = {3: {'asset_type': 'future', - 'contract_multiplier': 1000}, - 4: {'asset_type': 'future', - 'contract_multiplier': 1000}} - trading.environment.write_data(equities_data=equities_metadata, - futures_data=futures_metadata) - pt = perf.PositionTracker() + def test_position_values_and_exposures(self): + pt = perf.PositionTracker(self.env.asset_finder) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), last_sale_date=dt, last_sale_price=10) @@ -2199,21 +2233,17 @@ class TestPositionTracker(unittest.TestCase): self.assertEqual(100 + 200 + 300000 + 400000, pt._gross_exposure()) self.assertEqual(100 - 200 + 300000 - 400000, pt._net_exposure()) - def test_serialization(self, env=None): - metadata = {1: {'asset_type': 'equity'}, - 2: {'asset_type': 'future', - 'contract_multiplier': 1000}} - trading.environment.write_data(equities_data=metadata) - pt = perf.PositionTracker() + def test_serialization(self): + pt = perf.PositionTracker(self.env.asset_finder) dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(120.0), last_sale_date=dt, last_sale_price=3.4) - pos2 = perf.Position(2, amount=np.float64(100.0), + pos3 = perf.Position(3, amount=np.float64(100.0), last_sale_date=dt, last_sale_price=3.4) - pt.update_positions({1: pos1, 2: pos2}) - p_string = pickle.dumps(pt) - test = pickle.loads(p_string) + pt.update_positions({1: pos1, 3: pos3}) + p_string = dump_with_persistent_ids(pt) + test = load_with_persistent_ids(p_string, env=self.env) nt.assert_dict_equal(test._position_amounts, pt._position_amounts) nt.assert_dict_equal(test._position_last_sale_prices, pt._position_last_sale_prices) @@ -2224,16 +2254,15 @@ class TestPositionTracker(unittest.TestCase): class TestPerformancePeriod(unittest.TestCase): - def setUp(self): - pass def test_serialization(self): - pt = perf.PositionTracker() - pp = perf.PerformancePeriod(100) + env = TradingEnvironment() + pt = perf.PositionTracker(env.asset_finder) + pp = perf.PerformancePeriod(100, env.asset_finder) pp.position_tracker = pt - p_string = pickle.dumps(pp) - test = pickle.loads(p_string) + p_string = dump_with_persistent_ids(pp) + test = load_with_persistent_ids(p_string, env=env) correct = pp.__dict__.copy() del correct['_position_tracker'] diff --git a/tests/test_pickle_serialization.py b/tests/test_pickle_serialization.py index 40b6b543..4ef4e1a3 100644 --- a/tests/test_pickle_serialization.py +++ b/tests/test_pickle_serialization.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pickle +from zipline.utils.serialization_utils import ( + load_with_persistent_ids, dump_with_persistent_ids +) from nose_parameterized import parameterized from unittest import TestCase from .serialization_cases import ( object_serialization_cases, - assert_dict_equal + assert_dict_equal, + cases_env, ) @@ -37,9 +40,9 @@ class PickleSerializationTestCase(TestCase): obj = cls(*initargs) for k, v in di_vars.items(): setattr(obj, k, v) - state = pickle.dumps(obj) + state = dump_with_persistent_ids(obj) - obj2 = pickle.loads(state) + obj2 = load_with_persistent_ids(state, env=cases_env) for k, v in di_vars.items(): setattr(obj2, k, v) diff --git a/tests/test_rolling_panel.py b/tests/test_rolling_panel.py index 82274f56..aa3cc62e 100644 --- a/tests/test_rolling_panel.py +++ b/tests/test_rolling_panel.py @@ -23,17 +23,21 @@ import pandas as pd import pandas.util.testing as tm from zipline.utils.data import MutableIndexRollingPanel, RollingPanel -from zipline.finance.trading import with_environment +from zipline.finance.trading import TradingEnvironment class TestRollingPanel(unittest.TestCase): - @with_environment() - def test_alignment(self, env): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + + def test_alignment(self): items = ('a', 'b') sids = (1, 2) - dts = env.market_minute_window( - env.open_and_closes.market_open[0], 4, + dts = self.env.market_minute_window( + self.env.open_and_closes.market_open[0], 4, ).values rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1]) @@ -90,8 +94,7 @@ class TestRollingPanel(unittest.TestCase): expected, ) - @with_environment() - def test_get_current_multiple_call_same_tick(self, env): + def test_get_current_multiple_call_same_tick(self): """ In old get_current, each call the get_current would copy the data. Thus changing that object would have no side effects. @@ -104,8 +107,8 @@ class TestRollingPanel(unittest.TestCase): items = ('a', 'b') sids = (1, 2) - dts = env.market_minute_window( - env.open_and_closes.market_open[0], 4, + dts = self.env.market_minute_window( + self.env.open_and_closes.market_open[0], 4, ).values rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1]) diff --git a/tests/test_security_list.py b/tests/test_security_list.py index 4fb354db..9dbb116c 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -6,8 +6,7 @@ from unittest import TestCase from zipline.algorithm import TradingAlgorithm from zipline.errors import TradingControlViolation from zipline.sources import SpecificEquityTrades -from zipline.finance import trading -from zipline.finance.trading import with_environment +from zipline.finance.trading import TradingEnvironment from zipline.utils.test_utils import ( setup_logger, teardown_logger, security_list_copy, add_security_data,) from zipline.utils import factory @@ -19,7 +18,7 @@ LEVERAGED_ETFS = load_from_directory('leveraged_etf_list') class RestrictedAlgoWithCheck(TradingAlgorithm): def initialize(self, symbol): - self.rl = SecurityListSet(self.get_datetime) + self.rl = SecurityListSet(self.get_datetime, self.asset_finder) self.set_do_not_order_list(self.rl.leveraged_etf_list) self.order_count = 0 self.sid = self.symbol(symbol) @@ -34,7 +33,7 @@ class RestrictedAlgoWithCheck(TradingAlgorithm): class RestrictedAlgoWithoutCheck(TradingAlgorithm): def initialize(self, symbol): - self.rl = SecurityListSet(self.get_datetime) + self.rl = SecurityListSet(self.get_datetime, self.asset_finder) self.set_do_not_order_list(self.rl.leveraged_etf_list) self.order_count = 0 self.sid = self.symbol(symbol) @@ -46,7 +45,7 @@ class RestrictedAlgoWithoutCheck(TradingAlgorithm): class IterateRLAlgo(TradingAlgorithm): def initialize(self, symbol): - self.rl = SecurityListSet(self.get_datetime) + self.rl = SecurityListSet(self.get_datetime, self.asset_finder) self.set_do_not_order_list(self.rl.leveraged_etf_list) self.order_count = 0 self.sid = self.symbol(symbol) @@ -60,6 +59,12 @@ class IterateRLAlgo(TradingAlgorithm): class SecurityListTestCase(TestCase): + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + cls.env.write_data(equities_identifiers=['AAPL', 'GOOG', 'BZQ', + 'URTY', 'JFT']) + def setUp(self, env=None): self.extra_knowledge_date = \ @@ -69,43 +74,38 @@ class SecurityListTestCase(TestCase): setup_logger(self) - trading.environment = trading.TradingEnvironment() - def tearDown(self): teardown_logger(self) def test_iterate_over_rl(self): sim_params = factory.create_simulation_parameters( - start=list(LEVERAGED_ETFS.keys())[0], num_days=4) - trading.environment.write_data(equities_identifiers=['BZQ']) + start=list(LEVERAGED_ETFS.keys())[0], num_days=4, env=self.env) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) - algo = IterateRLAlgo(symbol='BZQ', sim_params=sim_params) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) + algo = IterateRLAlgo(symbol='BZQ', sim_params=sim_params, env=self.env) algo.run(self.source) self.assertTrue(algo.found) - @with_environment() - def test_security_list(self, env=None): + def test_security_list(self): # set the knowledge date to the first day of the # leveraged etf knowledge date. def get_datetime(): return list(LEVERAGED_ETFS.keys())[0] - env.write_data(equities_identifiers=['AAPL', 'GOOG', 'BZQ', - 'URTY', 'JFT']) - - rl = SecurityListSet(get_datetime) + rl = SecurityListSet(get_datetime, self.env.asset_finder) # assert that a sample from the leveraged list are in restricted should_exist = [ asset.sid for asset in - [env.asset_finder.lookup_symbol( + [self.env.asset_finder.lookup_symbol( symbol, as_of_date=self.extra_knowledge_date) for symbol in ["BZQ", "URTY", "JFT"]] @@ -116,7 +116,7 @@ class SecurityListTestCase(TestCase): # assert that a sample of allowed stocks are not in restricted shouldnt_exist = [ asset.sid for asset in - [env.asset_finder.lookup_symbol( + [self.env.asset_finder.lookup_symbol( symbol, as_of_date=self.extra_knowledge_date) for symbol in ["AAPL", "GOOG"]] @@ -124,18 +124,15 @@ class SecurityListTestCase(TestCase): for sid in shouldnt_exist: self.assertNotIn(sid, rl.leveraged_etf_list) - @with_environment() - def test_security_add(self, env=None): + def test_security_add(self): def get_datetime(): return datetime(2015, 1, 27, tzinfo=pytz.utc) with security_list_copy(): add_security_data(['AAPL', 'GOOG'], []) - env.write_data(equities_identifiers=['AAPL', 'GOOG', - 'BZQ', 'URTY']) - rl = SecurityListSet(get_datetime) + rl = SecurityListSet(get_datetime, self.env.asset_finder) should_exist = [ asset.sid for asset in - [env.asset_finder.lookup_symbol( + [self.env.asset_finder.lookup_symbol( symbol, as_of_date=self.extra_knowledge_date ) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]] @@ -147,57 +144,67 @@ class SecurityListTestCase(TestCase): with security_list_copy(): def get_datetime(): return datetime(2015, 1, 27, tzinfo=pytz.utc) - trading.environment.write_data(equities_identifiers=['BZQ', - 'URTY']) - rl = SecurityListSet(get_datetime) + rl = SecurityListSet(get_datetime, self.env.asset_finder) self.assertNotIn("BZQ", rl.leveraged_etf_list) self.assertNotIn("URTY", rl.leveraged_etf_list) def test_algo_without_rl_violation_via_check(self): sim_params = factory.create_simulation_parameters( - start=list(LEVERAGED_ETFS.keys())[0], num_days=4) - trading.environment.write_data(equities_identifiers=['BZQ']) + start=list(LEVERAGED_ETFS.keys())[0], num_days=4, + env=self.env) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) - algo = RestrictedAlgoWithCheck(symbol='BZQ', sim_params=sim_params) + algo = RestrictedAlgoWithCheck(symbol='BZQ', + sim_params=sim_params, + env=self.env) algo.run(self.source) def test_algo_without_rl_violation(self): sim_params = factory.create_simulation_parameters( - start=list(LEVERAGED_ETFS.keys())[0], num_days=4) - trading.environment.write_data(equities_identifiers=['AAPL']) + start=list(LEVERAGED_ETFS.keys())[0], num_days=4, + env=self.env) trade_history = factory.create_trade_history( 'AAPL', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) - algo = RestrictedAlgoWithoutCheck(symbol='AAPL', sim_params=sim_params) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) + algo = RestrictedAlgoWithoutCheck(symbol='AAPL', + sim_params=sim_params, + env=self.env) algo.run(self.source) def test_algo_with_rl_violation(self): sim_params = factory.create_simulation_parameters( - start=list(LEVERAGED_ETFS.keys())[0], num_days=4) - trading.environment.write_data(equities_identifiers=['BZQ', 'JFT']) + start=list(LEVERAGED_ETFS.keys())[0], num_days=4, + env=self.env) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) - algo = RestrictedAlgoWithoutCheck(symbol='BZQ', sim_params=sim_params) + algo = RestrictedAlgoWithoutCheck(symbol='BZQ', + sim_params=sim_params, + env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) @@ -209,11 +216,15 @@ class SecurityListTestCase(TestCase): [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) - algo = RestrictedAlgoWithoutCheck(symbol='JFT', sim_params=sim_params) + algo = RestrictedAlgoWithoutCheck(symbol='JFT', + sim_params=sim_params, + env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) @@ -222,17 +233,21 @@ class SecurityListTestCase(TestCase): def test_algo_with_rl_violation_after_knowledge_date(self): sim_params = factory.create_simulation_parameters( start=list( - LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5) - trading.environment.write_data(equities_identifiers=['BZQ']) + LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5, + env=self.env) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) - algo = RestrictedAlgoWithoutCheck(symbol='BZQ', sim_params=sim_params) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) + algo = RestrictedAlgoWithoutCheck(symbol='BZQ', + sim_params=sim_params, + env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) @@ -255,12 +270,13 @@ class SecurityListTestCase(TestCase): [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env, ) - trading.environment.write_data(equities_identifiers=['BZQ']) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) algo = RestrictedAlgoWithoutCheck( - symbol='BZQ', sim_params=sim_params) + symbol='BZQ', sim_params=sim_params, env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) @@ -273,18 +289,19 @@ class SecurityListTestCase(TestCase): add_security_data([], ['BZQ']) sim_params = factory.create_simulation_parameters( start=self.extra_knowledge_date, num_days=3) - trading.environment.write_data(equities_identifiers=['BZQ']) trade_history = factory.create_trade_history( 'BZQ', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env, ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) algo = RestrictedAlgoWithoutCheck( - symbol='BZQ', sim_params=sim_params + symbol='BZQ', sim_params=sim_params, env=self.env ) algo.run(self.source) @@ -293,17 +310,18 @@ class SecurityListTestCase(TestCase): add_security_data(['AAPL'], []) sim_params = factory.create_simulation_parameters( start=self.trading_day_before_first_kd, num_days=4) - trading.environment.write_data(equities_identifiers=['AAPL']) trade_history = factory.create_trade_history( 'AAPL', [10.0, 10.0, 11.0, 11.0], [100, 100, 100, 300], timedelta(days=1), - sim_params + sim_params, + env=self.env ) - self.source = SpecificEquityTrades(event_list=trade_history) + self.source = SpecificEquityTrades(event_list=trade_history, + env=self.env) algo = RestrictedAlgoWithoutCheck( - symbol='AAPL', sim_params=sim_params) + symbol='AAPL', sim_params=sim_params, env=self.env) with self.assertRaises(TradingControlViolation) as ctx: algo.run(self.source) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index eabbdf91..103de348 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -39,7 +39,7 @@ def gather_bad_dicts(state): class SerializationTestCase(TestCase): @classmethod def setUpClass(cls): - cls.env = TradingEnvironment.instance() + cls.env = TradingEnvironment() @parameterized.expand(object_serialization_cases()) def test_object_serialization(self, diff --git a/tests/test_sources.py b/tests/test_sources.py index f4dd8114..ff82e2fe 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -27,12 +27,12 @@ from zipline.sources import (DataFrameSource, RandomWalkSource) from zipline.utils import tradingcalendar as calendar_nyse from zipline.assets import AssetFinder -from zipline.finance import trading +from zipline.finance.trading import TradingEnvironment class TestDataFrameSource(TestCase): def test_df_source(self): - source, df = factory.create_test_df_source() + source, df = factory.create_test_df_source(env=None) assert isinstance(source.start, pd.lib.Timestamp) assert isinstance(source.end, pd.lib.Timestamp) @@ -43,7 +43,7 @@ class TestDataFrameSource(TestCase): assert expected_price[0] == sid0.price def test_df_sid_filtering(self): - _, df = factory.create_test_df_source() + _, df = factory.create_test_df_source(env=None) source = DataFrameSource(df) assert 1 not in [event.sid for event in source], \ "DataFrameSource should only stream selected sid 0, not sid 1." @@ -65,10 +65,10 @@ class TestDataFrameSource(TestCase): self.assertTrue(isinstance(event['arbitrary'], float)) def test_yahoo_bars_to_panel_source(self): - trading.environment = trading.TradingEnvironment() - finder = AssetFinder(trading.environment.engine) + env = TradingEnvironment() + finder = AssetFinder(env.engine) stocks = ['AAPL', 'GE'] - trading.environment.write_data(equities_identifiers=stocks) + env.write_data(equities_identifiers=stocks) start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc) end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) data = factory.load_bars_from_yahoo(stocks=stocks, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4d882d22..02b01688 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -103,6 +103,7 @@ def with_algo(f): initialize=initialize_with(self, tfm_name, days), handle_data=handle_data_wrapper(f), sim_params=sim_params, + env=self.env, ) algo.run(source) @@ -127,17 +128,19 @@ class TransformTestCase(TestCase): data_frequency='daily', emission_rate='daily', ) - cls.env = TradingEnvironment.instance() + cls.env = TradingEnvironment() cls.env.write_data(equities_identifiers=[1, 2, 3]) cls.sim_and_source = { 'minute': (minute_sim_ps, factory.create_minutely_trade_source( cls.sids, sim_params=minute_sim_ps, + env=cls.env, )), 'daily': (daily_sim_ps, factory.create_trade_source( cls.sids, trade_time_increment=timedelta(days=1), sim_params=daily_sim_ps, + env=cls.env, )), } diff --git a/tests/test_transforms_talib.py b/tests/test_transforms_talib.py index 203d0190..e8cc6d81 100644 --- a/tests/test_transforms_talib.py +++ b/tests/test_transforms_talib.py @@ -16,6 +16,7 @@ import pytz import numpy as np import pandas as pd +import talib from datetime import timedelta, datetime from unittest import TestCase, skip @@ -23,21 +24,26 @@ from unittest import TestCase, skip from zipline.utils.test_utils import setup_logger, teardown_logger import zipline.utils.factory as factory +from zipline.finance.trading import TradingEnvironment from zipline.test_algorithms import TALIBAlgorithm -import talib import zipline.transforms.ta as ta class TestTALIB(TestCase): + + @classmethod + def setUpClass(cls): + cls.env = TradingEnvironment() + def setUp(self): setup_logger(self) sim_params = factory.create_simulation_parameters( start=datetime(1990, 1, 1, tzinfo=pytz.utc), end=datetime(1990, 3, 30, tzinfo=pytz.utc)) self.source, self.panel = \ - factory.create_test_panel_ohlc_source(sim_params) + factory.create_test_panel_ohlc_source(sim_params, self.env) def tearDown(self): teardown_logger(self) @@ -60,7 +66,7 @@ class TestTALIB(TestCase): sim_params = factory.create_simulation_parameters( start=start, end=end) source, panel = \ - factory.create_test_panel_ohlc_source(sim_params) + factory.create_test_panel_ohlc_source(sim_params, self.env) algo = TALIBAlgorithm(talib=zipline_transform) algo.run(source) diff --git a/tests/utils/test_events.py b/tests/utils/test_events.py index b708445c..d89d0587 100644 --- a/tests/utils/test_events.py +++ b/tests/utils/test_events.py @@ -19,10 +19,12 @@ import random from six.moves import range, map from nose_parameterized import parameterized from unittest import TestCase +from functools import partial +from collections import namedtuple import numpy as np -from zipline.finance.trading import TradingEnvironment, with_environment +from zipline.finance.trading import TradingEnvironment import zipline.utils.events from zipline.utils.events import ( EventRule, @@ -161,7 +163,7 @@ class TestEventManager(TestCase): class CountingRule(Always): count = 0 - def should_trigger(self, dt): + def should_trigger(self, dt, env): CountingRule.count += 1 return True @@ -170,7 +172,10 @@ class TestEventManager(TestCase): Event(r(), lambda context, data: None) ) - self.em.handle_data(None, None, datetime.datetime.now()) + mock_algo_class = namedtuple('FakeAlgo', ['trading_environment']) + mock_algo = mock_algo_class(trading_environment="fake_env") + self.em.handle_data(mock_algo, None, datetime.datetime.now(), + mock_algo.trading_environment) self.assertEqual(CountingRule.count, 5) @@ -182,11 +187,10 @@ class TestEventRule(TestCase): def test_not_implemented(self): with self.assertRaises(NotImplementedError): - super(Always, Always()).should_trigger('a') + super(Always, Always()).should_trigger('a', env=None) -@with_environment() -def minutes_for_days(env=None): +def minutes_for_days(): """ 500 randomly selected days. This is used to make sure our test coverage is unbaised towards any rules. @@ -202,6 +206,7 @@ def minutes_for_days(env=None): Iterating over this yeilds a single day, iterating over the day yields the minutes for that day. """ + env = TradingEnvironment() random.seed('deterministic') return ((env.market_minutes_for_day(random.choice(env.trading_days)),) for _ in range(500)) @@ -210,7 +215,7 @@ def minutes_for_days(env=None): class RuleTestCase(TestCase): @classmethod def setUpClass(cls): - cls.env = TradingEnvironment.instance() + cls.env = TradingEnvironment() cls.class_ = None # Mark that this is the base class. def test_completeness(self): @@ -256,17 +261,18 @@ class TestStatelessRules(RuleTestCase): @parameterized.expand(minutes_for_days()) def test_Always(self, ms): - should_trigger = Always().should_trigger - self.assertTrue(all(map(should_trigger, ms))) + should_trigger = partial(Always().should_trigger, env=self.env) + self.assertTrue(all(map(partial(should_trigger, env=self.env), ms))) @parameterized.expand(minutes_for_days()) def test_Never(self, ms): - should_trigger = Never().should_trigger + should_trigger = partial(Never().should_trigger, env=self.env) self.assertFalse(any(map(should_trigger, ms))) @parameterized.expand(minutes_for_days()) def test_AfterOpen(self, ms): - should_trigger = AfterOpen(minutes=5, hours=1).should_trigger + should_trigger = partial(AfterOpen(minutes=5, hours=1).should_trigger, + env=self.env) for m in islice(ms, 64): # Check the first 64 minutes of data. # We use 64 because the offset is from market open @@ -280,20 +286,23 @@ class TestStatelessRules(RuleTestCase): @parameterized.expand(minutes_for_days()) def test_BeforeClose(self, ms): ms = list(ms) - should_trigger = BeforeClose(hours=1, minutes=5).should_trigger + should_trigger = partial( + BeforeClose(hours=1, minutes=5).should_trigger, env=self.env + ) for m in ms[0:-66]: self.assertFalse(should_trigger(m)) for m in ms[-66:]: self.assertTrue(should_trigger(m)) def test_NotHalfDay(self): - should_trigger = NotHalfDay().should_trigger + should_trigger = partial(NotHalfDay().should_trigger, env=self.env) self.assertTrue(should_trigger(FULL_DAY)) self.assertFalse(should_trigger(HALF_DAY)) @parameterized.expand(param_range(MAX_WEEK_RANGE)) def test_NthTradingDayOfWeek(self, n): - should_trigger = NthTradingDayOfWeek(n).should_trigger + should_trigger = partial(NthTradingDayOfWeek(n).should_trigger, + env=self.env) prev_day = self.sept_week[0].date() n_tdays = 0 for m in self.sept_week: @@ -308,7 +317,9 @@ class TestStatelessRules(RuleTestCase): @parameterized.expand(param_range(MAX_WEEK_RANGE)) def test_NDaysBeforeLastTradingDayOfWeek(self, n): - should_trigger = NDaysBeforeLastTradingDayOfWeek(n).should_trigger + should_trigger = partial( + NDaysBeforeLastTradingDayOfWeek(n).should_trigger, env=self.env + ) for m in self.sept_week: if should_trigger(m): n_tdays = 0 @@ -323,7 +334,8 @@ class TestStatelessRules(RuleTestCase): @parameterized.expand(param_range(MAX_MONTH_RANGE)) def test_NthTradingDayOfMonth(self, n): - should_trigger = NthTradingDayOfMonth(n).should_trigger + should_trigger = partial(NthTradingDayOfMonth(n).should_trigger, + env=self.env) for n_tdays, d in enumerate(self.sept_days): for m in self.env.market_minutes_for_day(d): if should_trigger(m): @@ -333,7 +345,9 @@ class TestStatelessRules(RuleTestCase): @parameterized.expand(param_range(MAX_MONTH_RANGE)) def test_NDaysBeforeLastTradingDayOfMonth(self, n): - should_trigger = NDaysBeforeLastTradingDayOfMonth(n).should_trigger + should_trigger = partial( + NDaysBeforeLastTradingDayOfMonth(n).should_trigger, env=self.env + ) for n_days_before, d in enumerate(reversed(self.sept_days)): for m in self.env.market_minutes_for_day(d): if should_trigger(m): @@ -347,10 +361,11 @@ class TestStatelessRules(RuleTestCase): rule2 = Never() composed = rule1 & rule2 + should_trigger = partial(composed.should_trigger, env=self.env) self.assertIsInstance(composed, ComposedRule) self.assertIs(composed.first, rule1) self.assertIs(composed.second, rule2) - self.assertFalse(any(map(composed.should_trigger, ms))) + self.assertFalse(any(map(should_trigger, ms))) class TestStatefulRules(RuleTestCase): @@ -369,14 +384,14 @@ class TestStatefulRules(RuleTestCase): """ count = 0 - def should_trigger(self, dt): - st = self.rule.should_trigger(dt) + def should_trigger(self, dt, env): + st = self.rule.should_trigger(dt, env) if st: self.count += 1 return st rule = RuleCounter(OncePerDay()) for m in ms: - rule.should_trigger(m) + rule.should_trigger(m, env=self.env) self.assertEqual(rule.count, 1) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index e850f164..dc95359f 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -191,6 +191,18 @@ class TradingAlgorithm(object): self.instant_fill = kwargs.pop('instant_fill', False) + # If an env has been provided, pop it + self.trading_environment = kwargs.pop('env', None) + + if self.trading_environment is None: + self.trading_environment = TradingEnvironment() + + # Update the TradingEnvironment with the provided asset metadata + self.trading_environment.write_data( + equities_data=kwargs.pop('asset_metadata', {}), + equities_identifiers=kwargs.pop('identifiers', []), + ) + # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) self.sim_params = kwargs.pop('sim_params', None) @@ -198,17 +210,15 @@ class TradingAlgorithm(object): self.sim_params = create_simulation_parameters( capital_base=self.capital_base, start=kwargs.pop('start', None), - end=kwargs.pop('end', None) + end=kwargs.pop('end', None), + env=self.trading_environment, ) - self.perf_tracker = PerformanceTracker(self.sim_params) + else: + self.sim_params.update_internal_from_env(self.trading_environment) - # Update the TradingEnvironment with the provided asset metadata - self.trading_environment = kwargs.pop('env', - TradingEnvironment.instance()) - self.trading_environment.write_data( - equities_data=kwargs.pop('asset_metadata', {}), - equities_identifiers=kwargs.pop('identifiers', []), - ) + # Build a perf_tracker + self.perf_tracker = PerformanceTracker(sim_params=self.sim_params, + env=self.trading_environment) # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder @@ -441,7 +451,9 @@ class TradingAlgorithm(object): if self.perf_tracker is None: # HACK: When running with the `run` method, we set perf_tracker to # None so that it will be overwritten here. - self.perf_tracker = PerformanceTracker(sim_params) + self.perf_tracker = PerformanceTracker( + sim_params=sim_params, env=self.trading_environment + ) self.portfolio_needs_update = True self.account_needs_update = True @@ -500,8 +512,21 @@ class TradingAlgorithm(object): # if DataFrame provided, map columns to sids and wrap # in DataFrameSource copy_frame = source.copy() + + # Build new Assets for identifiers that can't be resolved as + # sids/Assets + identifiers_to_build = [] + for identifier in source.columns: + if hasattr(identifier, '__int__'): + asset = self.asset_finder.retrieve_asset(sid=identifier, + default_none=True) + if asset is None: + identifiers_to_build.append(identifier) + else: + identifiers_to_build.append(identifier) + self.trading_environment.write_data( - equities_identifiers=source.columns) + equities_identifiers=identifiers_to_build) copy_frame.columns = \ self.asset_finder.map_identifier_index_to_sids( source.columns, source.index[0] @@ -512,8 +537,21 @@ class TradingAlgorithm(object): # If Panel provided, map items to sids and wrap # in DataPanelSource copy_panel = source.copy() + + # Build new Assets for identifiers that can't be resolved as + # sids/Assets + identifiers_to_build = [] + for identifier in source.items: + if hasattr(identifier, '__int__'): + asset = self.asset_finder.retrieve_asset(sid=identifier, + default_none=True) + if asset is None: + identifiers_to_build.append(identifier) + else: + identifiers_to_build.append(identifier) + self.trading_environment.write_data( - equities_identifiers=source.items) + equities_identifiers=identifiers_to_build) copy_panel.items = self.asset_finder.map_identifier_index_to_sids( source.items, source.major_axis[0] ) @@ -532,7 +570,9 @@ class TradingAlgorithm(object): self.sim_params.period_end = source.end # Changing period_start and period_close might require updating # of first_open and last_close. - self.sim_params._update_internal() + self.sim_params.update_internal_from_env( + env=self.trading_environment + ) # The sids field of the source is the reference for the universe at # the start of the run @@ -560,6 +600,7 @@ class TradingAlgorithm(object): self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, + self.trading_environment, ) # loop through simulated_trading, each iteration returns a @@ -1137,7 +1178,8 @@ class TradingAlgorithm(object): def add_history(self, bar_count, frequency, field, ffill=True): data_frequency = self.sim_params.data_frequency history_spec = HistorySpec(bar_count, frequency, field, ffill, - data_frequency=data_frequency) + data_frequency=data_frequency, + env=self.trading_environment) self.history_specs[history_spec.key_str] = history_spec if self.initialized: if self.history_container: @@ -1150,6 +1192,7 @@ class TradingAlgorithm(object): self.current_universe(), self.sim_params.first_open, self.sim_params.data_frequency, + env=self.trading_environment, ) def get_history_spec(self, bar_count, frequency, field, ffill): @@ -1162,6 +1205,7 @@ class TradingAlgorithm(object): field, ffill, data_frequency=data_freq, + env=self.trading_environment, ) self.history_specs[spec_key] = spec if not self.history_container: @@ -1171,6 +1215,7 @@ class TradingAlgorithm(object): self.datetime, self.sim_params.data_frequency, bar_data=self._most_recent_data, + env=self.trading_environment, ) self.history_container.ensure_spec( spec, self.datetime, self._most_recent_data, diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index b06da9e8..d53d1604 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -46,6 +46,10 @@ log = Logger('assets.py') class AssetFinder(object): + # Token used as a substitute for pickling objects that contain a + # reference to an AssetFinder + PERSISTENT_TOKEN = "" + def __init__(self, engine, allow_sid_assignment=True, fuzzy_char=None): self.fuzzy_char = fuzzy_char @@ -160,7 +164,9 @@ class AssetFinder(object): else: asset = None - self._asset_cache[sid] = asset + # Cache the asset if it has been retrieved + if asset is not None: + self._asset_cache[sid] = asset if asset is not None: return asset diff --git a/zipline/errors.py b/zipline/errors.py index b8941934..8fd97ba5 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -402,3 +402,15 @@ class UnsupportedDatetimeFormat(ZiplineError): """ msg = ("The input '{input}' passed to '{method}' is not " "coercible to a pandas.Timestamp object.") + + +class PositionTrackerMissingAssetFinder(ZiplineError): + """ + Raised by a PositionTracker if it is asked to update an Asset but does not + have an AssetFinder + """ + msg = ( + "PositionTracker attempted to update its Asset information but does " + "not have an AssetFinder. This may be caused by a failure to properly " + "de-serialize a TradingAlgorithm." + ) diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 4dc02146..6721cc9a 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -75,7 +75,6 @@ import logbook import numpy as np -from zipline.finance.trading import TradingEnvironment from zipline.assets import Future try: @@ -92,8 +91,6 @@ from zipline.utils.serialization_utils import ( VERSION_LABEL ) -from .position_tracker import PositionTracker - log = logbook.Logger('Performance') TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE @@ -103,12 +100,15 @@ class PerformancePeriod(object): def __init__( self, starting_cash, + asset_finder, period_open=None, period_close=None, keep_transactions=True, keep_orders=False, serialize_positions=True): + self.asset_finder = asset_finder + self.period_open = period_open self.period_close = period_close @@ -225,8 +225,7 @@ class PerformancePeriod(object): try: multiplier = self._execution_cash_flow_multipliers[txn.sid] except KeyError: - asset = TradingEnvironment.instance().asset_finder.\ - retrieve_asset(txn.sid) + asset = self.asset_finder.retrieve_asset(txn.sid) # Futures experience no cash flow on transactions if isinstance(asset, Future): multiplier = 0 @@ -424,13 +423,13 @@ class PerformancePeriod(object): state_dict['orders_by_modified'] = \ dict(self.orders_by_modified) - STATE_VERSION = 2 + STATE_VERSION = 3 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 1 + OLDEST_SUPPORTED_STATE = 3 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: @@ -450,16 +449,4 @@ class PerformancePeriod(object): self._execution_cash_flow_multipliers = {} - # pop positions to use for v1 - positions = state.pop('positions', None) self.__dict__.update(state) - - if version == 1: - # version 1 had PositionTracker logic inside of Period - # we create the PositionTracker here. - # Note: that in V2 it is assumed that the position_tracker - # will be dependency injected and so is not reconstructed - assert positions is not None, "positions should exist in v1" - position_tracker = PositionTracker() - position_tracker.update_positions(positions) - self.position_tracker = position_tracker diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index 4d0851d6..b3a9bdd6 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -21,7 +21,7 @@ import zipline.protocol as zp from zipline.assets import ( Equity, Future ) -from zipline.finance.trading import with_environment +from zipline.errors import PositionTrackerMissingAssetFinder from . position import positiondict log = logbook.Logger('Performance') @@ -29,7 +29,9 @@ log = logbook.Logger('Performance') class PositionTracker(object): - def __init__(self): + def __init__(self, asset_finder): + self.asset_finder = asset_finder + # sid => position object self.positions = positiondict() # Arrays for quick calculations of positions value @@ -47,18 +49,18 @@ class PositionTracker(object): # for any Assets in this tracker's positions self._auto_close_position_sids = {} - @with_environment() - def _retrieve_asset(self, sid, env=None): - return env.asset_finder.retrieve_asset(sid) - def _update_asset(self, sid): try: self._position_value_multipliers[sid] self._position_exposure_multipliers[sid] self._position_payout_multipliers[sid] except KeyError: + # Check if there is an AssetFinder + if self.asset_finder is None: + raise PositionTrackerMissingAssetFinder() + # Collect the value multipliers from applicable sids - asset = self._retrieve_asset(sid) + asset = self.asset_finder.retrieve_asset(sid) if isinstance(asset, Equity): self._position_value_multipliers[sid] = 1 self._position_exposure_multipliers[sid] = 1 @@ -400,20 +402,31 @@ class PositionTracker(object): def __getstate__(self): state_dict = {} + state_dict['asset_finder'] = self.asset_finder state_dict['positions'] = dict(self.positions) state_dict['unpaid_dividends'] = self._unpaid_dividends - STATE_VERSION = 1 + # Asset-finder dependent dicts must be serialized + state_dict['position_value_multipliers'] = \ + serialize_ordered_dict(self._position_value_multipliers) + state_dict['position_exposure_multipliers'] = \ + serialize_ordered_dict(self._position_exposure_multipliers) + state_dict['position_payout_multipliers'] = \ + serialize_ordered_dict(self._position_payout_multipliers) + state_dict['auto_close_position_sids'] = self._auto_close_position_sids + + STATE_VERSION = 3 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 1 + OLDEST_SUPPORTED_STATE = 3 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: raise BaseException("PositionTracker saved state is too old.") + self.asset_finder = state['asset_finder'] self.positions = positiondict() # note that positions_store is temporary and gets regened from # .positions @@ -421,12 +434,35 @@ class PositionTracker(object): self._unpaid_dividends = state['unpaid_dividends'] + # AssetFinder-dependent dicts are de-serialized + self._position_value_multipliers = \ + deserialize_ordered_dict(state['position_value_multipliers']) + self._position_exposure_multipliers = \ + deserialize_ordered_dict(state['position_exposure_multipliers']) + self._position_payout_multipliers = \ + deserialize_ordered_dict(state['position_payout_multipliers']) + self._auto_close_position_sids = state['auto_close_position_sids'] + # Arrays for quick calculations of positions value self._position_amounts = OrderedDict() self._position_last_sale_prices = OrderedDict() - self._position_value_multipliers = OrderedDict() - self._position_exposure_multipliers = OrderedDict() - self._position_payout_multipliers = OrderedDict() - self._auto_close_position_sids = {} + # Update positions is called without a finder self.update_positions(state['positions']) + + +def serialize_ordered_dict(ordered_dict): + """ + Converts an OrderedDict in to a list of key/value pair tuples + """ + return [(key, value) for key, value in ordered_dict.items()] + + +def deserialize_ordered_dict(serialized_ordered_dict): + """ + Converts a list of key/value pair tuples in to an OrderedDict + """ + result = OrderedDict() + for key, value in serialized_ordered_dict: + result[key] = value + return result diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index 376ed5fb..3db80d2d 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -68,7 +68,6 @@ import pandas as pd from pandas.tseries.tools import normalize_date import zipline.finance.risk as risk -from zipline.finance.trading import TradingEnvironment from . period import PerformancePeriod from zipline.utils.serialization_utils import ( @@ -83,15 +82,17 @@ class PerformanceTracker(object): """ Tracks the performance of the algorithm. """ - def __init__(self, sim_params): + def __init__(self, sim_params, env): self.sim_params = sim_params - env = TradingEnvironment.instance() + self.env = env self.period_start = self.sim_params.period_start self.period_end = self.sim_params.period_end self.last_close = self.sim_params.last_close - first_open = self.sim_params.first_open.tz_convert(env.exchange_tz) + first_open = self.sim_params.first_open.tz_convert( + self.env.exchange_tz + ) self.day = pd.Timestamp(datetime(first_open.year, first_open.month, first_open.day), tz='UTC') self.market_open, self.market_close = env.get_open_and_close(self.day) @@ -108,7 +109,7 @@ class PerformanceTracker(object): self.dividend_frame = pd.DataFrame() self._dividend_count = 0 - self.position_tracker = PositionTracker() + self.position_tracker = PositionTracker(asset_finder=env.asset_finder) self.perf_periods = [] @@ -116,7 +117,7 @@ class PerformanceTracker(object): self.all_benchmark_returns = pd.Series( index=self.trading_days) self.cumulative_risk_metrics = \ - risk.RiskMetricsCumulative(self.sim_params) + risk.RiskMetricsCumulative(self.sim_params, self.env) elif self.emission_rate == 'minute': self.all_benchmark_returns = pd.Series(index=pd.date_range( @@ -124,22 +125,23 @@ class PerformanceTracker(object): freq='Min')) self.cumulative_risk_metrics = \ - risk.RiskMetricsCumulative(self.sim_params, + risk.RiskMetricsCumulative(self.sim_params, self.env, create_first_day_stats=True) self.minute_performance = PerformancePeriod( # initial cash is your capital base. - self.capital_base, + starting_cash=self.capital_base, # the cumulative period will be calculated over the # entire test. - self.period_start, - self.period_end, + period_open=self.period_start, + period_close=self.period_end, # don't save the transactions for the cumulative # period keep_transactions=False, keep_orders=False, # don't serialize positions for cumualtive period - serialize_positions=False + serialize_positions=False, + asset_finder=self.env.asset_finder, ) self.minute_performance.position_tracker = self.position_tracker self.perf_periods.append(self.minute_performance) @@ -148,16 +150,17 @@ class PerformanceTracker(object): # inception. self.cumulative_performance = PerformancePeriod( # initial cash is your capital base. - self.capital_base, + starting_cash=self.capital_base, # the cumulative period will be calculated over the entire test. - self.period_start, - self.period_end, + period_open=self.period_start, + period_close=self.period_end, # don't save the transactions for the cumulative # period keep_transactions=False, keep_orders=False, # don't serialize positions for cumualtive period serialize_positions=False, + asset_finder=self.env.asset_finder, ) self.cumulative_performance.position_tracker = self.position_tracker self.perf_periods.append(self.cumulative_performance) @@ -165,13 +168,14 @@ class PerformanceTracker(object): # this performance period will span just the current market day self.todays_performance = PerformancePeriod( # initial cash is your capital base. - self.capital_base, + starting_cash=self.capital_base, # the daily period will be calculated for the market day - self.market_open, - self.market_close, + period_open=self.market_open, + period_close=self.market_close, keep_transactions=True, keep_orders=True, serialize_positions=True, + asset_finder=self.env.asset_finder, ) self.todays_performance.position_tracker = self.position_tracker @@ -490,8 +494,7 @@ class PerformanceTracker(object): # Get the next trading day and, if it is past the bounds of this # simulation, return the daily perf packet - next_trading_day = TradingEnvironment.instance().\ - next_trading_day(completed_date) + next_trading_day = self.env.next_trading_day(completed_date) # Check if any assets need to be auto-closed before generating today's # perf period @@ -509,10 +512,9 @@ class PerformanceTracker(object): return daily_update # move the market day markers forward - env = TradingEnvironment.instance() self.market_open, self.market_close = \ - env.next_open_and_close(self.day) - self.day = env.next_trading_day(self.day) + self.env.next_open_and_close(self.day) + self.day = self.env.next_trading_day(self.day) # Roll over positions to current day. self.todays_performance.rollover() @@ -552,7 +554,8 @@ class PerformanceTracker(object): ars, self.sim_params, benchmark_returns=bms, - algorithm_leverages=acl) + algorithm_leverages=acl, + env=self.env) risk_dict = self.risk_report.to_dict() return risk_dict @@ -569,14 +572,14 @@ class PerformanceTracker(object): # we already store perf periods as attributes del state_dict['perf_periods'] - STATE_VERSION = 3 + STATE_VERSION = 4 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 3 + OLDEST_SUPPORTED_STATE = 4 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: diff --git a/zipline/finance/risk/cumulative.py b/zipline/finance/risk/cumulative.py index 41b0ff5f..227b46c4 100644 --- a/zipline/finance/risk/cumulative.py +++ b/zipline/finance/risk/cumulative.py @@ -18,7 +18,6 @@ import logbook import math import numpy as np -from zipline.finance import trading import zipline.utils.math_utils as zp_math import pandas as pd @@ -91,10 +90,10 @@ class RiskMetricsCumulative(object): 'information', ) - def __init__(self, sim_params, + def __init__(self, sim_params, env, create_first_day_stats=False, account=None): - self.treasury_curves = trading.environment.treasury_curves + self.treasury_curves = env.treasury_curves self.start_date = sim_params.period_start.replace( hour=0, minute=0, second=0, microsecond=0 ) @@ -102,15 +101,12 @@ class RiskMetricsCumulative(object): hour=0, minute=0, second=0, microsecond=0 ) - self.trading_days = trading.environment.days_in_range( - self.start_date, - self.end_date) + self.trading_days = env.days_in_range(self.start_date, self.end_date) # Hold on to the trading day before the start, # used for index of the zero return value when forcing returns # on the first day. - self.day_before_start = self.start_date - \ - trading.environment.trading_days.freq + self.day_before_start = self.start_date - env.trading_days.freq last_day = normalize_date(sim_params.period_end) if last_day not in self.trading_days: @@ -120,6 +116,7 @@ class RiskMetricsCumulative(object): self.trading_days = self.trading_days.append(last_day) self.sim_params = sim_params + self.env = env self.create_first_day_stats = create_first_day_stats @@ -276,7 +273,8 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}" treasury_period_return = choose_treasury( self.treasury_curves, self.start_date, - treasury_end + treasury_end, + self.env, ) self.daily_treasury[treasury_end] = treasury_period_return self.treasury_period_return = self.daily_treasury[treasury_end] @@ -459,18 +457,17 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}" return beta def __getstate__(self): - state_dict = \ - {k: v for k, v in iteritems(self.__dict__) if - (not k.startswith('_') and not k == 'treasury_curves')} + state_dict = {k: v for k, v in iteritems(self.__dict__) + if not k.startswith('_')} - STATE_VERSION = 2 + STATE_VERSION = 3 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 2 + OLDEST_SUPPORTED_STATE = 3 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: @@ -478,7 +475,3 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}" saved state is too old.") self.__dict__.update(state) - - # This are big and we don't need to serialize them - # pop them back in now - self.treasury_curves = trading.environment.treasury_curves diff --git a/zipline/finance/risk/period.py b/zipline/finance/risk/period.py index 2ed81896..32cc2fe3 100644 --- a/zipline/finance/risk/period.py +++ b/zipline/finance/risk/period.py @@ -22,8 +22,6 @@ import numpy.linalg as la from six import iteritems -from zipline.finance import trading - import pandas as pd from . import risk @@ -47,11 +45,11 @@ choose_treasury = functools.partial(risk.choose_treasury, class RiskMetricsPeriod(object): - def __init__(self, start_date, end_date, returns, - benchmark_returns=None, - algorithm_leverages=None): + def __init__(self, start_date, end_date, returns, env, + benchmark_returns=None, algorithm_leverages=None): - treasury_curves = trading.environment.treasury_curves + self.env = env + treasury_curves = env.treasury_curves if treasury_curves.index[-1] >= start_date: mask = ((treasury_curves.index >= start_date) & (treasury_curves.index <= end_date)) @@ -66,12 +64,14 @@ class RiskMetricsPeriod(object): self.end_date = end_date if benchmark_returns is None: - br = trading.environment.benchmark_returns + br = env.benchmark_returns benchmark_returns = br[(br.index >= returns.index[0]) & (br.index <= returns.index[-1])] - self.algorithm_returns = self.mask_returns_to_period(returns) - self.benchmark_returns = self.mask_returns_to_period(benchmark_returns) + self.algorithm_returns = self.mask_returns_to_period(returns, + env) + self.benchmark_returns = self.mask_returns_to_period(benchmark_returns, + env) self.algorithm_leverages = algorithm_leverages self.calculate_metrics() @@ -114,7 +114,8 @@ class RiskMetricsPeriod(object): self.treasury_period_return = choose_treasury( self.treasury_curves, self.start_date, - self.end_date + self.end_date, + self.env, ) self.sharpe = self.calculate_sharpe() # The consumer currently expects a 0.0 value for sharpe in period, @@ -193,14 +194,14 @@ class RiskMetricsPeriod(object): return '\n'.join(statements) - def mask_returns_to_period(self, daily_returns): + def mask_returns_to_period(self, daily_returns, env): if isinstance(daily_returns, list): returns = pd.Series([x.returns for x in daily_returns], index=[x.date for x in daily_returns]) else: # otherwise we're receiving an index already returns = daily_returns - trade_days = trading.environment.trading_days + trade_days = env.trading_days trade_day_mask = returns.index.normalize().isin(trade_days) mask = ((returns.index >= self.start_date) & @@ -321,18 +322,17 @@ class RiskMetricsPeriod(object): return max(self.algorithm_leverages) def __getstate__(self): - state_dict = \ - {k: v for k, v in iteritems(self.__dict__) if - (not k.startswith('_') and not k == 'treasury_curves')} + state_dict = {k: v for k, v in iteritems(self.__dict__) + if not k.startswith('_')} - STATE_VERSION = 2 + STATE_VERSION = 3 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 2 + OLDEST_SUPPORTED_STATE = 3 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: @@ -340,5 +340,3 @@ class RiskMetricsPeriod(object): is too old.") self.__dict__.update(state) - - self.treasury_curves = trading.environment.treasury_curves diff --git a/zipline/finance/risk/report.py b/zipline/finance/risk/report.py index b769a3a2..651b3834 100644 --- a/zipline/finance/risk/report.py +++ b/zipline/finance/risk/report.py @@ -72,7 +72,7 @@ log = logbook.Logger('Risk Report') class RiskReport(object): - def __init__(self, algorithm_returns, sim_params, + def __init__(self, algorithm_returns, sim_params, env, benchmark_returns=None, algorithm_leverages=None): """ algorithm_returns needs to be a list of daily_return objects @@ -84,6 +84,7 @@ class RiskReport(object): self.algorithm_returns = algorithm_returns self.sim_params = sim_params + self.env = env self.benchmark_returns = benchmark_returns self.algorithm_leverages = algorithm_leverages @@ -144,6 +145,7 @@ class RiskReport(object): end_date=cur_end, returns=self.algorithm_returns, benchmark_returns=self.benchmark_returns, + env=self.env, algorithm_leverages=self.algorithm_leverages, ) @@ -160,14 +162,14 @@ class RiskReport(object): if '_dividend_count' in dir(self): state_dict['_dividend_count'] = self._dividend_count - STATE_VERSION = 1 + STATE_VERSION = 2 state_dict[VERSION_LABEL] = STATE_VERSION return state_dict def __setstate__(self, state): - OLDEST_SUPPORTED_STATE = 1 + OLDEST_SUPPORTED_STATE = 2 version = state.pop(VERSION_LABEL) if version < OLDEST_SUPPORTED_STATE: diff --git a/zipline/finance/risk/risk.py b/zipline/finance/risk/risk.py index f3ec615a..a3f99ac3 100644 --- a/zipline/finance/risk/risk.py +++ b/zipline/finance/risk/risk.py @@ -62,7 +62,6 @@ import logbook import math import numpy as np -from zipline.finance import trading import zipline.utils.math_utils as zp_math log = logbook.Logger('Risk') @@ -203,8 +202,8 @@ def get_treasury_rate(treasury_curves, treasury_duration, day): return rate -def search_day_distance(end_date, dt): - tdd = trading.environment.trading_day_distance(dt, end_date) +def search_day_distance(end_date, dt, env): + tdd = env.trading_day_distance(dt, end_date) if tdd is None: return None assert tdd >= 0 @@ -238,7 +237,7 @@ def select_treasury_duration(start_date, end_date): def choose_treasury(select_treasury, treasury_curves, start_date, end_date, - compound=True): + env, compound=True): """ Find the latest known interest rate for a given duration within a date range. @@ -270,7 +269,7 @@ def choose_treasury(select_treasury, treasury_curves, start_date, end_date, prev_day) if rate is not None: search_day = prev_day - search_dist = search_day_distance(end_date, prev_day) + search_dist = search_day_distance(end_date, prev_day, env) break if search_day: diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index f1bdf930..a2c71b51 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -16,7 +16,6 @@ import bisect import logbook import datetime -from functools import wraps import pandas as pd import numpy as np @@ -51,40 +50,17 @@ log = logbook.Logger('Trading') # for serialization and storage, and the timezone is used to # ensure proper rollover through daylight savings and so on. # -# This module maintains a global variable, environment, which is -# subsequently referenced directly by zipline financial -# components. To set the environment, you can set the property on -# the module directly: -# from zipline.finance import trading -# trading.environment = TradingEnvironment() -# -# or if you want to switch the environment for a limited context -# you can use a TradingEnvironment in a with clause: -# lse = TradingEnvironment(bm_index="^FTSE", exchange_tz="Europe/London") -# with lse: -# the code here will have lse as the global trading.environment -# algo.run(start, end) -# # User code will not normally need to use TradingEnvironment # directly. If you are extending zipline's core financial -# compponents and need to use the environment, you must import the module -# NOT the variable. If you import the module, you will get a -# reference to the environment at import time, which will prevent -# your code from responding to user code that changes the global -# state. - -environment = None - +# components and need to use the environment, you must import the module and +# build a new TradingEnvironment object, then pass that TradingEnvironment as +# the 'env' arg to your TradingAlgorithm. class TradingEnvironment(object): - @classmethod - def instance(cls): - global environment - if not environment: - environment = TradingEnvironment() - - return environment + # Token used as a substitute for pickling objects that contain a + # reference to a TradingEnvironment + PERSISTENT_TOKEN = "" def __init__( self, @@ -140,21 +116,6 @@ class TradingEnvironment(object): AssetDBWriterFromDictionary().init_db(engine) self.asset_finder = AssetFinder(engine) - def __enter__(self, *args, **kwargs): - global environment - self.prev_environment = environment - environment = self - # return value here is associated with "as such_and_such" on the - # with clause. - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global environment - environment = self.prev_environment - # signal that any exceptions need to be propagated up the - # stack. - return False - def write_data(self, engine=None, equities_data={}, @@ -486,7 +447,8 @@ class SimulationParameters(object): def __init__(self, period_start, period_end, capital_base=10e3, emission_rate='daily', - data_frequency='daily'): + data_frequency='daily', + env=None): self.period_start = period_start self.period_end = period_end @@ -498,55 +460,53 @@ class SimulationParameters(object): # copied to algorithm's environment for runtime access self.arena = 'backtest' - self._update_internal() + if env is not None: + self.update_internal_from_env(env=env) - def _update_internal(self): - # This is the global environment for trading simulation. - environment = TradingEnvironment.instance() + def update_internal_from_env(self, env): assert self.period_start <= self.period_end, \ "Period start falls after period end." - assert self.period_start <= environment.last_trading_day, \ + assert self.period_start <= env.last_trading_day, \ "Period start falls after the last known trading day." - assert self.period_end >= environment.first_trading_day, \ + assert self.period_end >= env.first_trading_day, \ "Period end falls before the first known trading day." - self.first_open = self.calculate_first_open() - self.last_close = self.calculate_last_close() - start_index = \ - environment.get_index(self.first_open) - end_index = environment.get_index(self.last_close) + self.first_open = self._calculate_first_open(env) + self.last_close = self._calculate_last_close(env) + + start_index = env.get_index(self.first_open) + end_index = env.get_index(self.last_close) # take an inclusive slice of the environment's # trading_days. - self.trading_days = \ - environment.trading_days[start_index:end_index + 1] + self.trading_days = env.trading_days[start_index:end_index + 1] - def calculate_first_open(self): + def _calculate_first_open(self, env): """ Finds the first trading day on or after self.period_start. """ first_open = self.period_start one_day = datetime.timedelta(days=1) - while not environment.is_trading_day(first_open): + while not env.is_trading_day(first_open): first_open = first_open + one_day - mkt_open, _ = environment.get_open_and_close(first_open) + mkt_open, _ = env.get_open_and_close(first_open) return mkt_open - def calculate_last_close(self): + def _calculate_last_close(self, env): """ Finds the last trading day on or before self.period_end """ last_close = self.period_end one_day = datetime.timedelta(days=1) - while not environment.is_trading_day(last_close): + while not env.is_trading_day(last_close): last_close = last_close - one_day - _, mkt_close = environment.get_open_and_close(last_close) + _, mkt_close = env.get_open_and_close(last_close) return mkt_close @property @@ -572,33 +532,3 @@ class SimulationParameters(object): emission_rate=self.emission_rate, first_open=self.first_open, last_close=self.last_close) - - -def with_environment(asname='env'): - """ - Decorator to automagically pass TradingEnvironment to the function - under the name asname. If the environment is passed explicitly as a keyword - then the explicitly passed value will be used instead. - - usage: - with_environment() - def f(env=None): - pass - - with_environment(asname='my_env') - def g(my_env=None): - pass - """ - def with_environment_decorator(f): - @wraps(f) - def wrapper(*args, **kwargs): - # inject env into the namespace for the function. - # This doesn't use setdefault so that grabbing the trading env - # is lazy. - if asname not in kwargs: - kwargs[asname] = TradingEnvironment.instance() - return f(*args, **kwargs) - - return wrapper - - return with_environment_decorator diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 15ff4411..e63ca853 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -20,7 +20,7 @@ from pandas.tslib import normalize_date from zipline.utils.api_support import ZiplineAPI -from zipline.finance import trading +from zipline.finance.trading import NoFurtherDataError from zipline.protocol import ( BarData, SIDData, @@ -50,6 +50,7 @@ class AlgorithmSimulator(object): # ============== self.algo = algo self.algo_start = normalize_date(self.sim_params.first_open) + self.env = algo.trading_environment # ============== # Snapshot Setup @@ -132,10 +133,9 @@ class AlgorithmSimulator(object): mkt_close < self.algo.perf_tracker.last_close try: mkt_open, mkt_close = \ - trading.environment \ - .next_open_and_close(mkt_close) + self.env.next_open_and_close(mkt_close) - except trading.NoFurtherDataError: + except NoFurtherDataError: # If at the end of backtest history, # skip advancing market close. pass @@ -144,7 +144,7 @@ class AlgorithmSimulator(object): self._call_before_trading_start(mkt_open) elif data_frequency == 'daily': - next_day = trading.environment.next_trading_day(date) + next_day = self.env.next_trading_day(date) if next_day is not None and \ next_day < self.algo.perf_tracker.last_close: diff --git a/zipline/history/history.py b/zipline/history/history.py index 817f49aa..8244aa01 100644 --- a/zipline/history/history.py +++ b/zipline/history/history.py @@ -19,8 +19,6 @@ import numpy as np import pandas as pd import re -from zipline.finance import trading -from zipline.finance.trading import with_environment from zipline.errors import IncompatibleHistoryFrequency @@ -45,7 +43,7 @@ class Frequency(object): MAX_MINUTES = {'m': 1, 'd': 390} MAX_DAYS = {'d': 1} - def __init__(self, freq_str, data_frequency): + def __init__(self, freq_str, data_frequency, env): if freq_str not in self.SUPPORTED_FREQUENCIES: raise ValueError( @@ -61,6 +59,7 @@ class Frequency(object): self.num, self.unit_str = parse_freq_str(freq_str) self.data_frequency = data_frequency + self.env = env def next_window_start(self, previous_window_close): """ @@ -68,35 +67,25 @@ class Frequency(object): finished on @previous_window_close. """ if self.unit_str == 'd': - return self.next_day_window_start(previous_window_close, + return self.next_day_window_start(previous_window_close, self.env, self.data_frequency) elif self.unit_str == 'm': - return self.next_minute_window_start(previous_window_close) + return self.env.next_market_minute(previous_window_close) @staticmethod - def next_day_window_start(previous_window_close, data_frequency='minute'): + def next_day_window_start(previous_window_close, env, + data_frequency='minute'): """ Get the next day window start after @previous_window_close. This is defined as the first market open strictly greater than @previous_window_close. """ - env = trading.environment if data_frequency == 'daily': next_open = env.next_trading_day(previous_window_close) else: next_open = env.next_market_minute(previous_window_close) return next_open - @staticmethod - def next_minute_window_start(previous_window_close): - """ - Get the next minute window start after @previous_window_close. This is - defined as the first market minute strictly greater than - @previous_window_close. - """ - env = trading.environment - return env.next_market_minute(previous_window_close) - def window_open(self, window_close): """ For a period ending on `window_end`, calculate the date of the first @@ -123,8 +112,7 @@ class Frequency(object): minute @window_close. This is calculated by searching backward until @num_days market_closes are encountered. """ - env = trading.environment - open_ = env.open_close_window( + open_ = self.env.open_close_window( window_close, 1, offset=-(num_days - 1) @@ -147,8 +135,9 @@ class Frequency(object): # Short circuit this case. return window_close - env = trading.environment - return env.market_minute_window(window_close, count=-num_minutes)[-1] + return self.env.market_minute_window( + window_close, count=-num_minutes + )[-1] def day_window_close(self, window_start, num_days): """ @@ -159,15 +148,13 @@ class Frequency(object): If the data_frequency is minute, this will be midnight utc of the last day of the window. """ - env = trading.environment - if self.data_frequency != 'daily': - return env.get_open_and_close( - env.add_trading_days(num_days - 1, window_start), + return self.env.get_open_and_close( + self.env.add_trading_days(num_days - 1, window_start), )[1] return pd.tslib.normalize_date( - env.add_trading_days(num_days - 1, window_start), + self.env.add_trading_days(num_days - 1, window_start), ) def minute_window_close(self, window_start, num_minutes): @@ -182,23 +169,23 @@ class Frequency(object): # Short circuit this case. return window_start - env = trading.environment - return env.market_minute_window(window_start, count=num_minutes)[-1] + return self.env.market_minute_window( + window_start, count=num_minutes + )[-1] - @with_environment() - def prev_bar(self, dt, env=None): + def prev_bar(self, dt): """ Returns the previous bar for dt. """ if self.unit_str == 'd': if self.data_frequency == 'minute': def func(dt): - return env.get_open_and_close( - env.previous_trading_day(dt))[1] + return self.env.get_open_and_close( + self.env.previous_trading_day(dt))[1] else: - func = env.previous_trading_day + func = self.env.previous_trading_day else: - func = env.previous_market_minute + func = self.env.previous_market_minute # Cache the function dispatch. self.prev_bar = func @@ -262,13 +249,13 @@ class HistorySpec(object): return "{0}:{1}:{2}:{3}".format( bar_count, freq_str, field, ffill) - def __init__(self, bar_count, frequency, field, ffill, + def __init__(self, bar_count, frequency, field, ffill, env, data_frequency='daily'): # Number of bars to look back. self.bar_count = bar_count if isinstance(frequency, str): - frequency = Frequency(frequency, data_frequency) + frequency = Frequency(frequency, data_frequency, env) if frequency.unit_str == 'm' and data_frequency == 'daily': raise IncompatibleHistoryFrequency( frequency=frequency.unit_str, @@ -299,12 +286,11 @@ class HistorySpec(object): return ''.join([self.__class__.__name__, "('", self.key_str, "')"]) -def days_index_at_dt(history_spec, algo_dt): +def days_index_at_dt(history_spec, algo_dt, env): """ Get the index of a frame to be used for a get_history call with daily frequency. """ - env = trading.environment # Get the previous (bar_count - 1) days' worth of market closes. day_delta = (history_spec.bar_count - 1) * history_spec.frequency.num market_closes = env.open_close_window( @@ -323,13 +309,12 @@ def days_index_at_dt(history_spec, algo_dt): return np.append(market_closes.values, algo_dt) -def minutes_index_at_dt(history_spec, algo_dt): +def minutes_index_at_dt(history_spec, algo_dt, env): """ Get the index of a frame to be used for a get_history_call with minutely frequency. """ # TODO: This is almost certainly going to be too slow for production. - env = trading.environment return env.market_minute_window( algo_dt, history_spec.bar_count, @@ -337,7 +322,7 @@ def minutes_index_at_dt(history_spec, algo_dt): )[::-1] -def index_at_dt(history_spec, algo_dt): +def index_at_dt(history_spec, algo_dt, env): """ Returns index of a frame returned by get_history() with the given history_spec and algo_dt. @@ -352,6 +337,6 @@ def index_at_dt(history_spec, algo_dt): """ frequency = history_spec.frequency if frequency.unit_str == 'd': - return days_index_at_dt(history_spec, algo_dt) + return days_index_at_dt(history_spec, algo_dt, env) elif frequency.unit_str == 'm': - return minutes_index_at_dt(history_spec, algo_dt) + return minutes_index_at_dt(history_spec, algo_dt, env) diff --git a/zipline/history/history_container.py b/zipline/history/history_container.py index 83c069b7..8afea5e3 100644 --- a/zipline/history/history_container.py +++ b/zipline/history/history_container.py @@ -23,7 +23,6 @@ from six import itervalues, iteritems, iterkeys from . history import HistorySpec -from zipline.finance.trading import with_environment from zipline.utils.data import RollingPanel, _ensure_index from zipline.utils.munge import ffill, bfill @@ -112,7 +111,6 @@ def freq_str_and_bar_count(history_spec): return (history_spec.frequency.freq_str, history_spec.bar_count) -@with_environment() def next_bar(spec, env): """ Returns a function that will return the next bar for a given datetime. @@ -208,6 +206,7 @@ class HistoryContainer(object): initial_sids, initial_dt, data_frequency, + env, bar_data=None): """ A container to hold a rolling window of historical data within a user's @@ -229,6 +228,9 @@ class HistoryContainer(object): An instance of a new HistoryContainer """ + # Store a reference to the env + self.env = env + # History specs to be served by this container. self.history_specs = history_specs self.largest_specs = compute_largest_specs( @@ -315,8 +317,7 @@ class HistoryContainer(object): """ return iterkeys(self.largest_specs) - @with_environment() - def _add_frequency(self, spec, dt, data, env=None): + def _add_frequency(self, spec, dt, data): """ Adds a new frequency to the container. This reshapes the buffer_panel if needed. @@ -350,9 +351,7 @@ class HistoryContainer(object): if spec.bar_count > 1: # This spec has more than one bar, construct a digest panel for it. - self.digest_panels[freq] = self._create_digest_panel( - dt, spec=spec, env=env, - ) + self.digest_panels[freq] = self._create_digest_panel(dt, spec=spec) else: self.cur_window_starts[freq] = dt self.cur_window_closes[freq] = freq.window_close( @@ -383,8 +382,7 @@ class HistoryContainer(object): ) return field - @with_environment() - def _add_length(self, spec, dt, env=None): + def _add_length(self, spec, dt): """ Increases the length of the digest panel for spec.frequency. If this does not have a panel, and one is needed; a digest panel will be @@ -399,21 +397,17 @@ class HistoryContainer(object): if panel is None: # The old length for this frequency was 1 bar, meaning no digest # panel was held. We must construct a new one here. - panel = self._create_digest_panel( - dt, spec=spec, env=env, - ) + panel = self._create_digest_panel(dt, spec=spec) else: - self._resize_panel( - panel, spec.bar_count - 1, dt, freq=spec.frequency, env=env, - ) + self._resize_panel(panel, spec.bar_count - 1, dt, + freq=spec.frequency) self.digest_panels[spec.frequency] = panel return LengthDelta(spec.frequency, delta) - @with_environment() - def _resize_panel(self, panel, size, dt, freq, env=None): + def _resize_panel(self, panel, size, dt, freq): """ Resizes a panel, fills the date_buf with the correct values. """ @@ -429,26 +423,24 @@ class HistoryContainer(object): panel.extend_back(missing_dts) - @with_environment() def _create_window_date_buf(self, window, unit_str, data_frequency, - dt, - env=None): + dt): """ Creates a window length date_buf looking backwards from dt. """ if unit_str == 'd': # Get the properly key'd datetime64 out of the pandas Timestamp if data_frequency != 'daily': - arr = env.open_close_window( + arr = self.env.open_close_window( dt, window, offset=-window, ).market_close.astype('datetime64[ns]').values else: - arr = env.open_close_window( + arr = self.env.open_close_window( dt, window, offset=-window, @@ -456,14 +448,13 @@ class HistoryContainer(object): return arr else: - return env.market_minute_window( - env.previous_market_minute(dt), + return self.env.market_minute_window( + self.env.previous_market_minute(dt), window, step=-1, )[::-1].values - @with_environment() - def _create_panel(self, dt, spec, env=None): + def _create_panel(self, dt, spec): """ Constructs a rolling panel with a properly aligned date_buf. """ @@ -476,7 +467,6 @@ class HistoryContainer(object): spec.frequency.unit_str, spec.frequency.data_frequency, dt, - env=env, ) panel = RollingPanel( @@ -488,13 +478,11 @@ class HistoryContainer(object): return panel - @with_environment() def _create_digest_panel(self, dt, spec, window_starts=None, - window_closes=None, - env=None): + window_closes=None): """ Creates a digest panel, setting the window_starts and window_closes. If window_starts or window_closes are None, then self.cur_window_starts @@ -510,7 +498,7 @@ class HistoryContainer(object): window_starts[freq] = freq.normalize(dt) window_closes[freq] = freq.window_close(window_starts[freq]) - return self._create_panel(dt, spec, env=env) + return self._create_panel(dt, spec) def ensure_spec(self, spec, dt, bar_data): """ @@ -565,11 +553,9 @@ class HistoryContainer(object): for panel in self.all_panels: panel.set_items(self.fields) - @with_environment() def create_digest_panels(self, initial_sids, - initial_dt, - env=None): + initial_dt): """ Initialize a RollingPanel for each unique panel frequency being stored by this container. Each RollingPanel pre-allocates enough storage @@ -601,7 +587,6 @@ class HistoryContainer(object): spec=largest_spec, window_starts=first_window_starts, window_closes=first_window_closes, - env=env, ) panels[freq] = rp @@ -618,7 +603,8 @@ class HistoryContainer(object): ) freq = '1m' if self.data_frequency == 'minute' else '1d' spec = HistorySpec( - max_bars_needed + 1, freq, None, None, self.data_frequency, + max_bars_needed + 1, freq, None, None, self.env, + self.data_frequency, ) rp = self._create_panel( diff --git a/zipline/protocol.py b/zipline/protocol.py index 83dada7b..26aa7832 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -23,7 +23,6 @@ import numpy as np from . utils.protocol_utils import Enum from . utils.math_utils import nanstd, nanmean, nansum -from zipline.finance.trading import with_environment from zipline.utils.algo_instance import get_algo_instance from zipline.utils.serialization_utils import ( VERSION_LABEL @@ -400,8 +399,7 @@ class SIDData(object): def daily_get_bars(days): return days - @with_environment() - def minute_get_bars(days, env=None): + def minute_get_bars(days): cls = self.__class__ now = get_algo_instance().datetime @@ -412,6 +410,7 @@ class SIDData(object): if days not in cls._minute_bar_cache: # Cache this calculation to happen once per bar, even if we # use another transform with the same number of days. + env = get_algo_instance().trading_environment prev = env.previous_trading_day(now) ds = env.days_in_range( env.add_trading_days(-days + 2, prev), diff --git a/zipline/sources/test_source.py b/zipline/sources/test_source.py index 278940c0..c4d02708 100644 --- a/zipline/sources/test_source.py +++ b/zipline/sources/test_source.py @@ -30,7 +30,6 @@ from zipline.protocol import ( DATASOURCE_TYPE ) from zipline.gens.utils import hash_args -from zipline.finance.trading import with_environment def create_trade(sid, price, amount, datetime, source_id="test_factory"): @@ -51,12 +50,11 @@ def create_trade(sid, price, amount, datetime, source_id="test_factory"): return trade -@with_environment() def date_gen(start, end, + env, delta=timedelta(minutes=1), - repeats=None, - env=None): + repeats=None): """ Utility to generate a stream of dates. """ @@ -111,11 +109,12 @@ class SpecificEquityTrades(object): delta : timedelta between internal events filter : filter to remove the sids """ - @with_environment() - def __init__(self, env=None, *args, **kwargs): + def __init__(self, env, *args, **kwargs): # We shouldn't get any positional arguments. assert len(args) == 0 + self.env = env + # Default to None for event_list and filter. self.event_list = kwargs.get('event_list') self.filter = kwargs.get('filter') @@ -206,12 +205,14 @@ class SpecificEquityTrades(object): end=self.end, delta=self.delta, repeats=len(self.sids), + env=self.env, ) else: date_generator = date_gen( start=self.start, end=self.end, - delta=self.delta + delta=self.delta, + env=self.env, ) source_id = self.get_hash() diff --git a/zipline/transforms/batch_transform.py b/zipline/transforms/batch_transform.py index ec6f3dc5..7d5f9d6b 100644 --- a/zipline/transforms/batch_transform.py +++ b/zipline/transforms/batch_transform.py @@ -34,8 +34,12 @@ from six import ( from zipline.utils.data import MutableIndexRollingPanel from zipline.protocol import Event +from zipline.finance.trading import TradingEnvironment -from zipline.finance import trading +# HACK the BatchTransform module stores a trading environment to be used by +# the transforms +# TODO remove this hack, if not this whole module +_batch_transform_env = TradingEnvironment() log = logbook.Logger('BatchTransform') func_map = {'open_price': 'first', @@ -67,8 +71,8 @@ def downsample_panel(minute_rp, daily_rp, mkt_close): cur_panel = minute_rp.get_current() sids = minute_rp.minor_axis day_frame = pd.DataFrame(columns=sids, index=cur_panel.items) - dt1 = trading.environment.normalize_date(mkt_close) - dt2 = trading.environment.next_trading_day(mkt_close) + dt1 = _batch_transform_env.normalize_date(mkt_close) + dt2 = _batch_transform_env.next_trading_day(mkt_close) by_close = functools.partial(get_date, mkt_close, dt1, dt2) for item in minute_rp.items: frame = cur_panel[item] @@ -333,11 +337,11 @@ class BatchTransform(object): # we may get events from non-trading sources which occurr on # non-trading days. The book-keeping for market close and # trading day counting should only consider trading days. - if trading.environment.is_trading_day(event.dt): - _, mkt_close = trading.environment.get_open_and_close(event.dt) + if _batch_transform_env.is_trading_day(event.dt): + _, mkt_close = _batch_transform_env.get_open_and_close(event.dt) if self.bars == 'daily': # Daily bars have their dt set to midnight. - mkt_close = trading.environment.normalize_date(mkt_close) + mkt_close = _batch_transform_env.normalize_date(mkt_close) if event.dt == mkt_close: if self.downsample: downsample_panel(self.rolling_panel, diff --git a/zipline/utils/events.py b/zipline/utils/events.py index 1514f587..55197698 100644 --- a/zipline/utils/events.py +++ b/zipline/utils/events.py @@ -20,8 +20,6 @@ import datetime import pandas as pd import pytz -from zipline.finance.trading import TradingEnvironment - __all__ = [ 'EventManager', @@ -191,7 +189,7 @@ class EventManager(object): def handle_data(self, context, data, dt): for event in self._events: - event.handle_data(context, data, dt) + event.handle_data(context, data, dt, context.trading_environment) class Event(namedtuple('Event', ['rule', 'callback'])): @@ -204,11 +202,11 @@ class Event(namedtuple('Event', ['rule', 'callback'])): callback = callback or (lambda *args, **kwargs: None) return super(cls, cls).__new__(cls, rule=rule, callback=callback) - def handle_data(self, context, data, dt): + def handle_data(self, context, data, dt, env): """ Calls the callable only when the rule is triggered. """ - if self.rule.should_trigger(dt): + if self.rule.should_trigger(dt, env): self.callback(context, data) @@ -216,12 +214,8 @@ class EventRule(six.with_metaclass(ABCMeta)): """ An event rule checks a datetime and sees if it should trigger. """ - @property - def env(self): - return TradingEnvironment.instance() - @abstractmethod - def should_trigger(self, dt): + def should_trigger(self, dt, env): """ Checks if the rule should trigger with it's current state. This method should be pure and NOT mutate any state on the object. @@ -267,7 +261,7 @@ class ComposedRule(StatelessRule): self.second = second self.composer = composer - def should_trigger(self, dt): + def should_trigger(self, dt, env): """ Composes the two rules with a lazy composer. """ @@ -275,15 +269,16 @@ class ComposedRule(StatelessRule): self.first.should_trigger, self.second.should_trigger, dt, + env, ) @staticmethod - def lazy_and(first_should_trigger, second_should_trigger, dt): + def lazy_and(first_should_trigger, second_should_trigger, dt, env): """ Lazily ands the two rules. This will NOT call the should_trigger of the second rule if the first one returns False. """ - return first_should_trigger(dt) and second_should_trigger(dt) + return first_should_trigger(dt, env) and second_should_trigger(dt, env) class Always(StatelessRule): @@ -291,7 +286,7 @@ class Always(StatelessRule): A rule that always triggers. """ @staticmethod - def always_trigger(dt): + def always_trigger(dt, env): """ A should_trigger implementation that will always trigger. """ @@ -304,7 +299,7 @@ class Never(StatelessRule): A rule that never triggers. """ @staticmethod - def never_trigger(dt): + def never_trigger(dt, env): """ A should_trigger implementation that will never trigger. """ @@ -328,15 +323,15 @@ class AfterOpen(StatelessRule): self._dt = None - def should_trigger(self, dt): - return self._get_open(dt) + self.offset <= dt + def should_trigger(self, dt, env): + return self._get_open(dt, env) + self.offset <= dt - def _get_open(self, dt): + def _get_open(self, dt, env): """ Cache the open for each day. """ if self._dt is None or (self._dt.date() != dt.date()): - self._dt = self.env.get_open_and_close(dt)[0] \ + self._dt = env.get_open_and_close(dt)[0] \ - datetime.timedelta(minutes=1) return self._dt @@ -358,15 +353,15 @@ class BeforeClose(StatelessRule): self._dt = None - def should_trigger(self, dt): - return self._get_close(dt) - self.offset <= dt + def should_trigger(self, dt, env): + return self._get_close(dt, env) - self.offset <= dt - def _get_close(self, dt): + def _get_close(self, dt, env): """ Cache the close for each day. """ if self._dt is None or (self._dt.date() != dt.date()): - self._dt = self.env.get_open_and_close(dt)[1] + self._dt = env.get_open_and_close(dt)[1] return self._dt @@ -375,8 +370,8 @@ class NotHalfDay(StatelessRule): """ A rule that only triggers when it is not a half day. """ - def should_trigger(self, dt): - return dt.date() not in self.env.early_closes + def should_trigger(self, dt, env): + return dt.date() not in env.early_closes class NthTradingDayOfWeek(StatelessRule): @@ -389,18 +384,18 @@ class NthTradingDayOfWeek(StatelessRule): raise _out_of_range_error(MAX_WEEK_RANGE) self.td_delta = n - def should_trigger(self, dt): - return _coerce_datetime(self.env.add_trading_days( + def should_trigger(self, dt, env): + return _coerce_datetime(env.add_trading_days( self.td_delta, - self.get_first_trading_day_of_week(dt), + self.get_first_trading_day_of_week(dt, env), )).date() == dt.date() - def get_first_trading_day_of_week(self, dt): + def get_first_trading_day_of_week(self, dt, env): prev = dt - dt = self.env.previous_trading_day(dt) + dt = env.previous_trading_day(dt) while dt.date().weekday() < prev.date().weekday(): prev = dt - dt = self.env.previous_trading_day(dt) + dt = env.previous_trading_day(dt) return prev.date() @@ -414,20 +409,20 @@ class NDaysBeforeLastTradingDayOfWeek(StatelessRule): self.td_delta = -n self.date = None - def should_trigger(self, dt): - return _coerce_datetime(self.env.add_trading_days( + def should_trigger(self, dt, env): + return _coerce_datetime(env.add_trading_days( self.td_delta, - self.get_last_trading_day_of_week(dt), + self.get_last_trading_day_of_week(dt, env), )).date() == dt.date() - def get_last_trading_day_of_week(self, dt): + def get_last_trading_day_of_week(self, dt, env): prev = dt - dt = self.env.next_trading_day(dt) + dt = env.next_trading_day(dt) # Traverse forward until we hit a week border, then jump back to the # previous trading day. while dt.date().weekday() > prev.date().weekday(): prev = dt - dt = self.env.next_trading_day(dt) + dt = env.next_trading_day(dt) return prev.date() @@ -443,30 +438,30 @@ class NthTradingDayOfMonth(StatelessRule): self.month = None self.day = None - def should_trigger(self, dt): - return self.get_nth_trading_day_of_month(dt) == dt.date() + def should_trigger(self, dt, env): + return self.get_nth_trading_day_of_month(dt, env) == dt.date() - def get_nth_trading_day_of_month(self, dt): + def get_nth_trading_day_of_month(self, dt, env): if self.month == dt.month: # We already computed the day for this month. return self.day if not self.td_delta: - self.day = self.get_first_trading_day_of_month(dt) + self.day = self.get_first_trading_day_of_month(dt, env) else: - self.day = self.env.add_trading_days( + self.day = env.add_trading_days( self.td_delta, - self.get_first_trading_day_of_month(dt), + self.get_first_trading_day_of_month(dt, env), ).date() return self.day - def get_first_trading_day_of_month(self, dt): + def get_first_trading_day_of_month(self, dt, env): self.month = dt.month dt = dt.replace(day=1) - self.first_day = (dt if self.env.is_trading_day(dt) - else self.env.next_trading_day(dt)).date() + self.first_day = (dt if env.is_trading_day(dt) + else env.next_trading_day(dt)).date() return self.first_day @@ -481,25 +476,25 @@ class NDaysBeforeLastTradingDayOfMonth(StatelessRule): self.month = None self.day = None - def should_trigger(self, dt): - return self.get_nth_to_last_trading_day_of_month(dt) == dt.date() + def should_trigger(self, dt, env): + return self.get_nth_to_last_trading_day_of_month(dt, env) == dt.date() - def get_nth_to_last_trading_day_of_month(self, dt): + def get_nth_to_last_trading_day_of_month(self, dt, env): if self.month == dt.month: # We already computed the last day for this month. return self.day if not self.td_delta: - self.day = self.get_last_trading_day_of_month(dt) + self.day = self.get_last_trading_day_of_month(dt, env) else: - self.day = self.env.add_trading_days( + self.day = env.add_trading_days( self.td_delta, - self.get_last_trading_day_of_month(dt), + self.get_last_trading_day_of_month(dt, env), ).date() return self.day - def get_last_trading_day_of_month(self, dt): + def get_last_trading_day_of_month(self, dt, env): self.month = dt.month if dt.month == 12: @@ -511,7 +506,7 @@ class NDaysBeforeLastTradingDayOfMonth(StatelessRule): year = dt.year month = dt.month + 1 - self.last_day = self.env.previous_trading_day( + self.last_day = env.previous_trading_day( dt.replace(year=year, month=month, day=1) ).date() return self.last_day @@ -543,14 +538,14 @@ class OncePerDay(StatefulRule): self.triggered = False super(OncePerDay, self).__init__(rule) - def should_trigger(self, dt): + def should_trigger(self, dt, env): dt_date = dt.date() if self.date is None or self.date != dt_date: # initialize or reset for new date self.triggered = False self.date = dt_date - if not self.triggered and self.rule.should_trigger(dt): + if not self.triggered and self.rule.should_trigger(dt, env): self.triggered = True return True diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 628f0f3f..fddd5a24 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -28,8 +28,7 @@ from zipline.protocol import Event, DATASOURCE_TYPE from zipline.sources import (SpecificEquityTrades, DataFrameSource, DataPanelSource) -from zipline.finance.trading import SimulationParameters -from zipline.finance import trading +from zipline.finance.trading import SimulationParameters, TradingEnvironment from zipline.sources.test_source import create_trade @@ -44,16 +43,18 @@ def create_simulation_parameters(year=2006, start=None, end=None, capital_base=float("1.0e5"), num_days=None, load=None, data_frequency='daily', - emission_rate='daily'): + emission_rate='daily', + env=None): """Construct a complete environment with reasonable defaults""" + if env is None: + env = TradingEnvironment(load=load) if start is None: start = datetime(year, 1, 1, tzinfo=pytz.utc) if end is None: if num_days: - trading.environment = trading.TradingEnvironment(load=load) - start_index = trading.environment.trading_days.searchsorted( + start_index = env.trading_days.searchsorted( start) - end = trading.environment.trading_days[start_index + num_days - 1] + end = env.trading_days[start_index + num_days - 1] else: end = datetime(year, 12, 31, tzinfo=pytz.utc) sim_params = SimulationParameters( @@ -62,14 +63,15 @@ def create_simulation_parameters(year=2006, start=None, end=None, capital_base=capital_base, data_frequency=data_frequency, emission_rate=emission_rate, + env=env, ) return sim_params def create_random_simulation_parameters(): - trading.environment = trading.TradingEnvironment() - treasury_curves = trading.environment.treasury_curves + env = TradingEnvironment() + treasury_curves = env.treasury_curves for n in range(100): @@ -92,30 +94,31 @@ check treasury and benchmark data in findb, and re-run the test.""" sim_params = SimulationParameters( period_start=start_dt, - period_end=end_dt + period_end=end_dt, + env=env, ) return sim_params, start_dt, end_dt -def get_next_trading_dt(current, interval): - next_dt = pd.Timestamp(current).tz_convert(trading.environment.exchange_tz) +def get_next_trading_dt(current, interval, env): + next_dt = pd.Timestamp(current).tz_convert(env.exchange_tz) while True: # Convert timestamp to naive before adding day, otherwise the when # stepping over EDT an hour is added. next_dt = pd.Timestamp(next_dt.replace(tzinfo=None)) next_dt = next_dt + interval - next_dt = pd.Timestamp(next_dt, tz=trading.environment.exchange_tz) + next_dt = pd.Timestamp(next_dt, tz=env.exchange_tz) next_dt_utc = next_dt.tz_convert('UTC') - if trading.environment.is_market_hours(next_dt_utc): + if env.is_market_hours(next_dt_utc): break - next_dt = next_dt_utc.tz_convert(trading.environment.exchange_tz) + next_dt = next_dt_utc.tz_convert(env.exchange_tz) return next_dt_utc -def create_trade_history(sid, prices, amounts, interval, sim_params, +def create_trade_history(sid, prices, amounts, interval, sim_params, env, source_id="test_factory"): trades = [] current = sim_params.first_open @@ -129,7 +132,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params, trade_dt = current trade = create_trade(sid, price, amount, trade_dt, source_id) trades.append(trade) - current = get_next_trading_dt(current, interval) + current = get_next_trading_dt(current, interval, env) assert len(trades) == len(prices) return trades @@ -200,12 +203,12 @@ def create_commission(sid, value, datetime): return txn -def create_txn_history(sid, priceList, amtList, interval, sim_params): +def create_txn_history(sid, priceList, amtList, interval, sim_params, env): txns = [] current = sim_params.first_open for price, amount in zip(priceList, amtList): - current = get_next_trading_dt(current, interval) + current = get_next_trading_dt(current, interval, env) txns.append(create_txn(sid, price, amount, current)) current = current + interval @@ -222,7 +225,7 @@ def create_returns_from_list(returns, sim_params): data=returns) -def create_daily_trade_source(sids, sim_params, concurrent=False): +def create_daily_trade_source(sids, sim_params, env, concurrent=False): """ creates trade_count trades for each sid in sids list. first trade will be on sim_params.period_start, and daily @@ -233,11 +236,12 @@ def create_daily_trade_source(sids, sim_params, concurrent=False): sids, timedelta(days=1), sim_params, - concurrent=concurrent + env=env, + concurrent=concurrent, ) -def create_minutely_trade_source(sids, sim_params, concurrent=False): +def create_minutely_trade_source(sids, sim_params, env, concurrent=False): """ creates trade_count trades for each sid in sids list. first trade will be on sim_params.period_start, and every minute @@ -248,16 +252,17 @@ def create_minutely_trade_source(sids, sim_params, concurrent=False): sids, timedelta(minutes=1), sim_params, - concurrent=concurrent + env=env, + concurrent=concurrent, ) -def create_trade_source(sids, trade_time_increment, sim_params, +def create_trade_source(sids, trade_time_increment, sim_params, env, concurrent=False): # If the sim_params define an end that is during market hours, that will be # used as the end of the data source - if trading.environment.is_market_hours(sim_params.period_end): + if env.is_market_hours(sim_params.period_end): end = sim_params.period_end # Otherwise, the last_close after the period_end is used as the end of the # data source @@ -271,14 +276,15 @@ def create_trade_source(sids, trade_time_increment, sim_params, 'end': end, 'delta': trade_time_increment, 'filter': sids, - 'concurrent': concurrent + 'concurrent': concurrent, + 'env': env, } source = SpecificEquityTrades(*args, **kwargs) return source -def create_test_df_source(sim_params=None, bars='daily'): +def create_test_df_source(sim_params=None, env=None, bars='daily'): if bars == 'daily': freq = pd.datetools.BDay() elif bars == 'minute': @@ -286,16 +292,16 @@ def create_test_df_source(sim_params=None, bars='daily'): else: raise ValueError('%s bars not understood.' % bars) - if sim_params: + if sim_params and bars == 'daily': index = sim_params.trading_days else: - if trading.environment is None: - trading.environment = trading.TradingEnvironment() + if env is None: + env = TradingEnvironment() start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) - days = trading.environment.days_in_range(start, end) + days = env.days_in_range(start, end) if bars == 'daily': index = days @@ -303,7 +309,7 @@ def create_test_df_source(sim_params=None, bars='daily'): index = pd.DatetimeIndex([], freq=freq) for day in days: - day_index = trading.environment.market_minutes_for_day(day) + day_index = env.market_minutes_for_day(day) index = index.append(day_index) x = np.arange(1, len(index) + 1) @@ -313,17 +319,17 @@ def create_test_df_source(sim_params=None, bars='daily'): return DataFrameSource(df), df -def create_test_panel_source(sim_params=None, source_type=None): +def create_test_panel_source(sim_params=None, env=None, source_type=None): start = sim_params.first_open \ if sim_params else pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = sim_params.last_close \ if sim_params else pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) - if trading.environment is None: - trading.environment = trading.TradingEnvironment() + if env is None: + env = TradingEnvironment() - index = trading.environment.days_in_range(start, end) + index = env.days_in_range(start, end) price = np.arange(0, len(index)) volume = np.ones(len(index)) * 1000 @@ -343,17 +349,14 @@ def create_test_panel_source(sim_params=None, source_type=None): return DataPanelSource(panel), panel -def create_test_panel_ohlc_source(sim_params=None): +def create_test_panel_ohlc_source(sim_params, env): start = sim_params.first_open \ if sim_params else pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc) end = sim_params.last_close \ if sim_params else pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc) - if trading.environment is None: - trading.environment = trading.TradingEnvironment() - - index = trading.environment.days_in_range(start, end) + index = env.days_in_range(start, end) price = np.arange(0, len(index)) + 100 high = price * 1.05 low = price * 0.95 diff --git a/zipline/utils/security_list.py b/zipline/utils/security_list.py index c9ba8658..47590d17 100644 --- a/zipline/utils/security_list.py +++ b/zipline/utils/security_list.py @@ -5,7 +5,6 @@ import os.path import pandas as pd import pytz import zipline -from zipline.finance.trading import with_environment DATE_FORMAT = "%Y%m%d" @@ -15,7 +14,7 @@ SECURITY_LISTS_DIR = os.path.join(zipline_dir, 'resources', 'security_lists') class SecurityList(object): - def __init__(self, data, current_date_func): + def __init__(self, data, current_date_func, asset_finder): """ data: a nested dictionary: knowledge_date -> lookup_date -> @@ -29,6 +28,7 @@ class SecurityList(object): self.current_date = current_date_func self.count = 0 self._current_set = set() + self.asset_finder = asset_finder def make_knowledge_dates(self, data): knowledge_dates = sorted( @@ -68,10 +68,9 @@ class SecurityList(object): self._cache[kd] = self._current_set return self._current_set - @with_environment() - def update_current(self, effective_date, symbols, change_func, env=None): + def update_current(self, effective_date, symbols, change_func): for symbol in symbols: - asset = env.asset_finder.lookup_symbol( + asset = self.asset_finder.lookup_symbol( symbol, as_of_date=effective_date ) @@ -86,8 +85,9 @@ class SecurityListSet(object): # list implementations. security_list_type = SecurityList - def __init__(self, current_date_func): + def __init__(self, current_date_func, asset_finder): self.current_date_func = current_date_func + self.asset_finder = asset_finder self._leveraged_etf = None @property @@ -95,7 +95,8 @@ class SecurityListSet(object): if self._leveraged_etf is None: self._leveraged_etf = self.security_list_type( load_from_directory('leveraged_etf_list'), - self.current_date_func + self.current_date_func, + asset_finder=self.asset_finder ) return self._leveraged_etf diff --git a/zipline/utils/serialization_utils.py b/zipline/utils/serialization_utils.py index f3d3d155..b1a7a7dd 100644 --- a/zipline/utils/serialization_utils.py +++ b/zipline/utils/serialization_utils.py @@ -13,6 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six import BytesIO +import pickle +from functools import partial + +from zipline.assets import AssetFinder +from zipline.finance.trading import TradingEnvironment + # Label for the serialization version field in the state returned by # __getstate__. VERSION_LABEL = '_stateversion_' + + +def _persistent_id(obj): + if isinstance(obj, AssetFinder): + return AssetFinder.PERSISTENT_TOKEN + if isinstance(obj, TradingEnvironment): + return TradingEnvironment.PERSISTENT_TOKEN + return None + + +def _persistent_load(persid, env): + if persid == AssetFinder.PERSISTENT_TOKEN: + return env.asset_finder + if persid == TradingEnvironment.PERSISTENT_TOKEN: + return env + + +def dump_with_persistent_ids(obj, protocol=None): + """ + Performs a pickle dump on the given object, substituting all references to + a TradingEnvironment or AssetFinder with tokenized representations. + + All arguments are passed to pickle.Pickler and are described therein. + """ + file = BytesIO() + pickler = pickle.Pickler(file, protocol) + pickler.persistent_id = _persistent_id + pickler.dump(obj) + return file.getvalue() + + +def load_with_persistent_ids(str, env): + """ + Performs a pickle load on the given string, substituting the given + TradingEnvironment in to any tokenized representations of a + TradingEnvironment or AssetFinder. + + Parameters + __________ + str : String + The string representation of the object to be unpickled. + env : TradingEnvironment + The TradingEnvironment to be inserted to the unpickled object. + + Returns + _______ + obj + An unpickled object formed from the parameter 'str'. + """ + file = BytesIO(str) + unpickler = pickle.Unpickler(file) + unpickler.persistent_load = partial(_persistent_load, env=env) + return unpickler.load() diff --git a/zipline/utils/simfactory.py b/zipline/utils/simfactory.py index 876d6734..5b0dc74b 100644 --- a/zipline/utils/simfactory.py +++ b/zipline/utils/simfactory.py @@ -72,7 +72,8 @@ def create_test_zipline(**config): trade_source = factory.create_daily_trade_source( sid_list, test_algo.sim_params, - concurrent=concurrent_trades + test_algo.trading_environment, + concurrent=concurrent_trades, ) if trade_source: test_algo.set_sources([trade_source])