From e810f26097fa4735ce32a9b84e2d905b4261937b Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Mon, 7 Mar 2016 15:17:03 -0500 Subject: [PATCH] ENH: Add utilities for checking types generically. --- tests/test_doctests.py | 4 ++ tests/utils/test_numpy_utils.py | 94 +++++++++++++++++++++++++++++++++ zipline/utils/functional.py | 22 ++++++++ zipline/utils/numpy_utils.py | 37 +++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 tests/utils/test_numpy_utils.py create mode 100644 zipline/utils/functional.py diff --git a/tests/test_doctests.py b/tests/test_doctests.py index c08659af..89d7f3cf 100644 --- a/tests/test_doctests.py +++ b/tests/test_doctests.py @@ -11,6 +11,7 @@ from zipline.pipeline import ( from zipline.utils import ( cache, data, + functional, input_validation, memoize, numpy_utils, @@ -82,3 +83,6 @@ class DoctestTestCase(TestCase): def test_data_docs(self): self._check_docs(data) + + def test_functional_docs(self): + self._check_docs(functional) diff --git a/tests/utils/test_numpy_utils.py b/tests/utils/test_numpy_utils.py new file mode 100644 index 00000000..6d7fe65e --- /dev/null +++ b/tests/utils/test_numpy_utils.py @@ -0,0 +1,94 @@ +""" +Tests for zipline.utils.numpy_utils. +""" +from datetime import datetime +from six import itervalues +from unittest import TestCase + +from numpy import ( + array, + float16, + float32, + float64, + int16, + int32, + int64, +) +from pandas import Timestamp +from toolz import concat, keyfilter +from toolz import curry +from toolz.curried.operator import ne + +from zipline.utils.functional import mapall as lazy_mapall +from zipline.utils.numpy_utils import ( + is_float, + is_int, + is_datetime, + make_datetime64D, + make_datetime64ns, + NaTns, + NaTD, +) + + +def mapall(*args): + "Strict version of mapall." + return list(lazy_mapall(*args)) + + +@curry +def make_array(dtype, value): + return array([value], dtype=dtype) + + +CASES = { + int: mapall( + (int, int16, int32, int64, make_array(int)), + [0, 1, -1] + ), + float: mapall( + (float16, float32, float64, float, make_array(float)), + [0., 1., -1., float('nan'), float('inf'), -float('inf')], + ), + datetime: mapall( + ( + make_datetime64D, + make_datetime64ns, + Timestamp, + make_array('datetime64[ns]'), + ), + [0, 1, 2], + ) + [NaTD, NaTns], +} + + +def everything_but(k, d): + """ + Return iterator of all values in d except the values in k. + """ + assert k in d + return concat(itervalues(keyfilter(ne(k), d))) + + +class TypeCheckTestCase(TestCase): + + def test_is_float(self): + for good_value in CASES[float]: + self.assertTrue(is_float(good_value)) + + for bad_value in everything_but(float, CASES): + self.assertFalse(is_float(bad_value)) + + def test_is_int(self): + for good_value in CASES[int]: + self.assertTrue(is_int(good_value)) + + for bad_value in everything_but(int, CASES): + self.assertFalse(is_int(bad_value)) + + def test_is_datetime(self): + for good_value in CASES[datetime]: + self.assertTrue(is_datetime(good_value)) + + for bad_value in everything_but(datetime, CASES): + self.assertFalse(is_datetime(bad_value)) diff --git a/zipline/utils/functional.py b/zipline/utils/functional.py new file mode 100644 index 00000000..420cd604 --- /dev/null +++ b/zipline/utils/functional.py @@ -0,0 +1,22 @@ +def mapall(funcs, seq): + """ + Parameters + ---------- + funcs : iterable[function] + Sequence of functions to map over `seq`. + seq : iterable + Sequence over which to map funcs. + + Yields + ------ + elem : object + Concatenated result of mapping each ``func`` over ``seq``. + + Example + ------- + >>> list(mapall([lambda x: x + 1, lambda x: x - 1], [1, 2, 3])) + [2, 3, 4, 0, 1, 2] + """ + for func in funcs: + for elem in seq: + yield func(elem) diff --git a/zipline/utils/numpy_utils.py b/zipline/utils/numpy_utils.py index 56315cf7..79bca175 100644 --- a/zipline/utils/numpy_utils.py +++ b/zipline/utils/numpy_utils.py @@ -1,6 +1,7 @@ """ Utilities for working with numpy arrays. """ +from datetime import datetime from numpy import ( broadcast, busday_count, @@ -50,6 +51,42 @@ class NoDefaultMissingValue(Exception): pass +def make_kind_check(python_types, numpy_kind): + """ + Make a function that checks whether a scalar or array is of a given kind + (e.g. float, int, datetime, timedelta). + """ + def check(value): + if hasattr(value, 'dtype'): + return value.dtype.kind == numpy_kind + return isinstance(value, python_types) + return check + + +is_float = make_kind_check(float, 'f') +is_int = make_kind_check(int, 'i') +is_datetime = make_kind_check(datetime, 'M') + + +def coerce_to_dtype(dtype, value): + """ + Make a value with the specified numpy dtype. + + Only datetime64[ns] and datetime64[D] are supported for datetime dtypes. + """ + name = dtype.name + if name.startswith('datetime64'): + if name == 'datetime64[D]': + return make_datetime64D(value) + elif name == 'datetime64[ns]': + return make_datetime64ns(value) + else: + raise TypeError( + "Don't know how to coerce values of dtype %s" % dtype + ) + return dtype.type(value) + + def default_missing_value_for_dtype(dtype): """ Get the default fill value for `dtype`.