BUG: RSI wasn't even close to working.

Fixed and added tests.
This commit is contained in:
Scott Sanderson
2015-10-09 20:10:30 -04:00
parent 4a9cd76dab
commit 1336dfc181
3 changed files with 67 additions and 15 deletions
+34 -2
View File
@@ -1,10 +1,15 @@
"""
Tests for Factor terms.
"""
from numpy import array, eye, nan, ones
from nose_parameterized import parameterized
from numpy import arange, array, empty, eye, nan, ones, datetime64
from numpy.random import randn, seed
from zipline.errors import UnknownRankMethod
from zipline.pipeline import Factor, Filter, TermGraph
from zipline.utils.test_utils import check_arrays
from zipline.pipeline.factors import RSI
from zipline.utils.test_utils import check_allclose, check_arrays
from .base import BasePipelineTestCase
@@ -190,3 +195,30 @@ class FactorTestCase(BasePipelineTestCase):
)
for method in results:
check_arrays(expected[method], results[method])
@parameterized.expand([
# Test cases computed by doing:
# from numpy.random import seed, randn
# from talib import RSI
# seed(seed_value)
# data = abs(randn(15, 3))
# expected = [RSI(data[:, i])[-1] for i in range(3)]
(100, array([41.032913785966, 51.553585468393, 51.022005016446])),
(101, array([43.506969935466, 46.145367530182, 50.57407044197])),
(102, array([46.610102205934, 47.646892444315, 52.13182788538])),
])
def test_rsi(self, seed_value, expected):
rsi = RSI()
today = datetime64(1, 'ns')
assets = arange(3)
out = empty((3,), dtype=float)
seed(seed_value) # Seed so we get deterministic results.
test_data = abs(randn(15, 3))
out = empty((3,), dtype=float)
rsi.compute(today, assets, out, test_data)
check_allclose(expected, out)
+7 -6
View File
@@ -9,6 +9,7 @@ from bottleneck import (
nansum,
)
from numpy import (
abs,
clip,
diff,
fmax,
@@ -30,19 +31,19 @@ class RSI(CustomFactor, SingleInputMixin):
**Default Inputs**: [USEquityPricing.close]
**Default Window Length**: 14
**Default Window Length**: 15
"""
window_length = 14
window_length = 15
inputs = (USEquityPricing.close,)
def compute(self, today, assets, out, closes):
diffs = diff(closes)
diffs = diff(closes, axis=0)
ups = nanmean(clip(diffs, 0, inf), axis=0)
downs = nanmean(clip(diffs, -inf, 0), axis=0)
downs = abs(nanmean(clip(diffs, -inf, 0), axis=0))
return evaluate(
"100 - (100 / (1 + (ups / downs)))",
locals_dict={'ups': ups, 'downs': downs},
globals_dict={},
local_dict={'ups': ups, 'downs': downs},
global_dict={},
out=out,
)
+26 -7
View File
@@ -4,7 +4,7 @@ from itertools import (
)
from logbook import FileHandler
from mock import patch
from numpy.testing import assert_array_equal
from numpy.testing import assert_allclose, assert_array_equal
import operator
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils import security_list
@@ -310,18 +310,37 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None):
)
def check_arrays(left, right, err_msg='', verbose=True):
def check_allclose(actual,
desired,
rtol=1e-07,
atol=0,
err_msg='',
verbose=True):
"""
Wrapper around np.assert_array_equal that also verifies that inputs are
ndarrays.
Wrapper around np.testing.assert_allclose that also verifies that inputs
are ndarrays.
See Also
--------
np.assert_allclose
"""
if type(actual) != type(desired):
raise AssertionError("%s != %s" % (type(actual), type(desired)))
return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
def check_arrays(x, y, err_msg='', verbose=True):
"""
Wrapper around np.testing.assert_array_equal that also verifies that inputs
are ndarrays.
See Also
--------
np.assert_array_equal
"""
if type(left) != type(right):
raise AssertionError("%s != %s" % (type(left), type(right)))
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)
if type(x) != type(y):
raise AssertionError("%s != %s" % (type(x), type(y)))
return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
class UnexpectedAttributeAccess(Exception):