mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 23:36:09 +08:00
MAINT: Create SecurityListRestrictions that takes a SecurityList
The SecurityList implements a non-exposed method `current_securities(dt)` which SecurityListRestrictions calls to determine if an asset is restricted. Deprecate the `__iter__` and `__contains__` methods of security lists in favor of `current_securities(dt)`
This commit is contained in:
@@ -9,6 +9,7 @@ from zipline.finance.restrictions import (
|
||||
Restriction,
|
||||
HistoricalRestrictions,
|
||||
StaticRestrictions,
|
||||
SecurityListRestrictions,
|
||||
NoopRestrictions,
|
||||
)
|
||||
|
||||
@@ -212,6 +213,48 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
|
||||
assert_vectorized_results([True, True, False], dt)
|
||||
|
||||
def test_security_list_restrictions(self):
|
||||
"""
|
||||
Test single- and multi-asset queries on restrictions defined by
|
||||
zipline.utils.security_list.SecurityList
|
||||
"""
|
||||
|
||||
# A mock SecurityList object filled with fake data
|
||||
class SecurityList(object):
|
||||
def __init__(self, assets_by_dt):
|
||||
self.assets_by_dt = assets_by_dt
|
||||
|
||||
def current_securities(self, dt):
|
||||
return self.assets_by_dt[dt]
|
||||
|
||||
assets_by_dt = {
|
||||
str_to_ts('2011-01-03'): [self.ASSET1],
|
||||
str_to_ts('2011-01-04'): [self.ASSET2, self.ASSET3],
|
||||
str_to_ts('2011-01-05'): [self.ASSET1, self.ASSET2, self.ASSET3],
|
||||
}
|
||||
|
||||
rl = SecurityListRestrictions(SecurityList(assets_by_dt))
|
||||
|
||||
assert_not_restricted = partial(self.assert_not_restricted, rl)
|
||||
assert_is_restricted = partial(self.assert_is_restricted, rl)
|
||||
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
|
||||
|
||||
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03'))
|
||||
assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03'))
|
||||
assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03'))
|
||||
assert_vectorized_results(
|
||||
[True, False, False], str_to_ts('2011-01-03'))
|
||||
|
||||
assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04'))
|
||||
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04'))
|
||||
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04'))
|
||||
assert_vectorized_results([False, True, True], str_to_ts('2011-01-04'))
|
||||
|
||||
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05'))
|
||||
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05'))
|
||||
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05'))
|
||||
assert_vectorized_results([True, True, True], str_to_ts('2011-01-05'))
|
||||
|
||||
def test_noop_restrictions(self):
|
||||
"""
|
||||
Test single- and multi-asset queries on no-op restrictions
|
||||
|
||||
@@ -36,7 +36,8 @@ class RestrictedAlgoWithCheck(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
if not self.order_count:
|
||||
if self.sid not in \
|
||||
self.rl.leveraged_etf_list:
|
||||
self.rl.leveraged_etf_list.\
|
||||
current_securities(self.get_datetime()):
|
||||
self.order(self.sid, 100)
|
||||
self.order_count += 1
|
||||
|
||||
@@ -62,7 +63,8 @@ class IterateRLAlgo(TradingAlgorithm):
|
||||
self.found = False
|
||||
|
||||
def handle_data(self, data):
|
||||
for stock in self.rl.leveraged_etf_list:
|
||||
for stock in self.rl.leveraged_etf_list.\
|
||||
current_securities(self.get_datetime()):
|
||||
if stock == self.sid:
|
||||
self.found = True
|
||||
|
||||
@@ -151,7 +153,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
for symbol in ["BZQ", "URTY", "JFT"]]
|
||||
]
|
||||
for sid in should_exist:
|
||||
self.assertIn(sid, rl.leveraged_etf_list)
|
||||
self.assertIn(
|
||||
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
|
||||
|
||||
# assert that a sample of allowed stocks are not in restricted
|
||||
shouldnt_exist = [
|
||||
@@ -162,7 +165,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
for symbol in ["AAPL", "GOOG"]]
|
||||
]
|
||||
for sid in shouldnt_exist:
|
||||
self.assertNotIn(sid, rl.leveraged_etf_list)
|
||||
self.assertNotIn(
|
||||
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
|
||||
|
||||
def test_security_add(self):
|
||||
def get_datetime():
|
||||
@@ -178,15 +182,24 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]]
|
||||
]
|
||||
for sid in should_exist:
|
||||
self.assertIn(sid, rl.leveraged_etf_list)
|
||||
self.assertIn(
|
||||
sid,
|
||||
rl.leveraged_etf_list.current_securities(get_datetime())
|
||||
)
|
||||
|
||||
def test_security_add_delete(self):
|
||||
with security_list_copy():
|
||||
def get_datetime():
|
||||
return pd.Timestamp("2015-01-27", tz='UTC')
|
||||
rl = SecurityListSet(get_datetime, self.env.asset_finder)
|
||||
self.assertNotIn("BZQ", rl.leveraged_etf_list)
|
||||
self.assertNotIn("URTY", rl.leveraged_etf_list)
|
||||
self.assertNotIn(
|
||||
"BZQ",
|
||||
rl.leveraged_etf_list.current_securities(get_datetime())
|
||||
)
|
||||
self.assertNotIn(
|
||||
"URTY",
|
||||
rl.leveraged_etf_list.current_securities(get_datetime())
|
||||
)
|
||||
|
||||
def test_algo_without_rl_violation_via_check(self):
|
||||
algo = RestrictedAlgoWithCheck(symbol='BZQ',
|
||||
|
||||
@@ -84,7 +84,8 @@ from zipline.finance.slippage import (
|
||||
from zipline.finance.cancel_policy import NeverCancel, CancelPolicy
|
||||
from zipline.finance.restrictions import (
|
||||
NoopRestrictions,
|
||||
StaticRestrictions
|
||||
StaticRestrictions,
|
||||
SecurityListRestrictions,
|
||||
)
|
||||
from zipline.assets import Asset, Future
|
||||
from zipline.gens.tradesimulation import AlgorithmSimulator
|
||||
@@ -2175,6 +2176,8 @@ class TradingAlgorithm(object):
|
||||
|
||||
if isinstance(restricted_list, (list, tuple, set)):
|
||||
restricted_list = StaticRestrictions(restricted_list)
|
||||
elif isinstance(restricted_list, SecurityList):
|
||||
restricted_list = SecurityListRestrictions(restricted_list)
|
||||
|
||||
control = RestrictedListOrder(on_error, restricted_list)
|
||||
self.register_trading_control(control)
|
||||
|
||||
@@ -21,6 +21,7 @@ from .finance.restrictions import (
|
||||
StaticRestrictions,
|
||||
HistoricalRestrictions,
|
||||
RESTRICTION_STATES,
|
||||
SecurityListRestrictions,
|
||||
)
|
||||
from .finance import commission, execution, slippage, cancel_policy
|
||||
from .finance.cancel_policy import (
|
||||
@@ -46,6 +47,7 @@ __all__ = [
|
||||
'StaticRestrictions',
|
||||
'HistoricalRestrictions',
|
||||
'RESTRICTION_STATES',
|
||||
'SecurityListRestrictions',
|
||||
'cancel_policy',
|
||||
'commission',
|
||||
'date_rules',
|
||||
|
||||
@@ -127,3 +127,26 @@ class HistoricalRestrictions(Restrictions):
|
||||
break
|
||||
state = r.state
|
||||
return state == RESTRICTION_STATES.FROZEN
|
||||
|
||||
|
||||
class SecurityListRestrictions(Restrictions):
|
||||
"""
|
||||
Restrictions based on a security list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
restrictions : zipline.utils.security_list.SecurityList
|
||||
The restrictions defined by a SecurityList
|
||||
"""
|
||||
|
||||
def __init__(self, security_list_by_dt):
|
||||
self.current_securities = security_list_by_dt.current_securities
|
||||
|
||||
def is_restricted(self, assets, dt):
|
||||
securities_in_list = self.current_securities(dt)
|
||||
if isinstance(assets, Asset):
|
||||
return assets in securities_in_list
|
||||
return pd.Series(
|
||||
index=pd.Index(assets),
|
||||
data=vectorized_is_element(assets, securities_in_list)
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from os import listdir
|
||||
import os.path
|
||||
@@ -7,6 +8,7 @@ import pytz
|
||||
import zipline
|
||||
|
||||
from zipline.errors import SymbolNotFound
|
||||
from zipline.zipline_warnings import ZiplineDeprecationWarning
|
||||
|
||||
|
||||
DATE_FORMAT = "%Y%m%d"
|
||||
@@ -38,17 +40,26 @@ class SecurityList(object):
|
||||
return knowledge_dates
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.restricted_list)
|
||||
warnings.warn(
|
||||
'Iterating over security_lists is deprecated. Use '
|
||||
'`for sid in <security_list>.current_securities(dt)` instead.',
|
||||
category=ZiplineDeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return iter(self.current_securities(self.current_date()))
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.restricted_list
|
||||
warnings.warn(
|
||||
'Evaluating inclusion in security_lists is deprecated. Use '
|
||||
'`sid in <security_list>.current_securities(dt)` instead.',
|
||||
category=ZiplineDeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return item in self.current_securities(self.current_date())
|
||||
|
||||
@property
|
||||
def restricted_list(self):
|
||||
|
||||
cd = self.current_date()
|
||||
def current_securities(self, dt):
|
||||
for kd in self._knowledge_dates:
|
||||
if cd < kd:
|
||||
if dt < kd:
|
||||
break
|
||||
if kd in self._cache:
|
||||
self._current_set = self._cache[kd]
|
||||
|
||||
Reference in New Issue
Block a user