mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 13:27:48 +08:00
BUG: Don't fail on integral floats in event rules.
Coerce and warn instead.
This commit is contained in:
@@ -16,12 +16,14 @@ import datetime
|
||||
from inspect import isabstract
|
||||
import random
|
||||
from unittest import TestCase
|
||||
import warnings
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
import pandas as pd
|
||||
from six import iteritems
|
||||
from six.moves import range, map
|
||||
|
||||
from zipline.testing import parameter_space
|
||||
import zipline.utils.events
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.events import (
|
||||
@@ -439,6 +441,26 @@ class StatelessRulesTests(RuleTestCase):
|
||||
self.assertIs(composed.second, rule2)
|
||||
self.assertFalse(any(map(should_trigger, minute)))
|
||||
|
||||
@parameterized.expand([
|
||||
('month_start', NthTradingDayOfMonth),
|
||||
('month_end', NDaysBeforeLastTradingDayOfMonth),
|
||||
('week_start', NthTradingDayOfWeek),
|
||||
('week_end', NthTradingDayOfWeek),
|
||||
])
|
||||
def test_pass_float_to_day_of_period_rule(self, name, rule_type):
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter('always')
|
||||
rule_type(n=3) # Shouldn't trigger a warning.
|
||||
rule_type(n=3.0) # Should trigger a warning about float coercion.
|
||||
|
||||
self.assertEqual(len(raised_warnings), 1)
|
||||
warning_str = raised_warnings[0].message.message
|
||||
|
||||
# We only implicitly convert from float to int when there's no loss of
|
||||
# precision.
|
||||
with self.assertRaises(TypeError):
|
||||
rule_type(3.1)
|
||||
|
||||
|
||||
class StatefulRulesTests(RuleTestCase):
|
||||
CALENDAR_STRING = "NYSE"
|
||||
|
||||
@@ -15,13 +15,17 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple
|
||||
import six
|
||||
import warnings
|
||||
|
||||
import datetime
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytz
|
||||
from toolz import curry
|
||||
|
||||
from zipline.utils.input_validation import preprocess
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
from .context_tricks import nop_context
|
||||
|
||||
|
||||
@@ -148,6 +152,31 @@ def _build_time(time, kwargs):
|
||||
return datetime.time(**kwargs)
|
||||
|
||||
|
||||
@curry
|
||||
def lossless_float_to_int(funcname, func, argname, arg):
|
||||
"""
|
||||
A preprocessor that coerces integral floats to ints.
|
||||
|
||||
Receipt of non-integral floats raises a TypeError.
|
||||
"""
|
||||
if not isinstance(arg, float):
|
||||
return arg
|
||||
|
||||
arg_as_int = int(arg)
|
||||
if arg == arg_as_int:
|
||||
warnings.warn(
|
||||
"{f} expected an int for argument {name!r}, but got float {arg}."
|
||||
" Coercing to int.".format(
|
||||
f=funcname,
|
||||
name=argname,
|
||||
arg=arg,
|
||||
),
|
||||
)
|
||||
return arg_as_int
|
||||
|
||||
raise TypeError(arg)
|
||||
|
||||
|
||||
class EventManager(object):
|
||||
"""Manages a list of Event objects.
|
||||
This manages the logic for checking the rules and dispatching to the
|
||||
@@ -402,8 +431,10 @@ class NotHalfDay(StatelessRule):
|
||||
|
||||
|
||||
class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)):
|
||||
@preprocess(n=lossless_float_to_int('TradingDayOfWeekRule'))
|
||||
def __init__(self, n, invert):
|
||||
if not 0 <= n < MAX_WEEK_RANGE:
|
||||
import nose.tools; nose.tools.set_trace()
|
||||
raise _out_of_range_error(MAX_WEEK_RANGE)
|
||||
|
||||
self.td_delta = (-n - 1) if invert else n
|
||||
@@ -443,6 +474,8 @@ class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule):
|
||||
|
||||
|
||||
class TradingDayOfMonthRule(six.with_metaclass(ABCMeta, StatelessRule)):
|
||||
|
||||
@preprocess(n=lossless_float_to_int('TradingDayOfMonthRule'))
|
||||
def __init__(self, n, invert):
|
||||
if not 0 <= n < MAX_MONTH_RANGE:
|
||||
raise _out_of_range_error(MAX_MONTH_RANGE)
|
||||
|
||||
Reference in New Issue
Block a user