MAINT: Create SecurityListRestrictions that takes a SecurityList

The SecurityList implements a non-exposed method
`current_securities(dt)` which SecurityListRestrictions calls to
determine if an asset is restricted. Deprecate the `__iter__` and
`__contains__` methods of security lists in favor of
`current_securities(dt)`
This commit is contained in:
Andrew Liang
2016-09-14 14:23:40 -04:00
parent b70084c6bf
commit e465f64f91
6 changed files with 110 additions and 15 deletions
+43
View File
@@ -9,6 +9,7 @@ from zipline.finance.restrictions import (
Restriction,
HistoricalRestrictions,
StaticRestrictions,
SecurityListRestrictions,
NoopRestrictions,
)
@@ -212,6 +213,48 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
assert_vectorized_results([True, True, False], dt)
def test_security_list_restrictions(self):
"""
Test single- and multi-asset queries on restrictions defined by
zipline.utils.security_list.SecurityList
"""
# A mock SecurityList object filled with fake data
class SecurityList(object):
def __init__(self, assets_by_dt):
self.assets_by_dt = assets_by_dt
def current_securities(self, dt):
return self.assets_by_dt[dt]
assets_by_dt = {
str_to_ts('2011-01-03'): [self.ASSET1],
str_to_ts('2011-01-04'): [self.ASSET2, self.ASSET3],
str_to_ts('2011-01-05'): [self.ASSET1, self.ASSET2, self.ASSET3],
}
rl = SecurityListRestrictions(SecurityList(assets_by_dt))
assert_not_restricted = partial(self.assert_not_restricted, rl)
assert_is_restricted = partial(self.assert_is_restricted, rl)
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03'))
assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03'))
assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03'))
assert_vectorized_results(
[True, False, False], str_to_ts('2011-01-03'))
assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04'))
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04'))
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04'))
assert_vectorized_results([False, True, True], str_to_ts('2011-01-04'))
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05'))
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05'))
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05'))
assert_vectorized_results([True, True, True], str_to_ts('2011-01-05'))
def test_noop_restrictions(self):
"""
Test single- and multi-asset queries on no-op restrictions
+20 -7
View File
@@ -36,7 +36,8 @@ class RestrictedAlgoWithCheck(TradingAlgorithm):
def handle_data(self, data):
if not self.order_count:
if self.sid not in \
self.rl.leveraged_etf_list:
self.rl.leveraged_etf_list.\
current_securities(self.get_datetime()):
self.order(self.sid, 100)
self.order_count += 1
@@ -62,7 +63,8 @@ class IterateRLAlgo(TradingAlgorithm):
self.found = False
def handle_data(self, data):
for stock in self.rl.leveraged_etf_list:
for stock in self.rl.leveraged_etf_list.\
current_securities(self.get_datetime()):
if stock == self.sid:
self.found = True
@@ -151,7 +153,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
for symbol in ["BZQ", "URTY", "JFT"]]
]
for sid in should_exist:
self.assertIn(sid, rl.leveraged_etf_list)
self.assertIn(
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
# assert that a sample of allowed stocks are not in restricted
shouldnt_exist = [
@@ -162,7 +165,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
for symbol in ["AAPL", "GOOG"]]
]
for sid in shouldnt_exist:
self.assertNotIn(sid, rl.leveraged_etf_list)
self.assertNotIn(
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
def test_security_add(self):
def get_datetime():
@@ -178,15 +182,24 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]]
]
for sid in should_exist:
self.assertIn(sid, rl.leveraged_etf_list)
self.assertIn(
sid,
rl.leveraged_etf_list.current_securities(get_datetime())
)
def test_security_add_delete(self):
with security_list_copy():
def get_datetime():
return pd.Timestamp("2015-01-27", tz='UTC')
rl = SecurityListSet(get_datetime, self.env.asset_finder)
self.assertNotIn("BZQ", rl.leveraged_etf_list)
self.assertNotIn("URTY", rl.leveraged_etf_list)
self.assertNotIn(
"BZQ",
rl.leveraged_etf_list.current_securities(get_datetime())
)
self.assertNotIn(
"URTY",
rl.leveraged_etf_list.current_securities(get_datetime())
)
def test_algo_without_rl_violation_via_check(self):
algo = RestrictedAlgoWithCheck(symbol='BZQ',
+4 -1
View File
@@ -84,7 +84,8 @@ from zipline.finance.slippage import (
from zipline.finance.cancel_policy import NeverCancel, CancelPolicy
from zipline.finance.restrictions import (
NoopRestrictions,
StaticRestrictions
StaticRestrictions,
SecurityListRestrictions,
)
from zipline.assets import Asset, Future
from zipline.gens.tradesimulation import AlgorithmSimulator
@@ -2175,6 +2176,8 @@ class TradingAlgorithm(object):
if isinstance(restricted_list, (list, tuple, set)):
restricted_list = StaticRestrictions(restricted_list)
elif isinstance(restricted_list, SecurityList):
restricted_list = SecurityListRestrictions(restricted_list)
control = RestrictedListOrder(on_error, restricted_list)
self.register_trading_control(control)
+2
View File
@@ -21,6 +21,7 @@ from .finance.restrictions import (
StaticRestrictions,
HistoricalRestrictions,
RESTRICTION_STATES,
SecurityListRestrictions,
)
from .finance import commission, execution, slippage, cancel_policy
from .finance.cancel_policy import (
@@ -46,6 +47,7 @@ __all__ = [
'StaticRestrictions',
'HistoricalRestrictions',
'RESTRICTION_STATES',
'SecurityListRestrictions',
'cancel_policy',
'commission',
'date_rules',
+23
View File
@@ -127,3 +127,26 @@ class HistoricalRestrictions(Restrictions):
break
state = r.state
return state == RESTRICTION_STATES.FROZEN
class SecurityListRestrictions(Restrictions):
"""
Restrictions based on a security list
Parameters
----------
restrictions : zipline.utils.security_list.SecurityList
The restrictions defined by a SecurityList
"""
def __init__(self, security_list_by_dt):
self.current_securities = security_list_by_dt.current_securities
def is_restricted(self, assets, dt):
securities_in_list = self.current_securities(dt)
if isinstance(assets, Asset):
return assets in securities_in_list
return pd.Series(
index=pd.Index(assets),
data=vectorized_is_element(assets, securities_in_list)
)
+18 -7
View File
@@ -1,3 +1,4 @@
import warnings
from datetime import datetime
from os import listdir
import os.path
@@ -7,6 +8,7 @@ import pytz
import zipline
from zipline.errors import SymbolNotFound
from zipline.zipline_warnings import ZiplineDeprecationWarning
DATE_FORMAT = "%Y%m%d"
@@ -38,17 +40,26 @@ class SecurityList(object):
return knowledge_dates
def __iter__(self):
return iter(self.restricted_list)
warnings.warn(
'Iterating over security_lists is deprecated. Use '
'`for sid in <security_list>.current_securities(dt)` instead.',
category=ZiplineDeprecationWarning,
stacklevel=2
)
return iter(self.current_securities(self.current_date()))
def __contains__(self, item):
return item in self.restricted_list
warnings.warn(
'Evaluating inclusion in security_lists is deprecated. Use '
'`sid in <security_list>.current_securities(dt)` instead.',
category=ZiplineDeprecationWarning,
stacklevel=2
)
return item in self.current_securities(self.current_date())
@property
def restricted_list(self):
cd = self.current_date()
def current_securities(self, dt):
for kd in self._knowledge_dates:
if cd < kd:
if dt < kd:
break
if kd in self._cache:
self._current_set = self._cache[kd]