TST: Don't modify master security lists directory during tests

Rather than drop files temporarily into the master security lists
directory during unit tests, create temporary directories for the
tests. This avoids issues when the tests are being run at the same
time as other code that uses the real security lists data.
This commit is contained in:
Jonathan Kamens
2015-04-29 21:41:32 -04:00
parent 00ea7b04d1
commit ca0f906b11
2 changed files with 30 additions and 27 deletions
+6 -16
View File
@@ -7,7 +7,7 @@ from zipline.algorithm import TradingAlgorithm
from zipline.errors import TradingControlViolation
from zipline.sources import SpecificEquityTrades
from zipline.utils.test_utils import (
setup_logger, add_security_data, remove_security_data_directory)
setup_logger, security_list_copy, add_security_data)
from zipline.utils import factory
from zipline.utils.security_list import (
SecurityListSet, load_from_directory)
@@ -105,26 +105,22 @@ class SecurityListTestCase(TestCase):
def test_security_add(self):
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
try:
with security_list_copy():
add_security_data(['AAPL', 'GOOG'], [])
rl = SecurityListSet(get_datetime)
self.assertIn("AAPL", rl.leveraged_etf_list)
self.assertIn("GOOG", rl.leveraged_etf_list)
self.assertIn("BZQ", rl.leveraged_etf_list)
self.assertIn("URTY", rl.leveraged_etf_list)
finally:
remove_security_data_directory()
def test_security_add_delete(self):
try:
with security_list_copy():
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
add_security_data([], ['BZQ', 'URTY'])
rl = SecurityListSet(get_datetime)
self.assertNotIn("BZQ", rl.leveraged_etf_list)
self.assertNotIn("URTY", rl.leveraged_etf_list)
finally:
remove_security_data_directory()
def test_algo_without_rl_violation_via_check(self):
sim_params = factory.create_simulation_parameters(
@@ -228,7 +224,7 @@ class SecurityListTestCase(TestCase):
start=list(
LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4)
try:
with security_list_copy():
add_security_data(['AAPL'], [])
trade_history = factory.create_trade_history(
'BZQ',
@@ -244,11 +240,9 @@ class SecurityListTestCase(TestCase):
algo.run(self.source)
self.check_algo_exception(algo, ctx, 0)
finally:
remove_security_data_directory()
def test_algo_without_rl_violation_after_delete(self):
try:
with security_list_copy():
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
@@ -266,11 +260,9 @@ class SecurityListTestCase(TestCase):
algo = RestrictedAlgoWithoutCheck(
sid='BZQ', sim_params=sim_params)
algo.run(self.source)
finally:
remove_security_data_directory()
def test_algo_with_rl_violation_after_add(self):
try:
with security_list_copy():
add_security_data(['AAPL'], [])
sim_params = factory.create_simulation_parameters(
start=self.trading_day_before_first_kd, num_days=4)
@@ -288,8 +280,6 @@ class SecurityListTestCase(TestCase):
algo.run(self.source)
self.check_algo_exception(algo, ctx, 2)
finally:
remove_security_data_directory()
def check_algo_exception(self, algo, ctx, expected_order_count):
self.assertEqual(algo.order_count, expected_order_count)
+24 -11
View File
@@ -1,13 +1,15 @@
from contextlib import contextmanager
from logbook import FileHandler
from mock import patch
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils.security_list import SECURITY_LISTS_DIR
from zipline.utils import security_list
from six import itervalues
import os
import pandas as pd
import shutil
import tempfile
def to_utc(time_str):
@@ -115,15 +117,34 @@ def nullctx():
Null context manager. Useful for conditionally adding a contextmanager in
a single line, e.g.:
with SomeContextManager() if some_expr else nullcontext:
with SomeContextManager() if some_expr else nullctx():
do_stuff()
"""
yield
@contextmanager
def security_list_copy():
old_dir = security_list.SECURITY_LISTS_DIR
new_dir = tempfile.mkdtemp()
try:
for subdir in os.listdir(old_dir):
shutil.copytree(os.path.join(old_dir, subdir),
os.path.join(new_dir, subdir))
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
patch.object(security_list, 'using_copy', True,
create=True):
yield
finally:
shutil.rmtree(new_dir, True)
def add_security_data(adds, deletes):
if not hasattr(security_list, 'using_copy'):
raise Exception('add_security_data must be used within '
'security_list_copy context')
directory = os.path.join(
SECURITY_LISTS_DIR,
security_list.SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/20150125"
)
if not os.path.exists(directory):
@@ -138,11 +159,3 @@ def add_security_data(adds, deletes):
for sym in adds:
f.write(sym)
f.write('\n')
def remove_security_data_directory():
directory = os.path.join(
SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/"
)
shutil.rmtree(directory)