Files
catalyst/zipline/testing/predicates.py
T
Joe Jevnik 59c8e371a2 ENH: Updates the cli, data bundles and extensions.
Adds the data bundle concept which makes it easy for users to register
loading functions to build out minute and daily data along with an
assets db and adjustments db. By default we have provided a `quandl`
bundle which pulls from the public domain WIKI dataset. Users may
register new bundles by decorating an ingest function with
`zipline.data.bundles.register(<name>)`. This also provides a
`yahoo_equities` function for creating an ingestion function that will
load a static set of assets from yahoo.

The cli is now structured as a couple of subcommands and has been
changed to `python -m zipline`. The old behavior of `run_algo.py` has
been moved to the `run` subcommand. This is almost entirely the same
except that it now takes the name of the data bundle to use, defaulting
to `quandl`.

The next subcommand is `ingest` which takes the name of
a data bundle to ingest. This will run the loading machinery and write
the data to a specified location that `run` can find.

There is also a `clean` subcommand which deletes the data that was
written with `ingest`.

Extensions have also been added to zipline. This is an experimental
feature where users can provide an extra set of python files to run at
the start of the process. These can be used to configure aspects of
zipline. Right now the only thing that is supported in an extension file
is the registration of a new data bundle.
2016-05-03 18:38:24 -04:00

323 lines
8.2 KiB
Python

from functools import partial
import inspect
from nose.tools import ( # noqa
assert_almost_equal,
assert_almost_equals,
assert_dict_contains_subset,
assert_false,
assert_greater,
assert_greater_equal,
assert_in,
assert_is,
assert_is_instance,
assert_is_none,
assert_is_not,
assert_is_not_none,
assert_less,
assert_less_equal,
assert_multi_line_equal,
assert_not_almost_equal,
assert_not_almost_equals,
assert_not_equal,
assert_not_equals,
assert_not_in,
assert_not_is_instance,
assert_raises,
assert_raises_regexp,
assert_regexp_matches,
assert_sequence_equal,
assert_set_equal,
assert_true,
assert_tuple_equal,
)
import numpy as np
import pandas as pd
from pandas.util.testing import assert_frame_equal
from six import iteritems, viewkeys, PY2
from toolz import dissoc, keyfilter
import toolz.curried.operator as op
from zipline.dispatch import dispatch
from zipline.lib.adjustment import Adjustment
from zipline.utils.functional import dzip_exact
from zipline.utils.math_utils import tolerant_equals
def keywords(func):
"""Get the argument names of a function
>>> def f(x, y=2):
... pass
>>> keywords(f)
['x', 'y']
Notes
-----
Taken from odo.utils
"""
if isinstance(func, type):
return keywords(func.__init__)
return inspect.getargspec(func).args
def filter_kwargs(f, kwargs):
"""Return a dict of valid kwargs for `f` from a subset of `kwargs`
Examples
--------
>>> def f(a, b=1, c=2):
... return a + b + c
...
>>> raw_kwargs = dict(a=1, b=3, d=4)
>>> f(**raw_kwargs)
Traceback (most recent call last):
...
TypeError: f() got an unexpected keyword argument 'd'
>>> kwargs = filter_kwargs(f, raw_kwargs)
>>> f(**kwargs)
6
Notes
-----
Taken from odo.utils
"""
return keyfilter(op.contains(keywords(f)), kwargs)
def _s(word, seq, suffix='s'):
"""Adds a suffix to ``word`` if some sequence has anything other than
exactly one element.
word : str
The string to add the suffix to.
seq : sequence
The sequence to check the length of.
suffix : str, optional.
The suffix to add to ``word``
Returns
-------
maybe_plural : str
``word`` with ``suffix`` added if ``len(seq) != 1``.
"""
return word + (suffix if len(seq) != 1 else '')
def _fmt_path(path):
"""Format the path for final display.
Parameters
----------
path : iterable of str
The path to the values that are not equal.
Returns
-------
fmtd : str
The formatted path to put into the error message.
"""
if not path:
return ''
return 'path: _' + ''.join(path)
def _fmt_msg(msg):
"""Format the message for final display.
Parameters
----------
msg : str
The message to show to the user to provide additional context.
returns
-------
fmtd : str
The formatted message to put into the error message.
"""
if not msg:
return ''
return msg + '\n'
@dispatch(object, object)
def assert_equal(result, expected, path=(), msg='', **kwargs):
"""Assert that two objects are equal using the ``==`` operator.
Parameters
----------
result : object
The result that came from the function under test.
expected : object
The expected result.
Raises
------
AssertionError
Raised when ``result`` is not equal to ``expected``.
"""
assert result == expected, '%s%s != %s\n%s' % (
_fmt_msg(msg),
result,
expected,
_fmt_path(path),
)
@assert_equal.register(float, float)
def assert_float_equal(result,
expected,
path=(),
msg='',
float_rtol=10e-7,
float_atol=10e-7,
float_equal_nan=True,
**kwargs):
assert tolerant_equals(
result,
expected,
rtol=float_rtol,
atol=float_atol,
equal_nan=float_equal_nan,
), '%s%s != %s with rtol=%s and atol=%s%s\n%s' % (
_fmt_msg(msg),
result,
expected,
float_rtol,
float_atol,
(' (with nan != nan)' if not float_equal_nan else ''),
_fmt_path(path),
)
@assert_equal.register(dict, dict)
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
if path is None:
path = ()
result_keys = viewkeys(result)
expected_keys = viewkeys(expected)
if result_keys != expected_keys:
if result_keys > expected_keys:
diff = result_keys - expected_keys
msg = 'extra %s in result: %r' % (_s('key', diff), diff)
elif result_keys < expected_keys:
diff = expected_keys - result_keys
msg = 'result is missing %s: %r' % (_s('key', diff), diff)
else:
sym = result_keys ^ expected_keys
in_result = sym - expected_keys
in_expected = sym - result_keys
msg = '%s only in result: %s\n%s only in expected: %s' % (
_s('key', in_result),
in_result,
_s('key', in_expected),
in_expected,
)
raise AssertionError(
'%sdict keys do not match\n%s' % (
_fmt_msg(msg),
_fmt_path(path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),)),
),
)
failures = []
for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
try:
assert_equal(
resultv,
expectedv,
path=path + ('[%r]' % k,),
msg=msg,
**kwargs
)
except AssertionError as e:
failures.append(str(e))
if failures:
raise AssertionError('\n'.join(failures))
@assert_equal.register(list, list) # noqa
def assert_list_equal(result, expected, path=(), msg='', **kwargs):
result_len = len(result)
expected_len = len(expected)
assert result_len == expected_len, (
'%slist lengths do not match: %d != %d\n%s' % (
_fmt_msg(msg),
result_len,
expected_len,
_fmt_path(path),
)
)
for n, (resultv, expectedv) in enumerate(zip(result, expected)):
assert_equal(
resultv,
expectedv,
path=path + ('[%d]' % n,),
msg=msg,
**kwargs
)
@assert_equal.register(np.ndarray, np.ndarray)
def assert_array_equal(result,
expected,
path=(),
msg='',
array_verbose=True,
array_decimal=None,
**kwargs):
f = (
np.testing.assert_array_equal
if array_decimal is None else
partial(np.testing.assert_array_almost_equal, decimal=array_decimal)
)
try:
f(
result,
expected,
verbose=array_verbose,
err_msg=msg,
)
except AssertionError as e:
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
@assert_equal.register(pd.DataFrame, pd.DataFrame)
def assert_dataframe_equal(result, expected, path=(), msg='', **kwargs):
try:
assert_frame_equal(
result,
expected,
**filter_kwargs(assert_frame_equal, kwargs)
)
except AssertionError as e:
raise AssertionError(
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
)
@assert_equal.register(Adjustment, Adjustment)
def assert_adjustment_equal(result, expected, path=(), **kwargs):
for attr in ('first_row', 'last_row', 'first_col', 'last_col', 'value'):
assert_equal(
getattr(result, attr),
getattr(expected, attr),
path=path + ('.' + attr,),
**kwargs
)
try:
# pull the dshape cases in
from datashape.util.testing import assert_dshape_equal
except ImportError:
pass
else:
assert_equal.funcs.update(
dissoc(assert_dshape_equal.funcs, (object, object)),
)