mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 14:46:29 +08:00
ENH: Add utilities for checking types generically.
This commit is contained in:
@@ -11,6 +11,7 @@ from zipline.pipeline import (
|
||||
from zipline.utils import (
|
||||
cache,
|
||||
data,
|
||||
functional,
|
||||
input_validation,
|
||||
memoize,
|
||||
numpy_utils,
|
||||
@@ -82,3 +83,6 @@ class DoctestTestCase(TestCase):
|
||||
|
||||
def test_data_docs(self):
|
||||
self._check_docs(data)
|
||||
|
||||
def test_functional_docs(self):
|
||||
self._check_docs(functional)
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Tests for zipline.utils.numpy_utils.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from six import itervalues
|
||||
from unittest import TestCase
|
||||
|
||||
from numpy import (
|
||||
array,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
)
|
||||
from pandas import Timestamp
|
||||
from toolz import concat, keyfilter
|
||||
from toolz import curry
|
||||
from toolz.curried.operator import ne
|
||||
|
||||
from zipline.utils.functional import mapall as lazy_mapall
|
||||
from zipline.utils.numpy_utils import (
|
||||
is_float,
|
||||
is_int,
|
||||
is_datetime,
|
||||
make_datetime64D,
|
||||
make_datetime64ns,
|
||||
NaTns,
|
||||
NaTD,
|
||||
)
|
||||
|
||||
|
||||
def mapall(*args):
|
||||
"Strict version of mapall."
|
||||
return list(lazy_mapall(*args))
|
||||
|
||||
|
||||
@curry
|
||||
def make_array(dtype, value):
|
||||
return array([value], dtype=dtype)
|
||||
|
||||
|
||||
CASES = {
|
||||
int: mapall(
|
||||
(int, int16, int32, int64, make_array(int)),
|
||||
[0, 1, -1]
|
||||
),
|
||||
float: mapall(
|
||||
(float16, float32, float64, float, make_array(float)),
|
||||
[0., 1., -1., float('nan'), float('inf'), -float('inf')],
|
||||
),
|
||||
datetime: mapall(
|
||||
(
|
||||
make_datetime64D,
|
||||
make_datetime64ns,
|
||||
Timestamp,
|
||||
make_array('datetime64[ns]'),
|
||||
),
|
||||
[0, 1, 2],
|
||||
) + [NaTD, NaTns],
|
||||
}
|
||||
|
||||
|
||||
def everything_but(k, d):
|
||||
"""
|
||||
Return iterator of all values in d except the values in k.
|
||||
"""
|
||||
assert k in d
|
||||
return concat(itervalues(keyfilter(ne(k), d)))
|
||||
|
||||
|
||||
class TypeCheckTestCase(TestCase):
|
||||
|
||||
def test_is_float(self):
|
||||
for good_value in CASES[float]:
|
||||
self.assertTrue(is_float(good_value))
|
||||
|
||||
for bad_value in everything_but(float, CASES):
|
||||
self.assertFalse(is_float(bad_value))
|
||||
|
||||
def test_is_int(self):
|
||||
for good_value in CASES[int]:
|
||||
self.assertTrue(is_int(good_value))
|
||||
|
||||
for bad_value in everything_but(int, CASES):
|
||||
self.assertFalse(is_int(bad_value))
|
||||
|
||||
def test_is_datetime(self):
|
||||
for good_value in CASES[datetime]:
|
||||
self.assertTrue(is_datetime(good_value))
|
||||
|
||||
for bad_value in everything_but(datetime, CASES):
|
||||
self.assertFalse(is_datetime(bad_value))
|
||||
@@ -0,0 +1,22 @@
|
||||
def mapall(funcs, seq):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
funcs : iterable[function]
|
||||
Sequence of functions to map over `seq`.
|
||||
seq : iterable
|
||||
Sequence over which to map funcs.
|
||||
|
||||
Yields
|
||||
------
|
||||
elem : object
|
||||
Concatenated result of mapping each ``func`` over ``seq``.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> list(mapall([lambda x: x + 1, lambda x: x - 1], [1, 2, 3]))
|
||||
[2, 3, 4, 0, 1, 2]
|
||||
"""
|
||||
for func in funcs:
|
||||
for elem in seq:
|
||||
yield func(elem)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Utilities for working with numpy arrays.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from numpy import (
|
||||
broadcast,
|
||||
busday_count,
|
||||
@@ -50,6 +51,42 @@ class NoDefaultMissingValue(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def make_kind_check(python_types, numpy_kind):
|
||||
"""
|
||||
Make a function that checks whether a scalar or array is of a given kind
|
||||
(e.g. float, int, datetime, timedelta).
|
||||
"""
|
||||
def check(value):
|
||||
if hasattr(value, 'dtype'):
|
||||
return value.dtype.kind == numpy_kind
|
||||
return isinstance(value, python_types)
|
||||
return check
|
||||
|
||||
|
||||
is_float = make_kind_check(float, 'f')
|
||||
is_int = make_kind_check(int, 'i')
|
||||
is_datetime = make_kind_check(datetime, 'M')
|
||||
|
||||
|
||||
def coerce_to_dtype(dtype, value):
|
||||
"""
|
||||
Make a value with the specified numpy dtype.
|
||||
|
||||
Only datetime64[ns] and datetime64[D] are supported for datetime dtypes.
|
||||
"""
|
||||
name = dtype.name
|
||||
if name.startswith('datetime64'):
|
||||
if name == 'datetime64[D]':
|
||||
return make_datetime64D(value)
|
||||
elif name == 'datetime64[ns]':
|
||||
return make_datetime64ns(value)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Don't know how to coerce values of dtype %s" % dtype
|
||||
)
|
||||
return dtype.type(value)
|
||||
|
||||
|
||||
def default_missing_value_for_dtype(dtype):
|
||||
"""
|
||||
Get the default fill value for `dtype`.
|
||||
|
||||
Reference in New Issue
Block a user