diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 847bf485..8a907fd2 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -54,16 +54,18 @@ from zipline.data.us_equity_pricing import ( BcolzDailyBarWriter, ) from zipline.errors import ( - OrderDuringInitialize, - RegisterTradingControlPostInit, - TradingControlViolation, AccountControlViolation, - SymbolNotFound, - UnsupportedDatetimeFormat, CannotOrderDelistedAsset, + OrderDuringInitialize, + OrderInBeforeTradingStart, + RegisterTradingControlPostInit, + ScheduleFunctionInvalidCalendar, SetCancelPolicyPostInit, + SymbolNotFound, + TradingControlViolation, UnsupportedCancelPolicy, - OrderInBeforeTradingStart) + UnsupportedDatetimeFormat, +) from zipline.api import ( order, order_value, @@ -437,22 +439,23 @@ def handle_data(context, data): # run a simulation on the CME cal, and schedule a function # using the NYSE cal algotext = """ -from zipline.api import schedule_function, get_datetime, time_rules, date_rules -from zipline.utils.calendars import get_calendar +from zipline.api import ( + schedule_function, get_datetime, time_rules, date_rules, calendars, +) def initialize(context): schedule_function( func=log_nyse_open, date_rule=date_rules.every_day(), time_rule=time_rules.market_open(), - calendar=get_calendar("NYSE") + calendar=calendars.US_EQUITIES, ) schedule_function( func=log_nyse_close, date_rule=date_rules.every_day(), time_rule=time_rules.market_close(), - calendar=get_calendar("NYSE") + calendar=calendars.US_EQUITIES, ) context.nyse_opens = [] @@ -488,6 +491,30 @@ def log_nyse_close(context, data): session_close = nyse.open_and_close_for_session(session_label)[1] self.assertEqual(session_close - timedelta(minutes=1), minute) + # Test that passing an invalid calendar parameter raises an error. + erroring_algotext = dedent( + """ + from zipline.api import schedule_function + from zipline.utils.calendars import get_calendar + + def initialize(context): + schedule_function(func=my_func, calendar=get_calendar('NYSE')) + + def my_func(context, data): + pass + """ + ) + + algo = TradingAlgorithm( + script=erroring_algotext, + sim_params=self.sim_params, + env=self.env, + trading_calendar=get_calendar('CME'), + ) + + with self.assertRaises(ScheduleFunctionInvalidCalendar): + algo.run(self.data_portal) + def test_schedule_function(self): us_eastern = pytz.timezone('US/Eastern') diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 6b438e9e..476dc706 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -47,23 +47,24 @@ from zipline.data.data_portal import DataPortal from zipline.data.us_equity_pricing import PanelBarReader from zipline.errors import ( AttachPipelineAfterInitialize, + CannotOrderDelistedAsset, HistoryInInitialize, NoSuchPipeline, OrderDuringInitialize, + OrderInBeforeTradingStart, PipelineOutputDuringInitialize, RegisterAccountControlPostInit, RegisterTradingControlPostInit, + ScheduleFunctionInvalidCalendar, SetBenchmarkOutsideInitialize, + SetCancelPolicyPostInit, SetCommissionPostInit, SetSlippagePostInit, + UnsupportedCancelPolicy, UnsupportedCommissionModel, UnsupportedDatetimeFormat, UnsupportedOrderParameters, UnsupportedSlippageModel, - CannotOrderDelistedAsset, - UnsupportedCancelPolicy, - SetCancelPolicyPostInit, - OrderInBeforeTradingStart ) from zipline.finance.trading import TradingEnvironment from zipline.finance.blotter import Blotter @@ -122,6 +123,7 @@ from zipline.utils.events import ( make_eventrule, date_rules, time_rules, + calendars, AfterOpen, BeforeClose ) @@ -1081,6 +1083,8 @@ class TradingAlgorithm(object): The rule for the times to execute this function. half_days : bool, optional Should this rule fire on half days? + calendar : Sentinel, optional + Calendar used to reconcile date and time rules. See Also -------- @@ -1106,7 +1110,19 @@ class TradingAlgorithm(object): # Check the type of the algorithm's schedule before pulling calendar # Note that the ExchangeTradingSchedule is currently the only # TradingSchedule class, so this is unlikely to be hit - cal = calendar or self.trading_calendar + if calendar is None: + cal = self.trading_calendar + elif calendar is calendars.US_EQUITIES: + cal = get_calendar('NYSE') + elif calendar is calendars.US_FUTURES: + cal = get_calendar('us_futures') + else: + raise ScheduleFunctionInvalidCalendar( + given_calendar=calendar, + allowed_calendars=( + '[calendars.US_EQUITIES, calendars.US_FUTURES]' + ), + ) self.add_event( make_eventrule(date_rule, time_rule, cal, half_days), diff --git a/zipline/api.py b/zipline/api.py index 68b78714..da4bfc74 100644 --- a/zipline/api.py +++ b/zipline/api.py @@ -33,6 +33,7 @@ from .finance.slippage import ( ) from .utils import math_utils, events from .utils.events import ( + calendars, date_rules, time_rules ) @@ -53,5 +54,6 @@ __all__ = [ 'execution', 'math_utils', 'slippage', - 'time_rules' + 'time_rules', + 'calendars', ] diff --git a/zipline/errors.py b/zipline/errors.py index 17626df3..dfd0a294 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -748,6 +748,16 @@ class ScheduleFunctionWithoutCalendar(ZiplineError): ) +class ScheduleFunctionInvalidCalendar(ZiplineError): + """ + Raised when schedule_function is called with an invalid calendar argument. + """ + msg = ( + "Invalid calendar '{given_calendar}' passed to schedule_function. " + "Allowed options are {allowed_calendars}." + ) + + class UnsupportedPipelineOutput(ZiplineError): """ Raised when a 1D term is added as a column to a pipeline. diff --git a/zipline/utils/events.py b/zipline/utils/events.py index 657d5ace..878f1b00 100644 --- a/zipline/utils/events.py +++ b/zipline/utils/events.py @@ -25,6 +25,7 @@ from toolz import curry from zipline.utils.input_validation import preprocess from zipline.utils.memoize import lazyval +from zipline.utils.sentinel import sentinel from .context_tricks import nop_context @@ -50,6 +51,7 @@ __all__ = [ # Factory API 'date_rules', 'time_rules', + 'calendars', 'make_eventrule', ] @@ -603,6 +605,11 @@ class time_rules(object): every_minute = Always +class calendars(object): + US_EQUITIES = sentinel('US_EQUITIES') + US_FUTURES = sentinel('US_FUTURES') + + def make_eventrule(date_rule, time_rule, cal, half_days=True): """ Constructs an event rule from the factory api.