Files
catalyst/zipline/utils/numpy_utils.py
T
2016-03-25 15:11:18 -04:00

261 lines
6.2 KiB
Python

"""
Utilities for working with numpy arrays.
"""
from datetime import datetime
from warnings import (
catch_warnings,
filterwarnings,
)
from numpy import (
broadcast,
busday_count,
datetime64,
dtype,
empty,
nan,
where
)
from numpy.lib.stride_tricks import as_strided
from toolz import flip
uint8_dtype = dtype('uint8')
bool_dtype = dtype('bool')
int64_dtype = dtype('int64')
float32_dtype = dtype('float32')
float64_dtype = dtype('float64')
complex128_dtype = dtype('complex128')
datetime64D_dtype = dtype('datetime64[D]')
datetime64ns_dtype = dtype('datetime64[ns]')
make_datetime64ns = flip(datetime64, 'ns')
make_datetime64D = flip(datetime64, 'D')
NaTmap = {
dtype('datetime64[%s]' % unit): datetime64('NaT', unit)
for unit in ('ns', 'us', 'ms', 's', 'm', 'D')
}
NaT_for_dtype = NaTmap.__getitem__
NaTns = NaT_for_dtype(datetime64ns_dtype)
NaTD = NaT_for_dtype(datetime64D_dtype)
_FILLVALUE_DEFAULTS = {
bool_dtype: False,
float32_dtype: nan,
float64_dtype: nan,
datetime64ns_dtype: NaTns,
}
class NoDefaultMissingValue(Exception):
pass
def make_kind_check(python_types, numpy_kind):
"""
Make a function that checks whether a scalar or array is of a given kind
(e.g. float, int, datetime, timedelta).
"""
def check(value):
if hasattr(value, 'dtype'):
return value.dtype.kind == numpy_kind
return isinstance(value, python_types)
return check
is_float = make_kind_check(float, 'f')
is_int = make_kind_check(int, 'i')
is_datetime = make_kind_check(datetime, 'M')
def coerce_to_dtype(dtype, value):
"""
Make a value with the specified numpy dtype.
Only datetime64[ns] and datetime64[D] are supported for datetime dtypes.
"""
name = dtype.name
if name.startswith('datetime64'):
if name == 'datetime64[D]':
return make_datetime64D(value)
elif name == 'datetime64[ns]':
return make_datetime64ns(value)
else:
raise TypeError(
"Don't know how to coerce values of dtype %s" % dtype
)
return dtype.type(value)
def default_missing_value_for_dtype(dtype):
"""
Get the default fill value for `dtype`.
"""
try:
return _FILLVALUE_DEFAULTS[dtype]
except KeyError:
raise NoDefaultMissingValue(
"No default value registered for dtype %s." % dtype
)
def repeat_first_axis(array, count):
"""
Restride `array` to repeat `count` times along the first axis.
Parameters
----------
array : np.array
The array to restride.
count : int
Number of times to repeat `array`.
Returns
-------
result : array
Array of shape (count,) + array.shape, composed of `array` repeated
`count` times along the first axis.
Example
-------
>>> from numpy import arange
>>> a = arange(3); a
array([0, 1, 2])
>>> repeat_first_axis(a, 2)
array([[0, 1, 2],
[0, 1, 2]])
>>> repeat_first_axis(a, 4)
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
Notes
----
The resulting array will share memory with `array`. If you need to assign
to the input or output, you should probably make a copy first.
See Also
--------
repeat_last_axis
"""
return as_strided(array, (count,) + array.shape, (0,) + array.strides)
def repeat_last_axis(array, count):
"""
Restride `array` to repeat `count` times along the last axis.
Parameters
----------
array : np.array
The array to restride.
count : int
Number of times to repeat `array`.
Returns
-------
result : array
Array of shape array.shape + (count,) composed of `array` repeated
`count` times along the last axis.
Example
-------
>>> from numpy import arange
>>> a = arange(3); a
array([0, 1, 2])
>>> repeat_last_axis(a, 2)
array([[0, 0],
[1, 1],
[2, 2]])
>>> repeat_last_axis(a, 4)
array([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]])
Notes
----
The resulting array will share memory with `array`. If you need to assign
to the input or output, you should probably make a copy first.
See Also
--------
repeat_last_axis
"""
return as_strided(array, array.shape + (count,), array.strides + (0,))
# Sentinel value that isn't NaT.
_notNaT = make_datetime64D(0)
def busday_count_mask_NaT(begindates,
enddates,
out=None):
"""
Simple of numpy.busday_count that returns `float` arrays rather than int
arrays, and handles `NaT`s by returning `NaN`s where the inputs were `NaT`.
Doesn't support custom weekdays or calendars, but probably should in the
future.
See Also
--------
np.busday_count
"""
if out is None:
out = empty(broadcast(begindates, enddates).shape, dtype=float)
beginmask = (begindates == NaTD)
endmask = (enddates == NaTD)
out = busday_count(
# Temporarily fill in non-NaT values.
where(beginmask, _notNaT, begindates),
where(endmask, _notNaT, enddates),
out=out,
)
# Fill in entries where either comparison was NaT with nan in the output.
out[beginmask | endmask] = nan
return out
class WarningContext(object):
"""
Re-usable contextmanager for contextually managing warnings.
"""
def __init__(self, *warning_specs):
self._warning_specs = warning_specs
self._catchers = []
def __enter__(self):
catcher = catch_warnings()
catcher.__enter__()
self._catchers.append(catcher)
for args, kwargs in self._warning_specs:
filterwarnings(*args, **kwargs)
return self
def __exit__(self, *exc_info):
catcher = self._catchers.pop()
return catcher.__exit__(*exc_info)
def ignore_nanwarnings():
"""
Helper for building a WarningContext that ignores warnings from numpy's
nanfunctions.
"""
return WarningContext(
(
('ignore',),
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
)
)