mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 00:37:43 +08:00
537 lines
13 KiB
Python
537 lines
13 KiB
Python
from contextlib import contextmanager
|
|
import datetime
|
|
from functools import partial
|
|
import inspect
|
|
import re
|
|
|
|
from nose.tools import ( # noqa
|
|
assert_almost_equal,
|
|
assert_almost_equals,
|
|
assert_dict_contains_subset,
|
|
assert_false,
|
|
assert_greater,
|
|
assert_greater_equal,
|
|
assert_in,
|
|
assert_is,
|
|
assert_is_instance,
|
|
assert_is_none,
|
|
assert_is_not,
|
|
assert_is_not_none,
|
|
assert_less,
|
|
assert_less_equal,
|
|
assert_multi_line_equal,
|
|
assert_not_almost_equal,
|
|
assert_not_almost_equals,
|
|
assert_not_equal,
|
|
assert_not_equals,
|
|
assert_not_in,
|
|
assert_not_is_instance,
|
|
assert_raises,
|
|
assert_raises_regexp,
|
|
assert_regexp_matches,
|
|
assert_sequence_equal,
|
|
assert_true,
|
|
assert_tuple_equal,
|
|
)
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pandas.util.testing import (
|
|
assert_frame_equal,
|
|
assert_panel_equal,
|
|
assert_series_equal,
|
|
)
|
|
from six import iteritems, viewkeys, PY2
|
|
from toolz import dissoc, keyfilter
|
|
import toolz.curried.operator as op
|
|
|
|
from zipline.testing.core import ensure_doctest
|
|
from zipline.dispatch import dispatch
|
|
from zipline.lib.adjustment import Adjustment
|
|
from zipline.utils.functional import dzip_exact, instance
|
|
from zipline.utils.math_utils import tolerant_equals
|
|
|
|
|
|
@instance
|
|
@ensure_doctest
|
|
class wildcard(object):
|
|
"""An object that compares equal to any other object.
|
|
|
|
This is useful when using :func:`~zipline.testing.predicates.assert_equal`
|
|
with a large recursive structure and some fields to be ignored.
|
|
|
|
Examples
|
|
--------
|
|
>>> wildcard == 5
|
|
True
|
|
>>> wildcard == 'ayy'
|
|
True
|
|
|
|
# reflected
|
|
>>> 5 == wildcard
|
|
True
|
|
>>> 'ayy' == wildcard
|
|
True
|
|
"""
|
|
@staticmethod
|
|
def __eq__(other):
|
|
return True
|
|
|
|
@staticmethod
|
|
def __ne__(other):
|
|
return False
|
|
|
|
def __repr__(self):
|
|
return '<%s>' % type(self).__name__
|
|
__str__ = __repr__
|
|
|
|
|
|
def keywords(func):
|
|
"""Get the argument names of a function
|
|
|
|
>>> def f(x, y=2):
|
|
... pass
|
|
|
|
>>> keywords(f)
|
|
['x', 'y']
|
|
|
|
Notes
|
|
-----
|
|
Taken from odo.utils
|
|
"""
|
|
if isinstance(func, type):
|
|
return keywords(func.__init__)
|
|
return inspect.getargspec(func).args
|
|
|
|
|
|
def filter_kwargs(f, kwargs):
|
|
"""Return a dict of valid kwargs for `f` from a subset of `kwargs`
|
|
|
|
Examples
|
|
--------
|
|
>>> def f(a, b=1, c=2):
|
|
... return a + b + c
|
|
...
|
|
>>> raw_kwargs = dict(a=1, b=3, d=4)
|
|
>>> f(**raw_kwargs)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: f() got an unexpected keyword argument 'd'
|
|
>>> kwargs = filter_kwargs(f, raw_kwargs)
|
|
>>> f(**kwargs)
|
|
6
|
|
|
|
Notes
|
|
-----
|
|
Taken from odo.utils
|
|
"""
|
|
return keyfilter(op.contains(keywords(f)), kwargs)
|
|
|
|
|
|
def _s(word, seq, suffix='s'):
|
|
"""Adds a suffix to ``word`` if some sequence has anything other than
|
|
exactly one element.
|
|
|
|
word : str
|
|
The string to add the suffix to.
|
|
seq : sequence
|
|
The sequence to check the length of.
|
|
suffix : str, optional.
|
|
The suffix to add to ``word``
|
|
|
|
Returns
|
|
-------
|
|
maybe_plural : str
|
|
``word`` with ``suffix`` added if ``len(seq) != 1``.
|
|
"""
|
|
return word + (suffix if len(seq) != 1 else '')
|
|
|
|
|
|
def _fmt_path(path):
|
|
"""Format the path for final display.
|
|
|
|
Parameters
|
|
----------
|
|
path : iterable of str
|
|
The path to the values that are not equal.
|
|
|
|
Returns
|
|
-------
|
|
fmtd : str
|
|
The formatted path to put into the error message.
|
|
"""
|
|
if not path:
|
|
return ''
|
|
return 'path: _' + ''.join(path)
|
|
|
|
|
|
def _fmt_msg(msg):
|
|
"""Format the message for final display.
|
|
|
|
Parameters
|
|
----------
|
|
msg : str
|
|
The message to show to the user to provide additional context.
|
|
|
|
returns
|
|
-------
|
|
fmtd : str
|
|
The formatted message to put into the error message.
|
|
"""
|
|
if not msg:
|
|
return ''
|
|
return msg + '\n'
|
|
|
|
|
|
def _safe_cls_name(cls):
|
|
try:
|
|
return cls.__name__
|
|
except AttributeError:
|
|
return repr(cls)
|
|
|
|
|
|
def assert_is_subclass(subcls, cls, msg=''):
|
|
"""Assert that ``subcls`` is a subclass of ``cls``.
|
|
|
|
Parameters
|
|
----------
|
|
subcls : type
|
|
The type to check.
|
|
cls : type
|
|
The type to check ``subcls`` against.
|
|
msg : str, optional
|
|
An extra assertion message to print if this fails.
|
|
"""
|
|
assert issubclass(subcls, cls), (
|
|
'%s is not a subclass of %s\n%s' % (
|
|
_safe_cls_name(subcls),
|
|
_safe_cls_name(cls),
|
|
msg,
|
|
)
|
|
)
|
|
|
|
|
|
def assert_regex(result, expected, msg=''):
|
|
"""Assert that ``expected`` matches the result.
|
|
|
|
Parameters
|
|
----------
|
|
result : str
|
|
The string to search.
|
|
expected : str or compiled regex
|
|
The pattern to search for in ``result``.
|
|
msg : str, optional
|
|
An extra assertion message to print if this fails.
|
|
"""
|
|
assert re.search(expected, result), (
|
|
'%s%r not found in %r' % (_fmt_msg(msg), expected, result)
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def assert_raises_regex(exc, pattern, msg=''):
|
|
"""Assert that some exception is raised in a context and that the message
|
|
matches some pattern.
|
|
|
|
Parameters
|
|
----------
|
|
exc : type or tuple[type]
|
|
The exception type or types to expect.
|
|
pattern : str or compiled regex
|
|
The pattern to search for in the str of the raised exception.
|
|
msg : str, optional
|
|
An extra assertion message to print if this fails.
|
|
"""
|
|
try:
|
|
yield
|
|
except exc as e:
|
|
assert re.search(pattern, str(e)), (
|
|
'%s%r not found in %r' % (_fmt_msg(msg), pattern, str(e))
|
|
)
|
|
else:
|
|
raise AssertionError('%s%s was not raised' % (_fmt_msg(msg), exc))
|
|
|
|
|
|
@dispatch(object, object)
|
|
def assert_equal(result, expected, path=(), msg='', **kwargs):
|
|
"""Assert that two objects are equal using the ``==`` operator.
|
|
|
|
Parameters
|
|
----------
|
|
result : object
|
|
The result that came from the function under test.
|
|
expected : object
|
|
The expected result.
|
|
|
|
Raises
|
|
------
|
|
AssertionError
|
|
Raised when ``result`` is not equal to ``expected``.
|
|
"""
|
|
assert result == expected, '%s%s != %s\n%s' % (
|
|
_fmt_msg(msg),
|
|
result,
|
|
expected,
|
|
_fmt_path(path),
|
|
)
|
|
|
|
|
|
@assert_equal.register(float, float)
|
|
def assert_float_equal(result,
|
|
expected,
|
|
path=(),
|
|
msg='',
|
|
float_rtol=10e-7,
|
|
float_atol=10e-7,
|
|
float_equal_nan=True,
|
|
**kwargs):
|
|
assert tolerant_equals(
|
|
result,
|
|
expected,
|
|
rtol=float_rtol,
|
|
atol=float_atol,
|
|
equal_nan=float_equal_nan,
|
|
), '%s%s != %s with rtol=%s and atol=%s%s\n%s' % (
|
|
_fmt_msg(msg),
|
|
result,
|
|
expected,
|
|
float_rtol,
|
|
float_atol,
|
|
(' (with nan != nan)' if not float_equal_nan else ''),
|
|
_fmt_path(path),
|
|
)
|
|
|
|
|
|
def _check_sets(result, expected, msg, path, type_):
|
|
"""Compare two sets. This is used to check dictionary keys and sets.
|
|
|
|
Parameters
|
|
----------
|
|
result : set
|
|
expected : set
|
|
msg : str
|
|
path : tuple
|
|
type : str
|
|
The type of an element. For dict we use ``'key'`` and for set we use
|
|
``'element'``.
|
|
"""
|
|
if result != expected:
|
|
if result > expected:
|
|
diff = result - expected
|
|
msg = 'extra %s in result: %r' % (_s(type_, diff), diff)
|
|
elif result < expected:
|
|
diff = expected - result
|
|
msg = 'result is missing %s: %r' % (_s(type_, diff), diff)
|
|
else:
|
|
in_result = result - expected
|
|
in_expected = expected - result
|
|
msg = '%s only in result: %s\n%s only in expected: %s' % (
|
|
_s(type_, in_result),
|
|
in_result,
|
|
_s(type_, in_expected),
|
|
in_expected,
|
|
)
|
|
raise AssertionError(
|
|
'%s%ss do not match\n%s' % (
|
|
_fmt_msg(msg),
|
|
type_,
|
|
_fmt_path(path),
|
|
),
|
|
)
|
|
|
|
|
|
@assert_equal.register(dict, dict)
|
|
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
|
|
_check_sets(
|
|
viewkeys(result),
|
|
viewkeys(expected),
|
|
msg,
|
|
path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),),
|
|
'key',
|
|
)
|
|
|
|
failures = []
|
|
for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
|
|
try:
|
|
assert_equal(
|
|
resultv,
|
|
expectedv,
|
|
path=path + ('[%r]' % k,),
|
|
msg=msg,
|
|
**kwargs
|
|
)
|
|
except AssertionError as e:
|
|
failures.append(str(e))
|
|
|
|
if failures:
|
|
raise AssertionError('\n'.join(failures))
|
|
|
|
|
|
@assert_equal.register(list, list)
|
|
def assert_list_equal(result, expected, path=(), msg='', **kwargs):
|
|
result_len = len(result)
|
|
expected_len = len(expected)
|
|
assert result_len == expected_len, (
|
|
'%slist lengths do not match: %d != %d\n%s' % (
|
|
_fmt_msg(msg),
|
|
result_len,
|
|
expected_len,
|
|
_fmt_path(path),
|
|
)
|
|
)
|
|
for n, (resultv, expectedv) in enumerate(zip(result, expected)):
|
|
assert_equal(
|
|
resultv,
|
|
expectedv,
|
|
path=path + ('[%d]' % n,),
|
|
msg=msg,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
@assert_equal.register(set, set)
|
|
def assert_set_equal(result, expected, path=(), msg='', **kwargs):
|
|
_check_sets(
|
|
result,
|
|
expected,
|
|
msg,
|
|
path,
|
|
'element',
|
|
)
|
|
|
|
|
|
@assert_equal.register(np.ndarray, np.ndarray)
|
|
def assert_array_equal(result,
|
|
expected,
|
|
path=(),
|
|
msg='',
|
|
array_verbose=True,
|
|
array_decimal=None,
|
|
**kwargs):
|
|
f = (
|
|
np.testing.assert_array_equal
|
|
if array_decimal is None else
|
|
partial(np.testing.assert_array_almost_equal, decimal=array_decimal)
|
|
)
|
|
try:
|
|
f(
|
|
result,
|
|
expected,
|
|
verbose=array_verbose,
|
|
err_msg=msg,
|
|
)
|
|
except AssertionError as e:
|
|
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
|
|
|
|
|
|
def _register_assert_ndframe_equal(type_, assert_eq):
|
|
"""Register a new check for an ndframe object.
|
|
|
|
Parameters
|
|
----------
|
|
type_ : type
|
|
The class to register an ``assert_equal`` dispatch for.
|
|
assert_eq : callable[type_, type_]
|
|
The function which checks that if the two ndframes are equal.
|
|
|
|
Returns
|
|
-------
|
|
assert_ndframe_equal : callable[type_, type_]
|
|
The wrapped function registered with ``assert_equal``.
|
|
"""
|
|
@assert_equal.register(type_, type_)
|
|
def assert_ndframe_equal(result, expected, path=(), msg='', **kwargs):
|
|
try:
|
|
assert_eq(
|
|
result,
|
|
expected,
|
|
**filter_kwargs(assert_frame_equal, kwargs)
|
|
)
|
|
except AssertionError as e:
|
|
raise AssertionError(
|
|
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
|
|
)
|
|
|
|
return assert_ndframe_equal
|
|
|
|
|
|
assert_frame_equal = _register_assert_ndframe_equal(
|
|
pd.DataFrame,
|
|
assert_frame_equal,
|
|
)
|
|
assert_panel_equal = _register_assert_ndframe_equal(
|
|
pd.Panel,
|
|
assert_panel_equal,
|
|
)
|
|
assert_series_equal = _register_assert_ndframe_equal(
|
|
pd.Series,
|
|
assert_series_equal,
|
|
)
|
|
|
|
|
|
@assert_equal.register(Adjustment, Adjustment)
|
|
def assert_adjustment_equal(result, expected, path=(), **kwargs):
|
|
for attr in ('first_row', 'last_row', 'first_col', 'last_col', 'value'):
|
|
assert_equal(
|
|
getattr(result, attr),
|
|
getattr(expected, attr),
|
|
path=path + ('.' + attr,),
|
|
**kwargs
|
|
)
|
|
|
|
|
|
@assert_equal.register(
|
|
(datetime.datetime, np.datetime64),
|
|
(datetime.datetime, np.datetime64),
|
|
)
|
|
def assert_timestamp_and_datetime_equal(result,
|
|
expected,
|
|
path=(),
|
|
msg='',
|
|
allow_datetime_coercions=False,
|
|
compare_nat_equal=True,
|
|
**kwargs):
|
|
"""
|
|
Branch for comparing python datetime (which includes pandas Timestamp) and
|
|
np.datetime64 as equal.
|
|
|
|
Returns raises unless ``allow_datetime_coercions`` is passed as True.
|
|
"""
|
|
assert allow_datetime_coercions or type(result) == type(expected), (
|
|
"%sdatetime types (%s, %s) don't match and "
|
|
"allow_datetime_coercions was not set.\n%s" % (
|
|
_fmt_msg(msg),
|
|
type(result),
|
|
type(expected),
|
|
_fmt_path(path),
|
|
)
|
|
)
|
|
|
|
result = pd.Timestamp(result)
|
|
expected = pd.Timestamp(result)
|
|
if compare_nat_equal and pd.isnull(result) and pd.isnull(expected):
|
|
return
|
|
|
|
assert_equal.dispatch(object, object)(
|
|
result,
|
|
expected,
|
|
path=path,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
def assert_isidentical(result, expected, msg=''):
|
|
assert result.isidentical(expected), (
|
|
'%s%s is not identical to %s' % (_fmt_msg(msg), result, expected)
|
|
)
|
|
|
|
|
|
try:
|
|
# pull the dshape cases in
|
|
from datashape.util.testing import assert_dshape_equal
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
assert_equal.funcs.update(
|
|
dissoc(assert_dshape_equal.funcs, (object, object)),
|
|
)
|