mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 18:59:19 +08:00
261 lines
6.2 KiB
Python
261 lines
6.2 KiB
Python
"""
|
|
Utilities for working with numpy arrays.
|
|
"""
|
|
from datetime import datetime
|
|
from warnings import (
|
|
catch_warnings,
|
|
filterwarnings,
|
|
)
|
|
|
|
from numpy import (
|
|
broadcast,
|
|
busday_count,
|
|
datetime64,
|
|
dtype,
|
|
empty,
|
|
nan,
|
|
where
|
|
)
|
|
from numpy.lib.stride_tricks import as_strided
|
|
from toolz import flip
|
|
|
|
uint8_dtype = dtype('uint8')
|
|
bool_dtype = dtype('bool')
|
|
|
|
int64_dtype = dtype('int64')
|
|
|
|
float32_dtype = dtype('float32')
|
|
float64_dtype = dtype('float64')
|
|
|
|
complex128_dtype = dtype('complex128')
|
|
|
|
datetime64D_dtype = dtype('datetime64[D]')
|
|
datetime64ns_dtype = dtype('datetime64[ns]')
|
|
|
|
make_datetime64ns = flip(datetime64, 'ns')
|
|
make_datetime64D = flip(datetime64, 'D')
|
|
|
|
NaTmap = {
|
|
dtype('datetime64[%s]' % unit): datetime64('NaT', unit)
|
|
for unit in ('ns', 'us', 'ms', 's', 'm', 'D')
|
|
}
|
|
NaT_for_dtype = NaTmap.__getitem__
|
|
NaTns = NaT_for_dtype(datetime64ns_dtype)
|
|
NaTD = NaT_for_dtype(datetime64D_dtype)
|
|
|
|
|
|
_FILLVALUE_DEFAULTS = {
|
|
bool_dtype: False,
|
|
float32_dtype: nan,
|
|
float64_dtype: nan,
|
|
datetime64ns_dtype: NaTns,
|
|
}
|
|
|
|
|
|
class NoDefaultMissingValue(Exception):
|
|
pass
|
|
|
|
|
|
def make_kind_check(python_types, numpy_kind):
|
|
"""
|
|
Make a function that checks whether a scalar or array is of a given kind
|
|
(e.g. float, int, datetime, timedelta).
|
|
"""
|
|
def check(value):
|
|
if hasattr(value, 'dtype'):
|
|
return value.dtype.kind == numpy_kind
|
|
return isinstance(value, python_types)
|
|
return check
|
|
|
|
|
|
is_float = make_kind_check(float, 'f')
|
|
is_int = make_kind_check(int, 'i')
|
|
is_datetime = make_kind_check(datetime, 'M')
|
|
|
|
|
|
def coerce_to_dtype(dtype, value):
|
|
"""
|
|
Make a value with the specified numpy dtype.
|
|
|
|
Only datetime64[ns] and datetime64[D] are supported for datetime dtypes.
|
|
"""
|
|
name = dtype.name
|
|
if name.startswith('datetime64'):
|
|
if name == 'datetime64[D]':
|
|
return make_datetime64D(value)
|
|
elif name == 'datetime64[ns]':
|
|
return make_datetime64ns(value)
|
|
else:
|
|
raise TypeError(
|
|
"Don't know how to coerce values of dtype %s" % dtype
|
|
)
|
|
return dtype.type(value)
|
|
|
|
|
|
def default_missing_value_for_dtype(dtype):
|
|
"""
|
|
Get the default fill value for `dtype`.
|
|
"""
|
|
try:
|
|
return _FILLVALUE_DEFAULTS[dtype]
|
|
except KeyError:
|
|
raise NoDefaultMissingValue(
|
|
"No default value registered for dtype %s." % dtype
|
|
)
|
|
|
|
|
|
def repeat_first_axis(array, count):
|
|
"""
|
|
Restride `array` to repeat `count` times along the first axis.
|
|
|
|
Parameters
|
|
----------
|
|
array : np.array
|
|
The array to restride.
|
|
count : int
|
|
Number of times to repeat `array`.
|
|
|
|
Returns
|
|
-------
|
|
result : array
|
|
Array of shape (count,) + array.shape, composed of `array` repeated
|
|
`count` times along the first axis.
|
|
|
|
Example
|
|
-------
|
|
>>> from numpy import arange
|
|
>>> a = arange(3); a
|
|
array([0, 1, 2])
|
|
>>> repeat_first_axis(a, 2)
|
|
array([[0, 1, 2],
|
|
[0, 1, 2]])
|
|
>>> repeat_first_axis(a, 4)
|
|
array([[0, 1, 2],
|
|
[0, 1, 2],
|
|
[0, 1, 2],
|
|
[0, 1, 2]])
|
|
|
|
Notes
|
|
----
|
|
The resulting array will share memory with `array`. If you need to assign
|
|
to the input or output, you should probably make a copy first.
|
|
|
|
See Also
|
|
--------
|
|
repeat_last_axis
|
|
"""
|
|
return as_strided(array, (count,) + array.shape, (0,) + array.strides)
|
|
|
|
|
|
def repeat_last_axis(array, count):
|
|
"""
|
|
Restride `array` to repeat `count` times along the last axis.
|
|
|
|
Parameters
|
|
----------
|
|
array : np.array
|
|
The array to restride.
|
|
count : int
|
|
Number of times to repeat `array`.
|
|
|
|
Returns
|
|
-------
|
|
result : array
|
|
Array of shape array.shape + (count,) composed of `array` repeated
|
|
`count` times along the last axis.
|
|
|
|
Example
|
|
-------
|
|
>>> from numpy import arange
|
|
>>> a = arange(3); a
|
|
array([0, 1, 2])
|
|
>>> repeat_last_axis(a, 2)
|
|
array([[0, 0],
|
|
[1, 1],
|
|
[2, 2]])
|
|
>>> repeat_last_axis(a, 4)
|
|
array([[0, 0, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[2, 2, 2, 2]])
|
|
|
|
Notes
|
|
----
|
|
The resulting array will share memory with `array`. If you need to assign
|
|
to the input or output, you should probably make a copy first.
|
|
|
|
See Also
|
|
--------
|
|
repeat_last_axis
|
|
"""
|
|
return as_strided(array, array.shape + (count,), array.strides + (0,))
|
|
|
|
|
|
# Sentinel value that isn't NaT.
|
|
_notNaT = make_datetime64D(0)
|
|
|
|
|
|
def busday_count_mask_NaT(begindates,
|
|
enddates,
|
|
out=None):
|
|
"""
|
|
Simple of numpy.busday_count that returns `float` arrays rather than int
|
|
arrays, and handles `NaT`s by returning `NaN`s where the inputs were `NaT`.
|
|
|
|
Doesn't support custom weekdays or calendars, but probably should in the
|
|
future.
|
|
|
|
See Also
|
|
--------
|
|
np.busday_count
|
|
"""
|
|
if out is None:
|
|
out = empty(broadcast(begindates, enddates).shape, dtype=float)
|
|
|
|
beginmask = (begindates == NaTD)
|
|
endmask = (enddates == NaTD)
|
|
|
|
out = busday_count(
|
|
# Temporarily fill in non-NaT values.
|
|
where(beginmask, _notNaT, begindates),
|
|
where(endmask, _notNaT, enddates),
|
|
out=out,
|
|
)
|
|
|
|
# Fill in entries where either comparison was NaT with nan in the output.
|
|
out[beginmask | endmask] = nan
|
|
return out
|
|
|
|
|
|
class WarningContext(object):
|
|
"""
|
|
Re-usable contextmanager for contextually managing warnings.
|
|
"""
|
|
def __init__(self, *warning_specs):
|
|
self._warning_specs = warning_specs
|
|
self._catchers = []
|
|
|
|
def __enter__(self):
|
|
catcher = catch_warnings()
|
|
catcher.__enter__()
|
|
self._catchers.append(catcher)
|
|
for args, kwargs in self._warning_specs:
|
|
filterwarnings(*args, **kwargs)
|
|
return self
|
|
|
|
def __exit__(self, *exc_info):
|
|
catcher = self._catchers.pop()
|
|
return catcher.__exit__(*exc_info)
|
|
|
|
|
|
def ignore_nanwarnings():
|
|
"""
|
|
Helper for building a WarningContext that ignores warnings from numpy's
|
|
nanfunctions.
|
|
"""
|
|
return WarningContext(
|
|
(
|
|
('ignore',),
|
|
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
|
|
)
|
|
)
|