diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 7cd064e3..1a34141e 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -1,5 +1,7 @@ from contextlib import contextmanager +from functools import wraps from itertools import ( + count, product, ) import operator @@ -12,9 +14,7 @@ from logbook import FileHandler from mock import patch from numpy.testing import assert_allclose, assert_array_equal import pandas as pd -from six import ( - itervalues, -) +from six import itervalues from six.moves import filter from sqlalchemy import create_engine @@ -402,3 +402,86 @@ class tmp_asset_finder(tmp_assets_db): """ def __enter__(self): return AssetFinder(super(tmp_asset_finder, self).__enter__()) + + +class SubTestFailures(AssertionError): + def __init__(self, *failures): + self.failures = failures + + def __str__(self): + return 'failures:\n %s' % '\n '.join( + '\n '.join(( + ', '.join('%s=%r' % item for item in scope.items()), + '%s: %s' % (type(exc).__name__, exc), + )) for scope, exc in self.failures, + ) + + +def subtest(iterator, *_names): + """Construct a subtest in a unittest. + + This works by decorating a function as a subtest. The test will be run + by iterating over the ``iterator`` and *unpacking the values into the + function. If any of the runs fail, the result will be put into a set and + the rest of the tests will be run. Finally, if any failed, all of the + results will be dumped as one failure. + + Paramaters + ---------- + iterator : iterable[iterable] + The iterator of arguments to pass to the function. + *name : iterator[str] + The names to use for each element of ``iterator``. These will be used + to print the scope when a test fails. If not provided, it will use the + integer index of the value as the name. + + Examples + -------- + + :: + + class MyTest(TestCase): + def test_thing(self): + # Example usage inside another test. + @subtest(([n] for n in range(100000)), 'n') + def subtest(n): + self.assertEqual(n % 2, 0, 'n what not even') + subtest() + + @subtest(([n] for n in range(100000)), 'n') + def test_decorated_function(self, n): + # Example usage to paramaterize an entire function. + self.assertEqual(n % 2, 1, 'n what not odd') + + Notes + ----- + We use this when we: + + * Will never want to run each parameter individually. + * Have a large parameter space we are testing + (see tests/utils/test_events.py). + + ``nose_paramaterized.expand`` will create a test for each parameter + combination which bloats the test output and makes the travis pages slow. + + We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and + nose2 do not support ``addSubTest``. + """ + def dec(f): + @wraps(f) + def wrapped(*args, **kwargs): + names = _names + failures = [] + for scope in iterator: + scope = tuple(scope) + try: + f(*args + scope, **kwargs) + except Exception as e: + if not names: + names = count() + failures.append((dict(zip(names, scope)), e)) + if failures: + raise SubTestFailures(*failures) + + return wrapped + return dec