mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 08:46:15 +08:00
Merge pull request #574 from quantopian/do_not_modify_security_lists_directory
TST: Don't modify master security lists directory during tests
This commit is contained in:
@@ -7,7 +7,7 @@ from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.errors import TradingControlViolation
|
||||
from zipline.sources import SpecificEquityTrades
|
||||
from zipline.utils.test_utils import (
|
||||
setup_logger, add_security_data, remove_security_data_directory)
|
||||
setup_logger, security_list_copy, add_security_data)
|
||||
from zipline.utils import factory
|
||||
from zipline.utils.security_list import (
|
||||
SecurityListSet, load_from_directory)
|
||||
@@ -105,26 +105,22 @@ class SecurityListTestCase(TestCase):
|
||||
def test_security_add(self):
|
||||
def get_datetime():
|
||||
return datetime(2015, 1, 27, tzinfo=pytz.utc)
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL', 'GOOG'], [])
|
||||
rl = SecurityListSet(get_datetime)
|
||||
self.assertIn("AAPL", rl.leveraged_etf_list)
|
||||
self.assertIn("GOOG", rl.leveraged_etf_list)
|
||||
self.assertIn("BZQ", rl.leveraged_etf_list)
|
||||
self.assertIn("URTY", rl.leveraged_etf_list)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_security_add_delete(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
def get_datetime():
|
||||
return datetime(2015, 1, 27, tzinfo=pytz.utc)
|
||||
add_security_data([], ['BZQ', 'URTY'])
|
||||
rl = SecurityListSet(get_datetime)
|
||||
self.assertNotIn("BZQ", rl.leveraged_etf_list)
|
||||
self.assertNotIn("URTY", rl.leveraged_etf_list)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_without_rl_violation_via_check(self):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
@@ -228,7 +224,7 @@ class SecurityListTestCase(TestCase):
|
||||
start=list(
|
||||
LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4)
|
||||
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL'], [])
|
||||
trade_history = factory.create_trade_history(
|
||||
'BZQ',
|
||||
@@ -244,11 +240,9 @@ class SecurityListTestCase(TestCase):
|
||||
algo.run(self.source)
|
||||
|
||||
self.check_algo_exception(algo, ctx, 0)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_without_rl_violation_after_delete(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
# add a delete statement removing bzq
|
||||
# write a new delete statement file to disk
|
||||
add_security_data([], ['BZQ'])
|
||||
@@ -266,11 +260,9 @@ class SecurityListTestCase(TestCase):
|
||||
algo = RestrictedAlgoWithoutCheck(
|
||||
sid='BZQ', sim_params=sim_params)
|
||||
algo.run(self.source)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def test_algo_with_rl_violation_after_add(self):
|
||||
try:
|
||||
with security_list_copy():
|
||||
add_security_data(['AAPL'], [])
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=self.trading_day_before_first_kd, num_days=4)
|
||||
@@ -288,8 +280,6 @@ class SecurityListTestCase(TestCase):
|
||||
algo.run(self.source)
|
||||
|
||||
self.check_algo_exception(algo, ctx, 2)
|
||||
finally:
|
||||
remove_security_data_directory()
|
||||
|
||||
def check_algo_exception(self, algo, ctx, expected_order_count):
|
||||
self.assertEqual(algo.order_count, expected_order_count)
|
||||
|
||||
+24
-11
@@ -1,13 +1,15 @@
|
||||
from contextlib import contextmanager
|
||||
from logbook import FileHandler
|
||||
from mock import patch
|
||||
from zipline.finance.blotter import ORDER_STATUS
|
||||
from zipline.utils.security_list import SECURITY_LISTS_DIR
|
||||
from zipline.utils import security_list
|
||||
|
||||
from six import itervalues
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
|
||||
def to_utc(time_str):
|
||||
@@ -115,15 +117,34 @@ def nullctx():
|
||||
Null context manager. Useful for conditionally adding a contextmanager in
|
||||
a single line, e.g.:
|
||||
|
||||
with SomeContextManager() if some_expr else nullcontext:
|
||||
with SomeContextManager() if some_expr else nullctx():
|
||||
do_stuff()
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def security_list_copy():
|
||||
old_dir = security_list.SECURITY_LISTS_DIR
|
||||
new_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
for subdir in os.listdir(old_dir):
|
||||
shutil.copytree(os.path.join(old_dir, subdir),
|
||||
os.path.join(new_dir, subdir))
|
||||
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
|
||||
patch.object(security_list, 'using_copy', True,
|
||||
create=True):
|
||||
yield
|
||||
finally:
|
||||
shutil.rmtree(new_dir, True)
|
||||
|
||||
|
||||
def add_security_data(adds, deletes):
|
||||
if not hasattr(security_list, 'using_copy'):
|
||||
raise Exception('add_security_data must be used within '
|
||||
'security_list_copy context')
|
||||
directory = os.path.join(
|
||||
SECURITY_LISTS_DIR,
|
||||
security_list.SECURITY_LISTS_DIR,
|
||||
"leveraged_etf_list/20150127/20150125"
|
||||
)
|
||||
if not os.path.exists(directory):
|
||||
@@ -138,11 +159,3 @@ def add_security_data(adds, deletes):
|
||||
for sym in adds:
|
||||
f.write(sym)
|
||||
f.write('\n')
|
||||
|
||||
|
||||
def remove_security_data_directory():
|
||||
directory = os.path.join(
|
||||
SECURITY_LISTS_DIR,
|
||||
"leveraged_etf_list/20150127/"
|
||||
)
|
||||
shutil.rmtree(directory)
|
||||
|
||||
Reference in New Issue
Block a user