From 2522ca28ae8e49f7e04b5f211f768cf3bcebbc04 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Wed, 7 Sep 2016 21:51:19 -0400 Subject: [PATCH] BUG: Don't fail on integral floats in event rules. Coerce and warn instead. --- tests/events/test_events.py | 22 ++++++++++++++++++++++ zipline/utils/events.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/tests/events/test_events.py b/tests/events/test_events.py index e7c6ca8e..467db3c8 100644 --- a/tests/events/test_events.py +++ b/tests/events/test_events.py @@ -16,12 +16,14 @@ import datetime from inspect import isabstract import random from unittest import TestCase +import warnings from nose_parameterized import parameterized import pandas as pd from six import iteritems from six.moves import range, map +from zipline.testing import parameter_space import zipline.utils.events from zipline.utils.calendars import get_calendar from zipline.utils.events import ( @@ -439,6 +441,26 @@ class StatelessRulesTests(RuleTestCase): self.assertIs(composed.second, rule2) self.assertFalse(any(map(should_trigger, minute))) + @parameterized.expand([ + ('month_start', NthTradingDayOfMonth), + ('month_end', NDaysBeforeLastTradingDayOfMonth), + ('week_start', NthTradingDayOfWeek), + ('week_end', NthTradingDayOfWeek), + ]) + def test_pass_float_to_day_of_period_rule(self, name, rule_type): + with warnings.catch_warnings(record=True) as raised_warnings: + warnings.simplefilter('always') + rule_type(n=3) # Shouldn't trigger a warning. + rule_type(n=3.0) # Should trigger a warning about float coercion. + + self.assertEqual(len(raised_warnings), 1) + warning_str = raised_warnings[0].message.message + + # We only implicitly convert from float to int when there's no loss of + # precision. + with self.assertRaises(TypeError): + rule_type(3.1) + class StatefulRulesTests(RuleTestCase): CALENDAR_STRING = "NYSE" diff --git a/zipline/utils/events.py b/zipline/utils/events.py index cbdfda54..2882d17d 100644 --- a/zipline/utils/events.py +++ b/zipline/utils/events.py @@ -15,13 +15,17 @@ from abc import ABCMeta, abstractmethod from collections import namedtuple import six +import warnings import datetime import numpy as np import pandas as pd import pytz +from toolz import curry +from zipline.utils.input_validation import preprocess from zipline.utils.memoize import lazyval + from .context_tricks import nop_context @@ -148,6 +152,31 @@ def _build_time(time, kwargs): return datetime.time(**kwargs) +@curry +def lossless_float_to_int(funcname, func, argname, arg): + """ + A preprocessor that coerces integral floats to ints. + + Receipt of non-integral floats raises a TypeError. + """ + if not isinstance(arg, float): + return arg + + arg_as_int = int(arg) + if arg == arg_as_int: + warnings.warn( + "{f} expected an int for argument {name!r}, but got float {arg}." + " Coercing to int.".format( + f=funcname, + name=argname, + arg=arg, + ), + ) + return arg_as_int + + raise TypeError(arg) + + class EventManager(object): """Manages a list of Event objects. This manages the logic for checking the rules and dispatching to the @@ -402,8 +431,10 @@ class NotHalfDay(StatelessRule): class TradingDayOfWeekRule(six.with_metaclass(ABCMeta, StatelessRule)): + @preprocess(n=lossless_float_to_int('TradingDayOfWeekRule')) def __init__(self, n, invert): if not 0 <= n < MAX_WEEK_RANGE: + import nose.tools; nose.tools.set_trace() raise _out_of_range_error(MAX_WEEK_RANGE) self.td_delta = (-n - 1) if invert else n @@ -443,6 +474,8 @@ class NDaysBeforeLastTradingDayOfWeek(TradingDayOfWeekRule): class TradingDayOfMonthRule(six.with_metaclass(ABCMeta, StatelessRule)): + + @preprocess(n=lossless_float_to_int('TradingDayOfMonthRule')) def __init__(self, n, invert): if not 0 <= n < MAX_MONTH_RANGE: raise _out_of_range_error(MAX_MONTH_RANGE)