diff --git a/tests/pipeline/test_factor.py b/tests/pipeline/test_factor.py index d43dfc7a..53de9839 100644 --- a/tests/pipeline/test_factor.py +++ b/tests/pipeline/test_factor.py @@ -1,10 +1,15 @@ """ Tests for Factor terms. """ -from numpy import array, eye, nan, ones +from nose_parameterized import parameterized + +from numpy import arange, array, empty, eye, nan, ones, datetime64 +from numpy.random import randn, seed + from zipline.errors import UnknownRankMethod from zipline.pipeline import Factor, Filter, TermGraph -from zipline.utils.test_utils import check_arrays +from zipline.pipeline.factors import RSI +from zipline.utils.test_utils import check_allclose, check_arrays from .base import BasePipelineTestCase @@ -190,3 +195,30 @@ class FactorTestCase(BasePipelineTestCase): ) for method in results: check_arrays(expected[method], results[method]) + + @parameterized.expand([ + # Test cases computed by doing: + # from numpy.random import seed, randn + # from talib import RSI + # seed(seed_value) + # data = abs(randn(15, 3)) + # expected = [RSI(data[:, i])[-1] for i in range(3)] + (100, array([41.032913785966, 51.553585468393, 51.022005016446])), + (101, array([43.506969935466, 46.145367530182, 50.57407044197])), + (102, array([46.610102205934, 47.646892444315, 52.13182788538])), + ]) + def test_rsi(self, seed_value, expected): + + rsi = RSI() + + today = datetime64(1, 'ns') + assets = arange(3) + out = empty((3,), dtype=float) + + seed(seed_value) # Seed so we get deterministic results. + test_data = abs(randn(15, 3)) + + out = empty((3,), dtype=float) + rsi.compute(today, assets, out, test_data) + + check_allclose(expected, out) diff --git a/zipline/pipeline/factors/technical.py b/zipline/pipeline/factors/technical.py index 32d2d7a2..ba58dd58 100644 --- a/zipline/pipeline/factors/technical.py +++ b/zipline/pipeline/factors/technical.py @@ -9,6 +9,7 @@ from bottleneck import ( nansum, ) from numpy import ( + abs, clip, diff, fmax, @@ -30,19 +31,19 @@ class RSI(CustomFactor, SingleInputMixin): **Default Inputs**: [USEquityPricing.close] - **Default Window Length**: 14 + **Default Window Length**: 15 """ - window_length = 14 + window_length = 15 inputs = (USEquityPricing.close,) def compute(self, today, assets, out, closes): - diffs = diff(closes) + diffs = diff(closes, axis=0) ups = nanmean(clip(diffs, 0, inf), axis=0) - downs = nanmean(clip(diffs, -inf, 0), axis=0) + downs = abs(nanmean(clip(diffs, -inf, 0), axis=0)) return evaluate( "100 - (100 / (1 + (ups / downs)))", - locals_dict={'ups': ups, 'downs': downs}, - globals_dict={}, + local_dict={'ups': ups, 'downs': downs}, + global_dict={}, out=out, ) diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index e444ff09..40b164a6 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -4,7 +4,7 @@ from itertools import ( ) from logbook import FileHandler from mock import patch -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal import operator from zipline.finance.blotter import ORDER_STATUS from zipline.utils import security_list @@ -310,18 +310,37 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None): ) -def check_arrays(left, right, err_msg='', verbose=True): +def check_allclose(actual, + desired, + rtol=1e-07, + atol=0, + err_msg='', + verbose=True): """ - Wrapper around np.assert_array_equal that also verifies that inputs are - ndarrays. + Wrapper around np.testing.assert_allclose that also verifies that inputs + are ndarrays. + + See Also + -------- + np.assert_allclose + """ + if type(actual) != type(desired): + raise AssertionError("%s != %s" % (type(actual), type(desired))) + return assert_allclose(actual, desired, err_msg=err_msg, verbose=True) + + +def check_arrays(x, y, err_msg='', verbose=True): + """ + Wrapper around np.testing.assert_array_equal that also verifies that inputs + are ndarrays. See Also -------- np.assert_array_equal """ - if type(left) != type(right): - raise AssertionError("%s != %s" % (type(left), type(right))) - return assert_array_equal(left, right, err_msg=err_msg, verbose=True) + if type(x) != type(y): + raise AssertionError("%s != %s" % (type(x), type(y))) + return assert_array_equal(x, y, err_msg=err_msg, verbose=True) class UnexpectedAttributeAccess(Exception):