From e465f64f912118258a78a5dbeb0e89cfcf109bed Mon Sep 17 00:00:00 2001 From: Andrew Liang Date: Wed, 14 Sep 2016 14:23:40 -0400 Subject: [PATCH] MAINT: Create SecurityListRestrictions that takes a SecurityList The SecurityList implements a non-exposed method `current_securities(dt)` which SecurityListRestrictions calls to determine if an asset is restricted. Deprecate the `__iter__` and `__contains__` methods of security lists in favor of `current_securities(dt)` --- tests/test_restrictions.py | 43 +++++++++++++++++++++++++++++++++ tests/test_security_list.py | 27 +++++++++++++++------ zipline/algorithm.py | 5 +++- zipline/api.py | 2 ++ zipline/finance/restrictions.py | 23 ++++++++++++++++++ zipline/utils/security_list.py | 25 +++++++++++++------ 6 files changed, 110 insertions(+), 15 deletions(-) diff --git a/tests/test_restrictions.py b/tests/test_restrictions.py index c1817b68..74327ce9 100644 --- a/tests/test_restrictions.py +++ b/tests/test_restrictions.py @@ -9,6 +9,7 @@ from zipline.finance.restrictions import ( Restriction, HistoricalRestrictions, StaticRestrictions, + SecurityListRestrictions, NoopRestrictions, ) @@ -212,6 +213,48 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): assert_vectorized_results([True, True, False], dt) + def test_security_list_restrictions(self): + """ + Test single- and multi-asset queries on restrictions defined by + zipline.utils.security_list.SecurityList + """ + + # A mock SecurityList object filled with fake data + class SecurityList(object): + def __init__(self, assets_by_dt): + self.assets_by_dt = assets_by_dt + + def current_securities(self, dt): + return self.assets_by_dt[dt] + + assets_by_dt = { + str_to_ts('2011-01-03'): [self.ASSET1], + str_to_ts('2011-01-04'): [self.ASSET2, self.ASSET3], + str_to_ts('2011-01-05'): [self.ASSET1, self.ASSET2, self.ASSET3], + } + + rl = SecurityListRestrictions(SecurityList(assets_by_dt)) + + assert_not_restricted = partial(self.assert_not_restricted, rl) + assert_is_restricted = partial(self.assert_is_restricted, rl) + assert_vectorized_results = partial(self.assert_vectorized_results, rl) + + assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03')) + assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03')) + assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03')) + assert_vectorized_results( + [True, False, False], str_to_ts('2011-01-03')) + + assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04')) + assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04')) + assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04')) + assert_vectorized_results([False, True, True], str_to_ts('2011-01-04')) + + assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05')) + assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05')) + assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05')) + assert_vectorized_results([True, True, True], str_to_ts('2011-01-05')) + def test_noop_restrictions(self): """ Test single- and multi-asset queries on no-op restrictions diff --git a/tests/test_security_list.py b/tests/test_security_list.py index f0609142..6dd975c9 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -36,7 +36,8 @@ class RestrictedAlgoWithCheck(TradingAlgorithm): def handle_data(self, data): if not self.order_count: if self.sid not in \ - self.rl.leveraged_etf_list: + self.rl.leveraged_etf_list.\ + current_securities(self.get_datetime()): self.order(self.sid, 100) self.order_count += 1 @@ -62,7 +63,8 @@ class IterateRLAlgo(TradingAlgorithm): self.found = False def handle_data(self, data): - for stock in self.rl.leveraged_etf_list: + for stock in self.rl.leveraged_etf_list.\ + current_securities(self.get_datetime()): if stock == self.sid: self.found = True @@ -151,7 +153,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): for symbol in ["BZQ", "URTY", "JFT"]] ] for sid in should_exist: - self.assertIn(sid, rl.leveraged_etf_list) + self.assertIn( + sid, rl.leveraged_etf_list.current_securities(get_datetime())) # assert that a sample of allowed stocks are not in restricted shouldnt_exist = [ @@ -162,7 +165,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): for symbol in ["AAPL", "GOOG"]] ] for sid in shouldnt_exist: - self.assertNotIn(sid, rl.leveraged_etf_list) + self.assertNotIn( + sid, rl.leveraged_etf_list.current_securities(get_datetime())) def test_security_add(self): def get_datetime(): @@ -178,15 +182,24 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): ) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]] ] for sid in should_exist: - self.assertIn(sid, rl.leveraged_etf_list) + self.assertIn( + sid, + rl.leveraged_etf_list.current_securities(get_datetime()) + ) def test_security_add_delete(self): with security_list_copy(): def get_datetime(): return pd.Timestamp("2015-01-27", tz='UTC') rl = SecurityListSet(get_datetime, self.env.asset_finder) - self.assertNotIn("BZQ", rl.leveraged_etf_list) - self.assertNotIn("URTY", rl.leveraged_etf_list) + self.assertNotIn( + "BZQ", + rl.leveraged_etf_list.current_securities(get_datetime()) + ) + self.assertNotIn( + "URTY", + rl.leveraged_etf_list.current_securities(get_datetime()) + ) def test_algo_without_rl_violation_via_check(self): algo = RestrictedAlgoWithCheck(symbol='BZQ', diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 9b6df86c..d5dc429f 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -84,7 +84,8 @@ from zipline.finance.slippage import ( from zipline.finance.cancel_policy import NeverCancel, CancelPolicy from zipline.finance.restrictions import ( NoopRestrictions, - StaticRestrictions + StaticRestrictions, + SecurityListRestrictions, ) from zipline.assets import Asset, Future from zipline.gens.tradesimulation import AlgorithmSimulator @@ -2175,6 +2176,8 @@ class TradingAlgorithm(object): if isinstance(restricted_list, (list, tuple, set)): restricted_list = StaticRestrictions(restricted_list) + elif isinstance(restricted_list, SecurityList): + restricted_list = SecurityListRestrictions(restricted_list) control = RestrictedListOrder(on_error, restricted_list) self.register_trading_control(control) diff --git a/zipline/api.py b/zipline/api.py index 426b5f26..f3e345ac 100644 --- a/zipline/api.py +++ b/zipline/api.py @@ -21,6 +21,7 @@ from .finance.restrictions import ( StaticRestrictions, HistoricalRestrictions, RESTRICTION_STATES, + SecurityListRestrictions, ) from .finance import commission, execution, slippage, cancel_policy from .finance.cancel_policy import ( @@ -46,6 +47,7 @@ __all__ = [ 'StaticRestrictions', 'HistoricalRestrictions', 'RESTRICTION_STATES', + 'SecurityListRestrictions', 'cancel_policy', 'commission', 'date_rules', diff --git a/zipline/finance/restrictions.py b/zipline/finance/restrictions.py index e977d2d8..52d3b150 100644 --- a/zipline/finance/restrictions.py +++ b/zipline/finance/restrictions.py @@ -127,3 +127,26 @@ class HistoricalRestrictions(Restrictions): break state = r.state return state == RESTRICTION_STATES.FROZEN + + +class SecurityListRestrictions(Restrictions): + """ + Restrictions based on a security list + + Parameters + ---------- + restrictions : zipline.utils.security_list.SecurityList + The restrictions defined by a SecurityList + """ + + def __init__(self, security_list_by_dt): + self.current_securities = security_list_by_dt.current_securities + + def is_restricted(self, assets, dt): + securities_in_list = self.current_securities(dt) + if isinstance(assets, Asset): + return assets in securities_in_list + return pd.Series( + index=pd.Index(assets), + data=vectorized_is_element(assets, securities_in_list) + ) diff --git a/zipline/utils/security_list.py b/zipline/utils/security_list.py index 50368b59..8532d3fb 100644 --- a/zipline/utils/security_list.py +++ b/zipline/utils/security_list.py @@ -1,3 +1,4 @@ +import warnings from datetime import datetime from os import listdir import os.path @@ -7,6 +8,7 @@ import pytz import zipline from zipline.errors import SymbolNotFound +from zipline.zipline_warnings import ZiplineDeprecationWarning DATE_FORMAT = "%Y%m%d" @@ -38,17 +40,26 @@ class SecurityList(object): return knowledge_dates def __iter__(self): - return iter(self.restricted_list) + warnings.warn( + 'Iterating over security_lists is deprecated. Use ' + '`for sid in .current_securities(dt)` instead.', + category=ZiplineDeprecationWarning, + stacklevel=2 + ) + return iter(self.current_securities(self.current_date())) def __contains__(self, item): - return item in self.restricted_list + warnings.warn( + 'Evaluating inclusion in security_lists is deprecated. Use ' + '`sid in .current_securities(dt)` instead.', + category=ZiplineDeprecationWarning, + stacklevel=2 + ) + return item in self.current_securities(self.current_date()) - @property - def restricted_list(self): - - cd = self.current_date() + def current_securities(self, dt): for kd in self._knowledge_dates: - if cd < kd: + if dt < kd: break if kd in self._cache: self._current_set = self._cache[kd]