MAINT: Refactor schedule function rules

Refactor to eliminate unnecessary type coercion. Reduce some code
duplication
This commit is contained in:
Andrew Liang
2016-06-10 17:40:45 -04:00
parent 28b1da443e
commit ba3ba053cb
2 changed files with 88 additions and 128 deletions
+7 -4
View File
@@ -66,7 +66,7 @@ import pandas as pd
from pandas.tseries.tools import normalize_date
from zipline.finance.performance.period import PerformancePeriod
from zipline.errors import NoFurtherDataError
import zipline.finance.risk as risk
from . position_tracker import PositionTracker
@@ -391,9 +391,12 @@ class PerformanceTracker(object):
# Get the next trading day and, if it is past the bounds of this
# simulation, return the daily perf packet
next_trading_day = self.trading_schedule.next_execution_day(
completed_date
)
try:
next_trading_day = self.trading_schedule.next_execution_day(
completed_date
)
except NoFurtherDataError:
next_trading_day = None
# Take a snapshot of our current performance to return to the
# browser.
+81 -124
View File
@@ -21,6 +21,7 @@ import pandas as pd
import pytz
from .context_tricks import nop_context
from zipline.errors import NoFurtherDataError
from zipline.utils.calendars import normalize_date
@@ -71,29 +72,6 @@ def ensure_utc(time, tz='UTC'):
return time.replace(tzinfo=pytz.utc)
def _coerce_datetime(maybe_dt):
if isinstance(maybe_dt, datetime.datetime):
return maybe_dt
elif isinstance(maybe_dt, datetime.date):
return datetime.datetime(
year=maybe_dt.year,
month=maybe_dt.month,
day=maybe_dt.day,
tzinfo=pytz.utc,
)
elif isinstance(maybe_dt, (tuple, list)) and len(maybe_dt) == 3:
year, month, day = maybe_dt
return datetime.datetime(
year=year,
month=month,
day=day,
tzinfo=pytz.utc,
)
else:
raise TypeError('Cannot coerce %s into a datetime.datetime'
% type(maybe_dt).__name__)
def _out_of_range_error(a, b=None, var='offset'):
start = 0
if b is None:
@@ -352,8 +330,8 @@ class AfterOpen(StatelessRule):
# on the fact that our clock only ever ticks forward, since it's
# cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means
# that we will NOT correctly recognize a new date if we go backwards
# in time(which should never happen in a simulation, or in a live
# trading environment)
# in time(which should never happen in a simulation, or in live
# trading)
if (
self._period_start is None or
self._period_close <= dt
@@ -396,8 +374,8 @@ class BeforeClose(StatelessRule):
# on the fact that our clock only ever ticks forward, since it's
# cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means
# that we will NOT correctly recognize a new date if we go backwards
# in time(which should never happen in a simulation, or in a live
# trading environment)
# in time(which should never happen in a simulation, or in live
# trading)
if (
self._period_start is None or
self._period_close <= dt
@@ -416,38 +394,31 @@ class NotHalfDay(StatelessRule):
class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)):
def __init__(self, n=0):
if not 0 <= abs(n) < MAX_WEEK_RANGE:
def __init__(self, n, invert):
if not 0 <= n < MAX_WEEK_RANGE:
raise _out_of_range_error(MAX_WEEK_RANGE)
self.td_delta = n
self.next_date_start = None
self.next_date_end = None
self.next_midnight_timestamp = None
self.td_delta = -n if invert else n
@abstractmethod
def date_func(self, dt, cal):
raise NotImplementedError
def calculate_start_and_end(self, dt):
next_trading_day = _coerce_datetime(
self.cal.add_trading_days(
while True:
next_trading_day = self.cal.add_trading_days(
self.td_delta,
self.date_func(dt, self.cal),
)
)
# If after applying the offset to the start/end day of the week, we get
# day in a different week, skip this week and go on to the next
while next_trading_day.isocalendar()[1] != dt.isocalendar()[1]:
dt += datetime.timedelta(days=7)
next_trading_day = _coerce_datetime(
self.cal.add_trading_days(
self.td_delta,
self.date_func(dt, self.cal),
)
)
# If after applying the offset to the start/end day of the week, we
# get day in a different week, skip this week and go on to the next
if next_trading_day.isocalendar()[1] == dt.isocalendar()[1]:
break
else:
dt += datetime.timedelta(days=7)
next_open, next_close = self.cal.open_and_close(next_trading_day)
self.next_date_start = next_open
@@ -479,28 +450,25 @@ class NthTradingDayOfWeek(TradingDayOfWeekRule):
A rule that triggers on the nth trading day of the week.
This is zero-indexed, n=0 is the first trading day of the week.
"""
def __init__(self, n):
super(NthTradingDayOfWeek, self).__init__(n, invert=False)
@staticmethod
def get_first_trading_day_of_week(dt, cal):
prev = dt
dt = cal.previous_trading_day(dt)
# If we're on the first trading day of the TradingEnvironment,
# calling previous_trading_day on it will return None, which
# will blow up when we try and call .date() on it. The first
# trading day of the env is also the first trading day of the
# week(in the TradingEnvironment, at least), so just return
# that date.
if dt is None:
return prev
while dt.date().weekday() < prev.date().weekday():
prev = None
# Traverse backward until we hit a week border, then jump back to the
# previous trading day.
try:
while not prev or dt.weekday() < prev.weekday():
prev = dt
dt = cal.previous_trading_day(dt)
except NoFurtherDataError:
prev = dt
dt = cal.previous_trading_day(dt)
if dt is None:
return prev
if cal.is_open_on_day(prev):
return prev.date()
return prev
else:
return cal.next_trading_day(prev).date()
return cal.next_trading_day(prev)
date_func = get_first_trading_day_of_week
@@ -510,93 +478,80 @@ class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule):
A rule that triggers n days before the last trading day of the week.
"""
def __init__(self, n):
super(NDaysBeforeLastTradingDayOfWeek, self).__init__(-n)
super(NDaysBeforeLastTradingDayOfWeek, self).__init__(n, invert=True)
@staticmethod
def get_last_trading_day_of_week(dt, cal):
prev = dt
dt = cal.next_trading_day(dt)
prev = None
# Traverse forward until we hit a week border, then jump back to the
# previous trading day.
while dt.date().weekday() > prev.date().weekday():
try:
while not prev or dt.weekday() > prev.weekday():
prev = dt
dt = cal.next_trading_day(dt)
except NoFurtherDataError:
prev = dt
dt = cal.next_trading_day(dt)
if cal.is_open_on_day(prev):
return prev.date()
return prev
else:
return cal.previous_trading_day(prev).date()
return cal.previous_trading_day(prev)
date_func = get_last_trading_day_of_week
class NthTradingDayOfMonth(StatelessRule):
class TradingDayOfMonthRule(six.with_metaclass(ABCMeta, StatelessRule)):
def __init__(self, n, invert):
if not 0 <= n < MAX_MONTH_RANGE:
raise _out_of_range_error(MAX_MONTH_RANGE)
self.month = None
self.date = None
self.td_delta = -n if invert else n
def should_trigger(self, dt):
return self.get_trigger_day_of_month(dt) == normalize_date(dt)
@abstractmethod
def date_func(self, dt):
raise NotImplementedError
def get_trigger_day_of_month(self, dt):
if self.month == dt.month:
# We already computed the day for this month.
return self.date
self.date = self.date_func(dt)
if self.td_delta:
self.date = self.cal.add_trading_days(self.td_delta, self.date)
return self.date
class NthTradingDayOfMonth(TradingDayOfMonthRule):
"""
A rule that triggers on the nth trading day of the month.
This is zero-indexed, n=0 is the first trading day of the month.
"""
def __init__(self, n=0):
if not 0 <= n < MAX_MONTH_RANGE:
raise _out_of_range_error(MAX_MONTH_RANGE)
self.td_delta = n
self.month = None
self.day = None
def should_trigger(self, dt):
return self.get_nth_trading_day_of_month(dt) == dt.date()
def get_nth_trading_day_of_month(self, dt):
if self.month == dt.month:
# We already computed the day for this month.
return self.day
if not self.td_delta:
self.day = self.get_first_trading_day_of_month(dt)
else:
self.day = self.cal.add_trading_days(
self.td_delta,
self.get_first_trading_day_of_month(dt),
).date()
return self.day
def __init__(self, n):
super(NthTradingDayOfMonth, self).__init__(n, invert=False)
def get_first_trading_day_of_month(self, dt):
self.month = dt.month
dt = dt.replace(day=1)
self.first_day = (dt if self.cal.is_open_on_day(dt)
else self.cal.next_trading_day(dt)).date()
return self.first_day
first_day = (dt if self.cal.is_open_on_day(dt)
else self.cal.next_trading_day(dt))
return first_day
date_func = get_first_trading_day_of_month
class NDaysBeforeLastTradingDayOfMonth(StatelessRule):
class NDaysBeforeLastTradingDayOfMonth(TradingDayOfMonthRule):
"""
A rule that triggers n days before the last trading day of the month.
"""
def __init__(self, n=0):
if not 0 <= n < MAX_MONTH_RANGE:
raise _out_of_range_error(MAX_MONTH_RANGE)
self.td_delta = -n
self.month = None
self.day = None
def should_trigger(self, dt):
return self.get_nth_to_last_trading_day_of_month(dt) == dt.date()
def get_nth_to_last_trading_day_of_month(self, dt):
if self.month == dt.month:
# We already computed the last day for this month.
return self.day
if not self.td_delta:
self.day = self.get_last_trading_day_of_month(dt)
else:
self.day = self.cal.add_trading_days(
self.td_delta,
self.get_last_trading_day_of_month(dt),
).date()
return self.day
def __init__(self, n):
super(NDaysBeforeLastTradingDayOfMonth, self).__init__(n, invert=True)
def get_last_trading_day_of_month(self, dt):
self.month = dt.month
@@ -610,10 +565,12 @@ class NDaysBeforeLastTradingDayOfMonth(StatelessRule):
year = dt.year
month = dt.month + 1
self.last_day = self.cal.previous_trading_day(
last_day = self.cal.previous_trading_day(
dt.replace(year=year, month=month, day=1)
).date()
return self.last_day
)
return last_day
date_func = get_last_trading_day_of_month
# Stateful rules