mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 02:29:54 +08:00
59c8e371a2
Adds the data bundle concept which makes it easy for users to register loading functions to build out minute and daily data along with an assets db and adjustments db. By default we have provided a `quandl` bundle which pulls from the public domain WIKI dataset. Users may register new bundles by decorating an ingest function with `zipline.data.bundles.register(<name>)`. This also provides a `yahoo_equities` function for creating an ingestion function that will load a static set of assets from yahoo. The cli is now structured as a couple of subcommands and has been changed to `python -m zipline`. The old behavior of `run_algo.py` has been moved to the `run` subcommand. This is almost entirely the same except that it now takes the name of the data bundle to use, defaulting to `quandl`. The next subcommand is `ingest` which takes the name of a data bundle to ingest. This will run the loading machinery and write the data to a specified location that `run` can find. There is also a `clean` subcommand which deletes the data that was written with `ingest`. Extensions have also been added to zipline. This is an experimental feature where users can provide an extra set of python files to run at the start of the process. These can be used to configure aspects of zipline. Right now the only thing that is supported in an extension file is the registration of a new data bundle.
323 lines
8.2 KiB
Python
323 lines
8.2 KiB
Python
from functools import partial
|
|
import inspect
|
|
|
|
from nose.tools import ( # noqa
|
|
assert_almost_equal,
|
|
assert_almost_equals,
|
|
assert_dict_contains_subset,
|
|
assert_false,
|
|
assert_greater,
|
|
assert_greater_equal,
|
|
assert_in,
|
|
assert_is,
|
|
assert_is_instance,
|
|
assert_is_none,
|
|
assert_is_not,
|
|
assert_is_not_none,
|
|
assert_less,
|
|
assert_less_equal,
|
|
assert_multi_line_equal,
|
|
assert_not_almost_equal,
|
|
assert_not_almost_equals,
|
|
assert_not_equal,
|
|
assert_not_equals,
|
|
assert_not_in,
|
|
assert_not_is_instance,
|
|
assert_raises,
|
|
assert_raises_regexp,
|
|
assert_regexp_matches,
|
|
assert_sequence_equal,
|
|
assert_set_equal,
|
|
assert_true,
|
|
assert_tuple_equal,
|
|
)
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pandas.util.testing import assert_frame_equal
|
|
from six import iteritems, viewkeys, PY2
|
|
from toolz import dissoc, keyfilter
|
|
import toolz.curried.operator as op
|
|
|
|
from zipline.dispatch import dispatch
|
|
from zipline.lib.adjustment import Adjustment
|
|
from zipline.utils.functional import dzip_exact
|
|
from zipline.utils.math_utils import tolerant_equals
|
|
|
|
|
|
def keywords(func):
|
|
"""Get the argument names of a function
|
|
|
|
>>> def f(x, y=2):
|
|
... pass
|
|
|
|
>>> keywords(f)
|
|
['x', 'y']
|
|
|
|
Notes
|
|
-----
|
|
Taken from odo.utils
|
|
"""
|
|
if isinstance(func, type):
|
|
return keywords(func.__init__)
|
|
return inspect.getargspec(func).args
|
|
|
|
|
|
def filter_kwargs(f, kwargs):
|
|
"""Return a dict of valid kwargs for `f` from a subset of `kwargs`
|
|
|
|
Examples
|
|
--------
|
|
>>> def f(a, b=1, c=2):
|
|
... return a + b + c
|
|
...
|
|
>>> raw_kwargs = dict(a=1, b=3, d=4)
|
|
>>> f(**raw_kwargs)
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: f() got an unexpected keyword argument 'd'
|
|
>>> kwargs = filter_kwargs(f, raw_kwargs)
|
|
>>> f(**kwargs)
|
|
6
|
|
|
|
Notes
|
|
-----
|
|
Taken from odo.utils
|
|
"""
|
|
return keyfilter(op.contains(keywords(f)), kwargs)
|
|
|
|
|
|
def _s(word, seq, suffix='s'):
|
|
"""Adds a suffix to ``word`` if some sequence has anything other than
|
|
exactly one element.
|
|
|
|
word : str
|
|
The string to add the suffix to.
|
|
seq : sequence
|
|
The sequence to check the length of.
|
|
suffix : str, optional.
|
|
The suffix to add to ``word``
|
|
|
|
Returns
|
|
-------
|
|
maybe_plural : str
|
|
``word`` with ``suffix`` added if ``len(seq) != 1``.
|
|
"""
|
|
return word + (suffix if len(seq) != 1 else '')
|
|
|
|
|
|
def _fmt_path(path):
|
|
"""Format the path for final display.
|
|
|
|
Parameters
|
|
----------
|
|
path : iterable of str
|
|
The path to the values that are not equal.
|
|
|
|
Returns
|
|
-------
|
|
fmtd : str
|
|
The formatted path to put into the error message.
|
|
"""
|
|
if not path:
|
|
return ''
|
|
return 'path: _' + ''.join(path)
|
|
|
|
|
|
def _fmt_msg(msg):
|
|
"""Format the message for final display.
|
|
|
|
Parameters
|
|
----------
|
|
msg : str
|
|
The message to show to the user to provide additional context.
|
|
|
|
returns
|
|
-------
|
|
fmtd : str
|
|
The formatted message to put into the error message.
|
|
"""
|
|
if not msg:
|
|
return ''
|
|
return msg + '\n'
|
|
|
|
|
|
@dispatch(object, object)
|
|
def assert_equal(result, expected, path=(), msg='', **kwargs):
|
|
"""Assert that two objects are equal using the ``==`` operator.
|
|
|
|
Parameters
|
|
----------
|
|
result : object
|
|
The result that came from the function under test.
|
|
expected : object
|
|
The expected result.
|
|
|
|
Raises
|
|
------
|
|
AssertionError
|
|
Raised when ``result`` is not equal to ``expected``.
|
|
"""
|
|
assert result == expected, '%s%s != %s\n%s' % (
|
|
_fmt_msg(msg),
|
|
result,
|
|
expected,
|
|
_fmt_path(path),
|
|
)
|
|
|
|
|
|
@assert_equal.register(float, float)
|
|
def assert_float_equal(result,
|
|
expected,
|
|
path=(),
|
|
msg='',
|
|
float_rtol=10e-7,
|
|
float_atol=10e-7,
|
|
float_equal_nan=True,
|
|
**kwargs):
|
|
assert tolerant_equals(
|
|
result,
|
|
expected,
|
|
rtol=float_rtol,
|
|
atol=float_atol,
|
|
equal_nan=float_equal_nan,
|
|
), '%s%s != %s with rtol=%s and atol=%s%s\n%s' % (
|
|
_fmt_msg(msg),
|
|
result,
|
|
expected,
|
|
float_rtol,
|
|
float_atol,
|
|
(' (with nan != nan)' if not float_equal_nan else ''),
|
|
_fmt_path(path),
|
|
)
|
|
|
|
|
|
@assert_equal.register(dict, dict)
|
|
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
|
|
if path is None:
|
|
path = ()
|
|
|
|
result_keys = viewkeys(result)
|
|
expected_keys = viewkeys(expected)
|
|
if result_keys != expected_keys:
|
|
if result_keys > expected_keys:
|
|
diff = result_keys - expected_keys
|
|
msg = 'extra %s in result: %r' % (_s('key', diff), diff)
|
|
elif result_keys < expected_keys:
|
|
diff = expected_keys - result_keys
|
|
msg = 'result is missing %s: %r' % (_s('key', diff), diff)
|
|
else:
|
|
sym = result_keys ^ expected_keys
|
|
in_result = sym - expected_keys
|
|
in_expected = sym - result_keys
|
|
msg = '%s only in result: %s\n%s only in expected: %s' % (
|
|
_s('key', in_result),
|
|
in_result,
|
|
_s('key', in_expected),
|
|
in_expected,
|
|
)
|
|
raise AssertionError(
|
|
'%sdict keys do not match\n%s' % (
|
|
_fmt_msg(msg),
|
|
_fmt_path(path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),)),
|
|
),
|
|
)
|
|
|
|
failures = []
|
|
for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
|
|
try:
|
|
assert_equal(
|
|
resultv,
|
|
expectedv,
|
|
path=path + ('[%r]' % k,),
|
|
msg=msg,
|
|
**kwargs
|
|
)
|
|
except AssertionError as e:
|
|
failures.append(str(e))
|
|
|
|
if failures:
|
|
raise AssertionError('\n'.join(failures))
|
|
|
|
|
|
@assert_equal.register(list, list) # noqa
|
|
def assert_list_equal(result, expected, path=(), msg='', **kwargs):
|
|
result_len = len(result)
|
|
expected_len = len(expected)
|
|
assert result_len == expected_len, (
|
|
'%slist lengths do not match: %d != %d\n%s' % (
|
|
_fmt_msg(msg),
|
|
result_len,
|
|
expected_len,
|
|
_fmt_path(path),
|
|
)
|
|
)
|
|
|
|
for n, (resultv, expectedv) in enumerate(zip(result, expected)):
|
|
assert_equal(
|
|
resultv,
|
|
expectedv,
|
|
path=path + ('[%d]' % n,),
|
|
msg=msg,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
@assert_equal.register(np.ndarray, np.ndarray)
|
|
def assert_array_equal(result,
|
|
expected,
|
|
path=(),
|
|
msg='',
|
|
array_verbose=True,
|
|
array_decimal=None,
|
|
**kwargs):
|
|
f = (
|
|
np.testing.assert_array_equal
|
|
if array_decimal is None else
|
|
partial(np.testing.assert_array_almost_equal, decimal=array_decimal)
|
|
)
|
|
try:
|
|
f(
|
|
result,
|
|
expected,
|
|
verbose=array_verbose,
|
|
err_msg=msg,
|
|
)
|
|
except AssertionError as e:
|
|
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
|
|
|
|
|
|
@assert_equal.register(pd.DataFrame, pd.DataFrame)
|
|
def assert_dataframe_equal(result, expected, path=(), msg='', **kwargs):
|
|
try:
|
|
assert_frame_equal(
|
|
result,
|
|
expected,
|
|
**filter_kwargs(assert_frame_equal, kwargs)
|
|
)
|
|
except AssertionError as e:
|
|
raise AssertionError(
|
|
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
|
|
)
|
|
|
|
|
|
@assert_equal.register(Adjustment, Adjustment)
|
|
def assert_adjustment_equal(result, expected, path=(), **kwargs):
|
|
for attr in ('first_row', 'last_row', 'first_col', 'last_col', 'value'):
|
|
assert_equal(
|
|
getattr(result, attr),
|
|
getattr(expected, attr),
|
|
path=path + ('.' + attr,),
|
|
**kwargs
|
|
)
|
|
|
|
|
|
try:
|
|
# pull the dshape cases in
|
|
from datashape.util.testing import assert_dshape_equal
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
assert_equal.funcs.update(
|
|
dissoc(assert_dshape_equal.funcs, (object, object)),
|
|
)
|