mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 14:47:08 +08:00
BUG: RSI wasn't even close to working.
Fixed and added tests.
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
"""
|
||||
Tests for Factor terms.
|
||||
"""
|
||||
from numpy import array, eye, nan, ones
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from numpy import arange, array, empty, eye, nan, ones, datetime64
|
||||
from numpy.random import randn, seed
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.pipeline import Factor, Filter, TermGraph
|
||||
from zipline.utils.test_utils import check_arrays
|
||||
from zipline.pipeline.factors import RSI
|
||||
from zipline.utils.test_utils import check_allclose, check_arrays
|
||||
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
@@ -190,3 +195,30 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
)
|
||||
for method in results:
|
||||
check_arrays(expected[method], results[method])
|
||||
|
||||
@parameterized.expand([
|
||||
# Test cases computed by doing:
|
||||
# from numpy.random import seed, randn
|
||||
# from talib import RSI
|
||||
# seed(seed_value)
|
||||
# data = abs(randn(15, 3))
|
||||
# expected = [RSI(data[:, i])[-1] for i in range(3)]
|
||||
(100, array([41.032913785966, 51.553585468393, 51.022005016446])),
|
||||
(101, array([43.506969935466, 46.145367530182, 50.57407044197])),
|
||||
(102, array([46.610102205934, 47.646892444315, 52.13182788538])),
|
||||
])
|
||||
def test_rsi(self, seed_value, expected):
|
||||
|
||||
rsi = RSI()
|
||||
|
||||
today = datetime64(1, 'ns')
|
||||
assets = arange(3)
|
||||
out = empty((3,), dtype=float)
|
||||
|
||||
seed(seed_value) # Seed so we get deterministic results.
|
||||
test_data = abs(randn(15, 3))
|
||||
|
||||
out = empty((3,), dtype=float)
|
||||
rsi.compute(today, assets, out, test_data)
|
||||
|
||||
check_allclose(expected, out)
|
||||
|
||||
@@ -9,6 +9,7 @@ from bottleneck import (
|
||||
nansum,
|
||||
)
|
||||
from numpy import (
|
||||
abs,
|
||||
clip,
|
||||
diff,
|
||||
fmax,
|
||||
@@ -30,19 +31,19 @@ class RSI(CustomFactor, SingleInputMixin):
|
||||
|
||||
**Default Inputs**: [USEquityPricing.close]
|
||||
|
||||
**Default Window Length**: 14
|
||||
**Default Window Length**: 15
|
||||
"""
|
||||
window_length = 14
|
||||
window_length = 15
|
||||
inputs = (USEquityPricing.close,)
|
||||
|
||||
def compute(self, today, assets, out, closes):
|
||||
diffs = diff(closes)
|
||||
diffs = diff(closes, axis=0)
|
||||
ups = nanmean(clip(diffs, 0, inf), axis=0)
|
||||
downs = nanmean(clip(diffs, -inf, 0), axis=0)
|
||||
downs = abs(nanmean(clip(diffs, -inf, 0), axis=0))
|
||||
return evaluate(
|
||||
"100 - (100 / (1 + (ups / downs)))",
|
||||
locals_dict={'ups': ups, 'downs': downs},
|
||||
globals_dict={},
|
||||
local_dict={'ups': ups, 'downs': downs},
|
||||
global_dict={},
|
||||
out=out,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from itertools import (
|
||||
)
|
||||
from logbook import FileHandler
|
||||
from mock import patch
|
||||
from numpy.testing import assert_array_equal
|
||||
from numpy.testing import assert_allclose, assert_array_equal
|
||||
import operator
|
||||
from zipline.finance.blotter import ORDER_STATUS
|
||||
from zipline.utils import security_list
|
||||
@@ -310,18 +310,37 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None):
|
||||
)
|
||||
|
||||
|
||||
def check_arrays(left, right, err_msg='', verbose=True):
|
||||
def check_allclose(actual,
|
||||
desired,
|
||||
rtol=1e-07,
|
||||
atol=0,
|
||||
err_msg='',
|
||||
verbose=True):
|
||||
"""
|
||||
Wrapper around np.assert_array_equal that also verifies that inputs are
|
||||
ndarrays.
|
||||
Wrapper around np.testing.assert_allclose that also verifies that inputs
|
||||
are ndarrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
np.assert_allclose
|
||||
"""
|
||||
if type(actual) != type(desired):
|
||||
raise AssertionError("%s != %s" % (type(actual), type(desired)))
|
||||
return assert_allclose(actual, desired, err_msg=err_msg, verbose=True)
|
||||
|
||||
|
||||
def check_arrays(x, y, err_msg='', verbose=True):
|
||||
"""
|
||||
Wrapper around np.testing.assert_array_equal that also verifies that inputs
|
||||
are ndarrays.
|
||||
|
||||
See Also
|
||||
--------
|
||||
np.assert_array_equal
|
||||
"""
|
||||
if type(left) != type(right):
|
||||
raise AssertionError("%s != %s" % (type(left), type(right)))
|
||||
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)
|
||||
if type(x) != type(y):
|
||||
raise AssertionError("%s != %s" % (type(x), type(y)))
|
||||
return assert_array_equal(x, y, err_msg=err_msg, verbose=True)
|
||||
|
||||
|
||||
class UnexpectedAttributeAccess(Exception):
|
||||
|
||||
Reference in New Issue
Block a user