Files
catalyst/zipline/testing/predicates.py
T
2016-08-10 15:32:36 -04:00

537 lines
13 KiB
Python

from contextlib import contextmanager
import datetime
from functools import partial
import inspect
import re
from nose.tools import ( # noqa
assert_almost_equal,
assert_almost_equals,
assert_dict_contains_subset,
assert_false,
assert_greater,
assert_greater_equal,
assert_in,
assert_is,
assert_is_instance,
assert_is_none,
assert_is_not,
assert_is_not_none,
assert_less,
assert_less_equal,
assert_multi_line_equal,
assert_not_almost_equal,
assert_not_almost_equals,
assert_not_equal,
assert_not_equals,
assert_not_in,
assert_not_is_instance,
assert_raises,
assert_raises_regexp,
assert_regexp_matches,
assert_sequence_equal,
assert_true,
assert_tuple_equal,
)
import numpy as np
import pandas as pd
from pandas.util.testing import (
assert_frame_equal,
assert_panel_equal,
assert_series_equal,
)
from six import iteritems, viewkeys, PY2
from toolz import dissoc, keyfilter
import toolz.curried.operator as op
from zipline.testing.core import ensure_doctest
from zipline.dispatch import dispatch
from zipline.lib.adjustment import Adjustment
from zipline.utils.functional import dzip_exact, instance
from zipline.utils.math_utils import tolerant_equals
@instance
@ensure_doctest
class wildcard(object):
"""An object that compares equal to any other object.
This is useful when using :func:`~zipline.testing.predicates.assert_equal`
with a large recursive structure and some fields to be ignored.
Examples
--------
>>> wildcard == 5
True
>>> wildcard == 'ayy'
True
# reflected
>>> 5 == wildcard
True
>>> 'ayy' == wildcard
True
"""
@staticmethod
def __eq__(other):
return True
@staticmethod
def __ne__(other):
return False
def __repr__(self):
return '<%s>' % type(self).__name__
__str__ = __repr__
def keywords(func):
"""Get the argument names of a function
>>> def f(x, y=2):
... pass
>>> keywords(f)
['x', 'y']
Notes
-----
Taken from odo.utils
"""
if isinstance(func, type):
return keywords(func.__init__)
return inspect.getargspec(func).args
def filter_kwargs(f, kwargs):
"""Return a dict of valid kwargs for `f` from a subset of `kwargs`
Examples
--------
>>> def f(a, b=1, c=2):
... return a + b + c
...
>>> raw_kwargs = dict(a=1, b=3, d=4)
>>> f(**raw_kwargs)
Traceback (most recent call last):
...
TypeError: f() got an unexpected keyword argument 'd'
>>> kwargs = filter_kwargs(f, raw_kwargs)
>>> f(**kwargs)
6
Notes
-----
Taken from odo.utils
"""
return keyfilter(op.contains(keywords(f)), kwargs)
def _s(word, seq, suffix='s'):
"""Adds a suffix to ``word`` if some sequence has anything other than
exactly one element.
word : str
The string to add the suffix to.
seq : sequence
The sequence to check the length of.
suffix : str, optional.
The suffix to add to ``word``
Returns
-------
maybe_plural : str
``word`` with ``suffix`` added if ``len(seq) != 1``.
"""
return word + (suffix if len(seq) != 1 else '')
def _fmt_path(path):
"""Format the path for final display.
Parameters
----------
path : iterable of str
The path to the values that are not equal.
Returns
-------
fmtd : str
The formatted path to put into the error message.
"""
if not path:
return ''
return 'path: _' + ''.join(path)
def _fmt_msg(msg):
"""Format the message for final display.
Parameters
----------
msg : str
The message to show to the user to provide additional context.
returns
-------
fmtd : str
The formatted message to put into the error message.
"""
if not msg:
return ''
return msg + '\n'
def _safe_cls_name(cls):
try:
return cls.__name__
except AttributeError:
return repr(cls)
def assert_is_subclass(subcls, cls, msg=''):
"""Assert that ``subcls`` is a subclass of ``cls``.
Parameters
----------
subcls : type
The type to check.
cls : type
The type to check ``subcls`` against.
msg : str, optional
An extra assertion message to print if this fails.
"""
assert issubclass(subcls, cls), (
'%s is not a subclass of %s\n%s' % (
_safe_cls_name(subcls),
_safe_cls_name(cls),
msg,
)
)
def assert_regex(result, expected, msg=''):
"""Assert that ``expected`` matches the result.
Parameters
----------
result : str
The string to search.
expected : str or compiled regex
The pattern to search for in ``result``.
msg : str, optional
An extra assertion message to print if this fails.
"""
assert re.search(expected, result), (
'%s%r not found in %r' % (_fmt_msg(msg), expected, result)
)
@contextmanager
def assert_raises_regex(exc, pattern, msg=''):
"""Assert that some exception is raised in a context and that the message
matches some pattern.
Parameters
----------
exc : type or tuple[type]
The exception type or types to expect.
pattern : str or compiled regex
The pattern to search for in the str of the raised exception.
msg : str, optional
An extra assertion message to print if this fails.
"""
try:
yield
except exc as e:
assert re.search(pattern, str(e)), (
'%s%r not found in %r' % (_fmt_msg(msg), pattern, str(e))
)
else:
raise AssertionError('%s%s was not raised' % (_fmt_msg(msg), exc))
@dispatch(object, object)
def assert_equal(result, expected, path=(), msg='', **kwargs):
"""Assert that two objects are equal using the ``==`` operator.
Parameters
----------
result : object
The result that came from the function under test.
expected : object
The expected result.
Raises
------
AssertionError
Raised when ``result`` is not equal to ``expected``.
"""
assert result == expected, '%s%s != %s\n%s' % (
_fmt_msg(msg),
result,
expected,
_fmt_path(path),
)
@assert_equal.register(float, float)
def assert_float_equal(result,
expected,
path=(),
msg='',
float_rtol=10e-7,
float_atol=10e-7,
float_equal_nan=True,
**kwargs):
assert tolerant_equals(
result,
expected,
rtol=float_rtol,
atol=float_atol,
equal_nan=float_equal_nan,
), '%s%s != %s with rtol=%s and atol=%s%s\n%s' % (
_fmt_msg(msg),
result,
expected,
float_rtol,
float_atol,
(' (with nan != nan)' if not float_equal_nan else ''),
_fmt_path(path),
)
def _check_sets(result, expected, msg, path, type_):
"""Compare two sets. This is used to check dictionary keys and sets.
Parameters
----------
result : set
expected : set
msg : str
path : tuple
type : str
The type of an element. For dict we use ``'key'`` and for set we use
``'element'``.
"""
if result != expected:
if result > expected:
diff = result - expected
msg = 'extra %s in result: %r' % (_s(type_, diff), diff)
elif result < expected:
diff = expected - result
msg = 'result is missing %s: %r' % (_s(type_, diff), diff)
else:
in_result = result - expected
in_expected = expected - result
msg = '%s only in result: %s\n%s only in expected: %s' % (
_s(type_, in_result),
in_result,
_s(type_, in_expected),
in_expected,
)
raise AssertionError(
'%s%ss do not match\n%s' % (
_fmt_msg(msg),
type_,
_fmt_path(path),
),
)
@assert_equal.register(dict, dict)
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
_check_sets(
viewkeys(result),
viewkeys(expected),
msg,
path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),),
'key',
)
failures = []
for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
try:
assert_equal(
resultv,
expectedv,
path=path + ('[%r]' % k,),
msg=msg,
**kwargs
)
except AssertionError as e:
failures.append(str(e))
if failures:
raise AssertionError('\n'.join(failures))
@assert_equal.register(list, list)
def assert_list_equal(result, expected, path=(), msg='', **kwargs):
result_len = len(result)
expected_len = len(expected)
assert result_len == expected_len, (
'%slist lengths do not match: %d != %d\n%s' % (
_fmt_msg(msg),
result_len,
expected_len,
_fmt_path(path),
)
)
for n, (resultv, expectedv) in enumerate(zip(result, expected)):
assert_equal(
resultv,
expectedv,
path=path + ('[%d]' % n,),
msg=msg,
**kwargs
)
@assert_equal.register(set, set)
def assert_set_equal(result, expected, path=(), msg='', **kwargs):
_check_sets(
result,
expected,
msg,
path,
'element',
)
@assert_equal.register(np.ndarray, np.ndarray)
def assert_array_equal(result,
expected,
path=(),
msg='',
array_verbose=True,
array_decimal=None,
**kwargs):
f = (
np.testing.assert_array_equal
if array_decimal is None else
partial(np.testing.assert_array_almost_equal, decimal=array_decimal)
)
try:
f(
result,
expected,
verbose=array_verbose,
err_msg=msg,
)
except AssertionError as e:
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
def _register_assert_ndframe_equal(type_, assert_eq):
"""Register a new check for an ndframe object.
Parameters
----------
type_ : type
The class to register an ``assert_equal`` dispatch for.
assert_eq : callable[type_, type_]
The function which checks that if the two ndframes are equal.
Returns
-------
assert_ndframe_equal : callable[type_, type_]
The wrapped function registered with ``assert_equal``.
"""
@assert_equal.register(type_, type_)
def assert_ndframe_equal(result, expected, path=(), msg='', **kwargs):
try:
assert_eq(
result,
expected,
**filter_kwargs(assert_frame_equal, kwargs)
)
except AssertionError as e:
raise AssertionError(
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
)
return assert_ndframe_equal
assert_frame_equal = _register_assert_ndframe_equal(
pd.DataFrame,
assert_frame_equal,
)
assert_panel_equal = _register_assert_ndframe_equal(
pd.Panel,
assert_panel_equal,
)
assert_series_equal = _register_assert_ndframe_equal(
pd.Series,
assert_series_equal,
)
@assert_equal.register(Adjustment, Adjustment)
def assert_adjustment_equal(result, expected, path=(), **kwargs):
for attr in ('first_row', 'last_row', 'first_col', 'last_col', 'value'):
assert_equal(
getattr(result, attr),
getattr(expected, attr),
path=path + ('.' + attr,),
**kwargs
)
@assert_equal.register(
(datetime.datetime, np.datetime64),
(datetime.datetime, np.datetime64),
)
def assert_timestamp_and_datetime_equal(result,
expected,
path=(),
msg='',
allow_datetime_coercions=False,
compare_nat_equal=True,
**kwargs):
"""
Branch for comparing python datetime (which includes pandas Timestamp) and
np.datetime64 as equal.
Returns raises unless ``allow_datetime_coercions`` is passed as True.
"""
assert allow_datetime_coercions or type(result) == type(expected), (
"%sdatetime types (%s, %s) don't match and "
"allow_datetime_coercions was not set.\n%s" % (
_fmt_msg(msg),
type(result),
type(expected),
_fmt_path(path),
)
)
result = pd.Timestamp(result)
expected = pd.Timestamp(result)
if compare_nat_equal and pd.isnull(result) and pd.isnull(expected):
return
assert_equal.dispatch(object, object)(
result,
expected,
path=path,
**kwargs
)
def assert_isidentical(result, expected, msg=''):
assert result.isidentical(expected), (
'%s%s is not identical to %s' % (_fmt_msg(msg), result, expected)
)
try:
# pull the dshape cases in
from datashape.util.testing import assert_dshape_equal
except ImportError:
pass
else:
assert_equal.funcs.update(
dissoc(assert_dshape_equal.funcs, (object, object)),
)