mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 14:20:48 +08:00
MAINT: Refactor schedule function rules
Refactor to eliminate unnecessary type coercion. Reduce some code duplication
This commit is contained in:
@@ -66,7 +66,7 @@ import pandas as pd
|
||||
from pandas.tseries.tools import normalize_date
|
||||
|
||||
from zipline.finance.performance.period import PerformancePeriod
|
||||
|
||||
from zipline.errors import NoFurtherDataError
|
||||
import zipline.finance.risk as risk
|
||||
|
||||
from . position_tracker import PositionTracker
|
||||
@@ -391,9 +391,12 @@ class PerformanceTracker(object):
|
||||
|
||||
# Get the next trading day and, if it is past the bounds of this
|
||||
# simulation, return the daily perf packet
|
||||
next_trading_day = self.trading_schedule.next_execution_day(
|
||||
completed_date
|
||||
)
|
||||
try:
|
||||
next_trading_day = self.trading_schedule.next_execution_day(
|
||||
completed_date
|
||||
)
|
||||
except NoFurtherDataError:
|
||||
next_trading_day = None
|
||||
|
||||
# Take a snapshot of our current performance to return to the
|
||||
# browser.
|
||||
|
||||
+81
-124
@@ -21,6 +21,7 @@ import pandas as pd
|
||||
import pytz
|
||||
|
||||
from .context_tricks import nop_context
|
||||
from zipline.errors import NoFurtherDataError
|
||||
|
||||
from zipline.utils.calendars import normalize_date
|
||||
|
||||
@@ -71,29 +72,6 @@ def ensure_utc(time, tz='UTC'):
|
||||
return time.replace(tzinfo=pytz.utc)
|
||||
|
||||
|
||||
def _coerce_datetime(maybe_dt):
|
||||
if isinstance(maybe_dt, datetime.datetime):
|
||||
return maybe_dt
|
||||
elif isinstance(maybe_dt, datetime.date):
|
||||
return datetime.datetime(
|
||||
year=maybe_dt.year,
|
||||
month=maybe_dt.month,
|
||||
day=maybe_dt.day,
|
||||
tzinfo=pytz.utc,
|
||||
)
|
||||
elif isinstance(maybe_dt, (tuple, list)) and len(maybe_dt) == 3:
|
||||
year, month, day = maybe_dt
|
||||
return datetime.datetime(
|
||||
year=year,
|
||||
month=month,
|
||||
day=day,
|
||||
tzinfo=pytz.utc,
|
||||
)
|
||||
else:
|
||||
raise TypeError('Cannot coerce %s into a datetime.datetime'
|
||||
% type(maybe_dt).__name__)
|
||||
|
||||
|
||||
def _out_of_range_error(a, b=None, var='offset'):
|
||||
start = 0
|
||||
if b is None:
|
||||
@@ -352,8 +330,8 @@ class AfterOpen(StatelessRule):
|
||||
# on the fact that our clock only ever ticks forward, since it's
|
||||
# cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means
|
||||
# that we will NOT correctly recognize a new date if we go backwards
|
||||
# in time(which should never happen in a simulation, or in a live
|
||||
# trading environment)
|
||||
# in time(which should never happen in a simulation, or in live
|
||||
# trading)
|
||||
if (
|
||||
self._period_start is None or
|
||||
self._period_close <= dt
|
||||
@@ -396,8 +374,8 @@ class BeforeClose(StatelessRule):
|
||||
# on the fact that our clock only ever ticks forward, since it's
|
||||
# cheaper to do dt1 <= dt2 than dt1.date() != dt2.date(). This means
|
||||
# that we will NOT correctly recognize a new date if we go backwards
|
||||
# in time(which should never happen in a simulation, or in a live
|
||||
# trading environment)
|
||||
# in time(which should never happen in a simulation, or in live
|
||||
# trading)
|
||||
if (
|
||||
self._period_start is None or
|
||||
self._period_close <= dt
|
||||
@@ -416,38 +394,31 @@ class NotHalfDay(StatelessRule):
|
||||
|
||||
|
||||
class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)):
|
||||
def __init__(self, n=0):
|
||||
if not 0 <= abs(n) < MAX_WEEK_RANGE:
|
||||
def __init__(self, n, invert):
|
||||
if not 0 <= n < MAX_WEEK_RANGE:
|
||||
raise _out_of_range_error(MAX_WEEK_RANGE)
|
||||
|
||||
self.td_delta = n
|
||||
|
||||
self.next_date_start = None
|
||||
self.next_date_end = None
|
||||
self.next_midnight_timestamp = None
|
||||
self.td_delta = -n if invert else n
|
||||
|
||||
@abstractmethod
|
||||
def date_func(self, dt, cal):
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_start_and_end(self, dt):
|
||||
next_trading_day = _coerce_datetime(
|
||||
self.cal.add_trading_days(
|
||||
while True:
|
||||
next_trading_day = self.cal.add_trading_days(
|
||||
self.td_delta,
|
||||
self.date_func(dt, self.cal),
|
||||
)
|
||||
)
|
||||
|
||||
# If after applying the offset to the start/end day of the week, we get
|
||||
# day in a different week, skip this week and go on to the next
|
||||
while next_trading_day.isocalendar()[1] != dt.isocalendar()[1]:
|
||||
dt += datetime.timedelta(days=7)
|
||||
next_trading_day = _coerce_datetime(
|
||||
self.cal.add_trading_days(
|
||||
self.td_delta,
|
||||
self.date_func(dt, self.cal),
|
||||
)
|
||||
)
|
||||
# If after applying the offset to the start/end day of the week, we
|
||||
# get day in a different week, skip this week and go on to the next
|
||||
if next_trading_day.isocalendar()[1] == dt.isocalendar()[1]:
|
||||
break
|
||||
else:
|
||||
dt += datetime.timedelta(days=7)
|
||||
|
||||
next_open, next_close = self.cal.open_and_close(next_trading_day)
|
||||
self.next_date_start = next_open
|
||||
@@ -479,28 +450,25 @@ class NthTradingDayOfWeek(TradingDayOfWeekRule):
|
||||
A rule that triggers on the nth trading day of the week.
|
||||
This is zero-indexed, n=0 is the first trading day of the week.
|
||||
"""
|
||||
def __init__(self, n):
|
||||
super(NthTradingDayOfWeek, self).__init__(n, invert=False)
|
||||
|
||||
@staticmethod
|
||||
def get_first_trading_day_of_week(dt, cal):
|
||||
prev = dt
|
||||
dt = cal.previous_trading_day(dt)
|
||||
# If we're on the first trading day of the TradingEnvironment,
|
||||
# calling previous_trading_day on it will return None, which
|
||||
# will blow up when we try and call .date() on it. The first
|
||||
# trading day of the env is also the first trading day of the
|
||||
# week(in the TradingEnvironment, at least), so just return
|
||||
# that date.
|
||||
if dt is None:
|
||||
return prev
|
||||
while dt.date().weekday() < prev.date().weekday():
|
||||
prev = None
|
||||
# Traverse backward until we hit a week border, then jump back to the
|
||||
# previous trading day.
|
||||
try:
|
||||
while not prev or dt.weekday() < prev.weekday():
|
||||
prev = dt
|
||||
dt = cal.previous_trading_day(dt)
|
||||
except NoFurtherDataError:
|
||||
prev = dt
|
||||
dt = cal.previous_trading_day(dt)
|
||||
if dt is None:
|
||||
return prev
|
||||
|
||||
if cal.is_open_on_day(prev):
|
||||
return prev.date()
|
||||
return prev
|
||||
else:
|
||||
return cal.next_trading_day(prev).date()
|
||||
return cal.next_trading_day(prev)
|
||||
|
||||
date_func = get_first_trading_day_of_week
|
||||
|
||||
@@ -510,93 +478,80 @@ class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule):
|
||||
A rule that triggers n days before the last trading day of the week.
|
||||
"""
|
||||
def __init__(self, n):
|
||||
super(NDaysBeforeLastTradingDayOfWeek, self).__init__(-n)
|
||||
super(NDaysBeforeLastTradingDayOfWeek, self).__init__(n, invert=True)
|
||||
|
||||
@staticmethod
|
||||
def get_last_trading_day_of_week(dt, cal):
|
||||
prev = dt
|
||||
dt = cal.next_trading_day(dt)
|
||||
prev = None
|
||||
# Traverse forward until we hit a week border, then jump back to the
|
||||
# previous trading day.
|
||||
while dt.date().weekday() > prev.date().weekday():
|
||||
try:
|
||||
while not prev or dt.weekday() > prev.weekday():
|
||||
prev = dt
|
||||
dt = cal.next_trading_day(dt)
|
||||
except NoFurtherDataError:
|
||||
prev = dt
|
||||
dt = cal.next_trading_day(dt)
|
||||
|
||||
if cal.is_open_on_day(prev):
|
||||
return prev.date()
|
||||
return prev
|
||||
else:
|
||||
return cal.previous_trading_day(prev).date()
|
||||
return cal.previous_trading_day(prev)
|
||||
|
||||
date_func = get_last_trading_day_of_week
|
||||
|
||||
|
||||
class NthTradingDayOfMonth(StatelessRule):
|
||||
class TradingDayOfMonthRule(six.with_metaclass(ABCMeta, StatelessRule)):
|
||||
def __init__(self, n, invert):
|
||||
if not 0 <= n < MAX_MONTH_RANGE:
|
||||
raise _out_of_range_error(MAX_MONTH_RANGE)
|
||||
self.month = None
|
||||
self.date = None
|
||||
self.td_delta = -n if invert else n
|
||||
|
||||
def should_trigger(self, dt):
|
||||
return self.get_trigger_day_of_month(dt) == normalize_date(dt)
|
||||
|
||||
@abstractmethod
|
||||
def date_func(self, dt):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_trigger_day_of_month(self, dt):
|
||||
if self.month == dt.month:
|
||||
# We already computed the day for this month.
|
||||
return self.date
|
||||
|
||||
self.date = self.date_func(dt)
|
||||
if self.td_delta:
|
||||
self.date = self.cal.add_trading_days(self.td_delta, self.date)
|
||||
|
||||
return self.date
|
||||
|
||||
|
||||
class NthTradingDayOfMonth(TradingDayOfMonthRule):
|
||||
"""
|
||||
A rule that triggers on the nth trading day of the month.
|
||||
This is zero-indexed, n=0 is the first trading day of the month.
|
||||
"""
|
||||
def __init__(self, n=0):
|
||||
if not 0 <= n < MAX_MONTH_RANGE:
|
||||
raise _out_of_range_error(MAX_MONTH_RANGE)
|
||||
self.td_delta = n
|
||||
self.month = None
|
||||
self.day = None
|
||||
|
||||
def should_trigger(self, dt):
|
||||
return self.get_nth_trading_day_of_month(dt) == dt.date()
|
||||
|
||||
def get_nth_trading_day_of_month(self, dt):
|
||||
if self.month == dt.month:
|
||||
# We already computed the day for this month.
|
||||
return self.day
|
||||
|
||||
if not self.td_delta:
|
||||
self.day = self.get_first_trading_day_of_month(dt)
|
||||
else:
|
||||
self.day = self.cal.add_trading_days(
|
||||
self.td_delta,
|
||||
self.get_first_trading_day_of_month(dt),
|
||||
).date()
|
||||
|
||||
return self.day
|
||||
def __init__(self, n):
|
||||
super(NthTradingDayOfMonth, self).__init__(n, invert=False)
|
||||
|
||||
def get_first_trading_day_of_month(self, dt):
|
||||
self.month = dt.month
|
||||
|
||||
dt = dt.replace(day=1)
|
||||
self.first_day = (dt if self.cal.is_open_on_day(dt)
|
||||
else self.cal.next_trading_day(dt)).date()
|
||||
return self.first_day
|
||||
first_day = (dt if self.cal.is_open_on_day(dt)
|
||||
else self.cal.next_trading_day(dt))
|
||||
return first_day
|
||||
|
||||
date_func = get_first_trading_day_of_month
|
||||
|
||||
|
||||
class NDaysBeforeLastTradingDayOfMonth(StatelessRule):
|
||||
class NDaysBeforeLastTradingDayOfMonth(TradingDayOfMonthRule):
|
||||
"""
|
||||
A rule that triggers n days before the last trading day of the month.
|
||||
"""
|
||||
def __init__(self, n=0):
|
||||
if not 0 <= n < MAX_MONTH_RANGE:
|
||||
raise _out_of_range_error(MAX_MONTH_RANGE)
|
||||
self.td_delta = -n
|
||||
self.month = None
|
||||
self.day = None
|
||||
|
||||
def should_trigger(self, dt):
|
||||
return self.get_nth_to_last_trading_day_of_month(dt) == dt.date()
|
||||
|
||||
def get_nth_to_last_trading_day_of_month(self, dt):
|
||||
if self.month == dt.month:
|
||||
# We already computed the last day for this month.
|
||||
return self.day
|
||||
|
||||
if not self.td_delta:
|
||||
self.day = self.get_last_trading_day_of_month(dt)
|
||||
else:
|
||||
self.day = self.cal.add_trading_days(
|
||||
self.td_delta,
|
||||
self.get_last_trading_day_of_month(dt),
|
||||
).date()
|
||||
|
||||
return self.day
|
||||
def __init__(self, n):
|
||||
super(NDaysBeforeLastTradingDayOfMonth, self).__init__(n, invert=True)
|
||||
|
||||
def get_last_trading_day_of_month(self, dt):
|
||||
self.month = dt.month
|
||||
@@ -610,10 +565,12 @@ class NDaysBeforeLastTradingDayOfMonth(StatelessRule):
|
||||
year = dt.year
|
||||
month = dt.month + 1
|
||||
|
||||
self.last_day = self.cal.previous_trading_day(
|
||||
last_day = self.cal.previous_trading_day(
|
||||
dt.replace(year=year, month=month, day=1)
|
||||
).date()
|
||||
return self.last_day
|
||||
)
|
||||
return last_day
|
||||
|
||||
date_func = get_last_trading_day_of_month
|
||||
|
||||
|
||||
# Stateful rules
|
||||
|
||||
Reference in New Issue
Block a user