From ba3ba053cb57bccf59706fe484aabbbbc0ff6285 Mon Sep 17 00:00:00 2001 From: Andrew Liang Date: Fri, 10 Jun 2016 17:40:45 -0400 Subject: [PATCH] MAINT: Refactor schedule function rules Refactor to eliminate unnecessary type coercion. Reduce some code duplication --- zipline/finance/performance/tracker.py | 11 +- zipline/utils/events.py | 205 ++++++++++--------------- 2 files changed, 88 insertions(+), 128 deletions(-) diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index 5f7dfa41..1d4c3630 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -66,7 +66,7 @@ import pandas as pd from pandas.tseries.tools import normalize_date from zipline.finance.performance.period import PerformancePeriod - +from zipline.errors import NoFurtherDataError import zipline.finance.risk as risk from . position_tracker import PositionTracker @@ -391,9 +391,12 @@ class PerformanceTracker(object): # Get the next trading day and, if it is past the bounds of this # simulation, return the daily perf packet - next_trading_day = self.trading_schedule.next_execution_day( - completed_date - ) + try: + next_trading_day = self.trading_schedule.next_execution_day( + completed_date + ) + except NoFurtherDataError: + next_trading_day = None # Take a snapshot of our current performance to return to the # browser. diff --git a/zipline/utils/events.py b/zipline/utils/events.py index eafaddce..1a7bb068 100644 --- a/zipline/utils/events.py +++ b/zipline/utils/events.py @@ -21,6 +21,7 @@ import pandas as pd import pytz from .context_tricks import nop_context +from zipline.errors import NoFurtherDataError from zipline.utils.calendars import normalize_date @@ -71,29 +72,6 @@ def ensure_utc(time, tz='UTC'): return time.replace(tzinfo=pytz.utc) -def _coerce_datetime(maybe_dt): - if isinstance(maybe_dt, datetime.datetime): - return maybe_dt - elif isinstance(maybe_dt, datetime.date): - return datetime.datetime( - year=maybe_dt.year, - month=maybe_dt.month, - day=maybe_dt.day, - tzinfo=pytz.utc, - ) - elif isinstance(maybe_dt, (tuple, list)) and len(maybe_dt) == 3: - year, month, day = maybe_dt - return datetime.datetime( - year=year, - month=month, - day=day, - tzinfo=pytz.utc, - ) - else: - raise TypeError('Cannot coerce %s into a datetime.datetime' - % type(maybe_dt).__name__) - - def _out_of_range_error(a, b=None, var='offset'): start = 0 if b is None: @@ -352,8 +330,8 @@ class AfterOpen(StatelessRule): # on the fact that our clock only ever ticks forward, since it's # cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means # that we will NOT correctly recognize a new date if we go backwards - # in time(which should never happen in a simulation, or in a live - # trading environment) + # in time(which should never happen in a simulation, or in live + # trading) if ( self._period_start is None or self._period_close <= dt @@ -396,8 +374,8 @@ class BeforeClose(StatelessRule): # on the fact that our clock only ever ticks forward, since it's # cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means # that we will NOT correctly recognize a new date if we go backwards - # in time(which should never happen in a simulation, or in a live - # trading environment) + # in time(which should never happen in a simulation, or in live + # trading) if ( self._period_start is None or self._period_close <= dt @@ -416,38 +394,31 @@ class NotHalfDay(StatelessRule): class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)): - def __init__(self, n=0): - if not 0 <= abs(n) < MAX_WEEK_RANGE: + def __init__(self, n, invert): + if not 0 <= n < MAX_WEEK_RANGE: raise _out_of_range_error(MAX_WEEK_RANGE) - - self.td_delta = n - self.next_date_start = None self.next_date_end = None self.next_midnight_timestamp = None + self.td_delta = -n if invert else n @abstractmethod def date_func(self, dt, cal): raise NotImplementedError def calculate_start_and_end(self, dt): - next_trading_day = _coerce_datetime( - self.cal.add_trading_days( + while True: + next_trading_day = self.cal.add_trading_days( self.td_delta, self.date_func(dt, self.cal), ) - ) - # If after applying the offset to the start/end day of the week, we get - # day in a different week, skip this week and go on to the next - while next_trading_day.isocalendar()[1] != dt.isocalendar()[1]: - dt += datetime.timedelta(days=7) - next_trading_day = _coerce_datetime( - self.cal.add_trading_days( - self.td_delta, - self.date_func(dt, self.cal), - ) - ) + # If after applying the offset to the start/end day of the week, we + # get day in a different week, skip this week and go on to the next + if next_trading_day.isocalendar()[1] == dt.isocalendar()[1]: + break + else: + dt += datetime.timedelta(days=7) next_open, next_close = self.cal.open_and_close(next_trading_day) self.next_date_start = next_open @@ -479,28 +450,25 @@ class NthTradingDayOfWeek(TradingDayOfWeekRule): A rule that triggers on the nth trading day of the week. This is zero-indexed, n=0 is the first trading day of the week. """ + def __init__(self, n): + super(NthTradingDayOfWeek, self).__init__(n, invert=False) + @staticmethod def get_first_trading_day_of_week(dt, cal): - prev = dt - dt = cal.previous_trading_day(dt) - # If we're on the first trading day of the TradingEnvironment, - # calling previous_trading_day on it will return None, which - # will blow up when we try and call .date() on it. The first - # trading day of the env is also the first trading day of the - # week(in the TradingEnvironment, at least), so just return - # that date. - if dt is None: - return prev - while dt.date().weekday() < prev.date().weekday(): + prev = None + # Traverse backward until we hit a week border, then jump back to the + # previous trading day. + try: + while not prev or dt.weekday() < prev.weekday(): + prev = dt + dt = cal.previous_trading_day(dt) + except NoFurtherDataError: prev = dt - dt = cal.previous_trading_day(dt) - if dt is None: - return prev if cal.is_open_on_day(prev): - return prev.date() + return prev else: - return cal.next_trading_day(prev).date() + return cal.next_trading_day(prev) date_func = get_first_trading_day_of_week @@ -510,93 +478,80 @@ class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule): A rule that triggers n days before the last trading day of the week. """ def __init__(self, n): - super(NDaysBeforeLastTradingDayOfWeek, self).__init__(-n) + super(NDaysBeforeLastTradingDayOfWeek, self).__init__(n, invert=True) @staticmethod def get_last_trading_day_of_week(dt, cal): - prev = dt - dt = cal.next_trading_day(dt) + prev = None # Traverse forward until we hit a week border, then jump back to the # previous trading day. - while dt.date().weekday() > prev.date().weekday(): + try: + while not prev or dt.weekday() > prev.weekday(): + prev = dt + dt = cal.next_trading_day(dt) + except NoFurtherDataError: prev = dt - dt = cal.next_trading_day(dt) if cal.is_open_on_day(prev): - return prev.date() + return prev else: - return cal.previous_trading_day(prev).date() + return cal.previous_trading_day(prev) date_func = get_last_trading_day_of_week -class NthTradingDayOfMonth(StatelessRule): +class TradingDayOfMonthRule(six.with_metaclass(ABCMeta, StatelessRule)): + def __init__(self, n, invert): + if not 0 <= n < MAX_MONTH_RANGE: + raise _out_of_range_error(MAX_MONTH_RANGE) + self.month = None + self.date = None + self.td_delta = -n if invert else n + + def should_trigger(self, dt): + return self.get_trigger_day_of_month(dt) == normalize_date(dt) + + @abstractmethod + def date_func(self, dt): + raise NotImplementedError + + def get_trigger_day_of_month(self, dt): + if self.month == dt.month: + # We already computed the day for this month. + return self.date + + self.date = self.date_func(dt) + if self.td_delta: + self.date = self.cal.add_trading_days(self.td_delta, self.date) + + return self.date + + +class NthTradingDayOfMonth(TradingDayOfMonthRule): """ A rule that triggers on the nth trading day of the month. This is zero-indexed, n=0 is the first trading day of the month. """ - def __init__(self, n=0): - if not 0 <= n < MAX_MONTH_RANGE: - raise _out_of_range_error(MAX_MONTH_RANGE) - self.td_delta = n - self.month = None - self.day = None - - def should_trigger(self, dt): - return self.get_nth_trading_day_of_month(dt) == dt.date() - - def get_nth_trading_day_of_month(self, dt): - if self.month == dt.month: - # We already computed the day for this month. - return self.day - - if not self.td_delta: - self.day = self.get_first_trading_day_of_month(dt) - else: - self.day = self.cal.add_trading_days( - self.td_delta, - self.get_first_trading_day_of_month(dt), - ).date() - - return self.day + def __init__(self, n): + super(NthTradingDayOfMonth, self).__init__(n, invert=False) def get_first_trading_day_of_month(self, dt): self.month = dt.month dt = dt.replace(day=1) - self.first_day = (dt if self.cal.is_open_on_day(dt) - else self.cal.next_trading_day(dt)).date() - return self.first_day + first_day = (dt if self.cal.is_open_on_day(dt) + else self.cal.next_trading_day(dt)) + return first_day + + date_func = get_first_trading_day_of_month -class NDaysBeforeLastTradingDayOfMonth(StatelessRule): +class NDaysBeforeLastTradingDayOfMonth(TradingDayOfMonthRule): """ A rule that triggers n days before the last trading day of the month. """ - def __init__(self, n=0): - if not 0 <= n < MAX_MONTH_RANGE: - raise _out_of_range_error(MAX_MONTH_RANGE) - self.td_delta = -n - self.month = None - self.day = None - - def should_trigger(self, dt): - return self.get_nth_to_last_trading_day_of_month(dt) == dt.date() - - def get_nth_to_last_trading_day_of_month(self, dt): - if self.month == dt.month: - # We already computed the last day for this month. - return self.day - - if not self.td_delta: - self.day = self.get_last_trading_day_of_month(dt) - else: - self.day = self.cal.add_trading_days( - self.td_delta, - self.get_last_trading_day_of_month(dt), - ).date() - - return self.day + def __init__(self, n): + super(NDaysBeforeLastTradingDayOfMonth, self).__init__(n, invert=True) def get_last_trading_day_of_month(self, dt): self.month = dt.month @@ -610,10 +565,12 @@ class NDaysBeforeLastTradingDayOfMonth(StatelessRule): year = dt.year month = dt.month + 1 - self.last_day = self.cal.previous_trading_day( + last_day = self.cal.previous_trading_day( dt.replace(year=year, month=month, day=1) - ).date() - return self.last_day + ) + return last_day + + date_func = get_last_trading_day_of_month # Stateful rules