DEV: Add parameter_space test decorator.

This commit is contained in:
Scott Sanderson
2016-01-20 22:03:35 -05:00
parent 9c448b5238
commit 94c02c710b
2 changed files with 82 additions and 0 deletions
+31
View File
@@ -0,0 +1,31 @@
"""
Tests for our testing utilities.
"""
from itertools import product
from unittest import TestCase
from zipline.utils.test_utils import parameter_space
class TestParameterSpace(TestCase):
x_args = [1, 2]
y_args = [3, 4]
@classmethod
def setUpClass(cls):
cls.xy_invocations = []
@classmethod
def tearDownClass(cls):
# This is the only actual test here.
assert cls.xy_invocations == list(product(cls.x_args, cls.y_args))
@parameter_space(x=x_args, y=y_args)
def test_xy(self, x, y):
self.xy_invocations.append((x, y))
def test_nothing(self):
# Ensure that there's at least one "real" test in the class, or else
# our {setUp,tearDown}Class won't be called if, for example,
# `parameter_space` returns None.
pass
+51
View File
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from functools import wraps
from inspect import getargspec
from itertools import (
combinations,
count,
@@ -722,3 +723,53 @@ def temp_pipeline_engine(calendar, sids, random_seed, symbols=None):
with tmp_asset_finder(equities=equity_info) as finder:
yield SimplePipelineEngine(get_loader, calendar, finder)
def parameter_space(**params):
"""
Wrapper around subtest that allows passing keywords mapping names to
iterables of values.
The decorated test function will be called with the cross-product of all
possible inputs
Usage
-----
>>> class SomeTestCase(TestCase):
... @parameter_space(x=[1, 2], y=[2, 3])
... def test_some_func(self, x, y):
... # Will be called with every possible combination of x and y.
... self.assertEqual(somefunc(x, y), expected_result(x, y))
"""
def decorator(f):
argspec = getargspec(f)
if argspec.varargs:
raise AssertionError("parameter_space() doesn't support *args")
if argspec.keywords:
raise AssertionError("parameter_space() doesn't support **kwargs")
if argspec.defaults:
raise AssertionError("parameter_space() doesn't support defaults.")
# Skip over implicit self.
argnames = argspec.args
if argnames[0] == 'self':
argnames = argnames[1:]
extra = set(params) - set(argnames)
if extra:
raise AssertionError(
"Keywords %s supplied to parameter_space() are "
"not in function signature." % extra
)
unspecified = set(argnames) - set(params)
if unspecified:
raise AssertionError(
"Function arguments %s were not "
"supplied to parameter_space()." % extra
)
param_sets = product(*(params[name] for name in argnames))
return subtest(param_sets, *argnames)(f)
return decorator