diff --git a/tests/test_security_list.py b/tests/test_security_list.py index 7827ffce..8596444d 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -7,7 +7,7 @@ from zipline.algorithm import TradingAlgorithm from zipline.errors import TradingControlViolation from zipline.sources import SpecificEquityTrades from zipline.utils.test_utils import ( - setup_logger, add_security_data, remove_security_data_directory) + setup_logger, security_list_copy, add_security_data) from zipline.utils import factory from zipline.utils.security_list import ( SecurityListSet, load_from_directory) @@ -105,26 +105,22 @@ class SecurityListTestCase(TestCase): def test_security_add(self): def get_datetime(): return datetime(2015, 1, 27, tzinfo=pytz.utc) - try: + with security_list_copy(): add_security_data(['AAPL', 'GOOG'], []) rl = SecurityListSet(get_datetime) self.assertIn("AAPL", rl.leveraged_etf_list) self.assertIn("GOOG", rl.leveraged_etf_list) self.assertIn("BZQ", rl.leveraged_etf_list) self.assertIn("URTY", rl.leveraged_etf_list) - finally: - remove_security_data_directory() def test_security_add_delete(self): - try: + with security_list_copy(): def get_datetime(): return datetime(2015, 1, 27, tzinfo=pytz.utc) add_security_data([], ['BZQ', 'URTY']) rl = SecurityListSet(get_datetime) self.assertNotIn("BZQ", rl.leveraged_etf_list) self.assertNotIn("URTY", rl.leveraged_etf_list) - finally: - remove_security_data_directory() def test_algo_without_rl_violation_via_check(self): sim_params = factory.create_simulation_parameters( @@ -228,7 +224,7 @@ class SecurityListTestCase(TestCase): start=list( LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4) - try: + with security_list_copy(): add_security_data(['AAPL'], []) trade_history = factory.create_trade_history( 'BZQ', @@ -244,11 +240,9 @@ class SecurityListTestCase(TestCase): algo.run(self.source) self.check_algo_exception(algo, ctx, 0) - finally: - remove_security_data_directory() def test_algo_without_rl_violation_after_delete(self): - try: + with security_list_copy(): # add a delete statement removing bzq # write a new delete statement file to disk add_security_data([], ['BZQ']) @@ -266,11 +260,9 @@ class SecurityListTestCase(TestCase): algo = RestrictedAlgoWithoutCheck( sid='BZQ', sim_params=sim_params) algo.run(self.source) - finally: - remove_security_data_directory() def test_algo_with_rl_violation_after_add(self): - try: + with security_list_copy(): add_security_data(['AAPL'], []) sim_params = factory.create_simulation_parameters( start=self.trading_day_before_first_kd, num_days=4) @@ -288,8 +280,6 @@ class SecurityListTestCase(TestCase): algo.run(self.source) self.check_algo_exception(algo, ctx, 2) - finally: - remove_security_data_directory() def check_algo_exception(self, algo, ctx, expected_order_count): self.assertEqual(algo.order_count, expected_order_count) diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 75faff62..fb217edf 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -1,13 +1,15 @@ from contextlib import contextmanager from logbook import FileHandler +from mock import patch from zipline.finance.blotter import ORDER_STATUS -from zipline.utils.security_list import SECURITY_LISTS_DIR +from zipline.utils import security_list from six import itervalues import os import pandas as pd import shutil +import tempfile def to_utc(time_str): @@ -115,15 +117,34 @@ def nullctx(): Null context manager. Useful for conditionally adding a contextmanager in a single line, e.g.: - with SomeContextManager() if some_expr else nullcontext: + with SomeContextManager() if some_expr else nullctx(): do_stuff() """ yield +@contextmanager +def security_list_copy(): + old_dir = security_list.SECURITY_LISTS_DIR + new_dir = tempfile.mkdtemp() + try: + for subdir in os.listdir(old_dir): + shutil.copytree(os.path.join(old_dir, subdir), + os.path.join(new_dir, subdir)) + with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \ + patch.object(security_list, 'using_copy', True, + create=True): + yield + finally: + shutil.rmtree(new_dir, True) + + def add_security_data(adds, deletes): + if not hasattr(security_list, 'using_copy'): + raise Exception('add_security_data must be used within ' + 'security_list_copy context') directory = os.path.join( - SECURITY_LISTS_DIR, + security_list.SECURITY_LISTS_DIR, "leveraged_etf_list/20150127/20150125" ) if not os.path.exists(directory): @@ -138,11 +159,3 @@ def add_security_data(adds, deletes): for sym in adds: f.write(sym) f.write('\n') - - -def remove_security_data_directory(): - directory = os.path.join( - SECURITY_LISTS_DIR, - "leveraged_etf_list/20150127/" - ) - shutil.rmtree(directory)