mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 21:36:39 +08:00
TST: Don't modify master security lists directory during tests
Rather than drop files temporarily into the master security lists directory during unit tests, create temporary directories for the tests. This avoids issues when the tests are being run at the same time as other code that uses the real security lists data.
This commit is contained in:
@@ -7,7 +7,7 @@ from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.errors import TradingControlViolation
|
||||
from zipline.sources import SpecificEquityTrades
|
||||
from zipline.utils.test_utils import (
|
||||
setup_logger, add_security_data, remove_security_data_directory)
|
||||
setup_logger, security_list_copy, add_security_data)
|
||||
from zipline.utils import factory
|
||||
from zipline.utils.security_list import (
|
||||
SecurityListSet, load_from_directory)
|
||||
@@ -105,26 +105,22 @@ class SecurityListTestCase(TestCase):
|
||||
def test_security_add(self):
|
||||
def get_datetime():
|
||||
return datetime(2015, 1, 27, tzinfo=pytz.utc)
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL', 'GOOG'], [])
|
||||
rl = SecurityListSet(get_datetime)
|
||||
self.assertIn("AAPL", rl.leveraged_etf_list)
|
||||
self.assertIn("GOOG", rl.leveraged_etf_list)
|
||||
self.assertIn("BZQ", rl.leveraged_etf_list)
|
||||
self.assertIn("URTY", rl.leveraged_etf_list)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_security_add_delete(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
def get_datetime():
|
||||
return datetime(2015, 1, 27, tzinfo=pytz.utc)
|
||||
add_security_data([], ['BZQ', 'URTY'])
|
||||
rl = SecurityListSet(get_datetime)
|
||||
self.assertNotIn("BZQ", rl.leveraged_etf_list)
|
||||
self.assertNotIn("URTY", rl.leveraged_etf_list)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_without_rl_violation_via_check(self):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
@@ -228,7 +224,7 @@ class SecurityListTestCase(TestCase):
|
||||
start=list(
|
||||
LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4)
|
||||
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL'], [])
|
||||
trade_history = factory.create_trade_history(
|
||||
'BZQ',
|
||||
@@ -244,11 +240,9 @@ class SecurityListTestCase(TestCase):
|
||||
algo.run(self.source)
|
||||
|
||||
self.check_algo_exception(algo, ctx, 0)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_without_rl_violation_after_delete(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
# add a delete statement removing bzq
|
||||
# write a new delete statement file to disk
|
||||
add_security_data([], ['BZQ'])
|
||||
@@ -266,11 +260,9 @@ class SecurityListTestCase(TestCase):
|
||||
algo = RestrictedAlgoWithoutCheck(
|
||||
sid='BZQ', sim_params=sim_params)
|
||||
algo.run(self.source)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_with_rl_violation_after_add(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL'], [])
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=self.trading_day_before_first_kd, num_days=4)
|
||||
@@ -288,8 +280,6 @@ class SecurityListTestCase(TestCase):
|
||||
algo.run(self.source)
|
||||
|
||||
self.check_algo_exception(algo, ctx, 2)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def check_algo_exception(self, algo, ctx, expected_order_count):
|
||||
self.assertEqual(algo.order_count, expected_order_count)
|
||||
|
||||
+24
-11
@@ -1,13 +1,15 @@
|
||||
from contextlib import contextmanager
|
||||
from logbook import FileHandler
|
||||
from mock import patch
|
||||
from zipline.finance.blotter import ORDER_STATUS
|
||||
from zipline.utils.security_list import SECURITY_LISTS_DIR
|
||||
from zipline.utils import security_list
|
||||
|
||||
from six import itervalues
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
|
||||
def to_utc(time_str):
|
||||
@@ -115,15 +117,34 @@ def nullctx():
|
||||
Null context manager. Useful for conditionally adding a contextmanager in
|
||||
a single line, e.g.:
|
||||
|
||||
with SomeContextManager() if some_expr else nullcontext:
|
||||
with SomeContextManager() if some_expr else nullctx():
|
||||
do_stuff()
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def security_list_copy():
|
||||
old_dir = security_list.SECURITY_LISTS_DIR
|
||||
new_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
for subdir in os.listdir(old_dir):
|
||||
shutil.copytree(os.path.join(old_dir, subdir),
|
||||
os.path.join(new_dir, subdir))
|
||||
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
|
||||
patch.object(security_list, 'using_copy', True,
|
||||
create=True):
|
||||
yield
|
||||
finally:
|
||||
shutil.rmtree(new_dir, True)
|
||||
|
||||
|
||||
def add_security_data(adds, deletes):
|
||||
if not hasattr(security_list, 'using_copy'):
|
||||
raise Exception('add_security_data must be used within '
|
||||
'security_list_copy context')
|
||||
directory = os.path.join(
|
||||
SECURITY_LISTS_DIR,
|
||||
security_list.SECURITY_LISTS_DIR,
|
||||
"leveraged_etf_list/20150127/20150125"
|
||||
)
|
||||
if not os.path.exists(directory):
|
||||
@@ -138,11 +159,3 @@ def add_security_data(adds, deletes):
|
||||
for sym in adds:
|
||||
f.write(sym)
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def remove_security_data_directory():
|
||||
directory = os.path.join(
|
||||
SECURITY_LISTS_DIR,
|
||||
"leveraged_etf_list/20150127/"
|
||||
)
|
||||
shutil.rmtree(directory)
|
||||
|
||||
Reference in New Issue
Block a user