Files
catalyst/tests/test_labelarray.py
T
Scott Sanderson 5f190395ad ENH: Add support for strings in Pipeline.
- Adds a new class, ``LabelArray``, which is a subclass of np.ndarray.
  LabelArray is conceptually similar to pandas.Categorical, in that it
  stores data with many duplicate values as indices into an array of
  unique values.  For string data with many duplicates (e.g. time-series
  of tickers or or industry classifications), this provides multiple
  orders of magnitude of improvement when doing string operations,
  especially string comparison/matching operations.

- Adds a new generic object "specialization" for `AdjustedArrayWindow`,
  and a corresponding ObjectOverwrite adjustment.

- Adds a new ``postprocess`` method to ``zipline.pipeline.term.Term``.
  This method is called on the final result of any pipeline expression
  after screen filtering has occurred. The default implementation of
  ``postprocess`` is identity, but Classifier overrides it to coerce
  string columns into pandas.Categoricals before presenting them to the
  user.
2016-05-04 15:50:52 -04:00

152 lines
5.3 KiB
Python

from itertools import product
from operator import eq, ne
import numpy as np
from zipline.lib.labelarray import LabelArray
from zipline.testing import check_arrays, parameter_space, ZiplineTestCase
def rotN(l, N):
"""
Rotate a list of elements.
Pulls N elements off the end of the list and appends them to the front.
>>> rotN(['a', 'b', 'c', 'd'], 2)
['c', 'd', 'a', 'b']
>>> rotN(['a', 'b', 'c', 'd'], 3)
['d', 'a', 'b', 'c']
"""
assert len(l) >= N, "Can't rotate list by longer than its length."
return l[N:] + l[:N]
class LabelArrayTestCase(ZiplineTestCase):
@classmethod
def init_class_fixtures(cls):
super(LabelArrayTestCase, cls).init_class_fixtures()
cls.rowvalues = row = ['', 'a', 'b', 'ab', 'a', '', 'b', 'ab', 'z']
cls.strs = np.array([rotN(row, i) for i in range(3)])
def test_fail_on_direct_construction(self):
# See http://docs.scipy.org/doc/numpy-1.10.0/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray # noqa
with self.assertRaises(TypeError) as e:
np.ndarray.__new__(LabelArray, (5, 5))
self.assertEqual(
str(e.exception),
"Direct construction of LabelArrays is not supported."
)
@parameter_space(
__fail_fast=True,
s=['', 'a', 'z', 'aa', 'not in the array'],
shape=[(27,), (9, 3), (3, 9), (3, 3, 3)],
array_astype=(bytes, unicode, object),
scalar_astype=(bytes, unicode, object),
)
def test_compare_to_str(self, s, shape, array_astype, scalar_astype):
strs = self.strs.reshape(shape).astype(array_astype)
arr = LabelArray(strs, missing_value='')
check_arrays(strs == s, arr == s)
check_arrays(strs != s, arr != s)
np_startswith = np.vectorize(lambda elem: elem.startswith(s))
check_arrays(arr.startswith(s), np_startswith(strs))
np_endswith = np.vectorize(lambda elem: elem.endswith(s))
check_arrays(arr.endswith(s), np_endswith(strs))
np_contains = np.vectorize(lambda elem: s in elem)
check_arrays(arr.contains(s), np_contains(strs))
def test_compare_to_str_array(self):
strs = self.strs
shape = strs.shape
arr = LabelArray(strs, missing_value='')
check_arrays(strs == arr, np.full_like(strs, True, dtype=bool))
check_arrays(strs != arr, np.full_like(strs, False, dtype=bool))
def broadcastable_row(value, dtype):
return np.full((shape[0], 1), value, dtype=strs.dtype)
def broadcastable_col(value, dtype):
return np.full((1, shape[1]), value, dtype=strs.dtype)
for comparator, dtype, value in product((eq, ne),
(bytes, unicode, object),
set(self.rowvalues)):
check_arrays(
comparator(arr, np.full_like(strs, value)),
comparator(strs, value),
)
check_arrays(
comparator(arr, broadcastable_row(value, dtype=dtype)),
comparator(strs, value),
)
check_arrays(
comparator(arr, broadcastable_col(value, dtype=dtype)),
comparator(strs, value),
)
@parameter_space(
__fail_fast=True,
slice_=[
0, 1, -1,
slice(None),
slice(0, 0),
slice(0, 3),
slice(1, 4),
slice(0),
slice(None, 1),
slice(0, 4, 2),
(slice(None), 1),
(slice(None), slice(None)),
(slice(None), slice(1, 2)),
]
)
def test_slicing_preserves_attributes(self, slice_):
arr = LabelArray(self.strs.reshape((9, 3)), missing_value='')
sliced = arr[slice_]
self.assertIsInstance(sliced, LabelArray)
self.assertIs(sliced.categories, arr.categories)
self.assertIs(sliced.reverse_categories, arr.reverse_categories)
self.assertIs(sliced.missing_value, arr.missing_value)
def test_infer_categories(self):
arr1d = LabelArray(self.strs, missing_value='')
codes1d = arr1d.as_int_array()
self.assertEqual(arr1d.shape, self.strs.shape)
self.assertEqual(arr1d.shape, codes1d.shape)
categories = arr1d.categories
unique_rowvalues = set(self.rowvalues)
# There should be an entry in categories for each unique row value, and
# each integer stored in the data array should be an index into
# categories.
self.assertEqual(list(categories), sorted(set(self.rowvalues)))
self.assertEqual(
set(codes1d.ravel()),
set(range(len(unique_rowvalues)))
)
for idx, value in enumerate(arr1d.categories):
check_arrays(
self.strs == value,
arr1d.view(type=np.ndarray) == idx,
)
for shape in (9, 3), (3, 9), (3, 3, 3):
strs2d = self.strs.reshape(shape)
arr2d = LabelArray(strs2d, missing_value='')
codes2d = arr2d.as_int_array()
self.assertEqual(arr2d.shape, shape)
check_arrays(arr2d.categories, categories)
for idx, value in enumerate(arr2d.categories):
check_arrays(strs2d == value, codes2d == idx)