mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 01:09:56 +08:00
Merge pull request #905 from quantopian/refactor-adjusted-array
Adds support for different typed adjusted arrays and adds an EarningsCalendar loader
This commit is contained in:
@@ -45,6 +45,12 @@ Pipeline API
|
||||
.. autoclass:: zipline.pipeline.factors.RSI
|
||||
:members:
|
||||
|
||||
.. autoclass:: zipline.pipeline.factors.BusinessDaysUntilNextEarnings
|
||||
:members:
|
||||
|
||||
.. autoclass:: zipline.pipeline.factors.BusinessDaysSincePreviousEarnings
|
||||
:members:
|
||||
|
||||
.. autoclass:: zipline.pipeline.factors.SimpleMovingAverage
|
||||
:members:
|
||||
|
||||
@@ -58,6 +64,15 @@ Pipeline API
|
||||
:members: __and__, __or__
|
||||
:exclude-members: dtype
|
||||
|
||||
.. autoclass:: zipline.pipeline.data.EarningsCalendar
|
||||
:members: next_announcement, previous_announcement
|
||||
:undoc-members:
|
||||
|
||||
.. autoclass:: zipline.pipeline.data.USEquityPricing
|
||||
:members: open, high, low, close, volume
|
||||
:undoc-members:
|
||||
|
||||
|
||||
Asset Metadata
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -11,7 +11,12 @@ Development
|
||||
Highlights
|
||||
~~~~~~~~~~
|
||||
|
||||
None
|
||||
* Added a new :class:`~zipline.pipeline.data.EarningsCalendar` dataset
|
||||
for use in the Pipeline API. (:issue:`905`).
|
||||
|
||||
* :class:`~zipline.assets.assets.AssetFinder` speedups (:issue:`830` and
|
||||
:issue:`817`).
|
||||
|
||||
|
||||
Enhancements
|
||||
~~~~~~~~~~~~
|
||||
@@ -23,13 +28,30 @@ Enhancements
|
||||
passed to :class:`~zipline.algorithm.TradingAlgorithm` by the keyword argument
|
||||
``create_event_context`` (:issue:`828`).
|
||||
|
||||
* Added support for :class:`zipline.pipeline.factors.Factor` instances with
|
||||
``datetime64[ns]`` dtypes. (:issue:`905`)
|
||||
|
||||
* Added a new :class:`~zipline.pipeline.data.earnings.EarningsCalendar` dataset
|
||||
for use in the Pipeline API. This dataset provides an abstract interface for
|
||||
adding earnings announcement data to a new algorithm. A pandas-based
|
||||
reference implementation for this dataset can be found in
|
||||
:mod:`zipline.pipeline.loaders.earnings`, and an experimental blaze-based
|
||||
implementation can be found in
|
||||
:mod:`zipline.pipeline.loaders.blaze.earnings`. (:issue:`905`).
|
||||
|
||||
* Added new built-in factors,
|
||||
:class:`zipline.pipeline.factors.BusinessDaysUntilNextEarnings` and
|
||||
:class:`zipline.pipeline.factors.BusinessDaysSincePreviousEarnings`. These
|
||||
factors use the new ``EarningsCalendar`` dataset. (:issue:`905`).
|
||||
|
||||
* Added :meth:`~zipline.pipeline.factors.Factor.isnan`,
|
||||
:meth:`~zipline.pipeline.factors.Factor.notnan` and
|
||||
:meth:`~zipline.pipeline.factors.Factor.isfinite` methods to
|
||||
:class:`zipline.pipeline.factors.Factor` (:issue:`861`).
|
||||
|
||||
* Added :class:`zipline.pipeline.factors.Returns`, a built-in factor which
|
||||
calculates the percent change in close price over the given window_length.
|
||||
calculates the percent change in close price over the given
|
||||
window_length. (:issue:`884`).
|
||||
|
||||
Experimental Features
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -61,6 +83,10 @@ Performance
|
||||
:meth:`~zipline.assets.assets.AssetFinder.lookup_symbol` to these dictionaries
|
||||
to find matching equities (:issue:`830`).
|
||||
|
||||
* Improved performance of
|
||||
:meth:`~zipline.assets.assets.AssetFinder.lookup_symbol` by performing
|
||||
batched queries. (:issue:`817`).
|
||||
|
||||
Maintenance and Refactorings
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
-e git://github.com/quantopian/blaze.git@787e79850a5e1ed3072a15f3d6d1acb6308af051#egg=blaze-dev
|
||||
-e git://github.com/quantopian/odo.git@a79cef896ad13afe356fd104714d74a5019eb761#egg=odo-dev
|
||||
-e git://github.com/quantopian/datashape.git@8896ab63fee76100769404abb6b2676dc8cab3f1#egg=datashape-dev
|
||||
-e git://github.com/quantopian/blaze.git@43d2f7e00a228106cea038a53322497831539559#egg=blaze-dev
|
||||
-e git://github.com/quantopian/odo.git@4f7f45fb039d89ea101803b95da21fc055901d66#egg=odo-dev
|
||||
-e git://github.com/quantopian/datashape.git@9bd8fb970a0fc55e866a0b46b5101c9aa47e24ed#egg=datashape-dev
|
||||
|
||||
@@ -62,8 +62,10 @@ class LazyCythonizingList(list):
|
||||
|
||||
ext_modules = LazyCythonizingList([
|
||||
('zipline.assets._assets', ['zipline/assets/_assets.pyx']),
|
||||
('zipline.lib.adjusted_array', ['zipline/lib/adjusted_array.pyx']),
|
||||
('zipline.lib.adjustment', ['zipline/lib/adjustment.pyx']),
|
||||
('zipline.lib._float64window', ['zipline/lib/_float64window.pyx']),
|
||||
('zipline.lib._int64window', ['zipline/lib/_int64window.pyx']),
|
||||
('zipline.lib._uint8window', ['zipline/lib/_uint8window.pyx']),
|
||||
('zipline.lib.rank', ['zipline/lib/rank.pyx']),
|
||||
(
|
||||
'zipline.data._equities',
|
||||
|
||||
+21
-18
@@ -8,11 +8,14 @@ from numpy import arange, prod
|
||||
from pandas import date_range, Int64Index, DataFrame
|
||||
from six import iteritems
|
||||
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.term import AssetExists
|
||||
from zipline.utils.pandas_utils import explode
|
||||
from zipline.utils.test_utils import make_simple_equity_info, ExplodingObject
|
||||
from zipline.utils.test_utils import (
|
||||
ExplodingObject,
|
||||
make_simple_equity_info,
|
||||
tmp_asset_finder,
|
||||
)
|
||||
from zipline.utils.tradingcalendar import trading_day
|
||||
|
||||
|
||||
@@ -45,27 +48,27 @@ with_default_shape = with_defaults(shape=lambda self: self.default_shape)
|
||||
|
||||
class BasePipelineTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.__calendar = date_range('2014', '2015', freq=trading_day)
|
||||
self.__assets = assets = Int64Index(arange(1, 20))
|
||||
|
||||
# Set up env for test
|
||||
env = TradingEnvironment()
|
||||
env.write_data(
|
||||
equities_df=make_simple_equity_info(
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.__calendar = date_range('2014', '2015', freq=trading_day)
|
||||
cls.__assets = assets = Int64Index(arange(1, 20))
|
||||
cls.__tmp_finder_ctx = tmp_asset_finder(
|
||||
equities=make_simple_equity_info(
|
||||
assets,
|
||||
self.__calendar[0],
|
||||
self.__calendar[-1],
|
||||
),
|
||||
cls.__calendar[0],
|
||||
cls.__calendar[-1],
|
||||
)
|
||||
)
|
||||
self.__finder = env.asset_finder
|
||||
|
||||
# Use a 30-day period at the end of the year by default.
|
||||
self.__mask = self.__finder.lifetimes(
|
||||
self.__calendar[-30:],
|
||||
cls.__finder = cls.__tmp_finder_ctx.__enter__()
|
||||
cls.__mask = cls.__finder.lifetimes(
|
||||
cls.__calendar[-30:],
|
||||
include_start_date=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.__tmp_finder_ctx.__exit__()
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
"""Default shape for methods that build test data."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Tests for chunked adjustments.
|
||||
"""
|
||||
from itertools import chain
|
||||
from textwrap import dedent
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -13,17 +14,17 @@ from numpy import (
|
||||
from numpy.testing import assert_array_equal
|
||||
from six.moves import zip_longest
|
||||
|
||||
from zipline.errors import WindowLengthNotPositive, WindowLengthTooLong
|
||||
from zipline.lib.adjustment import (
|
||||
Datetime64Overwrite,
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
)
|
||||
from zipline.lib.adjusted_array import (
|
||||
adjusted_array,
|
||||
NOMASK,
|
||||
)
|
||||
from zipline.errors import (
|
||||
WindowLengthNotPositive,
|
||||
WindowLengthTooLong,
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
make_datetime64ns,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,11 +49,23 @@ def valid_window_lengths(underlying_buffer_length):
|
||||
return iter(range(1, underlying_buffer_length + 1))
|
||||
|
||||
|
||||
def value_with_dtype(dtype, value):
|
||||
"""
|
||||
Make a value with the specified numpy dtype.
|
||||
"""
|
||||
name = dtype.name
|
||||
if name.startswith('datetime64'):
|
||||
if name != 'datetime64[ns]':
|
||||
raise TypeError("Expected datetime64[ns], but got %s." % name)
|
||||
return make_datetime64ns(value)
|
||||
return dtype.type(value)
|
||||
|
||||
|
||||
def _gen_unadjusted_cases(dtype):
|
||||
|
||||
nrows = 6
|
||||
ncols = 3
|
||||
data = arange(nrows * ncols, dtype=dtype).reshape(nrows, ncols)
|
||||
data = arange(nrows * ncols).astype(dtype).reshape(nrows, ncols)
|
||||
|
||||
for windowlen in valid_window_lengths(nrows):
|
||||
|
||||
@@ -61,7 +74,7 @@ def _gen_unadjusted_cases(dtype):
|
||||
)
|
||||
|
||||
yield (
|
||||
"length_%d" % windowlen,
|
||||
"dtype_%s_length_%d" % (dtype, windowlen),
|
||||
data,
|
||||
windowlen,
|
||||
{},
|
||||
@@ -86,7 +99,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
We then build all legal windows over these buffers.
|
||||
"""
|
||||
adjustment_type = {
|
||||
float: Float64Multiply,
|
||||
float64_dtype: Float64Multiply,
|
||||
}[dtype]
|
||||
|
||||
nrows, ncols = 6, 3
|
||||
@@ -96,7 +109,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, dtype(2)),
|
||||
adjustment_type(0, 0, 0, 0, value_with_dtype(dtype, 2)),
|
||||
]
|
||||
buffer_as_of[1] = array([[2, 1, 1],
|
||||
[1, 1, 1],
|
||||
@@ -109,8 +122,8 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
adjustments[3] = [
|
||||
adjustment_type(1, 2, 1, 1, dtype(3)),
|
||||
adjustment_type(0, 1, 0, 0, dtype(4)),
|
||||
adjustment_type(1, 2, 1, 1, value_with_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, value_with_dtype(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = array([[8, 1, 1],
|
||||
[4, 3, 1],
|
||||
@@ -120,7 +133,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
[1, 1, 1]], dtype=dtype)
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, dtype(5))
|
||||
adjustment_type(0, 3, 2, 2, value_with_dtype(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = array([[8, 1, 5],
|
||||
[4, 3, 5],
|
||||
@@ -130,8 +143,8 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
[1, 1, 1]], dtype=dtype)
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, dtype(6)),
|
||||
adjustment_type(2, 2, 2, 2, dtype(7)),
|
||||
adjustment_type(0, 4, 1, 1, value_with_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, value_with_dtype(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = array([[8, 6, 5],
|
||||
[4, 18, 5],
|
||||
@@ -151,9 +164,9 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
multiplicative adjustments. The only difference is the semantics of how
|
||||
the adjustments are expected to modify the arrays.
|
||||
"""
|
||||
|
||||
adjustment_type = {
|
||||
float: Float64Overwrite,
|
||||
float64_dtype: Float64Overwrite,
|
||||
datetime64ns_dtype: Datetime64Overwrite,
|
||||
}[dtype]
|
||||
|
||||
nrows, ncols = 6, 3
|
||||
@@ -163,7 +176,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, dtype(1)),
|
||||
adjustment_type(0, 0, 0, 0, value_with_dtype(dtype, 1)),
|
||||
]
|
||||
buffer_as_of[1] = array([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
@@ -176,8 +189,8 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
adjustments[3] = [
|
||||
adjustment_type(1, 2, 1, 1, dtype(3)),
|
||||
adjustment_type(0, 1, 0, 0, dtype(4)),
|
||||
adjustment_type(1, 2, 1, 1, value_with_dtype(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, value_with_dtype(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = array([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
@@ -187,7 +200,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
[2, 2, 2]], dtype=dtype)
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, dtype(5))
|
||||
adjustment_type(0, 3, 2, 2, value_with_dtype(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = array([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
@@ -197,8 +210,8 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
[2, 2, 2]], dtype=dtype)
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, dtype(6)),
|
||||
adjustment_type(2, 2, 2, 2, dtype(7)),
|
||||
adjustment_type(0, 4, 1, 1, value_with_dtype(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, value_with_dtype(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = array([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
@@ -224,7 +237,7 @@ def _gen_expectations(baseline, adjustments, buffer_as_of, nrows):
|
||||
)
|
||||
|
||||
yield (
|
||||
"length_%d" % windowlen,
|
||||
"dtype_%s_length_%d" % (baseline.dtype, windowlen),
|
||||
baseline,
|
||||
windowlen,
|
||||
adjustments,
|
||||
@@ -243,61 +256,63 @@ def _gen_expectations(baseline, adjustments, buffer_as_of, nrows):
|
||||
|
||||
class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
@parameterized.expand(_gen_unadjusted_cases(float))
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_unadjusted_cases(float64_dtype),
|
||||
_gen_unadjusted_cases(datetime64ns_dtype),
|
||||
)
|
||||
)
|
||||
def test_no_adjustments(self,
|
||||
name,
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
expected):
|
||||
array = adjusted_array(
|
||||
data,
|
||||
NOMASK,
|
||||
adjustments,
|
||||
)
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
self.assertEqual(yielded.dtype, data.dtype)
|
||||
assert_array_equal(yielded, expected_yield)
|
||||
|
||||
@parameterized.expand(_gen_multiplicative_adjustment_cases(float))
|
||||
@parameterized.expand(_gen_multiplicative_adjustment_cases(float64_dtype))
|
||||
def test_multiplicative_adjustments(self,
|
||||
name,
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
expected):
|
||||
array = adjusted_array(
|
||||
data,
|
||||
NOMASK,
|
||||
adjustments,
|
||||
)
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
assert_array_equal(yielded, expected_yield)
|
||||
|
||||
@parameterized.expand(_gen_overwrite_adjustment_cases(float))
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_overwrite_adjustment_cases(float64_dtype),
|
||||
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
|
||||
)
|
||||
)
|
||||
def test_overwrite_adjustment_cases(self,
|
||||
name,
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
expected):
|
||||
array = adjusted_array(
|
||||
data,
|
||||
NOMASK,
|
||||
adjustments,
|
||||
)
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
self.assertEqual(yielded.dtype, data.dtype)
|
||||
assert_array_equal(yielded, expected_yield)
|
||||
|
||||
def test_invalid_lookback(self):
|
||||
|
||||
data = arange(30, dtype=float).reshape(6, 5)
|
||||
adj_array = adjusted_array(data, NOMASK, {})
|
||||
adj_array = AdjustedArray(data, NOMASK, {})
|
||||
|
||||
with self.assertRaises(WindowLengthTooLong):
|
||||
adj_array.traverse(7)
|
||||
@@ -311,7 +326,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
def test_array_views_arent_writable(self):
|
||||
|
||||
data = arange(30, dtype=float).reshape(6, 5)
|
||||
adj_array = adjusted_array(data, NOMASK, {})
|
||||
adj_array = AdjustedArray(data, NOMASK, {})
|
||||
|
||||
for frame in adj_array.traverse(3):
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -323,11 +338,11 @@ class AdjustedArrayTestCase(TestCase):
|
||||
bad_mask = array([[0, 1, 1], [0, 0, 1]], dtype=bool)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
adjusted_array(data, bad_mask, {})
|
||||
AdjustedArray(data, bad_mask, {})
|
||||
|
||||
def test_inspect(self):
|
||||
data = arange(15, dtype=float).reshape(5, 3)
|
||||
adj_array = adjusted_array(
|
||||
adj_array = AdjustedArray(
|
||||
data,
|
||||
NOMASK,
|
||||
{4: [Float64Multiply(2, 3, 0, 0, 4.0)]},
|
||||
@@ -335,7 +350,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
expected = dedent(
|
||||
"""\
|
||||
Adjusted Array:
|
||||
Adjusted Array (float64):
|
||||
|
||||
Data:
|
||||
array([[ 0., 1., 2.],
|
||||
@@ -349,4 +364,5 @@ class AdjustedArrayTestCase(TestCase):
|
||||
last_col=0, value=4.000000)]}
|
||||
"""
|
||||
)
|
||||
self.assertEqual(expected, adj_array.inspect())
|
||||
got = adj_array.inspect()
|
||||
self.assertEqual(expected, got)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Tests for zipline.lib.adjustment
|
||||
"""
|
||||
from unittest import TestCase
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from zipline.lib import adjustment as adj
|
||||
from zipline.utils.numpy_utils import make_datetime64ns
|
||||
|
||||
|
||||
class AdjustmentTestCase(TestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
('add', adj.ADD),
|
||||
('multiply', adj.MULTIPLY),
|
||||
('overwrite', adj.OVERWRITE),
|
||||
])
|
||||
def test_make_float_adjustment(self, name, adj_type):
|
||||
expected_types = {
|
||||
'add': adj.Float64Add,
|
||||
'multiply': adj.Float64Multiply,
|
||||
'overwrite': adj.Float64Overwrite,
|
||||
}
|
||||
result = adj.make_adjustment_from_indices(
|
||||
1, 2, 3, 4,
|
||||
adjustment_kind=adj_type,
|
||||
value=0.5,
|
||||
)
|
||||
expected = expected_types[name](
|
||||
first_row=1,
|
||||
last_row=2,
|
||||
first_col=3,
|
||||
last_col=4,
|
||||
value=0.5,
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_make_datetime_adjustment(self):
|
||||
overwrite_dt = make_datetime64ns(0)
|
||||
result = adj.make_adjustment_from_indices(
|
||||
1, 2, 3, 4,
|
||||
adjustment_kind=adj.OVERWRITE,
|
||||
value=overwrite_dt,
|
||||
)
|
||||
expected = adj.Datetime64Overwrite(
|
||||
first_row=1,
|
||||
last_row=2,
|
||||
first_col=3,
|
||||
last_col=4,
|
||||
value=overwrite_dt,
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_unsupported_type(self):
|
||||
class SomeClass(object):
|
||||
pass
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
adj.make_adjustment_from_indices(
|
||||
1, 2, 3, 4,
|
||||
adjustment_kind=adj.OVERWRITE,
|
||||
value=SomeClass(),
|
||||
)
|
||||
|
||||
exc = e.exception
|
||||
expected_msg = (
|
||||
"Don't know how to make overwrite adjustments for values of type "
|
||||
"%r." % SomeClass
|
||||
)
|
||||
self.assertEqual(str(exc), expected_msg)
|
||||
@@ -25,8 +25,11 @@ from zipline.pipeline.loaders.blaze import (
|
||||
from_blaze,
|
||||
BlazeLoader,
|
||||
NoDeltasWarning,
|
||||
)
|
||||
from zipline.pipeline.loaders.blaze.core import (
|
||||
NonNumpyField,
|
||||
NonPipelineField,
|
||||
no_deltas_rules,
|
||||
)
|
||||
from zipline.utils.numpy_utils import repeat_last_axis
|
||||
from zipline.utils.test_utils import tmp_asset_finder, make_simple_equity_info
|
||||
@@ -82,7 +85,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
self.assertEqual(ds.__name__, name)
|
||||
self.assertTrue(issubclass(ds, DataSet))
|
||||
@@ -102,7 +105,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
),
|
||||
ds,
|
||||
)
|
||||
@@ -113,7 +116,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
value = from_blaze(
|
||||
expr.value,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
self.assertEqual(value.name, 'value')
|
||||
self.assertIsInstance(value, BoundColumn)
|
||||
@@ -124,7 +127,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr.value,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
),
|
||||
value,
|
||||
)
|
||||
@@ -132,7 +135,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
).value,
|
||||
value,
|
||||
)
|
||||
@@ -142,7 +145,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
),
|
||||
value.dataset,
|
||||
)
|
||||
@@ -164,7 +167,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
self.assertIn("'asof_date'", str(e.exception))
|
||||
self.assertIn(repr(str(expr.dshape.measure)), str(e.exception))
|
||||
@@ -193,7 +196,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule='warn',
|
||||
no_deltas_rule=no_deltas_rules.warn,
|
||||
)
|
||||
self.assertEqual(len(ws), 1)
|
||||
w = ws[0].message
|
||||
@@ -207,7 +210,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
from_blaze(
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule='raise',
|
||||
no_deltas_rule=no_deltas_rules.raise_,
|
||||
)
|
||||
self.assertIn(str(expr), str(e.exception))
|
||||
|
||||
@@ -224,7 +227,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
ds.a
|
||||
@@ -246,7 +249,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
with self.assertRaises(AttributeError):
|
||||
ds.a
|
||||
@@ -298,7 +301,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
p = Pipeline()
|
||||
p.add(ds.value.latest, 'value')
|
||||
@@ -326,7 +329,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule='ignore',
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
p = Pipeline()
|
||||
p.add(ds.value.latest, 'value')
|
||||
@@ -367,7 +370,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
deltas,
|
||||
loader=loader,
|
||||
no_deltas_rule='raise',
|
||||
no_deltas_rule=no_deltas_rules.raise_,
|
||||
)
|
||||
p = Pipeline()
|
||||
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Tests for the reference loader for EarningsCalendar.
|
||||
"""
|
||||
from unittest import TestCase
|
||||
|
||||
import blaze as bz
|
||||
from blaze.compute.core import swap_resources_into_scope
|
||||
from contextlib2 import ExitStack
|
||||
from nose_parameterized import parameterized
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pandas.util.testing import assert_series_equal
|
||||
from six import iteritems
|
||||
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data import EarningsCalendar
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.factors.events import (
|
||||
BusinessDaysUntilNextEarnings,
|
||||
BusinessDaysSincePreviousEarnings,
|
||||
)
|
||||
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
|
||||
from zipline.pipeline.loaders.blaze import (
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
BlazeEarningsCalendarLoader,
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.utils.numpy_utils import make_datetime64D, np_NaT
|
||||
from zipline.utils.tradingcalendar import trading_days
|
||||
from zipline.utils.test_utils import (
|
||||
make_simple_equity_info,
|
||||
powerset,
|
||||
tmp_asset_finder,
|
||||
)
|
||||
|
||||
|
||||
def _to_series(knowledge_dates, earning_dates):
|
||||
"""
|
||||
Helper for converting a dict of strings to a Series of datetimes.
|
||||
|
||||
This is just for making the test cases more readable.
|
||||
"""
|
||||
return pd.Series(
|
||||
index=pd.to_datetime(knowledge_dates),
|
||||
data=pd.to_datetime(earning_dates),
|
||||
)
|
||||
|
||||
|
||||
def num_days_in_range(dates, start, end):
|
||||
"""
|
||||
Return the number of days in `dates` between start and end, inclusive.
|
||||
"""
|
||||
start_idx, stop_idx = dates.slice_locs(start, end)
|
||||
return stop_idx - start_idx
|
||||
|
||||
|
||||
def gen_calendars():
|
||||
"""
|
||||
Generate calendars to use as inputs to test_compute_latest.
|
||||
"""
|
||||
start, stop = '2014-01-01', '2014-01-31'
|
||||
all_dates = pd.date_range(start, stop, tz='utc')
|
||||
|
||||
# These dates are the points where announcements or knowledge dates happen.
|
||||
# Test every combination of them being absent.
|
||||
critical_dates = pd.to_datetime([
|
||||
'2014-01-05',
|
||||
'2014-01-10',
|
||||
'2014-01-15',
|
||||
'2014-01-20',
|
||||
])
|
||||
for to_drop in map(list, powerset(critical_dates)):
|
||||
# Have to yield tuples.
|
||||
yield (all_dates.drop(to_drop),)
|
||||
|
||||
# Also test with the trading calendar.
|
||||
yield (trading_days[trading_days.slice_indexer(start, stop)],)
|
||||
|
||||
|
||||
class EarningsCalendarLoaderTestCase(TestCase):
|
||||
"""
|
||||
Tests for loading the earnings announcement data.
|
||||
"""
|
||||
loader_type = EarningsCalendarLoader
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._cleanup_stack = stack = ExitStack()
|
||||
cls.sids = A, B, C, D, E = range(5)
|
||||
equity_info = make_simple_equity_info(
|
||||
cls.sids,
|
||||
start_date=pd.Timestamp('2013-01-01', tz='UTC'),
|
||||
end_date=pd.Timestamp('2015-01-01', tz='UTC'),
|
||||
)
|
||||
cls.finder = stack.enter_context(
|
||||
tmp_asset_finder(equities=equity_info),
|
||||
)
|
||||
|
||||
cls.earnings_dates = {
|
||||
# K1--K2--E1--E2.
|
||||
A: _to_series(
|
||||
knowledge_dates=['2014-01-05', '2014-01-10'],
|
||||
earning_dates=['2014-01-15', '2014-01-20'],
|
||||
),
|
||||
# K1--K2--E2--E1.
|
||||
B: _to_series(
|
||||
knowledge_dates=['2014-01-05', '2014-01-10'],
|
||||
earning_dates=['2014-01-20', '2014-01-15']
|
||||
),
|
||||
# K1--E1--K2--E2.
|
||||
C: _to_series(
|
||||
knowledge_dates=['2014-01-05', '2014-01-15'],
|
||||
earning_dates=['2014-01-10', '2014-01-20']
|
||||
),
|
||||
# K1 == K2.
|
||||
D: _to_series(
|
||||
knowledge_dates=['2014-01-05'] * 2,
|
||||
earning_dates=['2014-01-10', '2014-01-15'],
|
||||
),
|
||||
E: pd.Series(
|
||||
data=[],
|
||||
index=pd.DatetimeIndex([]),
|
||||
dtype='datetime64[ns]',
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._cleanup_stack.close()
|
||||
|
||||
def loader_args(self, dates):
|
||||
"""Construct the base earnings announcements object to pass to the
|
||||
loader.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : pd.DatetimeIndex
|
||||
The dates we can serve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
args : tuple[any]
|
||||
The arguments to forward to the loader positionally.
|
||||
"""
|
||||
return dates, self.earnings_dates
|
||||
|
||||
def setup(self, dates):
|
||||
"""
|
||||
Make a PipelineEngine and expectation functions for the given dates
|
||||
calendar.
|
||||
|
||||
This exists to make it easy to test our various cases with critical
|
||||
dates missing from the calendar.
|
||||
"""
|
||||
A, B, C, D, E = self.sids
|
||||
|
||||
def num_days_between(start_date, end_date):
|
||||
return num_days_in_range(dates, start_date, end_date)
|
||||
|
||||
def zip_with_dates(dts):
|
||||
return pd.Series(pd.to_datetime(dts), index=dates)
|
||||
|
||||
_expected_next_announce = pd.DataFrame({
|
||||
A: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-04') +
|
||||
['2014-01-15'] * num_days_between('2014-01-05', '2014-01-15') +
|
||||
['2014-01-20'] * num_days_between('2014-01-16', '2014-01-20') +
|
||||
['NaT'] * num_days_between('2014-01-21', None)
|
||||
),
|
||||
B: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-04') +
|
||||
['2014-01-20'] * num_days_between('2014-01-05', '2014-01-09') +
|
||||
['2014-01-15'] * num_days_between('2014-01-10', '2014-01-15') +
|
||||
['2014-01-20'] * num_days_between('2014-01-16', '2014-01-20') +
|
||||
['NaT'] * num_days_between('2014-01-21', None)
|
||||
),
|
||||
C: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-04') +
|
||||
['2014-01-10'] * num_days_between('2014-01-05', '2014-01-10') +
|
||||
['NaT'] * num_days_between('2014-01-11', '2014-01-14') +
|
||||
['2014-01-20'] * num_days_between('2014-01-15', '2014-01-20') +
|
||||
['NaT'] * num_days_between('2014-01-21', None)
|
||||
),
|
||||
D: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-04') +
|
||||
['2014-01-10'] * num_days_between('2014-01-05', '2014-01-10') +
|
||||
['2014-01-15'] * num_days_between('2014-01-11', '2014-01-15') +
|
||||
['NaT'] * num_days_between('2014-01-16', None)
|
||||
),
|
||||
E: zip_with_dates(['NaT'] * len(dates)),
|
||||
}, index=dates)
|
||||
|
||||
_expected_previous_announce = pd.DataFrame({
|
||||
A: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-14') +
|
||||
['2014-01-15'] * num_days_between('2014-01-15', '2014-01-19') +
|
||||
['2014-01-20'] * num_days_between('2014-01-20', None)
|
||||
),
|
||||
B: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-14') +
|
||||
['2014-01-15'] * num_days_between('2014-01-15', '2014-01-19') +
|
||||
['2014-01-20'] * num_days_between('2014-01-20', None)
|
||||
),
|
||||
C: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-09') +
|
||||
['2014-01-10'] * num_days_between('2014-01-10', '2014-01-19') +
|
||||
['2014-01-20'] * num_days_between('2014-01-20', None)
|
||||
),
|
||||
D: zip_with_dates(
|
||||
['NaT'] * num_days_between(None, '2014-01-09') +
|
||||
['2014-01-10'] * num_days_between('2014-01-10', '2014-01-14') +
|
||||
['2014-01-15'] * num_days_between('2014-01-15', None)
|
||||
),
|
||||
E: zip_with_dates(['NaT'] * len(dates)),
|
||||
}, index=dates)
|
||||
|
||||
_expected_next_busday_offsets = self._compute_busday_offsets(
|
||||
_expected_next_announce
|
||||
)
|
||||
_expected_previous_busday_offsets = self._compute_busday_offsets(
|
||||
_expected_previous_announce
|
||||
)
|
||||
|
||||
def expected_next_announce(sid):
|
||||
"""
|
||||
Return the expected next announcement dates for ``sid``.
|
||||
"""
|
||||
return _expected_next_announce[sid]
|
||||
|
||||
def expected_next_busday_offset(sid):
|
||||
"""
|
||||
Return the expected number of days to the next announcement for
|
||||
``sid``.
|
||||
"""
|
||||
return _expected_next_busday_offsets[sid]
|
||||
|
||||
def expected_previous_announce(sid):
|
||||
"""
|
||||
Return the expected previous announcement dates for ``sid``.
|
||||
"""
|
||||
return _expected_previous_announce[sid]
|
||||
|
||||
def expected_previous_busday_offset(sid):
|
||||
"""
|
||||
Return the expected number of days to the next announcement for
|
||||
``sid``.
|
||||
"""
|
||||
return _expected_previous_busday_offsets[sid]
|
||||
|
||||
loader = self.loader_type(*self.loader_args(dates))
|
||||
engine = SimplePipelineEngine(lambda _: loader, dates, self.finder)
|
||||
return (
|
||||
engine,
|
||||
expected_next_announce,
|
||||
expected_next_busday_offset,
|
||||
expected_previous_announce,
|
||||
expected_previous_busday_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_busday_offsets(announcement_dates):
|
||||
"""
|
||||
Compute expected business day offsets from a DataFrame of announcement
|
||||
dates.
|
||||
"""
|
||||
# Column-vector of dates on which factor `compute` will be called.
|
||||
raw_call_dates = announcement_dates.index.values.astype(
|
||||
'datetime64[D]'
|
||||
)[:, None]
|
||||
|
||||
# 2D array of dates containining expected nexg announcement.
|
||||
raw_announce_dates = (
|
||||
announcement_dates.values.astype('datetime64[D]')
|
||||
)
|
||||
|
||||
# Set NaTs to 0 temporarily because busday_count doesn't support NaT.
|
||||
# We fill these entries with NaNs later.
|
||||
whereNaT = raw_announce_dates == np_NaT
|
||||
raw_announce_dates[whereNaT] = make_datetime64D(0)
|
||||
|
||||
# The abs call here makes it so that we can use this function to
|
||||
# compute offsets for both next and previous earnings (previous
|
||||
# earnings offsets come back negative).
|
||||
expected = abs(np.busday_count(
|
||||
raw_call_dates,
|
||||
raw_announce_dates
|
||||
).astype(float))
|
||||
|
||||
expected[whereNaT] = np.nan
|
||||
return pd.DataFrame(
|
||||
data=expected,
|
||||
columns=announcement_dates.columns,
|
||||
index=announcement_dates.index,
|
||||
)
|
||||
|
||||
@parameterized.expand(gen_calendars())
|
||||
def test_compute_earnings(self, dates):
|
||||
|
||||
(
|
||||
engine,
|
||||
expected_next,
|
||||
expected_next_busday_offset,
|
||||
expected_previous,
|
||||
expected_previous_busday_offset,
|
||||
) = self.setup(dates)
|
||||
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'next': EarningsCalendar.next_announcement.latest,
|
||||
'previous': EarningsCalendar.previous_announcement.latest,
|
||||
'days_to_next': BusinessDaysUntilNextEarnings(),
|
||||
'days_since_prev': BusinessDaysSincePreviousEarnings(),
|
||||
}
|
||||
)
|
||||
|
||||
result = engine.run_pipeline(
|
||||
pipe,
|
||||
start_date=dates[0],
|
||||
end_date=dates[-1],
|
||||
)
|
||||
|
||||
computed_next = result['next']
|
||||
computed_previous = result['previous']
|
||||
computed_next_busday_offset = result['days_to_next']
|
||||
computed_previous_busday_offset = result['days_since_prev']
|
||||
|
||||
# NaTs in next/prev should correspond to NaNs in offsets.
|
||||
assert_series_equal(
|
||||
computed_next.isnull(),
|
||||
computed_next_busday_offset.isnull(),
|
||||
)
|
||||
assert_series_equal(
|
||||
computed_previous.isnull(),
|
||||
computed_previous_busday_offset.isnull(),
|
||||
)
|
||||
|
||||
for sid in self.sids:
|
||||
|
||||
assert_series_equal(
|
||||
computed_next.xs(sid, level=1),
|
||||
expected_next(sid),
|
||||
sid,
|
||||
)
|
||||
|
||||
assert_series_equal(
|
||||
computed_previous.xs(sid, level=1),
|
||||
expected_previous(sid),
|
||||
sid,
|
||||
)
|
||||
|
||||
assert_series_equal(
|
||||
computed_next_busday_offset.xs(sid, level=1),
|
||||
expected_next_busday_offset(sid),
|
||||
sid,
|
||||
)
|
||||
|
||||
assert_series_equal(
|
||||
computed_previous_busday_offset.xs(sid, level=1),
|
||||
expected_previous_busday_offset(sid),
|
||||
sid,
|
||||
)
|
||||
|
||||
|
||||
class BlazeEarningsCalendarLoaderTestCase(EarningsCalendarLoaderTestCase):
|
||||
loader_type = BlazeEarningsCalendarLoader
|
||||
|
||||
def loader_args(self, dates):
|
||||
_, mapping = super(
|
||||
BlazeEarningsCalendarLoaderTestCase,
|
||||
self,
|
||||
).loader_args(dates)
|
||||
return (bz.Data(pd.concat(
|
||||
pd.DataFrame({
|
||||
ANNOUNCEMENT_FIELD_NAME: earning_dates,
|
||||
TS_FIELD_NAME: earning_dates.index,
|
||||
SID_FIELD_NAME: sid,
|
||||
})
|
||||
for sid, earning_dates in iteritems(mapping)
|
||||
).reset_index(drop=True)),)
|
||||
|
||||
|
||||
class BlazeEarningsCalendarLoaderNotInteractiveTestCase(
|
||||
BlazeEarningsCalendarLoaderTestCase):
|
||||
"""Test case for passing a non-interactive symbol and a dict of resources.
|
||||
"""
|
||||
def loader_args(self, dates):
|
||||
(bound_expr,) = super(
|
||||
BlazeEarningsCalendarLoaderNotInteractiveTestCase,
|
||||
self,
|
||||
).loader_args(dates)
|
||||
return swap_resources_into_scope(bound_expr, {})
|
||||
|
||||
|
||||
class EarningsCalendarLoaderInferTimestampTestCase(TestCase):
|
||||
def test_infer_timestamp(self):
|
||||
dtx = pd.date_range('2014-01-01', '2014-01-10')
|
||||
announcement_dates = {
|
||||
0: dtx,
|
||||
1: pd.Series(dtx, dtx),
|
||||
}
|
||||
loader = EarningsCalendarLoader(
|
||||
dtx,
|
||||
announcement_dates,
|
||||
infer_timestamps=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
loader.announcement_dates.keys(),
|
||||
announcement_dates.keys(),
|
||||
)
|
||||
assert_series_equal(
|
||||
loader.announcement_dates[0],
|
||||
pd.Series(index=[dtx[0]] * 10, data=dtx),
|
||||
)
|
||||
assert_series_equal(
|
||||
loader.announcement_dates[1],
|
||||
announcement_dates[1],
|
||||
)
|
||||
@@ -29,16 +29,17 @@ from pandas.util.testing import assert_frame_equal
|
||||
from six import iteritems, itervalues
|
||||
from testfixtures import TempDirectory
|
||||
|
||||
from zipline.data.us_equity_pricing import BcolzDailyBarReader
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
ConstantLoader,
|
||||
NullAdjustmentReader,
|
||||
SyntheticDailyBarWriter,
|
||||
)
|
||||
from zipline.data.us_equity_pricing import BcolzDailyBarReader
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing, DataSet, Column
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader, MULTIPLY
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
|
||||
@@ -1,20 +1,32 @@
|
||||
"""
|
||||
Tests for Factor terms.
|
||||
"""
|
||||
from itertools import product
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from numpy import arange, array, empty, eye, nan, ones, datetime64
|
||||
from numpy import (
|
||||
arange,
|
||||
array,
|
||||
datetime64,
|
||||
empty,
|
||||
eye,
|
||||
nan,
|
||||
ones,
|
||||
)
|
||||
from numpy.random import randn, seed
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.rank import masked_rankdata_2d
|
||||
from zipline.pipeline import Factor, Filter, TermGraph
|
||||
from zipline.pipeline.factors import RSI, Returns
|
||||
from zipline.utils.test_utils import check_allclose, check_arrays
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype, np_NaT
|
||||
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
|
||||
class F(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
@@ -24,6 +36,12 @@ class Mask(Filter):
|
||||
window_length = 0
|
||||
|
||||
|
||||
for_each_factor_dtype = parameterized.expand([
|
||||
('datetime64[ns]', datetime64ns_dtype),
|
||||
('float', float64_dtype),
|
||||
])
|
||||
|
||||
|
||||
class FactorTestCase(BasePipelineTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -34,7 +52,10 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
with self.assertRaises(UnknownRankMethod):
|
||||
self.f.rank("not a real rank method")
|
||||
|
||||
def test_rank_ascending(self):
|
||||
@for_each_factor_dtype
|
||||
def test_rank_ascending(self, name, factor_dtype):
|
||||
|
||||
f = F(dtype=factor_dtype)
|
||||
|
||||
# Generated with:
|
||||
# data = arange(25).reshape(5, 5).transpose() % 4
|
||||
@@ -42,7 +63,8 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[1, 2, 3, 0, 1],
|
||||
[2, 3, 0, 1, 2],
|
||||
[3, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 0]], dtype=float)
|
||||
[0, 1, 2, 3, 0]], dtype=factor_dtype)
|
||||
|
||||
expected_ranks = {
|
||||
'ordinal': array([[1., 3., 4., 5., 2.],
|
||||
[2., 4., 5., 1., 3.],
|
||||
@@ -75,22 +97,25 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: data},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({meth: self.f.rank(method=meth) for meth in expected_ranks})
|
||||
check({meth: f.rank(method=meth) for meth in expected_ranks})
|
||||
check({
|
||||
meth: self.f.rank(method=meth, ascending=True)
|
||||
meth: f.rank(method=meth, ascending=True)
|
||||
for meth in expected_ranks
|
||||
})
|
||||
# Not passing a method should default to ordinal.
|
||||
check({'ordinal': self.f.rank()})
|
||||
check({'ordinal': self.f.rank(ascending=True)})
|
||||
check({'ordinal': f.rank()})
|
||||
check({'ordinal': f.rank(ascending=True)})
|
||||
|
||||
def test_rank_descending(self):
|
||||
@for_each_factor_dtype
|
||||
def test_rank_descending(self, name, factor_dtype):
|
||||
|
||||
f = F(dtype=factor_dtype)
|
||||
|
||||
# Generated with:
|
||||
# data = arange(25).reshape(5, 5).transpose() % 4
|
||||
@@ -98,7 +123,7 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
[1, 2, 3, 0, 1],
|
||||
[2, 3, 0, 1, 2],
|
||||
[3, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 0]], dtype=float)
|
||||
[0, 1, 2, 3, 0]], dtype=factor_dtype)
|
||||
expected_ranks = {
|
||||
'ordinal': array([[4., 3., 2., 1., 5.],
|
||||
[3., 2., 1., 5., 4.],
|
||||
@@ -131,35 +156,38 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
graph = TermGraph(terms)
|
||||
results = self.run_graph(
|
||||
graph,
|
||||
initial_workspace={self.f: data},
|
||||
initial_workspace={f: data},
|
||||
mask=self.build_mask(ones((5, 5))),
|
||||
)
|
||||
for method in terms:
|
||||
check_arrays(results[method], expected_ranks[method])
|
||||
|
||||
check({
|
||||
meth: self.f.rank(method=meth, ascending=False)
|
||||
meth: f.rank(method=meth, ascending=False)
|
||||
for meth in expected_ranks
|
||||
})
|
||||
# Not passing a method should default to ordinal.
|
||||
check({'ordinal': self.f.rank(ascending=False)})
|
||||
check({'ordinal': f.rank(ascending=False)})
|
||||
|
||||
def test_rank_after_mask(self):
|
||||
@for_each_factor_dtype
|
||||
def test_rank_after_mask(self, name, factor_dtype):
|
||||
|
||||
f = F(dtype=factor_dtype)
|
||||
# data = arange(25).reshape(5, 5).transpose() % 4
|
||||
data = array([[0, 1, 2, 3, 0],
|
||||
[1, 2, 3, 0, 1],
|
||||
[2, 3, 0, 1, 2],
|
||||
[3, 0, 1, 2, 3],
|
||||
[0, 1, 2, 3, 0]], dtype=float)
|
||||
[0, 1, 2, 3, 0]], dtype=factor_dtype)
|
||||
mask_data = ~eye(5, dtype=bool)
|
||||
initial_workspace = {self.f: data, Mask(): mask_data}
|
||||
initial_workspace = {f: data, Mask(): mask_data}
|
||||
|
||||
graph = TermGraph(
|
||||
{
|
||||
"ascending_nomask": self.f.rank(ascending=True),
|
||||
"ascending_mask": self.f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": self.f.rank(ascending=False),
|
||||
"descending_mask": self.f.rank(ascending=False, mask=Mask()),
|
||||
"ascending_nomask": f.rank(ascending=True),
|
||||
"ascending_mask": f.rank(ascending=True, mask=Mask()),
|
||||
"descending_nomask": f.rank(ascending=False),
|
||||
"descending_mask": f.rank(ascending=False, mask=Mask()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -246,3 +274,53 @@ class FactorTestCase(BasePipelineTestCase):
|
||||
returns.compute(today, assets, out, test_data)
|
||||
|
||||
check_allclose(expected, out)
|
||||
|
||||
def gen_ranking_cases():
|
||||
seeds = range(int(1e4), int(1e5), int(1e4))
|
||||
methods = ('ordinal', 'average')
|
||||
use_mask_values = (True, False)
|
||||
set_missing_values = (True, False)
|
||||
ascending_values = (True, False)
|
||||
return product(
|
||||
seeds,
|
||||
methods,
|
||||
use_mask_values,
|
||||
set_missing_values,
|
||||
ascending_values,
|
||||
)
|
||||
|
||||
@parameterized.expand(gen_ranking_cases())
|
||||
def test_masked_rankdata_2d(self,
|
||||
seed_value,
|
||||
method,
|
||||
use_mask,
|
||||
set_missing,
|
||||
ascending):
|
||||
eyemask = ~eye(5, dtype=bool)
|
||||
nomask = ones((5, 5), dtype=bool)
|
||||
|
||||
seed(seed_value)
|
||||
asfloat = (randn(5, 5) * seed_value)
|
||||
asdatetime = (asfloat).copy().view('datetime64[ns]')
|
||||
|
||||
mask = eyemask if use_mask else nomask
|
||||
if set_missing:
|
||||
asfloat[:, 2] = nan
|
||||
asdatetime[:, 2] = np_NaT
|
||||
|
||||
float_result = masked_rankdata_2d(
|
||||
data=asfloat,
|
||||
mask=mask,
|
||||
missing_value=nan,
|
||||
method=method,
|
||||
ascending=True,
|
||||
)
|
||||
datetime_result = masked_rankdata_2d(
|
||||
data=asdatetime,
|
||||
mask=mask,
|
||||
missing_value=np_NaT,
|
||||
method=method,
|
||||
ascending=True,
|
||||
)
|
||||
|
||||
check_arrays(float_result, datetime_result)
|
||||
|
||||
@@ -24,6 +24,7 @@ from numpy.random import randn, seed as random_seed
|
||||
from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, TermGraph
|
||||
from zipline.utils.test_utils import check_arrays
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
from .base import BasePipelineTestCase, with_default_shape
|
||||
|
||||
@@ -57,11 +58,13 @@ def rowwise_rank(array, mask=None):
|
||||
|
||||
|
||||
class SomeFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class SomeOtherFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
@@ -13,16 +13,16 @@ from pandas import (
|
||||
)
|
||||
|
||||
from zipline.lib.adjustment import (
|
||||
ADD,
|
||||
Float64Add,
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
MULTIPLY,
|
||||
OVERWRITE,
|
||||
)
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.loaders.frame import (
|
||||
ADD,
|
||||
DataFrameLoader,
|
||||
MULTIPLY,
|
||||
OVERWRITE,
|
||||
)
|
||||
from zipline.utils.tradingcalendar import trading_day
|
||||
|
||||
@@ -226,7 +226,7 @@ class DataFrameLoaderTestCase(TestCase):
|
||||
self.assertEqual(formatted_adjustments, expected_formatted_adjustments)
|
||||
|
||||
mask = self.mask[dates_slice, sids_slice]
|
||||
with patch('zipline.pipeline.loaders.frame.adjusted_array') as m:
|
||||
with patch('zipline.pipeline.loaders.frame.AdjustedArray') as m:
|
||||
loader.load_adjusted_array(
|
||||
columns=[USEquityPricing.close],
|
||||
dates=self.dates[dates_slice],
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from operator import (
|
||||
add,
|
||||
and_,
|
||||
ge,
|
||||
gt,
|
||||
le,
|
||||
lt,
|
||||
methodcaller,
|
||||
mul,
|
||||
ne,
|
||||
or_,
|
||||
)
|
||||
@@ -14,6 +16,7 @@ import numpy
|
||||
from numpy import (
|
||||
arange,
|
||||
eye,
|
||||
float64,
|
||||
full,
|
||||
isnan,
|
||||
zeros,
|
||||
@@ -30,20 +33,30 @@ from zipline.pipeline.expression import (
|
||||
NUMEXPR_MATH_FUNCS,
|
||||
)
|
||||
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype
|
||||
from zipline.utils.test_utils import check_arrays
|
||||
|
||||
|
||||
class F(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class G(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class H(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class DateFactor(Factor):
|
||||
dtype = datetime64ns_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
@@ -56,10 +69,12 @@ class NumericalExpressionTestCase(TestCase):
|
||||
self.f = F()
|
||||
self.g = G()
|
||||
self.h = H()
|
||||
self.d = DateFactor()
|
||||
self.fake_raw_data = {
|
||||
self.f: full((5, 5), 3),
|
||||
self.g: full((5, 5), 2),
|
||||
self.h: full((5, 5), 1),
|
||||
self.d: full((5, 5), 0, dtype='datetime64[ns]'),
|
||||
}
|
||||
self.mask = DataFrame(True, index=self.dates, columns=self.assets)
|
||||
|
||||
@@ -80,39 +95,39 @@ class NumericalExpressionTestCase(TestCase):
|
||||
f = self.f
|
||||
g = self.g
|
||||
|
||||
NumericalExpression("x_0", (f,))
|
||||
NumericalExpression("x_0 ", (f,))
|
||||
NumericalExpression("x_0 + x_0", (f,))
|
||||
NumericalExpression("x_0 + 2", (f,))
|
||||
NumericalExpression("2 * x_0", (f,))
|
||||
NumericalExpression("x_0 + x_1", (f, g))
|
||||
NumericalExpression("x_0 + x_1 + x_0", (f, g))
|
||||
NumericalExpression("x_0 + 1 + x_1", (f, g))
|
||||
NumericalExpression("x_0", (f,), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 ", (f,), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 + x_0", (f,), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 + 2", (f,), dtype=float64_dtype)
|
||||
NumericalExpression("2 * x_0", (f,), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 + x_1", (f, g), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype)
|
||||
NumericalExpression("x_0 + 1 + x_1", (f, g), dtype=float64_dtype)
|
||||
|
||||
def test_validate_bad(self):
|
||||
f, g, h = F(), G(), H()
|
||||
f, g, h = self.f, self.g, self.h
|
||||
|
||||
# Too few inputs.
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0", ())
|
||||
NumericalExpression("x_0", (), dtype=float64_dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0 + x_1", (f,))
|
||||
NumericalExpression("x_0 + x_1", (f,), dtype=float64_dtype)
|
||||
|
||||
# Too many inputs.
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0", (f, g))
|
||||
NumericalExpression("x_0", (f, g), dtype=float64_dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0 + x_1", (f, g, h))
|
||||
NumericalExpression("x_0 + x_1", (f, g, h), dtype=float64_dtype)
|
||||
|
||||
# Invalid variable name.
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0x_1", (f,))
|
||||
NumericalExpression("x_0x_1", (f,), dtype=float64_dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_0x_1", (f, g))
|
||||
NumericalExpression("x_0x_1", (f, g), dtype=float64_dtype)
|
||||
|
||||
# Variable index must start at 0.
|
||||
with self.assertRaises(ValueError):
|
||||
NumericalExpression("x_1", (f,))
|
||||
NumericalExpression("x_1", (f,), dtype=float64_dtype)
|
||||
|
||||
# Scalar operands must be numeric.
|
||||
with self.assertRaises(TypeError):
|
||||
@@ -128,6 +143,64 @@ class NumericalExpressionTestCase(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
(f > f) > f
|
||||
|
||||
def test_combine_datetimes(self):
|
||||
with self.assertRaises(TypeError) as e:
|
||||
self.d + self.d
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] + datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of dtype "
|
||||
"'float64'."
|
||||
)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
# Confirm that * shows up in the error instead of +.
|
||||
with self.assertRaises(TypeError) as e:
|
||||
self.d * self.d
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] * datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of dtype "
|
||||
"'float64'."
|
||||
)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
def test_combine_datetime_with_float(self):
|
||||
# Test with both float-type factors and numeric values.
|
||||
for float_value in (self.f, float64(1.0), 1.0):
|
||||
for op, sym in ((add, '+'), (mul, '*')):
|
||||
with self.assertRaises(TypeError) as e:
|
||||
op(self.f, self.d)
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute float64 {sym} datetime64[ns].\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"dtype 'float64'."
|
||||
).format(sym=sym)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
op(self.d, self.f)
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Don't know how to compute datetime64[ns] {sym} float64.\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"dtype 'float64'."
|
||||
).format(sym=sym)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
def test_negate_datetime(self):
|
||||
with self.assertRaises(TypeError) as e:
|
||||
-self.d
|
||||
|
||||
message = e.exception.args[0]
|
||||
expected = (
|
||||
"Can't apply unary operator '-' to instance of "
|
||||
"'DateFactor' with dtype 'datetime64[ns]'.\n"
|
||||
"'-' is only supported for Factors of dtype 'float64'."
|
||||
)
|
||||
self.assertEqual(message, expected)
|
||||
|
||||
def test_negate(self):
|
||||
f, g = self.f, self.g
|
||||
|
||||
|
||||
@@ -5,14 +5,17 @@ from unittest import TestCase
|
||||
|
||||
from zipline.pipeline import Factor, Filter, Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
|
||||
class SomeFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [USEquityPricing.close, USEquityPricing.high]
|
||||
|
||||
|
||||
class SomeOtherFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [USEquityPricing.close, USEquityPricing.high]
|
||||
|
||||
|
||||
@@ -49,10 +49,11 @@ from zipline.data.us_equity_pricing import (
|
||||
SQLiteAdjustmentReader,
|
||||
)
|
||||
from zipline.finance import trading
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.factors import VWAP
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader, MULTIPLY
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
|
||||
+39
-21
@@ -4,14 +4,10 @@ Tests for Term.
|
||||
from itertools import product
|
||||
from unittest import TestCase
|
||||
|
||||
from numpy import (
|
||||
float32,
|
||||
uint32,
|
||||
uint8,
|
||||
)
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
TermInputsNotSpecified,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
@@ -19,30 +15,40 @@ from zipline.pipeline import Factor, TermGraph
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.term import AssetExists, NotSpecified
|
||||
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
)
|
||||
|
||||
|
||||
class SomeDataSet(DataSet):
|
||||
|
||||
foo = Column(float32)
|
||||
bar = Column(uint32)
|
||||
buzz = Column(uint8)
|
||||
foo = Column(float64_dtype)
|
||||
bar = Column(float64_dtype)
|
||||
buzz = Column(float64_dtype)
|
||||
|
||||
|
||||
class SomeFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.foo, SomeDataSet.bar]
|
||||
|
||||
|
||||
class NoLookbackFactor(Factor):
|
||||
window_length = 0
|
||||
SomeFactorAlias = SomeFactor
|
||||
|
||||
|
||||
class SomeOtherFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.bar, SomeDataSet.buzz]
|
||||
|
||||
|
||||
SomeFactorAlias = SomeFactor
|
||||
class DateFactor(Factor):
|
||||
dtype = datetime64ns_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.bar, SomeDataSet.buzz]
|
||||
|
||||
|
||||
class NoLookbackFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 0
|
||||
|
||||
|
||||
def gen_equivalent_factors():
|
||||
@@ -172,8 +178,8 @@ class ObjectIdentityTestCase(TestCase):
|
||||
)
|
||||
|
||||
self.assertIs(
|
||||
SomeFactor(dtype=int),
|
||||
SomeFactor(dtype=int),
|
||||
SomeFactor(dtype=float64_dtype),
|
||||
SomeFactor(dtype=float64_dtype),
|
||||
)
|
||||
|
||||
self.assertIs(
|
||||
@@ -194,7 +200,7 @@ class ObjectIdentityTestCase(TestCase):
|
||||
# Different dtype
|
||||
self.assertIsNot(
|
||||
f,
|
||||
SomeFactor(dtype=int)
|
||||
SomeFactor(dtype=datetime64ns_dtype)
|
||||
)
|
||||
|
||||
# Reordering inputs changes semantics.
|
||||
@@ -208,6 +214,7 @@ class ObjectIdentityTestCase(TestCase):
|
||||
orig_foobar_instance = SomeFactorAlias()
|
||||
|
||||
class SomeFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.foo, SomeDataSet.bar]
|
||||
|
||||
@@ -255,14 +262,19 @@ class ObjectIdentityTestCase(TestCase):
|
||||
def test_bad_input(self):
|
||||
|
||||
class SomeFactor(Factor):
|
||||
pass
|
||||
dtype = float64_dtype
|
||||
|
||||
class SomeFactorDefaultInputs(Factor):
|
||||
class SomeFactorDefaultInputs(SomeFactor):
|
||||
inputs = (SomeDataSet.foo, SomeDataSet.bar)
|
||||
|
||||
class SomeFactorDefaultLength(Factor):
|
||||
class SomeFactorDefaultLength(SomeFactor):
|
||||
window_length = 10
|
||||
|
||||
class SomeFactorNoDType(SomeFactor):
|
||||
window_length = 10
|
||||
inputs = (SomeDataSet.foo,)
|
||||
dtype = NotSpecified
|
||||
|
||||
with self.assertRaises(TermInputsNotSpecified):
|
||||
SomeFactor(window_length=1)
|
||||
|
||||
@@ -274,3 +286,9 @@ class ObjectIdentityTestCase(TestCase):
|
||||
|
||||
with self.assertRaises(WindowLengthNotSpecified):
|
||||
SomeFactorDefaultInputs()
|
||||
|
||||
with self.assertRaises(DTypeNotSpecified):
|
||||
SomeFactorNoDType()
|
||||
|
||||
with self.assertRaises(InvalidDType):
|
||||
SomeFactor(dtype=1)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""
|
||||
Tests for zipline.utils.validate.
|
||||
"""
|
||||
from operator import attrgetter
|
||||
from types import FunctionType
|
||||
from unittest import TestCase
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from numpy import arange, dtype
|
||||
from six import PY3
|
||||
|
||||
from zipline.utils.preprocess import call, preprocess
|
||||
from zipline.utils.input_validation import (
|
||||
expect_element,
|
||||
expect_dtypes,
|
||||
expect_types,
|
||||
optional,
|
||||
)
|
||||
@@ -19,6 +24,13 @@ def noop(func, argname, argvalue):
|
||||
return argvalue
|
||||
|
||||
|
||||
if PY3:
|
||||
qualname = attrgetter('__qualname__')
|
||||
else:
|
||||
def qualname(ob):
|
||||
return '.'.join((__name__, ob.__name__))
|
||||
|
||||
|
||||
class PreprocessTestCase(TestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
@@ -180,9 +192,9 @@ class PreprocessTestCase(TestCase):
|
||||
foo(not_int(1), 2, 3)
|
||||
self.assertEqual(
|
||||
e.exception.args[0],
|
||||
"{modname}.foo() expected a value of type "
|
||||
"{qualname}() expected a value of type "
|
||||
"int for argument 'a', but got {t} instead.".format(
|
||||
modname=foo.__module__,
|
||||
qualname=qualname(foo),
|
||||
t=not_int.__name__,
|
||||
)
|
||||
)
|
||||
@@ -203,9 +215,9 @@ class PreprocessTestCase(TestCase):
|
||||
foo('1')
|
||||
|
||||
expected_message = (
|
||||
"{modname}.foo() expected a value of "
|
||||
"{qualname}() expected a value of "
|
||||
"type int or float for argument 'a', but got str instead."
|
||||
).format(modname=foo.__module__)
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_expect_optional_types(self):
|
||||
@@ -225,9 +237,9 @@ class PreprocessTestCase(TestCase):
|
||||
foo('1')
|
||||
|
||||
expected_message = (
|
||||
"{modname}.foo() expected a value of "
|
||||
"{qualname}() expected a value of "
|
||||
"type int or NoneType for argument 'a', but got str instead."
|
||||
).format(modname=foo.__module__)
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_expect_element(self):
|
||||
@@ -244,7 +256,64 @@ class PreprocessTestCase(TestCase):
|
||||
f('c')
|
||||
|
||||
expected_message = (
|
||||
"{modname}.f() expected a value in {set_!r}"
|
||||
"{qualname}() expected a value in {set_!r}"
|
||||
" for argument 'a', but got 'c' instead."
|
||||
).format(set_=set_, modname=f.__module__)
|
||||
).format(set_=set_, qualname=qualname(f))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_expect_dtypes(self):
|
||||
|
||||
@expect_dtypes(a=dtype(float), b=dtype('datetime64[ns]'))
|
||||
def foo(a, b, c):
|
||||
return a, b, c
|
||||
|
||||
good_a = arange(3, dtype=float)
|
||||
good_b = arange(3).astype('datetime64[ns]')
|
||||
good_c = object()
|
||||
|
||||
a_ret, b_ret, c_ret = foo(good_a, good_b, good_c)
|
||||
self.assertIs(a_ret, good_a)
|
||||
self.assertIs(b_ret, good_b)
|
||||
self.assertIs(c_ret, good_c)
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
foo(good_a, arange(3), good_c)
|
||||
|
||||
expected_message = (
|
||||
"{qualname}() expected a value with dtype 'datetime64[ns]'"
|
||||
" for argument 'b', but got 'int64' instead."
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
foo(arange(3, dtype='uint32'), good_c, good_c)
|
||||
|
||||
expected_message = (
|
||||
"{qualname}() expected a value with dtype 'float64'"
|
||||
" for argument 'a', but got 'uint32' instead."
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_expect_dtypes_with_tuple(self):
|
||||
|
||||
allowed_dtypes = (dtype('datetime64[ns]'), dtype('float'))
|
||||
|
||||
@expect_dtypes(a=allowed_dtypes)
|
||||
def foo(a, b):
|
||||
return a, b
|
||||
|
||||
for d in allowed_dtypes:
|
||||
good_a = arange(3).astype(d)
|
||||
good_b = object()
|
||||
ret_a, ret_b = foo(good_a, good_b)
|
||||
self.assertIs(good_a, ret_a)
|
||||
self.assertIs(good_b, ret_b)
|
||||
|
||||
with self.assertRaises(TypeError) as e:
|
||||
foo(arange(3, dtype='uint32'), object())
|
||||
|
||||
expected_message = (
|
||||
"{qualname}() expected a value with dtype 'datetime64[ns]' "
|
||||
"or 'float64' for argument 'a', but got 'uint32' instead."
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
+12
-1
@@ -385,6 +385,17 @@ class DTypeNotSpecified(ZiplineError):
|
||||
)
|
||||
|
||||
|
||||
class InvalidDType(ZiplineError):
|
||||
"""
|
||||
Raised when a pipeline Term is constructed with a dtype that isn't a numpy
|
||||
dtype object.
|
||||
"""
|
||||
msg = (
|
||||
"{termname} expected a numpy dtype "
|
||||
"object for a dtype, but got {dtype} instead."
|
||||
)
|
||||
|
||||
|
||||
class BadPercentileBounds(ZiplineError):
|
||||
"""
|
||||
Raised by API functions accepting percentile bounds when the passed bounds
|
||||
@@ -442,7 +453,7 @@ class UnsupportedDataType(ZiplineError):
|
||||
"""
|
||||
Raised by CustomFactors with unsupported dtypes.
|
||||
"""
|
||||
msg = "CustomFactors with dtype {dtype} are not supported."
|
||||
msg = "{typename} instances with dtype {dtype} are not supported."
|
||||
|
||||
|
||||
class NoFurtherDataError(ZiplineError):
|
||||
|
||||
@@ -31,7 +31,7 @@ from zipline.finance.slippage import (
|
||||
check_order_triggers
|
||||
)
|
||||
from zipline.finance.commission import PerShare
|
||||
from zipline.utils.protocol_utils import Enum
|
||||
from zipline.utils.enum import enum
|
||||
|
||||
from zipline.utils.serialization_utils import (
|
||||
VERSION_LABEL
|
||||
@@ -39,7 +39,7 @@ from zipline.utils.serialization_utils import (
|
||||
|
||||
log = Logger('Blotter')
|
||||
|
||||
ORDER_STATUS = Enum(
|
||||
ORDER_STATUS = enum(
|
||||
'OPEN',
|
||||
'FILLED',
|
||||
'CANCELLED',
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
float specialization of AdjustedArrayWindow
|
||||
"""
|
||||
from numpy cimport float64_t as ctype
|
||||
include "_windowtemplate.pxi"
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
datetime specialization of AdjustedArrayWindow
|
||||
"""
|
||||
from numpy cimport int64_t as ctype
|
||||
include "_windowtemplate.pxi"
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
bool specialization of AdjustedArrayWindow
|
||||
"""
|
||||
from numpy cimport uint8_t as ctype
|
||||
include "_windowtemplate.pxi"
|
||||
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Template for AdjustedArray windowed iterators.
|
||||
|
||||
This file is intended to be used by inserting it via a Cython include into a
|
||||
file that's define a type symbol named `ctype` and string constant named
|
||||
`dtype`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.lib._floatwindow
|
||||
zipline.lib._intwindow
|
||||
zipline.lib._datewindow
|
||||
"""
|
||||
from numpy cimport ndarray
|
||||
from numpy import asarray
|
||||
|
||||
ctypedef ctype[:, :] databuffer
|
||||
|
||||
|
||||
cdef class AdjustedArrayWindow:
|
||||
"""
|
||||
An iterator representing a moving view over an AdjustedArray.
|
||||
|
||||
Concrete subtypes should subclass this and provide a `data` attribute for
|
||||
specific types.
|
||||
|
||||
This object stores a copy of the data from the AdjustedArray over which
|
||||
it's iterating. At each step in the iteration, it mutates its copy to
|
||||
allow us to show different data when looking back over the array.
|
||||
|
||||
The arrays yielded by this iterator are always views over the underlying
|
||||
data.
|
||||
"""
|
||||
cdef:
|
||||
# ctype must be defined by the file into which this is being copied.
|
||||
databuffer data
|
||||
object viewtype
|
||||
readonly Py_ssize_t window_length
|
||||
Py_ssize_t anchor, max_anchor, next_adj
|
||||
dict adjustments
|
||||
list adjustment_indices
|
||||
|
||||
def __cinit__(self,
|
||||
databuffer data not None,
|
||||
object viewtype not None,
|
||||
dict adjustments not None,
|
||||
Py_ssize_t offset,
|
||||
Py_ssize_t window_length):
|
||||
|
||||
self.data = data
|
||||
self.viewtype = viewtype
|
||||
self.adjustments = adjustments
|
||||
self.adjustment_indices = sorted(adjustments, reverse=True)
|
||||
self.window_length = window_length
|
||||
self.anchor = window_length + offset
|
||||
self.max_anchor = data.shape[0]
|
||||
|
||||
self.next_adj = self.pop_next_adj()
|
||||
|
||||
cdef pop_next_adj(self):
|
||||
"""
|
||||
Pop the index of the next adjustment to apply from self.adjustment_indices.
|
||||
"""
|
||||
if len(self.adjustment_indices) > 0:
|
||||
return self.adjustment_indices.pop()
|
||||
else:
|
||||
return self.max_anchor
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
cdef:
|
||||
ndarray out
|
||||
object adjustment
|
||||
Py_ssize_t start, anchor
|
||||
|
||||
anchor = self.anchor
|
||||
if anchor > self.max_anchor:
|
||||
raise StopIteration()
|
||||
|
||||
# Apply any adjustments that occured before our current anchor.
|
||||
# Equivalently, apply any adjustments known **on or before** the date
|
||||
# for which we're calculating a window.
|
||||
while self.next_adj < anchor:
|
||||
|
||||
for adjustment in self.adjustments[self.next_adj]:
|
||||
adjustment.mutate(self.data)
|
||||
|
||||
self.next_adj = self.pop_next_adj()
|
||||
|
||||
start = anchor - self.window_length
|
||||
out = asarray(self.data[start:self.anchor]).view(self.viewtype)
|
||||
out.setflags(write=False)
|
||||
|
||||
self.anchor += 1
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: window_length=%d, anchor=%d, max_anchor=%d, dtype=%r>" % (
|
||||
type(self).__name__,
|
||||
self.window_length,
|
||||
self.anchor,
|
||||
self.max_anchor,
|
||||
self.viewtype,
|
||||
)
|
||||
@@ -0,0 +1,249 @@
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy import (
|
||||
bool_,
|
||||
dtype,
|
||||
float32,
|
||||
float64,
|
||||
int32,
|
||||
int64,
|
||||
ndarray,
|
||||
uint32,
|
||||
uint8,
|
||||
)
|
||||
from zipline.errors import (
|
||||
WindowLengthNotPositive,
|
||||
WindowLengthTooLong,
|
||||
)
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
default_fillvalue_for_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
uint8_dtype,
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
# These class names are all the same because of our bootleg templating system.
|
||||
from ._float64window import AdjustedArrayWindow as Float64Window
|
||||
from ._int64window import AdjustedArrayWindow as Int64Window
|
||||
from ._uint8window import AdjustedArrayWindow as UInt8Window
|
||||
|
||||
Infer = sentinel(
|
||||
'Infer',
|
||||
"Sentinel used to say 'infer missing_value from data type.'"
|
||||
)
|
||||
NOMASK = None
|
||||
SUPPORTED_NUMERIC_DTYPES = frozenset(
|
||||
map(dtype, [float32, float64, int32, int64, uint32])
|
||||
)
|
||||
CONCRETE_WINDOW_TYPES = {
|
||||
float64_dtype: Float64Window,
|
||||
int64_dtype: Int64Window,
|
||||
uint8_dtype: UInt8Window,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_array(data):
|
||||
"""
|
||||
Coerce buffer data for an AdjustedArray into a standard scalar
|
||||
representation, returning the coerced array and a numpy dtype object to use
|
||||
as a view type when providing public view into the data.
|
||||
|
||||
Semantically numerical data (float*, int*, uint*) is coerced to float64 and
|
||||
viewed as float64. We coerce integral data to float so that we can use NaN
|
||||
as a missing value.
|
||||
|
||||
datetime[*] data is coerced to int64 with a viewtype of ``datetime64[ns]``.
|
||||
|
||||
``bool_`` data is coerced to uint8 with a viewtype of ``bool_``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
|
||||
Returns
|
||||
-------
|
||||
coerced, viewtype : (np.ndarray, np.dtype)
|
||||
"""
|
||||
data_dtype = data.dtype
|
||||
if data_dtype == bool_:
|
||||
return data.astype(uint8), dtype(bool_)
|
||||
elif data_dtype in SUPPORTED_NUMERIC_DTYPES:
|
||||
return data.astype(float64), dtype(float64)
|
||||
elif data_dtype.name.startswith('datetime'):
|
||||
try:
|
||||
outarray = data.astype('datetime64[ns]').view('int64')
|
||||
return outarray, datetime64ns_dtype
|
||||
except OverflowError:
|
||||
raise ValueError(
|
||||
"AdjustedArray received a datetime array "
|
||||
"not representable as datetime64[ns].\n"
|
||||
"Min Date: %s\n"
|
||||
"Max Date: %s\n"
|
||||
) % (data.min(), data.max())
|
||||
else:
|
||||
raise TypeError(
|
||||
"Don't know how to construct AdjustedArray "
|
||||
"on data of type %s." % dtype
|
||||
)
|
||||
|
||||
|
||||
class AdjustedArray(object):
|
||||
"""
|
||||
An array that can be iterated with a variable-length window, and which can
|
||||
provide different views on data from different perspectives.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
The baseline data values.
|
||||
mask : np.ndarray[bool]
|
||||
A mask indicating the locations of missing data.
|
||||
adjustments : dict[int -> list[Adjustment]]
|
||||
A dict mapping row indices to lists of adjustments to apply when we
|
||||
reach that row.
|
||||
fillvalue : object, optional
|
||||
A value to use to fill missing data in yielded windows.
|
||||
Default behavior is to infer a value based on the dtype of `data`.
|
||||
`NaN` is used for numeric data, and `NaT` is used for datetime data.
|
||||
"""
|
||||
__slots__ = ('_data', '_viewtype', 'adjustments', '__weakref__')
|
||||
|
||||
def __init__(self, data, mask, adjustments, fillvalue=Infer):
|
||||
self._data, self._viewtype = _normalize_array(data)
|
||||
self.adjustments = adjustments
|
||||
if fillvalue is Infer:
|
||||
fillvalue = default_fillvalue_for_dtype(self.data.dtype)
|
||||
|
||||
if mask is not NOMASK:
|
||||
if mask.dtype != bool_:
|
||||
raise ValueError("Mask must be a bool array.")
|
||||
if data.shape != mask.shape:
|
||||
raise ValueError(
|
||||
"Mask shape %s != data shape %s." %
|
||||
(mask.shape, data.shape),
|
||||
)
|
||||
self._data[~mask] = fillvalue
|
||||
|
||||
@lazyval
|
||||
def data(self):
|
||||
"""
|
||||
The data stored in this array.
|
||||
"""
|
||||
return self._data.view(self._viewtype)
|
||||
|
||||
@lazyval
|
||||
def dtype(self):
|
||||
"""
|
||||
The dtype of the data stored in this array.
|
||||
"""
|
||||
return self._viewtype
|
||||
|
||||
@lazyval
|
||||
def _iterator_type(self):
|
||||
"""
|
||||
The iterator produced when `traverse` is called on this Array.
|
||||
"""
|
||||
return CONCRETE_WINDOW_TYPES[self._data.dtype]
|
||||
|
||||
def traverse(self, window_length, offset=0):
|
||||
"""
|
||||
Produce an iterator rolling windows rows over our data.
|
||||
Each emitted window will have `window_length` rows.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
window_length : int
|
||||
The number of rows in each emitted window.
|
||||
offset : int, optional
|
||||
Number of rows to skip before the first window.
|
||||
"""
|
||||
data = self._data.copy()
|
||||
_check_window_params(data, window_length)
|
||||
return self._iterator_type(
|
||||
data,
|
||||
self._viewtype,
|
||||
self.adjustments,
|
||||
offset,
|
||||
window_length,
|
||||
)
|
||||
|
||||
def inspect(self):
|
||||
"""
|
||||
Return a string representation of the data stored in this array.
|
||||
"""
|
||||
return dedent(
|
||||
"""\
|
||||
Adjusted Array ({dtype}):
|
||||
|
||||
Data:
|
||||
{data!r}
|
||||
|
||||
Adjustments:
|
||||
{adjustments}
|
||||
"""
|
||||
).format(
|
||||
dtype=self.dtype.name,
|
||||
data=self.data,
|
||||
adjustments=self.adjustments,
|
||||
)
|
||||
|
||||
|
||||
def ensure_ndarray(ndarray_or_adjusted_array):
|
||||
"""
|
||||
Return the input as a numpy ndarray.
|
||||
|
||||
This is a no-op if the input is already an ndarray. If the input is an
|
||||
adjusted_array, this extracts a read-only view of its internal data buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ndarray_or_adjusted_array : numpy.ndarray | zipline.data.adjusted_array
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : The input, converted to an ndarray.
|
||||
"""
|
||||
if isinstance(ndarray_or_adjusted_array, ndarray):
|
||||
return ndarray_or_adjusted_array
|
||||
elif isinstance(ndarray_or_adjusted_array, AdjustedArray):
|
||||
return ndarray_or_adjusted_array.data
|
||||
else:
|
||||
raise TypeError(
|
||||
"Can't convert %s to ndarray" %
|
||||
type(ndarray_or_adjusted_array).__name__
|
||||
)
|
||||
|
||||
|
||||
def _check_window_params(data, window_length):
|
||||
"""
|
||||
Check that a window of length `window_length` is well-defined on `data`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray[ndim=2]
|
||||
The array of data to check.
|
||||
window_length : int
|
||||
Length of the desired window.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
Raises
|
||||
------
|
||||
WindowLengthNotPositive
|
||||
If window_length < 1.
|
||||
WindowLengthTooLong
|
||||
If window_length is greater than the number of rows in `data`.
|
||||
"""
|
||||
if window_length < 1:
|
||||
raise WindowLengthNotPositive(window_length=window_length)
|
||||
|
||||
if window_length > data.shape[0]:
|
||||
raise WindowLengthTooLong(
|
||||
nrows=data.shape[0],
|
||||
window_length=window_length,
|
||||
)
|
||||
@@ -1,243 +0,0 @@
|
||||
"""
|
||||
Class capable of yielding adjusted chunks of an ndarray.
|
||||
"""
|
||||
from cpython cimport (
|
||||
Py_EQ,
|
||||
PyObject_RichCompare,
|
||||
)
|
||||
from pprint import pformat
|
||||
|
||||
from numpy import (
|
||||
asarray,
|
||||
bool_,
|
||||
float64,
|
||||
full,
|
||||
uint8,
|
||||
)
|
||||
from numpy cimport (
|
||||
float64_t,
|
||||
ndarray,
|
||||
uint8_t,
|
||||
)
|
||||
|
||||
from zipline.errors import (
|
||||
WindowLengthNotPositive,
|
||||
WindowLengthTooLong,
|
||||
)
|
||||
|
||||
|
||||
cdef double NAN = float64('nan')
|
||||
|
||||
|
||||
NOMASK = None
|
||||
|
||||
|
||||
def ensure_ndarray(ndarray_or_adjusted_array):
|
||||
"""
|
||||
Return the input as a numpy ndarray.
|
||||
|
||||
This is a no-op if the input is already an ndarray. If the input is an
|
||||
adjusted_array, this extracts a read-only view of its internal data buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ndarray_or_adjusted_array : numpy.ndarray | zipline.data.adjusted_array
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : The input, converted to an ndarray.
|
||||
"""
|
||||
if isinstance(ndarray_or_adjusted_array, ndarray):
|
||||
return ndarray_or_adjusted_array
|
||||
elif isinstance(ndarray_or_adjusted_array, AdjustedArray):
|
||||
return ndarray_or_adjusted_array.data
|
||||
else:
|
||||
raise TypeError(
|
||||
"Can't convert %s to ndarray" %
|
||||
type(ndarray_or_adjusted_array).__name__
|
||||
)
|
||||
|
||||
|
||||
cpdef adjusted_array(ndarray data, ndarray mask, dict adjustments):
|
||||
"""
|
||||
Factory function for producing adjusted arrays on inputs of different
|
||||
dtypes.
|
||||
|
||||
If mask is None, the array is assumed to contain all valid data points.
|
||||
Otherwise mask should be an array of uint8 of the same shape
|
||||
as data, containing 0s for valid values and 1s for invalid values.
|
||||
"""
|
||||
if data.dtype != float64:
|
||||
data = data.astype(float64)
|
||||
if mask is not NOMASK:
|
||||
if mask.dtype == bool_:
|
||||
# Cython isn't smart enough to make this coercion even though the
|
||||
# arrays are bools internally.
|
||||
mask = mask.view(uint8)
|
||||
|
||||
return Float64AdjustedArray(data, mask, adjustments)
|
||||
|
||||
|
||||
cdef _check_window_length(object data, int window_length):
|
||||
|
||||
if window_length < 1:
|
||||
raise WindowLengthNotPositive(window_length=window_length)
|
||||
|
||||
if window_length > data.shape[0]:
|
||||
raise WindowLengthTooLong(
|
||||
nrows=data.shape[0],
|
||||
window_length=window_length,
|
||||
)
|
||||
|
||||
|
||||
cdef class AdjustedArray:
|
||||
|
||||
property data:
|
||||
def __get__(self):
|
||||
out = asarray(self._data, dtype=self.dtype)
|
||||
out.setflags(write=False)
|
||||
return out
|
||||
|
||||
|
||||
cdef class Float64AdjustedArray(AdjustedArray):
|
||||
"""
|
||||
Adjusted array of float64.
|
||||
"""
|
||||
cdef:
|
||||
readonly float64_t[:, :] _data
|
||||
dict adjustments
|
||||
|
||||
def __cinit__(self,
|
||||
float64_t[:, :] data not None,
|
||||
uint8_t[:, :] mask, # None is equivalent to all 0s.
|
||||
dict adjustments):
|
||||
cdef Py_ssize_t row, col
|
||||
|
||||
if mask is not NOMASK:
|
||||
if not PyObject_RichCompare(mask.shape, data.shape, Py_EQ):
|
||||
raise ValueError(
|
||||
"Mask shape %s != data shape %s" % (
|
||||
(mask.shape[0], mask.shape[1]),
|
||||
(data.shape[0], data.shape[1]),
|
||||
)
|
||||
)
|
||||
# Fill in NaNs for the mask.
|
||||
for row in range(mask.shape[0]):
|
||||
for col in range(mask.shape[1]):
|
||||
if not mask[row, col]:
|
||||
data[row, col] = NAN
|
||||
|
||||
self._data = data
|
||||
self.adjustments = adjustments
|
||||
|
||||
def inspect(self):
|
||||
return (
|
||||
"Adjusted Array:\n\nData:\n"
|
||||
"{data}\n\nAdjustments:\n{adjustments}\n".format(
|
||||
data=repr(asarray(self._data)),
|
||||
adjustments=pformat(self.adjustments),
|
||||
)
|
||||
)
|
||||
|
||||
property dtype:
|
||||
def __get__(self):
|
||||
return float64
|
||||
|
||||
cpdef traverse(self, Py_ssize_t window_length, Py_ssize_t offset=0):
|
||||
return _Float64AdjustedArrayWindow(
|
||||
self._data.copy(),
|
||||
self.adjustments,
|
||||
window_length,
|
||||
offset,
|
||||
)
|
||||
|
||||
|
||||
cdef class _Float64AdjustedArrayWindow:
|
||||
"""
|
||||
An iterator representing a moving view over an AdjustedArray.
|
||||
|
||||
This object stores a copy of the data from the AdjustedArray over which
|
||||
it's iterating. At each step in the iteration, it mutates its copy to
|
||||
allow us to show different data when looking back over the array.
|
||||
|
||||
The arrays yielded by this iterator are always views over the underlying
|
||||
data.
|
||||
"""
|
||||
|
||||
cdef float64_t[:, :] data
|
||||
cdef readonly Py_ssize_t window_length
|
||||
cdef Py_ssize_t anchor, max_anchor, next_adj
|
||||
cdef dict adjustments
|
||||
cdef list adjustment_indices
|
||||
|
||||
def __cinit__(self,
|
||||
float64_t[:, :] data,
|
||||
dict adjustments,
|
||||
Py_ssize_t window_length,
|
||||
Py_ssize_t offset):
|
||||
|
||||
_check_window_length(data, window_length)
|
||||
|
||||
self.data = data
|
||||
self.window_length = window_length
|
||||
|
||||
# anchor is the index of the row **after** the row from which we're
|
||||
# looking back.
|
||||
self.anchor = window_length + offset
|
||||
self.max_anchor = data.shape[0]
|
||||
|
||||
self.adjustments = adjustments
|
||||
self.adjustment_indices = sorted(adjustments, reverse=True)
|
||||
|
||||
if len(self.adjustment_indices) > 0:
|
||||
self.next_adj = self.adjustment_indices.pop()
|
||||
else:
|
||||
self.next_adj = self.max_anchor
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
cdef:
|
||||
ndarray[float64_t, ndim=2] out
|
||||
object adjustment
|
||||
Py_ssize_t start, anchor
|
||||
|
||||
anchor = self.anchor
|
||||
if anchor > self.max_anchor:
|
||||
raise StopIteration()
|
||||
|
||||
# Apply any adjustments that occured before our current anchor.
|
||||
# Equivalently, apply any adjustments known **on or before** the date
|
||||
# for which we're calculating a window.
|
||||
while self.next_adj < anchor:
|
||||
|
||||
for adjustment in self.adjustments[self.next_adj]:
|
||||
adjustment.mutate(self.data)
|
||||
|
||||
if len(self.adjustment_indices) > 0:
|
||||
self.next_adj = self.adjustment_indices.pop()
|
||||
else:
|
||||
self.next_adj = self.max_anchor
|
||||
|
||||
start = anchor - self.window_length
|
||||
out = asarray(self.data[start:self.anchor])
|
||||
out.setflags(write=False)
|
||||
|
||||
self.anchor += 1
|
||||
return out
|
||||
|
||||
def inspect(self):
|
||||
return (
|
||||
"{type_}\n"
|
||||
"Window Length: {window_length}\n"
|
||||
"Current Buffer:\n"
|
||||
"{data}\n"
|
||||
"Remaining Adjustments:\n"
|
||||
"{adjustments}\n"
|
||||
).format(
|
||||
type_=type(self).__name__,
|
||||
window_length=self.window_length,
|
||||
data=asarray(self.data[self.anchor - self.window_length:self.anchor]),
|
||||
adjustments=pformat(self.adjustments),
|
||||
)
|
||||
+313
-36
@@ -1,12 +1,118 @@
|
||||
# cython: embedsignature=True
|
||||
from cpython cimport Py_EQ
|
||||
|
||||
from pandas import isnull
|
||||
from numpy cimport float64_t, uint8_t
|
||||
from pandas import isnull, Timestamp
|
||||
from numpy cimport float64_t, uint8_t, int64_t
|
||||
from numpy import datetime64, float64
|
||||
# Purely for readability. There aren't C-level declarations for these types.
|
||||
ctypedef object Int64Index_t
|
||||
ctypedef object DatetimeIndex_t
|
||||
ctypedef object Timestamp_t
|
||||
|
||||
# Adjustment kinds.
|
||||
cpdef enum AdjustmentKind:
|
||||
MULTIPLY = 0
|
||||
ADD = 1
|
||||
OVERWRITE = 2
|
||||
|
||||
ADJUSTMENT_KIND_NAMES = {
|
||||
MULTIPLY: 'MULTIPLY',
|
||||
ADD: 'ADD',
|
||||
OVERWRITE: 'OVERWRITE',
|
||||
}
|
||||
|
||||
cdef dict _float_adjustment_types = {
|
||||
ADD: Float64Add,
|
||||
MULTIPLY: Float64Multiply,
|
||||
OVERWRITE: Float64Overwrite,
|
||||
}
|
||||
cdef dict _datetime_adjustment_types = {
|
||||
OVERWRITE: Datetime64Overwrite,
|
||||
}
|
||||
|
||||
cdef _is_float(object value):
|
||||
return isinstance(value, (float, float64))
|
||||
|
||||
def _is_datetime(object value):
|
||||
return isinstance(value, (datetime64, Timestamp))
|
||||
|
||||
|
||||
cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value):
|
||||
"""
|
||||
Make an adjustment object of the type appropriate for the given kind and
|
||||
value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
adjustment_kind : {ADD, MULTIPLY, OVERWRITE}
|
||||
The kind of adjustment to construct.
|
||||
value : object
|
||||
The value parameter to the adjustment. Only floating-point values and
|
||||
datetime-like values are currently supported
|
||||
"""
|
||||
if adjustment_kind in (ADD, MULTIPLY):
|
||||
if not _is_float(value):
|
||||
raise TypeError(
|
||||
"Can't construct %s Adjustment with value of type %r.\n"
|
||||
"ADD and MULTIPLY adjustments are only supported for "
|
||||
"floating point data." % (
|
||||
ADJUSTMENT_KIND_NAMES[adjustment_kind],
|
||||
type(value),
|
||||
)
|
||||
)
|
||||
return _float_adjustment_types[adjustment_kind]
|
||||
|
||||
elif adjustment_kind == OVERWRITE:
|
||||
if _is_float(value):
|
||||
return _float_adjustment_types[adjustment_kind]
|
||||
elif _is_datetime(value):
|
||||
return _datetime_adjustment_types[adjustment_kind]
|
||||
else:
|
||||
raise TypeError(
|
||||
"Don't know how to make overwrite "
|
||||
"adjustments for values of type %r." % type(value),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown adjustment type %d." % adjustment_kind)
|
||||
|
||||
|
||||
cpdef make_adjustment_from_indices(Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_column,
|
||||
Py_ssize_t last_column,
|
||||
AdjustmentKind adjustment_kind,
|
||||
object value):
|
||||
"""
|
||||
Make an Adjustment object from row/column indices into a baseline array.
|
||||
"""
|
||||
cdef type type_ = choose_adjustment_type(adjustment_kind, value)
|
||||
# NOTE_SS: Cython appears to generate incorrect code here if values are
|
||||
# passed by name. This is true even if cython.always_allow_keywords is
|
||||
# enabled. Yay Cython.
|
||||
return type_(first_row, last_row, first_column, last_column, value)
|
||||
|
||||
|
||||
cpdef make_adjustment_from_labels(DatetimeIndex_t dates_index,
|
||||
Int64Index_t assets_index,
|
||||
Timestamp_t start_date,
|
||||
Timestamp_t end_date,
|
||||
int asset_id,
|
||||
AdjustmentKind adjustment_kind,
|
||||
object value):
|
||||
"""
|
||||
Make an Adjustment object from date/asset labels into a labelled baseline
|
||||
array.
|
||||
"""
|
||||
cdef type type_ = choose_adjustment_type(adjustment_kind, value)
|
||||
return type_.from_assets_and_dates(
|
||||
dates_index,
|
||||
assets_index,
|
||||
start_date,
|
||||
end_date,
|
||||
asset_id,
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
cpdef tuple get_adjustment_locs(DatetimeIndex_t dates_index,
|
||||
Int64Index_t assets_index,
|
||||
@@ -87,31 +193,82 @@ cpdef _from_assets_and_dates(cls,
|
||||
end_date,
|
||||
asset_id,
|
||||
)
|
||||
return cls(first_row, last_row, col, col, value)
|
||||
return cls(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=col,
|
||||
last_col=col,
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
cdef class Float64Adjustment:
|
||||
cdef class Adjustment:
|
||||
"""
|
||||
Base class for adjustments that operate on Float64 buffers.
|
||||
Base class for Adjustments.
|
||||
|
||||
Subclasses should inherit and provide a `value` attribute and a `mutate` method.
|
||||
"""
|
||||
cdef:
|
||||
readonly Py_ssize_t first_col, last_col, first_row, last_row
|
||||
readonly float64_t value
|
||||
|
||||
def __cinit__(self,
|
||||
Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_col,
|
||||
Py_ssize_t last_col,
|
||||
object value):
|
||||
def __init__(self,
|
||||
Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_col,
|
||||
Py_ssize_t last_col):
|
||||
assert 0 <= first_row <= last_row
|
||||
assert 0 <= first_col <= last_col
|
||||
|
||||
self.first_row = first_row
|
||||
self.last_row = last_row
|
||||
self.first_col = first_col
|
||||
self.last_col = last_col
|
||||
self.value = float(value)
|
||||
self.first_row = first_row
|
||||
self.last_row = last_row
|
||||
|
||||
from_assets_and_dates = classmethod(_from_assets_and_dates)
|
||||
|
||||
def __richcmp__(self, object other, int op):
|
||||
"""
|
||||
Rich comparison method. Only Equality is defined.
|
||||
"""
|
||||
if op != Py_EQ or type(self) != type(other):
|
||||
return NotImplemented
|
||||
|
||||
return self._key() == other._key()
|
||||
|
||||
cpdef tuple _key(self):
|
||||
"""
|
||||
Comparison key
|
||||
"""
|
||||
return (
|
||||
self.first_row,
|
||||
self.last_row,
|
||||
self.first_col,
|
||||
self.last_col,
|
||||
self.value,
|
||||
)
|
||||
|
||||
|
||||
cdef class Float64Adjustment(Adjustment):
|
||||
"""
|
||||
Base class for adjustments that operate on Float64 data.
|
||||
"""
|
||||
cdef:
|
||||
readonly float64_t value
|
||||
|
||||
def __init__(self,
|
||||
Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_col,
|
||||
Py_ssize_t last_col,
|
||||
float64_t value):
|
||||
|
||||
super(Float64Adjustment, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
self.value = value
|
||||
|
||||
from_assets_and_dates = classmethod(_from_assets_and_dates)
|
||||
|
||||
@@ -128,28 +285,10 @@ cdef class Float64Adjustment:
|
||||
)
|
||||
)
|
||||
|
||||
def __richcmp__(self, object other, int op):
|
||||
"""
|
||||
Rich comparison method. Only Equality is defined.
|
||||
"""
|
||||
if op != Py_EQ or type(self) != type(other):
|
||||
return NotImplemented
|
||||
|
||||
return self._key() == other._key()
|
||||
|
||||
cpdef _key(self):
|
||||
return (
|
||||
self.first_row,
|
||||
self.last_row,
|
||||
self.first_col,
|
||||
self.last_col,
|
||||
self.value,
|
||||
)
|
||||
|
||||
|
||||
cdef class Float64Multiply(Float64Adjustment):
|
||||
"""
|
||||
An adjustment that multiplies by a scalar.
|
||||
An adjustment that multiplies by a float.
|
||||
|
||||
Example
|
||||
-------
|
||||
@@ -187,7 +326,7 @@ cdef class Float64Multiply(Float64Adjustment):
|
||||
|
||||
cdef class Float64Overwrite(Float64Adjustment):
|
||||
"""
|
||||
An adjustment that overwrites with a scalar.
|
||||
An adjustment that overwrites with a float.
|
||||
|
||||
Example
|
||||
-------
|
||||
@@ -225,7 +364,7 @@ cdef class Float64Overwrite(Float64Adjustment):
|
||||
|
||||
cdef class Float64Add(Float64Adjustment):
|
||||
"""
|
||||
An adjustment that adds a scalar.
|
||||
An adjustment that adds a float.
|
||||
|
||||
Example
|
||||
-------
|
||||
@@ -259,3 +398,141 @@ cdef class Float64Add(Float64Adjustment):
|
||||
# last_row + 1 because last_row should also be affected.
|
||||
for row in range(self.first_row, self.last_row + 1):
|
||||
data[row, col] += self.value
|
||||
|
||||
|
||||
cdef class _Int64Adjustment(Adjustment):
|
||||
"""
|
||||
Base class for adjustments that operate on integral data.
|
||||
|
||||
This is private because we never actually operate on integers as data, but
|
||||
we use integer arrays to represent datetime and timedelta data.
|
||||
"""
|
||||
cdef:
|
||||
readonly int64_t value
|
||||
|
||||
def __init__(self,
|
||||
Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_col,
|
||||
Py_ssize_t last_col,
|
||||
int64_t value):
|
||||
super(_Int64Adjustment, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"%s(first_row=%d, last_row=%d,"
|
||||
" first_col=%d, last_col=%d, value=%d)" % (
|
||||
type(self).__name__,
|
||||
self.first_row,
|
||||
self.last_row,
|
||||
self.first_col,
|
||||
self.last_col,
|
||||
self.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cdef datetime_to_int(object datetimelike):
|
||||
"""
|
||||
Coerce a datetime-like object to the int format used by AdjustedArrays of
|
||||
Datetime64 type.
|
||||
"""
|
||||
if isinstance(datetimelike, Timestamp):
|
||||
return datetimelike.value
|
||||
|
||||
if not isinstance(datetimelike, datetime64):
|
||||
raise TypeError("Expected datetime64, got %s" % type(datetimelike))
|
||||
|
||||
elif datetimelike.dtype.name != 'datetime64[ns]':
|
||||
raise TypeError(
|
||||
"Expected datetime64[ns], got %s",
|
||||
datetimelike.dtype.name,
|
||||
)
|
||||
|
||||
return datetimelike.astype(int)
|
||||
|
||||
|
||||
cdef class Datetime64Adjustment(_Int64Adjustment):
|
||||
"""
|
||||
Base class for adjustments that operate on Datetime64 data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Numpy stores datetime64 values in arrays of type int64. There's no
|
||||
straightforward way to work with statically-typed datetime64 data, so
|
||||
instead we work with int64 values everywhere, and we do validation/coercion
|
||||
at API boundaries.
|
||||
"""
|
||||
def __init__(self,
|
||||
Py_ssize_t first_row,
|
||||
Py_ssize_t last_row,
|
||||
Py_ssize_t first_col,
|
||||
Py_ssize_t last_col,
|
||||
object value):
|
||||
|
||||
super(Datetime64Adjustment, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
value=datetime_to_int(value),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"%s(first_row=%d, last_row=%d,"
|
||||
" first_col=%d, last_col=%d, value=%r)" % (
|
||||
type(self).__name__,
|
||||
self.first_row,
|
||||
self.last_row,
|
||||
self.first_col,
|
||||
self.last_col,
|
||||
datetime64(self.value, 'ns'),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cdef class Datetime64Overwrite(Datetime64Adjustment):
|
||||
"""
|
||||
An adjustment that overwrites with a datetime.
|
||||
|
||||
This operates on int64 data which should be interpreted as nanoseconds
|
||||
since the epoch.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np; import pandas as pd
|
||||
>>> dts = pd.date_range('2014', freq='D', periods=9, tz='UTC')
|
||||
>>> arr = dts.values.reshape(3, 3)
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
>>> adj = Datetime64Overwrite(
|
||||
... first_row=1,
|
||||
... last_row=2,
|
||||
... first_col=1,
|
||||
... last_col=2,
|
||||
... value=np.datetime64(0, 'ns'),
|
||||
... )
|
||||
>>> adj.mutate(arr.view(np.int64))
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, True, True],
|
||||
[False, True, True]], dtype=bool)
|
||||
"""
|
||||
cpdef mutate(self, int64_t[:, :] data):
|
||||
cdef Py_ssize_t row, col
|
||||
|
||||
# last_col + 1 because last_col should also be affected.
|
||||
for col in range(self.first_col, self.last_col + 1):
|
||||
# last_row + 1 because last_row should also be affected.
|
||||
for row in range(self.first_row, self.last_row + 1):
|
||||
data[row, col] = self.value
|
||||
|
||||
+46
-3
@@ -2,6 +2,7 @@
|
||||
Functions for ranking and sorting.
|
||||
"""
|
||||
cimport cython
|
||||
from cpython cimport bool
|
||||
from numpy cimport (
|
||||
float64_t,
|
||||
import_array,
|
||||
@@ -13,18 +14,60 @@ from numpy cimport (
|
||||
PyArray_DIMS,
|
||||
PyArray_EMPTY,
|
||||
)
|
||||
from numpy import nan
|
||||
from numpy import apply_along_axis, float64, isnan, nan
|
||||
from scipy.stats import rankdata
|
||||
|
||||
|
||||
import_array()
|
||||
|
||||
|
||||
cdef double NAN = nan
|
||||
def masked_rankdata_2d(ndarray data,
|
||||
ndarray mask,
|
||||
object missing_value,
|
||||
str method,
|
||||
bool ascending):
|
||||
"""
|
||||
Compute masked rankdata on data on float64, int64, or datetime64 data.
|
||||
"""
|
||||
cdef str dtype_name = data.dtype.name
|
||||
if dtype_name not in ('float64', 'int64', 'datetime64[ns]'):
|
||||
raise TypeError(
|
||||
"Can't compute rankdata on array of dtype %r." % dtype_name
|
||||
)
|
||||
|
||||
cdef ndarray missing_locations = ~mask
|
||||
# Mask out any entries that are equal to the missing value.
|
||||
if dtype_name == 'float64' and isnan(missing_value):
|
||||
missing_locations |= isnan(data)
|
||||
else:
|
||||
missing_locations |= (data == missing_value)
|
||||
|
||||
# Interpret the bytes of integral data as floats for sorting.
|
||||
data = data.copy().view(float64)
|
||||
data[missing_locations] = nan
|
||||
if not ascending:
|
||||
data = -data
|
||||
|
||||
# OPTIMIZATION: Fast path the default case with our own specialized
|
||||
# Cython implementation.
|
||||
if method == 'ordinal':
|
||||
result = rankdata_2d_ordinal(data)
|
||||
else:
|
||||
# FUTURE OPTIMIZATION:
|
||||
# Write a less general "apply to rows" method that doesn't do all
|
||||
# the extra work that apply_along_axis does.
|
||||
result = apply_along_axis(rankdata, 1, data, method=method)
|
||||
|
||||
# rankdata will sort missing values into last place, but we want our nans
|
||||
# to propagate, so explicitly re-apply.
|
||||
result[missing_locations] = nan
|
||||
return result
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
@cython.embedsignature(True)
|
||||
def rankdata_2d_ordinal(ndarray[float64_t, ndim=2] array):
|
||||
cpdef rankdata_2d_ordinal(ndarray[float64_t, ndim=2] array):
|
||||
"""
|
||||
Equivalent to:
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .earnings import EarningsCalendar
|
||||
from .equity_pricing import USEquityPricing
|
||||
from .dataset import DataSet, Column, BoundColumn
|
||||
|
||||
@@ -5,5 +6,6 @@ __all__ = [
|
||||
'BoundColumn',
|
||||
'Column',
|
||||
'DataSet',
|
||||
'EarningsCalendar',
|
||||
'USEquityPricing',
|
||||
]
|
||||
|
||||
@@ -8,7 +8,8 @@ from six import (
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, AssetExists
|
||||
from zipline.pipeline.factors import Latest
|
||||
from zipline.utils.input_validation import ensure_dtype
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
class Column(object):
|
||||
@@ -16,6 +17,7 @@ class Column(object):
|
||||
An abstract column of data, not yet associated with a dataset.
|
||||
"""
|
||||
|
||||
@preprocess(dtype=ensure_dtype)
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -73,15 +75,13 @@ class BoundColumn(Term):
|
||||
|
||||
@property
|
||||
def latest(self):
|
||||
# FIXME: Once we support non-float dtypes, this should pass a dtype
|
||||
# along. Right now we're just assuming that inputs will safely coerce
|
||||
# to float.
|
||||
return Latest(inputs=(self,))
|
||||
from zipline.pipeline.factors import Latest
|
||||
return Latest(inputs=(self,), dtype=self.dtype)
|
||||
|
||||
def __repr__(self):
|
||||
return "{qualname}::{dtype}".format(
|
||||
qualname=self.qualname,
|
||||
dtype=self.dtype.__name__,
|
||||
dtype=self.dtype.name,
|
||||
)
|
||||
|
||||
def short_repr(self):
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Dataset representing dates of upcoming earnings.
|
||||
"""
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype
|
||||
|
||||
from .dataset import Column, DataSet
|
||||
|
||||
|
||||
class EarningsCalendar(DataSet):
|
||||
"""
|
||||
Dataset representing dates of upcoming or recently announced earnings.
|
||||
"""
|
||||
next_announcement = Column(datetime64ns_dtype)
|
||||
previous_announcement = Column(datetime64ns_dtype)
|
||||
|
||||
# TODO: Provide categorical columns for when during the day the
|
||||
# announcement occurred.
|
||||
@@ -1,12 +1,17 @@
|
||||
from numpy import float64, uint32
|
||||
"""
|
||||
Dataset representing OHLCV data.
|
||||
"""
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
from .dataset import Column, DataSet
|
||||
|
||||
|
||||
class USEquityPricing(DataSet):
|
||||
|
||||
open = Column(float64)
|
||||
high = Column(float64)
|
||||
low = Column(float64)
|
||||
close = Column(float64)
|
||||
volume = Column(uint32)
|
||||
"""
|
||||
Dataset representing daily trading prices and volumes.
|
||||
"""
|
||||
open = Column(float64_dtype)
|
||||
high = Column(float64_dtype)
|
||||
low = Column(float64_dtype)
|
||||
close = Column(float64_dtype)
|
||||
volume = Column(float64_dtype)
|
||||
|
||||
@@ -9,11 +9,11 @@ import numexpr
|
||||
from numexpr.necompiler import getExprNames
|
||||
from numpy import (
|
||||
empty,
|
||||
find_common_type,
|
||||
inf,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, NotSpecified, CompositeTerm
|
||||
from zipline.pipeline.term import Term, CompositeTerm
|
||||
|
||||
|
||||
_VARIABLE_NAME_RE = re.compile("^(x_)([0-9]+)$")
|
||||
|
||||
@@ -54,6 +54,11 @@ _ops_to_commuted_methods = {
|
||||
'>=': '__le__',
|
||||
'>': '__lt__',
|
||||
}
|
||||
_unary_ops_to_methods = {
|
||||
'-': '__neg__',
|
||||
'~': '__invert__',
|
||||
}
|
||||
|
||||
UNARY_OPS = {'-'}
|
||||
MATH_BINOPS = {'+', '-', '*', '/', '**', '%'}
|
||||
FILTER_BINOPS = {'&', '|'} # NumExpr doesn't support xor.
|
||||
@@ -151,6 +156,10 @@ def method_name_for_op(op, commute=False):
|
||||
return _ops_to_methods[op]
|
||||
|
||||
|
||||
def unary_op_name(op):
|
||||
return _unary_ops_to_methods[op]
|
||||
|
||||
|
||||
def is_comparison(op):
|
||||
return op in COMPARISONS
|
||||
|
||||
@@ -162,31 +171,17 @@ class NumericalExpression(CompositeTerm):
|
||||
Parameters
|
||||
----------
|
||||
expr : string
|
||||
A string suitable for passing to numexpr. All variables in 'expr'
|
||||
should be of the form "x_i", where i is the index of the corresponding
|
||||
factor input in 'binds'.
|
||||
A string suitable for passing to numexpr. All variables in 'expr'
|
||||
should be of the form "x_i", where i is the index of the corresponding
|
||||
factor input in 'binds'.
|
||||
binds : tuple
|
||||
A tuple of factors to use as inputs.
|
||||
A tuple of factors to use as inputs.
|
||||
dtype : np.dtype
|
||||
The dtype for the expression.
|
||||
"""
|
||||
window_length = 0
|
||||
|
||||
def __new__(cls, expr, binds):
|
||||
|
||||
# If our class doesn't have an explicit dtype set, infer one from the
|
||||
# inputs.
|
||||
|
||||
# FIXME: This doesn't take into account dtypes of constants, so it will
|
||||
# break if we have something like
|
||||
# factor(int64) + factor(int64) + 2.5.
|
||||
# The real fix for this is probably for the calling context to specify
|
||||
# dtypes.
|
||||
if cls.dtype is not NotSpecified:
|
||||
dtype = cls.dtype
|
||||
else:
|
||||
dtype = find_common_type(
|
||||
[factor.dtype for factor in binds],
|
||||
[],
|
||||
)
|
||||
def __new__(cls, expr, binds, dtype):
|
||||
return super(NumericalExpression, cls).__new__(
|
||||
cls,
|
||||
inputs=binds,
|
||||
|
||||
@@ -3,6 +3,10 @@ from .factor import (
|
||||
CustomFactor,
|
||||
)
|
||||
from .latest import Latest
|
||||
from .events import (
|
||||
BusinessDaysSincePreviousEarnings,
|
||||
BusinessDaysUntilNextEarnings,
|
||||
)
|
||||
from .technical import (
|
||||
MaxDrawdown,
|
||||
RSI,
|
||||
@@ -14,6 +18,8 @@ from .technical import (
|
||||
|
||||
__all__ = [
|
||||
'CustomFactor',
|
||||
'BusinessDaysSincePreviousEarnings',
|
||||
'BusinessDaysUntilNextEarnings',
|
||||
'Factor',
|
||||
'Latest',
|
||||
'MaxDrawdown',
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Factors describing information about event data (e.g. earnings
|
||||
announcements, acquisitions, dividends, etc.).
|
||||
"""
|
||||
from numpy import newaxis
|
||||
from zipline.pipeline.data.earnings import EarningsCalendar
|
||||
from zipline.utils.numpy_utils import (
|
||||
np_NaT,
|
||||
busday_count_mask_NaT,
|
||||
datetime64D_dtype,
|
||||
float64_dtype,
|
||||
)
|
||||
|
||||
from .factor import Factor
|
||||
|
||||
|
||||
class BusinessDaysUntilNextEarnings(Factor):
|
||||
"""
|
||||
Factor returning the number of **business days** (not trading days!) until
|
||||
the next known earnings date for each asset.
|
||||
|
||||
This doesn't use trading days because the trading calendar includes
|
||||
information that may not have been available to the algorithm at the time
|
||||
when `compute` is called.
|
||||
|
||||
For example, the NYSE closings September 11th 2001, would not have been
|
||||
known to the algorithm on September 10th.
|
||||
|
||||
Assets that announced or will announce earnings today will produce a value
|
||||
of 0.0. Assets that will announce earnings on the next upcoming business
|
||||
day will produce a value of 1.0.
|
||||
|
||||
Assets for which `EarningsCalendar.next_announcement` is `NaT` will produce
|
||||
a value of `NaN`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
BusinessDaysSincePreviousEarnings
|
||||
"""
|
||||
inputs = [EarningsCalendar.next_announcement]
|
||||
window_length = 0
|
||||
dtype = float64_dtype
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
|
||||
# Coerce from [ns] to [D] for numpy busday_count.
|
||||
announce_dates = arrays[0].astype(datetime64D_dtype)
|
||||
|
||||
# Set masked values to NaT.
|
||||
announce_dates[~mask] = np_NaT
|
||||
|
||||
# Convert row labels into a column vector for broadcasted comparison.
|
||||
reference_dates = dates.values.astype(datetime64D_dtype)[:, newaxis]
|
||||
return busday_count_mask_NaT(reference_dates, announce_dates)
|
||||
|
||||
|
||||
class BusinessDaysSincePreviousEarnings(Factor):
|
||||
"""
|
||||
Factor returning the number of **business days** (not trading days!) since
|
||||
the most recent earnings date for each asset.
|
||||
|
||||
This doesn't use trading days for symmetry with
|
||||
BusinessDaysUntilNextEarnings.
|
||||
|
||||
Assets which announced or will announce earnings today will produce a value
|
||||
of 0.0. Assets that announced earnings on the previous business day will
|
||||
produce a value of 1.0.
|
||||
|
||||
Assets for which `EarningsCalendar.previous_announcement` is `NaT` will
|
||||
produce a value of `NaN`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
BusinessDaysUntilNextEarnings
|
||||
"""
|
||||
inputs = [EarningsCalendar.previous_announcement]
|
||||
window_length = 0
|
||||
dtype = float64_dtype
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
|
||||
# Coerce from [ns] to [D] for numpy busday_count.
|
||||
announce_dates = arrays[0].astype(datetime64D_dtype)
|
||||
|
||||
# Set masked values to NaT.
|
||||
announce_dates[~mask] = np_NaT
|
||||
|
||||
# Convert row labels into a column vector for broadcasted comparison.
|
||||
reference_dates = dates.values.astype(datetime64D_dtype)[:, newaxis]
|
||||
return busday_count_mask_NaT(announce_dates, reference_dates)
|
||||
@@ -4,19 +4,14 @@ factor.py
|
||||
from operator import attrgetter
|
||||
from numbers import Number
|
||||
|
||||
from numpy import (
|
||||
apply_along_axis,
|
||||
float64,
|
||||
nan,
|
||||
inf,
|
||||
)
|
||||
from scipy.stats import rankdata
|
||||
from numpy import float64, inf
|
||||
from toolz import curry
|
||||
|
||||
from zipline.errors import (
|
||||
UnknownRankMethod,
|
||||
UnsupportedDataType,
|
||||
)
|
||||
from zipline.lib.rank import rankdata_2d_ordinal
|
||||
from zipline.lib.rank import masked_rankdata_2d
|
||||
from zipline.pipeline.term import (
|
||||
CustomTermMixin,
|
||||
NotSpecified,
|
||||
@@ -33,17 +28,67 @@ from zipline.pipeline.expression import (
|
||||
NumericalExpression,
|
||||
NUMEXPR_MATH_FUNCS,
|
||||
UNARY_OPS,
|
||||
unary_op_name,
|
||||
)
|
||||
from zipline.pipeline.filters import (
|
||||
NumExprFilter,
|
||||
PercentileFilter,
|
||||
)
|
||||
from zipline.utils.control_flow import nullctx
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
)
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
_RANK_METHODS = frozenset(['average', 'min', 'max', 'dense', 'ordinal'])
|
||||
|
||||
|
||||
def numbers_to_float64(func, argname, argvalue):
|
||||
"""
|
||||
Preprocessor for converting numerical inputs into floats.
|
||||
|
||||
This is used in the binary operator constructors for Factor so that
|
||||
`2 + Factor()` has the same behavior as `2.0 + Factor()`.
|
||||
"""
|
||||
if isinstance(argvalue, Number):
|
||||
return float64(argvalue)
|
||||
return argvalue
|
||||
|
||||
|
||||
@curry
|
||||
def set_attribute(name, value):
|
||||
"""
|
||||
Decorator factory for setting attributes on a function.
|
||||
|
||||
Doesn't change the behavior of the wrapped function.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> @set_attribute('__name__', 'foo')
|
||||
... def bar():
|
||||
... return 3
|
||||
...
|
||||
>>> bar()
|
||||
3
|
||||
>>> bar.__name__
|
||||
'foo'
|
||||
"""
|
||||
def decorator(f):
|
||||
setattr(f, name, value)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
# Decorators for setting the __name__ and __doc__ properties of a decorated
|
||||
# function.
|
||||
# Example:
|
||||
with_name = set_attribute('__name__')
|
||||
with_doc = set_attribute('__doc__')
|
||||
|
||||
|
||||
def binop_return_type(op):
|
||||
if is_comparison(op):
|
||||
return NumExprFilter
|
||||
@@ -51,6 +96,46 @@ def binop_return_type(op):
|
||||
return NumExprFactor
|
||||
|
||||
|
||||
def binop_return_dtype(op, left, right):
|
||||
"""
|
||||
Compute the expected return dtype for the given binary operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op : str
|
||||
Operator symbol, (e.g. '+', '-', ...).
|
||||
left : numpy.dtype
|
||||
Dtype of left hand side.
|
||||
right : numpy.dtype
|
||||
Dtype of right hand side.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outdtype : numpy.dtype
|
||||
The dtype of the result of `left <op> right`.
|
||||
"""
|
||||
if is_comparison(op):
|
||||
if left != right:
|
||||
raise TypeError(
|
||||
"Don't know how to compute {left} {op} {right}.\n"
|
||||
"Comparisons are only supported between Factors of equal "
|
||||
"dtypes.".format(left=left, op=op, right=right)
|
||||
)
|
||||
return bool_dtype
|
||||
|
||||
elif left != float64_dtype or right != float64_dtype:
|
||||
raise TypeError(
|
||||
"Don't know how to compute {left} {op} {right}.\n"
|
||||
"Arithmetic operators are only supported on Factors of "
|
||||
"dtype 'float64'.".format(
|
||||
left=left.name,
|
||||
op=op,
|
||||
right=right.name,
|
||||
)
|
||||
)
|
||||
return float64_dtype
|
||||
|
||||
|
||||
def binary_operator(op):
|
||||
"""
|
||||
Factory function for making binary operator methods on a Factor subclass.
|
||||
@@ -63,6 +148,9 @@ def binary_operator(op):
|
||||
# NumericalExpression operator.
|
||||
commuted_method_getter = attrgetter(method_name_for_op(op, commute=True))
|
||||
|
||||
@preprocess(other=numbers_to_float64)
|
||||
@with_doc("Binary Operator: '%s'" % op)
|
||||
@with_name(method_name_for_op(op))
|
||||
def binary_operator(self, other):
|
||||
# This can't be hoisted up a scope because the types returned by
|
||||
# binop_return_type aren't defined when the top-level function is
|
||||
@@ -79,6 +167,7 @@ def binary_operator(op):
|
||||
right=other_expr,
|
||||
),
|
||||
new_inputs,
|
||||
dtype=binop_return_dtype(op, self.dtype, other.dtype),
|
||||
)
|
||||
elif isinstance(other, NumExprFactor):
|
||||
# NumericalExpression overrides ops to correctly handle merging of
|
||||
@@ -90,19 +179,22 @@ def binary_operator(op):
|
||||
return return_type(
|
||||
"x_0 {op} x_0".format(op=op),
|
||||
(self,),
|
||||
dtype=binop_return_dtype(op, self.dtype, other.dtype),
|
||||
)
|
||||
return return_type(
|
||||
"x_0 {op} x_1".format(op=op),
|
||||
(self, other),
|
||||
dtype=binop_return_dtype(op, self.dtype, other.dtype),
|
||||
)
|
||||
elif isinstance(other, Number):
|
||||
return return_type(
|
||||
"x_0 {op} ({constant})".format(op=op, constant=other),
|
||||
binds=(self,),
|
||||
# Interpret numeric literals as floats.
|
||||
dtype=binop_return_dtype(op, self.dtype, other.dtype)
|
||||
)
|
||||
raise BadBinaryOperator(op, self, other)
|
||||
|
||||
binary_operator.__doc__ = "Binary Operator: '%s'" % op
|
||||
return binary_operator
|
||||
|
||||
|
||||
@@ -115,6 +207,8 @@ def reflected_binary_operator(op):
|
||||
"""
|
||||
assert not is_comparison(op)
|
||||
|
||||
@preprocess(other=numbers_to_float64)
|
||||
@with_name(method_name_for_op(op, commute=True))
|
||||
def reflected_binary_operator(self, other):
|
||||
|
||||
if isinstance(self, NumericalExpression):
|
||||
@@ -128,6 +222,7 @@ def reflected_binary_operator(op):
|
||||
op=op,
|
||||
),
|
||||
new_inputs,
|
||||
dtype=binop_return_dtype(op, other.dtype, self.dtype)
|
||||
)
|
||||
|
||||
# Only have to handle the numeric case because in all other valid cases
|
||||
@@ -136,6 +231,7 @@ def reflected_binary_operator(op):
|
||||
return NumExprFactor(
|
||||
"{constant} {op} x_0".format(op=op, constant=other),
|
||||
binds=(self,),
|
||||
dtype=binop_return_dtype(op, other.dtype, self.dtype),
|
||||
)
|
||||
raise BadBinaryOperator(op, other, self)
|
||||
return reflected_binary_operator
|
||||
@@ -145,12 +241,26 @@ def unary_operator(op):
|
||||
"""
|
||||
Factory function for making unary operator methods for Factors.
|
||||
"""
|
||||
# Only negate is currently supported for all our possible input types.
|
||||
# Only negate is currently supported.
|
||||
valid_ops = {'-'}
|
||||
if op not in valid_ops:
|
||||
raise ValueError("Invalid unary operator %s." % op)
|
||||
|
||||
@with_doc("Unary Operator: '%s'" % op)
|
||||
@with_name(unary_op_name(op))
|
||||
def unary_operator(self):
|
||||
if self.dtype != float64_dtype:
|
||||
raise TypeError(
|
||||
"Can't apply unary operator {op!r} to instance of "
|
||||
"{typename!r} with dtype {dtypename!r}.\n"
|
||||
"{op!r} is only supported for Factors of dtype "
|
||||
"'float64'.".format(
|
||||
op=op,
|
||||
typename=type(self).__name__,
|
||||
dtypename=self.dtype.name,
|
||||
)
|
||||
)
|
||||
|
||||
# This can't be hoisted up a scope because the types returned by
|
||||
# unary_op_return_type aren't defined when the top-level function is
|
||||
# invoked.
|
||||
@@ -158,11 +268,14 @@ def unary_operator(op):
|
||||
return NumExprFactor(
|
||||
"{op}({expr})".format(op=op, expr=self._expr),
|
||||
self.inputs,
|
||||
dtype=float64_dtype,
|
||||
)
|
||||
else:
|
||||
return NumExprFactor("{op}x_0".format(op=op), (self,))
|
||||
|
||||
unary_operator.__doc__ = "Unary Operator: '%s'" % op
|
||||
return NumExprFactor(
|
||||
"{op}x_0".format(op=op),
|
||||
(self,),
|
||||
dtype=float64_dtype,
|
||||
)
|
||||
return unary_operator
|
||||
|
||||
|
||||
@@ -174,23 +287,30 @@ def function_application(func):
|
||||
if func not in NUMEXPR_MATH_FUNCS:
|
||||
raise ValueError("Unsupported mathematical function '%s'" % func)
|
||||
|
||||
@with_name(func)
|
||||
def mathfunc(self):
|
||||
if isinstance(self, NumericalExpression):
|
||||
return NumExprFactor(
|
||||
"{func}({expr})".format(func=func, expr=self._expr),
|
||||
self.inputs,
|
||||
dtype=float64_dtype,
|
||||
)
|
||||
else:
|
||||
return NumExprFactor("{func}(x_0)".format(func=func), (self,))
|
||||
return NumExprFactor(
|
||||
"{func}(x_0)".format(func=func),
|
||||
(self,),
|
||||
dtype=float64_dtype,
|
||||
)
|
||||
return mathfunc
|
||||
|
||||
|
||||
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype])
|
||||
|
||||
|
||||
class Factor(CompositeTerm):
|
||||
"""
|
||||
Pipeline API expression producing numerically-valued outputs.
|
||||
"""
|
||||
dtype = float64
|
||||
|
||||
# Dynamically add functions for creating NumExprFactor/NumExprFilter
|
||||
# instances.
|
||||
clsdict = locals()
|
||||
@@ -210,10 +330,11 @@ class Factor(CompositeTerm):
|
||||
)
|
||||
clsdict.update(
|
||||
{
|
||||
'__neg__': unary_operator(op)
|
||||
unary_op_name(op): unary_operator(op)
|
||||
for op in UNARY_OPS
|
||||
}
|
||||
)
|
||||
|
||||
clsdict.update(
|
||||
{
|
||||
funcname: function_application(funcname)
|
||||
@@ -226,6 +347,17 @@ class Factor(CompositeTerm):
|
||||
|
||||
eq = binary_operator('==')
|
||||
|
||||
def _validate(self):
|
||||
# Do superclass validation first so that `NotSpecified` dtypes get
|
||||
# handled.
|
||||
retval = super(Factor, self)._validate()
|
||||
if self.dtype not in FACTOR_DTYPES:
|
||||
raise UnsupportedDataType(
|
||||
typename=type(self).__name__,
|
||||
dtype=self.dtype
|
||||
)
|
||||
return retval
|
||||
|
||||
def rank(self, method='ordinal', ascending=True, mask=NotSpecified):
|
||||
"""
|
||||
Construct a new Factor representing the sorted rank of each column
|
||||
@@ -266,7 +398,7 @@ class Factor(CompositeTerm):
|
||||
zipline.lib.rank
|
||||
zipline.pipeline.factors.Rank
|
||||
"""
|
||||
return Rank(self if ascending else -self, method=method, mask=mask)
|
||||
return Rank(self, method=method, ascending=ascending, mask=mask)
|
||||
|
||||
def top(self, N, mask=NotSpecified):
|
||||
"""
|
||||
@@ -347,6 +479,10 @@ class Factor(CompositeTerm):
|
||||
def isnan(self):
|
||||
"""
|
||||
A Filter producing True for all values where this Factor is NaN.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nanfilter : zipline.pipeline.filters.Filter
|
||||
"""
|
||||
return self != self
|
||||
|
||||
@@ -413,25 +549,28 @@ class Rank(SingleInputMixin, Factor):
|
||||
instance of this class.
|
||||
"""
|
||||
window_length = 0
|
||||
dtype = float64
|
||||
dtype = float64_dtype
|
||||
|
||||
def __new__(cls, factor, method, mask):
|
||||
def __new__(cls, factor, method, ascending, mask):
|
||||
return super(Rank, cls).__new__(
|
||||
cls,
|
||||
inputs=(factor,),
|
||||
method=method,
|
||||
ascending=ascending,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
def _init(self, method, *args, **kwargs):
|
||||
def _init(self, method, ascending, *args, **kwargs):
|
||||
self._method = method
|
||||
self._ascending = ascending
|
||||
return super(Rank, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, method, *args, **kwargs):
|
||||
def static_identity(cls, method, ascending, *args, **kwargs):
|
||||
return (
|
||||
super(Rank, cls).static_identity(*args, **kwargs),
|
||||
method,
|
||||
ascending,
|
||||
)
|
||||
|
||||
def _validate(self):
|
||||
@@ -450,23 +589,13 @@ class Rank(SingleInputMixin, Factor):
|
||||
For each row in the input, compute a like-shaped array of per-row
|
||||
ranks.
|
||||
"""
|
||||
inv_mask = ~mask
|
||||
data = arrays[0].copy()
|
||||
data[inv_mask] = nan
|
||||
# OPTIMIZATION: Fast path the default case with our own specialized
|
||||
# Cython implementation.
|
||||
if self._method == 'ordinal':
|
||||
result = rankdata_2d_ordinal(data)
|
||||
else:
|
||||
# FUTURE OPTIMIZATION:
|
||||
# Write a less general "apply to rows" method that doesn't do all
|
||||
# the extra work that apply_along_axis does.
|
||||
result = apply_along_axis(rankdata, 1, data, method=self._method)
|
||||
|
||||
# rankdata will sort nan values into last place, but we want our
|
||||
# nans to propagate, so explicitly re-apply.
|
||||
result[inv_mask] = nan
|
||||
return result
|
||||
return masked_rankdata_2d(
|
||||
arrays[0],
|
||||
mask,
|
||||
self.inputs[0].missing_value,
|
||||
self._method,
|
||||
self._ascending,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "{type}({input_}, method='{method}', mask={mask})".format(
|
||||
@@ -513,7 +642,7 @@ class CustomFactor(RequiredWindowLengthMixin, CustomTermMixin, Factor):
|
||||
Row label for the last row of all arrays passed as `inputs`.
|
||||
assets : np.array[int64, ndim=1]
|
||||
Column labels for `out` and`inputs`.
|
||||
out : np.array[float64, ndim=1]
|
||||
out : np.array[self.dtype, ndim=1]
|
||||
Output array of the same shape as `assets`. `compute` should write
|
||||
its desired return values into `out`.
|
||||
*inputs : tuple of np.array
|
||||
@@ -581,9 +710,5 @@ class CustomFactor(RequiredWindowLengthMixin, CustomTermMixin, Factor):
|
||||
median_close10 = MedianValue([USEquityPricing.close], window_length=10)
|
||||
median_low15 = MedianValue([USEquityPricing.low], window_length=15)
|
||||
'''
|
||||
dtype = float64_dtype
|
||||
ctx = nullctx()
|
||||
|
||||
def _validate(self):
|
||||
if self.dtype != float64:
|
||||
raise UnsupportedDataType(dtype=self.dtype)
|
||||
return super(CustomFactor, self)._validate()
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
filter.py
|
||||
"""
|
||||
from numpy import (
|
||||
bool_,
|
||||
float64,
|
||||
nan,
|
||||
nanpercentile,
|
||||
@@ -12,10 +11,13 @@ from operator import attrgetter
|
||||
|
||||
from zipline.errors import (
|
||||
BadPercentileBounds,
|
||||
UnsupportedDataType,
|
||||
)
|
||||
from zipline.pipeline.term import (
|
||||
SingleInputMixin,
|
||||
CompositeTerm,
|
||||
CustomTermMixin,
|
||||
RequiredWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
@@ -23,6 +25,8 @@ from zipline.pipeline.expression import (
|
||||
method_name_for_op,
|
||||
NumericalExpression,
|
||||
)
|
||||
from zipline.utils.control_flow import nullctx
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
|
||||
|
||||
def concat_tuples(*tuples):
|
||||
@@ -49,7 +53,7 @@ def binary_operator(op):
|
||||
self_expr, other_expr, new_inputs = self.build_binary_op(
|
||||
op, other,
|
||||
)
|
||||
return NumExprFilter(
|
||||
return NumExprFilter.create(
|
||||
"({left}) {op} ({right})".format(
|
||||
left=self_expr,
|
||||
op=op,
|
||||
@@ -64,16 +68,16 @@ def binary_operator(op):
|
||||
return commuted_method_getter(other)(self)
|
||||
elif isinstance(other, Filter):
|
||||
if self is other:
|
||||
return NumExprFilter(
|
||||
return NumExprFilter.create(
|
||||
"x_0 {op} x_0".format(op=op),
|
||||
(self,),
|
||||
)
|
||||
return NumExprFilter(
|
||||
return NumExprFilter.create(
|
||||
"x_0 {op} x_1".format(op=op),
|
||||
(self, other),
|
||||
)
|
||||
elif isinstance(other, int): # Note that this is true for bool as well
|
||||
return NumExprFilter(
|
||||
return NumExprFilter.create(
|
||||
"x_0 {op} ({constant})".format(op=op, constant=int(other)),
|
||||
binds=(self,),
|
||||
)
|
||||
@@ -96,12 +100,12 @@ def unary_operator(op):
|
||||
# unary_op_return_type aren't defined when the top-level function is
|
||||
# invoked.
|
||||
if isinstance(self, NumericalExpression):
|
||||
return NumExprFilter(
|
||||
return NumExprFilter.create(
|
||||
"{op}({expr})".format(op=op, expr=self._expr),
|
||||
self.inputs,
|
||||
)
|
||||
else:
|
||||
return NumExprFilter("{op}x_0".format(op=op), (self,))
|
||||
return NumExprFilter.create("{op}x_0".format(op=op), (self,))
|
||||
|
||||
unary_operator.__doc__ = "Unary Operator: '%s'" % op
|
||||
return unary_operator
|
||||
@@ -111,7 +115,7 @@ class Filter(CompositeTerm):
|
||||
"""
|
||||
Pipeline API expression producing boolean-valued outputs.
|
||||
"""
|
||||
dtype = bool_
|
||||
dtype = bool_dtype
|
||||
|
||||
clsdict = locals()
|
||||
clsdict.update(
|
||||
@@ -122,12 +126,34 @@ class Filter(CompositeTerm):
|
||||
)
|
||||
__invert__ = unary_operator('~')
|
||||
|
||||
def _validate(self):
|
||||
# Run superclass validation first so that we handle `dtype not passed`
|
||||
# before this.
|
||||
retval = super(Filter, self)._validate()
|
||||
if self.dtype != bool_dtype:
|
||||
raise UnsupportedDataType(
|
||||
typename=type(self).__name__,
|
||||
dtype=self.dtype
|
||||
)
|
||||
return retval
|
||||
|
||||
|
||||
class NumExprFilter(NumericalExpression, Filter):
|
||||
"""
|
||||
A Filter computed from a numexpr expression.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, expr, binds):
|
||||
"""
|
||||
Helper for creating new NumExprFactors.
|
||||
|
||||
This is just a wrapper around NumExprFactor.__new__ that always
|
||||
forwards `bool` as the dtype, since Filters can only be of boolean
|
||||
dtype.
|
||||
"""
|
||||
return cls(expr=expr, binds=binds, dtype=bool_dtype)
|
||||
|
||||
def _compute(self, arrays, dates, assets, mask):
|
||||
"""
|
||||
Compute our result with numexpr, then re-apply `mask`.
|
||||
@@ -215,3 +241,10 @@ class PercentileFilter(SingleInputMixin, Filter):
|
||||
keepdims=True,
|
||||
)
|
||||
return (lower_bounds <= data) & (data <= upper_bounds)
|
||||
|
||||
|
||||
class CustomFilter(RequiredWindowLengthMixin, CustomTermMixin, Filter):
|
||||
"""
|
||||
Filter analog to ``CustomFactor``.
|
||||
"""
|
||||
ctx = nullctx()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from .earnings import EarningsCalendarLoader
|
||||
from .equity_pricing_loader import USEquityPricingLoader
|
||||
|
||||
__all__ = ['USEquityPricingLoader']
|
||||
__all__ = [
|
||||
'EarningsCalendarLoader',
|
||||
'USEquityPricingLoader',
|
||||
]
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from .core import (
|
||||
AD_FIELD_NAME,
|
||||
BlazeLoader,
|
||||
NoDeltasWarning,
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
from_blaze,
|
||||
global_loader,
|
||||
)
|
||||
from .earnings import (
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
BlazeEarningsCalendarLoader,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
'AD_FIELD_NAME',
|
||||
'ANNOUNCEMENT_FIELD_NAME',
|
||||
'BlazeEarningsCalendarLoader',
|
||||
'BlazeLoader',
|
||||
'NoDeltasWarning',
|
||||
'SID_FIELD_NAME',
|
||||
'TS_FIELD_NAME',
|
||||
'from_blaze',
|
||||
'global_loader',
|
||||
)
|
||||
@@ -157,9 +157,10 @@ import toolz.curried.operator as op
|
||||
from six import with_metaclass, PY2, itervalues
|
||||
|
||||
|
||||
from ..data.dataset import DataSet, Column
|
||||
from zipline.lib.adjusted_array import adjusted_array
|
||||
from zipline.pipeline.data.dataset import DataSet, Column
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
from zipline.utils.enum import enum
|
||||
from zipline.utils.input_validation import expect_element
|
||||
from zipline.utils.numpy_utils import repeat_last_axis
|
||||
|
||||
@@ -306,7 +307,7 @@ def new_dataset(expr, deltas):
|
||||
except TypeError:
|
||||
col = NonNumpyField(name, type_)
|
||||
else:
|
||||
col = Column(type_.to_numpy_dtype().type)
|
||||
col = Column(type_.to_numpy_dtype())
|
||||
|
||||
columns[name] = col
|
||||
|
||||
@@ -391,10 +392,10 @@ class NoDeltasWarning(UserWarning):
|
||||
return 'No deltas could be inferred from expr: %s' % self._expr
|
||||
|
||||
|
||||
_valid_no_deltas_rules = 'warn', 'raise', 'ignore'
|
||||
no_deltas_rules = enum('warn', 'raise_', 'ignore')
|
||||
|
||||
|
||||
def _get_deltas(expr, deltas, no_deltas_rule):
|
||||
def get_deltas(expr, deltas, no_deltas_rule):
|
||||
"""Find the correct deltas for the expression.
|
||||
|
||||
Parameters
|
||||
@@ -406,7 +407,7 @@ def _get_deltas(expr, deltas, no_deltas_rule):
|
||||
be searched for by walking up the expression tree. If this cannot be
|
||||
reflected, then an action will be taken based on the
|
||||
``no_deltas_rule``.
|
||||
no_deltas_rule : {'warn', 'raise', 'ignore'}
|
||||
no_deltas_rule : no_deltas_rule
|
||||
How to handle the case where deltas='auto' but no deltas could be
|
||||
found.
|
||||
|
||||
@@ -421,11 +422,11 @@ def _get_deltas(expr, deltas, no_deltas_rule):
|
||||
try:
|
||||
return expr._child[(expr._name or '') + '_deltas']
|
||||
except (ValueError, AttributeError):
|
||||
if no_deltas_rule == 'raise':
|
||||
if no_deltas_rule == no_deltas_rules.raise_:
|
||||
raise ValueError(
|
||||
"no deltas table could be reflected for %s" % expr
|
||||
)
|
||||
elif no_deltas_rule == 'warn':
|
||||
elif no_deltas_rule == no_deltas_rules.warn:
|
||||
warnings.warn(NoDeltasWarning(expr))
|
||||
return None
|
||||
|
||||
@@ -466,12 +467,12 @@ def _ensure_timestamp_field(dataset_expr, deltas):
|
||||
return dataset_expr, deltas
|
||||
|
||||
|
||||
@expect_element(no_deltas_rule=_valid_no_deltas_rules)
|
||||
@expect_element(no_deltas_rule=no_deltas_rules)
|
||||
def from_blaze(expr,
|
||||
deltas='auto',
|
||||
loader=None,
|
||||
resources=None,
|
||||
no_deltas_rule=_valid_no_deltas_rules[0]):
|
||||
no_deltas_rule=no_deltas_rules.warn):
|
||||
"""Create a Pipeline API object from a blaze expression.
|
||||
|
||||
Parameters
|
||||
@@ -490,7 +491,7 @@ def from_blaze(expr,
|
||||
resources : dict or any, optional
|
||||
The data to execute the blaze expressions against. This is used as the
|
||||
scope for ``bz.compute``.
|
||||
no_deltas_rule : {'warn', 'raise', 'ignore'}
|
||||
no_deltas_rule : no_deltas_rule
|
||||
What should happen if ``deltas='auto'`` but no deltas can be found.
|
||||
'warn' says to raise a warning but continue.
|
||||
'raise' says to raise an exception if no deltas can be found.
|
||||
@@ -505,7 +506,7 @@ def from_blaze(expr,
|
||||
is passed, a ``BoundColumn`` on the dataset that would be constructed
|
||||
from passing the parent is returned.
|
||||
"""
|
||||
deltas = _get_deltas(expr, deltas, no_deltas_rule)
|
||||
deltas = get_deltas(expr, deltas, no_deltas_rule)
|
||||
if deltas is not None:
|
||||
invalid_nodes = tuple(filter(is_invalid_deltas_node, expr._subterms()))
|
||||
if invalid_nodes:
|
||||
@@ -886,7 +887,7 @@ class BlazeLoader(dict):
|
||||
|
||||
for column_idx, column in enumerate(columns):
|
||||
column_name = column.name
|
||||
yield column, adjusted_array(
|
||||
yield column, AdjustedArray(
|
||||
column_view(
|
||||
dense_output[column_name].values.astype(column.dtype),
|
||||
),
|
||||
@@ -0,0 +1,150 @@
|
||||
import blaze as bz
|
||||
from datashape import istabular
|
||||
from odo import odo
|
||||
import pandas as pd
|
||||
from six import iteritems
|
||||
from toolz import valmap
|
||||
|
||||
from .core import TS_FIELD_NAME, SID_FIELD_NAME
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
|
||||
|
||||
|
||||
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
|
||||
|
||||
|
||||
def bind_expression_to_resources(expr, resources):
|
||||
"""
|
||||
Bind a Blaze expression to resources.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : bz.Expr
|
||||
The expression to which we want to bind resources.
|
||||
resources : dict[bz.Symbol -> any]
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bound_expr : bz.Expr
|
||||
``expr`` with bound resources.
|
||||
"""
|
||||
# bind the resources into the expression
|
||||
if resources is None:
|
||||
resources = {}
|
||||
|
||||
# _subs stands for substitute. It's not actually private, blaze just
|
||||
# prefixes symbol-manipulation methods with underscores to prevent
|
||||
# collisions with data column names.
|
||||
return expr._subs({
|
||||
k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
|
||||
})
|
||||
|
||||
|
||||
class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
"""A pipeline loader for the ``EarningsCalendar`` dataset that loads
|
||||
data from a blaze expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
resources : dict, optional
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The expression should have a tabular dshape of::
|
||||
|
||||
Dim * {{
|
||||
{SID_FIELD_NAME}: int64,
|
||||
{TS_FIELD_NAME}: datetime64,
|
||||
{ANNOUNCEMENT_FIELD_NAME}: datetime64,
|
||||
}}
|
||||
|
||||
Where each row of the table is a record including the sid to identify the
|
||||
company, the timestamp where we learned about the announcement, and the
|
||||
date when the earnings will be announced.
|
||||
|
||||
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
|
||||
start the backtest with knowledge of all announcements.
|
||||
"""
|
||||
__doc__ = __doc__.format(
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
ANNOUNCEMENT_FIELD_NAME=ANNOUNCEMENT_FIELD_NAME,
|
||||
)
|
||||
|
||||
_expected_fields = frozenset({
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
})
|
||||
|
||||
def __init__(self,
|
||||
expr,
|
||||
resources=None,
|
||||
compute_kwargs=None,
|
||||
odo_kwargs=None):
|
||||
dshape = expr.dshape
|
||||
|
||||
if not istabular(dshape):
|
||||
raise ValueError(
|
||||
'expression dshape must be tabular, got: %s' % dshape,
|
||||
)
|
||||
|
||||
expected_fields = self._expected_fields
|
||||
self._expr = bind_expression_to_resources(
|
||||
expr[list(expected_fields)],
|
||||
resources,
|
||||
)
|
||||
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
expr = self._expr
|
||||
filtered = expr[expr[TS_FIELD_NAME] <= dates[0]]
|
||||
lower = odo(
|
||||
bz.by(
|
||||
filtered[SID_FIELD_NAME],
|
||||
timestamp=filtered[TS_FIELD_NAME].max(),
|
||||
).timestamp.min(),
|
||||
pd.Timestamp,
|
||||
**self._odo_kwargs
|
||||
)
|
||||
if pd.isnull(lower):
|
||||
# If there is no lower date, just query for data in the date
|
||||
# range. It must all be null anyways.
|
||||
lower = dates[0]
|
||||
|
||||
raw = odo(
|
||||
expr[
|
||||
(expr[TS_FIELD_NAME] >= lower) &
|
||||
(expr[TS_FIELD_NAME] <= dates[-1])
|
||||
],
|
||||
pd.DataFrame,
|
||||
**self._odo_kwargs
|
||||
)
|
||||
|
||||
sids = raw.loc[:, SID_FIELD_NAME]
|
||||
raw.drop(
|
||||
sids[~(sids.isin(assets) | sids.notnull())].index,
|
||||
inplace=True
|
||||
)
|
||||
|
||||
gb = raw.groupby(SID_FIELD_NAME)
|
||||
|
||||
def mkseries(idx, raw_loc=raw.loc):
|
||||
vs = raw_loc[
|
||||
idx, [TS_FIELD_NAME, ANNOUNCEMENT_FIELD_NAME]
|
||||
].values
|
||||
return pd.Series(
|
||||
index=pd.DatetimeIndex(vs[:, 0]),
|
||||
data=vs[:, 1],
|
||||
)
|
||||
|
||||
return EarningsCalendarLoader(
|
||||
dates,
|
||||
valmap(mkseries, gb.groups),
|
||||
).load_adjusted_array(columns, dates, assets, mask)
|
||||
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Reference implementation for EarningsCalendar loaders.
|
||||
"""
|
||||
from itertools import repeat
|
||||
|
||||
from numpy import full_like, full
|
||||
import pandas as pd
|
||||
from six import iteritems
|
||||
from six.moves import zip
|
||||
from toolz import merge
|
||||
|
||||
from .base import PipelineLoader
|
||||
from .frame import DataFrameLoader
|
||||
from ..data.earnings import EarningsCalendar
|
||||
from zipline.utils.numpy_utils import np_NaT
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
|
||||
class EarningsCalendarLoader(PipelineLoader):
|
||||
"""
|
||||
Reference loader for
|
||||
:class:`zipline.pipeline.data.earnings.EarningsCalendar`.
|
||||
|
||||
Does not currently support adjustments to the dates of known earnings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
all_dates : pd.DatetimeIndex
|
||||
Index of dates for which we can serve queries.
|
||||
announcement_dates : dict[int -> pd.Series or pd.DatetimeIndex]
|
||||
Dict mapping sids to objects representing dates on which earnings
|
||||
occurred.
|
||||
|
||||
If a dict value is a Series, it's interpreted as a mapping from the
|
||||
date on which we learned an announcement was coming to the date on
|
||||
which the announcement was made.
|
||||
|
||||
If a dict value is a DatetimeIndex, it's interpreted as just containing
|
||||
the dates that announcements were made, and we assume we knew about the
|
||||
announcement on all prior dates. This mode is only supported if
|
||||
``infer_timestamp`` is explicitly passed as a truthy value.
|
||||
|
||||
infer_timestamps : bool, optional
|
||||
Whether to allow passing ``DatetimeIndex`` values in
|
||||
``announcement_dates``.
|
||||
"""
|
||||
def __init__(self, all_dates, announcement_dates, infer_timestamps=False):
|
||||
self.all_dates = all_dates
|
||||
|
||||
self.announcement_dates = announcement_dates = (
|
||||
announcement_dates.copy()
|
||||
)
|
||||
dates = self.all_dates.values
|
||||
for k, v in iteritems(announcement_dates):
|
||||
if isinstance(v, pd.DatetimeIndex):
|
||||
if not infer_timestamps:
|
||||
raise ValueError(
|
||||
"Got DatetimeIndex of announcement dates for sid %d.\n"
|
||||
"Pass `infer_timestamps=True` to use the first date in"
|
||||
" `all_dates` as implicit timestamp."
|
||||
)
|
||||
# If we are passed a DatetimeIndex, we always have
|
||||
# knowledge of the announcements.
|
||||
announcement_dates[k] = pd.Series(
|
||||
v, index=repeat(dates[0], len(v)),
|
||||
)
|
||||
|
||||
def get_loader(self, column):
|
||||
"""Dispatch to the loader for ``column``.
|
||||
"""
|
||||
if column is EarningsCalendar.next_announcement:
|
||||
return self.next_announcement_loader
|
||||
elif column is EarningsCalendar.previous_announcement:
|
||||
return self.previous_announcement_loader
|
||||
else:
|
||||
raise ValueError("Don't know how to load column '%s'." % column)
|
||||
|
||||
@lazyval
|
||||
def next_announcement_loader(self):
|
||||
return DataFrameLoader(
|
||||
EarningsCalendar.next_announcement,
|
||||
next_earnings_date_frame(
|
||||
self.all_dates,
|
||||
self.announcement_dates,
|
||||
),
|
||||
adjustments=None,
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def previous_announcement_loader(self):
|
||||
return DataFrameLoader(
|
||||
EarningsCalendar.previous_announcement,
|
||||
previous_earnings_date_frame(
|
||||
self.all_dates,
|
||||
self.announcement_dates,
|
||||
),
|
||||
adjustments=None,
|
||||
)
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
return merge(
|
||||
self.get_loader(column).load_adjusted_array(
|
||||
[column], dates, assets, mask
|
||||
)
|
||||
for column in columns
|
||||
)
|
||||
|
||||
|
||||
def next_earnings_date_frame(dates, announcement_dates):
|
||||
"""
|
||||
Make a DataFrame representing simulated next earnings dates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : pd.DatetimeIndex.
|
||||
The index of the returned DataFrame.
|
||||
announcement_dates : dict[int -> pd.Series]
|
||||
Dict mapping sids to an index of dates on which earnings were announced
|
||||
for that sid.
|
||||
|
||||
Returns
|
||||
-------
|
||||
next_earnings: pd.DataFrame
|
||||
A DataFrame representing, for each (label, date) pair, the first entry
|
||||
in `earnings_calendars[label]` on or after `date`. Entries falling
|
||||
after the last date in a calendar will have `np_NaT` as the result in
|
||||
the output.
|
||||
|
||||
See Also
|
||||
--------
|
||||
previous_earnings_date_frame
|
||||
"""
|
||||
cols = {equity: full_like(dates, np_NaT) for equity in announcement_dates}
|
||||
raw_dates = dates.values
|
||||
for equity, earnings_dates in iteritems(announcement_dates):
|
||||
data = cols[equity]
|
||||
if not earnings_dates.index.is_monotonic_increasing:
|
||||
earnings_dates = earnings_dates.sort_index()
|
||||
|
||||
# Iterate over the raw Series values, since we're comparing against
|
||||
# numpy arrays anyway.
|
||||
iterkv = zip(earnings_dates.index.values, earnings_dates.values)
|
||||
for timestamp, announce_date in iterkv:
|
||||
date_mask = (timestamp <= raw_dates) & (raw_dates <= announce_date)
|
||||
value_mask = (announce_date <= data) | (data == np_NaT)
|
||||
data[date_mask & value_mask] = announce_date
|
||||
|
||||
return pd.DataFrame(index=dates, data=cols)
|
||||
|
||||
|
||||
def previous_earnings_date_frame(dates, announcement_dates):
|
||||
"""
|
||||
Make a DataFrame representing simulated next earnings dates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : DatetimeIndex.
|
||||
The index of the returned DataFrame.
|
||||
announcement_dates : dict[int -> DatetimeIndex]
|
||||
Dict mapping sids to an index of dates on which earnings were announced
|
||||
for that sid.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prev_earnings: pd.DataFrame
|
||||
A DataFrame representing, for (label, date) pair, the first entry in
|
||||
`announcement_dates[label]` strictly before `date`. Entries falling
|
||||
before the first date in a calendar will have `NaT` as the result in
|
||||
the output.
|
||||
|
||||
See Also
|
||||
--------
|
||||
next_earnings_date_frame
|
||||
"""
|
||||
sids = list(announcement_dates)
|
||||
out = full((len(dates), len(sids)), np_NaT, dtype='datetime64[ns]')
|
||||
dn = dates[-1].asm8
|
||||
for col_idx, sid in enumerate(sids):
|
||||
# announcement_dates[sid] is Series mapping knowledge_date to actual
|
||||
# announcement date. We don't care about the knowledge date for
|
||||
# computing previous earnings.
|
||||
values = announcement_dates[sid].values
|
||||
values = values[values <= dn]
|
||||
out[dates.searchsorted(values), col_idx] = values
|
||||
|
||||
frame = pd.DataFrame(out, index=dates, columns=sids)
|
||||
frame.ffill(inplace=True)
|
||||
return frame
|
||||
@@ -20,9 +20,7 @@ from zipline.data.us_equity_pricing import (
|
||||
BcolzDailyBarReader,
|
||||
SQLiteAdjustmentReader,
|
||||
)
|
||||
from zipline.lib.adjusted_array import (
|
||||
adjusted_array,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.errors import NoFurtherDataError
|
||||
|
||||
from .base import PipelineLoader
|
||||
@@ -84,7 +82,7 @@ class USEquityPricingLoader(PipelineLoader):
|
||||
assets,
|
||||
)
|
||||
adjusted_arrays = [
|
||||
adjusted_array(raw_array, mask, col_adjustments)
|
||||
AdjustedArray(raw_array, mask, col_adjustments)
|
||||
for raw_array, col_adjustments in zip(raw_arrays, adjustments)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
PipelineLoader accepting a DataFrame as input.
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from numpy import (
|
||||
ix_,
|
||||
zeros,
|
||||
@@ -11,22 +13,10 @@ from pandas import (
|
||||
Index,
|
||||
Int64Index,
|
||||
)
|
||||
from zipline.lib.adjusted_array import adjusted_array
|
||||
from zipline.lib.adjustment import (
|
||||
Float64Add,
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
)
|
||||
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import make_adjustment_from_labels
|
||||
from .base import PipelineLoader
|
||||
|
||||
|
||||
ADD, MULTIPLY, OVERWRITE = range(3)
|
||||
ADJUSTMENT_CONSTRUCTORS = {
|
||||
ADD: Float64Add.from_assets_and_dates,
|
||||
MULTIPLY: Float64Multiply.from_assets_and_dates,
|
||||
OVERWRITE: Float64Overwrite.from_assets_and_dates,
|
||||
}
|
||||
ADJUSTMENT_COLUMNS = Index([
|
||||
'sid',
|
||||
'value',
|
||||
@@ -91,7 +81,7 @@ class DataFrameLoader(PipelineLoader):
|
||||
def format_adjustments(self, dates, assets):
|
||||
"""
|
||||
Build a dict of Adjustment objects in the format expected by
|
||||
adjusted_array.
|
||||
AdjustedArray.
|
||||
|
||||
Returns a dict of the form:
|
||||
{
|
||||
@@ -105,6 +95,8 @@ class DataFrameLoader(PipelineLoader):
|
||||
...
|
||||
}
|
||||
"""
|
||||
make_adjustment = partial(make_adjustment_from_labels, dates, assets)
|
||||
|
||||
min_date, max_date = dates[[0, -1]]
|
||||
# TODO: Consider porting this to Cython.
|
||||
if len(self.adjustments) == 0:
|
||||
@@ -148,14 +140,7 @@ class DataFrameLoader(PipelineLoader):
|
||||
# Look up the approprate Adjustment constructor based on the value
|
||||
# of `kind`.
|
||||
current_date_adjustments.append(
|
||||
ADJUSTMENT_CONSTRUCTORS[kind](
|
||||
dates,
|
||||
assets,
|
||||
start_date,
|
||||
end_date,
|
||||
sid,
|
||||
value,
|
||||
),
|
||||
make_adjustment(start_date, end_date, sid, kind, value)
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -163,11 +148,12 @@ class DataFrameLoader(PipelineLoader):
|
||||
"""
|
||||
Load data from our stored baseline.
|
||||
"""
|
||||
column = self.column
|
||||
if len(columns) != 1:
|
||||
raise ValueError(
|
||||
"Can't load multiple columns with DataFrameLoader"
|
||||
)
|
||||
elif columns[0] != self.column:
|
||||
elif columns[0] != column:
|
||||
raise ValueError("Can't load unknown column %s" % columns[0])
|
||||
|
||||
date_indexer = self.dates.get_indexer(dates)
|
||||
@@ -177,11 +163,12 @@ class DataFrameLoader(PipelineLoader):
|
||||
good_dates = (date_indexer != -1)
|
||||
good_assets = (assets_indexer != -1)
|
||||
|
||||
arrays = [adjusted_array(
|
||||
# Pull out requested columns/rows from our baseline data.
|
||||
data=self.baseline[ix_(date_indexer, assets_indexer)],
|
||||
# Mask out requested columns/rows that didnt match.
|
||||
mask=(good_assets & good_dates[:, None]) & mask,
|
||||
adjustments=self.format_adjustments(dates, assets),
|
||||
)]
|
||||
return dict(zip(columns, arrays))
|
||||
return {
|
||||
column: AdjustedArray(
|
||||
# Pull out requested columns/rows from our baseline data.
|
||||
data=self.baseline[ix_(date_indexer, assets_indexer)],
|
||||
# Mask out requested columns/rows that didnt match.
|
||||
mask=(good_assets & good_dates[:, None]) & mask,
|
||||
adjustments=self.format_adjustments(dates, assets),
|
||||
),
|
||||
}
|
||||
|
||||
+68
-24
@@ -4,17 +4,19 @@ Base class for Filters, Factors and Classifiers
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import bool_, full, nan
|
||||
from numpy import full_like, dtype as dtype_class
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
TermInputsNotSpecified,
|
||||
WindowLengthNotPositive,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype, default_fillvalue_for_dtype
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
|
||||
@@ -54,9 +56,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
|
||||
if domain is NotSpecified:
|
||||
domain = cls.domain
|
||||
|
||||
if dtype is NotSpecified:
|
||||
dtype = cls.dtype
|
||||
dtype = cls._validate_dtype(dtype)
|
||||
|
||||
identity = cls.static_identity(
|
||||
domain=domain,
|
||||
@@ -75,6 +75,41 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
)
|
||||
return new_instance
|
||||
|
||||
@classmethod
|
||||
def _validate_dtype(cls, passed_dtype):
|
||||
"""
|
||||
Validate a `dtype` passed to Term.__new__.
|
||||
|
||||
If passed_dtype is NotSpecified, then we try to fall back to a
|
||||
class-level attribute. If a value is found at that point, we pass it
|
||||
to np.dtype so that users can pass `float` or `bool` and have them
|
||||
coerce to the appropriate numpy types.
|
||||
|
||||
Returns
|
||||
-------
|
||||
validated : np.dtype
|
||||
The dtype to use for the new term.
|
||||
|
||||
Raises
|
||||
------
|
||||
DTypeNotSpecified
|
||||
When no dtype was passed to the instance, and the class doesn't
|
||||
provide a default.
|
||||
InvalidDType
|
||||
When either the class or the instance provides a value not
|
||||
coercible to a numpy dtype.
|
||||
"""
|
||||
dtype = passed_dtype
|
||||
if dtype is NotSpecified:
|
||||
dtype = cls.dtype
|
||||
if dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=cls.__name__)
|
||||
try:
|
||||
dtype = dtype_class(dtype)
|
||||
except TypeError:
|
||||
raise InvalidDType(dtype=dtype, termname=cls.__name__)
|
||||
return dtype
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Noop constructor to play nicely with our caching __new__. Subclasses
|
||||
@@ -91,13 +126,6 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _init(self, domain, dtype):
|
||||
self.domain = domain
|
||||
self.dtype = dtype
|
||||
|
||||
self._validate()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, domain, dtype):
|
||||
"""
|
||||
@@ -113,13 +141,27 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
return (cls, domain, dtype)
|
||||
|
||||
def _init(self, domain, dtype):
|
||||
self.domain = domain
|
||||
self.dtype = dtype
|
||||
|
||||
# Make sure that subclasses call super() in their _validate() methods
|
||||
# by setting this flag. The base class implementation of _validate
|
||||
# should set this flag to True.
|
||||
self._subclass_called_super_validate = False
|
||||
self._validate()
|
||||
del self._subclass_called_super_validate
|
||||
|
||||
return self
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Assert that this term is well-formed. This should be called exactly
|
||||
once, at the end of Term._init().
|
||||
"""
|
||||
if self.dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=type(self).__name__)
|
||||
# mark that we got here to enforce that subclasses overriding _validate
|
||||
# call super().
|
||||
self._subclass_called_super_validate = True
|
||||
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
@@ -145,6 +187,10 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
return not any(dep for dep in self.dependencies
|
||||
if dep is not AssetExists())
|
||||
|
||||
@lazyval
|
||||
def missing_value(self):
|
||||
return default_fillvalue_for_dtype(self.dtype)
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
@@ -160,7 +206,7 @@ class AssetExists(Term):
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
"""
|
||||
dtype = bool_
|
||||
dtype = bool_dtype
|
||||
dataset = None
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
@@ -204,18 +250,15 @@ class CustomTermMixin(object):
|
||||
Used by CustomFactor, CustomFilter, CustomClassifier, etc.
|
||||
"""
|
||||
|
||||
def __new__(cls, inputs=NotSpecified, window_length=NotSpecified):
|
||||
|
||||
def __new__(cls,
|
||||
inputs=NotSpecified,
|
||||
window_length=NotSpecified,
|
||||
dtype=NotSpecified):
|
||||
return super(CustomTermMixin, cls).__new__(
|
||||
cls,
|
||||
inputs=inputs,
|
||||
window_length=window_length,
|
||||
)
|
||||
|
||||
def __init__(self, inputs=NotSpecified, window_length=NotSpecified):
|
||||
return super(CustomTermMixin, self).__init__(
|
||||
inputs=inputs,
|
||||
window_length=window_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def compute(self, today, assets, out, *arrays):
|
||||
@@ -231,7 +274,8 @@ class CustomTermMixin(object):
|
||||
"""
|
||||
# TODO: Make mask available to user's `compute`.
|
||||
compute = self.compute
|
||||
out = full(mask.shape, nan, dtype=self.dtype)
|
||||
missing_value = self.missing_value
|
||||
out = full_like(mask, missing_value, dtype=self.dtype)
|
||||
with self.ctx:
|
||||
# TODO: Consider pre-filtering columns that are all-nan at each
|
||||
# time-step?
|
||||
@@ -242,7 +286,7 @@ class CustomTermMixin(object):
|
||||
out[idx],
|
||||
*(next(w) for w in windows)
|
||||
)
|
||||
out[~mask] = nan
|
||||
out[~mask] = missing_value
|
||||
return out
|
||||
|
||||
def short_repr(self):
|
||||
|
||||
+3
-3
@@ -18,8 +18,8 @@ from six import iteritems, iterkeys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from . utils.protocol_utils import Enum
|
||||
from . utils.math_utils import nanstd, nanmean, nansum
|
||||
from .utils.enum import enum
|
||||
from .utils.math_utils import nanstd, nanmean, nansum
|
||||
|
||||
from zipline.utils.algo_instance import get_algo_instance
|
||||
from zipline.utils.serialization_utils import (
|
||||
@@ -28,7 +28,7 @@ from zipline.utils.serialization_utils import (
|
||||
|
||||
# Datasource type should completely determine the other fields of a
|
||||
# message with its type.
|
||||
DATASOURCE_TYPE = Enum(
|
||||
DATASOURCE_TYPE = enum(
|
||||
'AS_TRADED_EQUITY',
|
||||
'MERGER',
|
||||
'SPLIT',
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
#
|
||||
# Copyright 2015 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ctypes import (
|
||||
Structure,
|
||||
c_ubyte,
|
||||
c_uint,
|
||||
c_ulong,
|
||||
c_ulonglong,
|
||||
c_ushort,
|
||||
sizeof,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from six.moves import range
|
||||
|
||||
|
||||
_inttypes_map = {
|
||||
sizeof(t) - 1: t for t in {
|
||||
c_ubyte,
|
||||
c_uint,
|
||||
c_ulong,
|
||||
c_ulonglong,
|
||||
c_ushort
|
||||
}
|
||||
}
|
||||
_inttypes = list(
|
||||
pd.Series(_inttypes_map).reindex(
|
||||
range(max(_inttypes_map.keys())),
|
||||
method='bfill',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def enum(option, *options):
|
||||
"""
|
||||
Construct a new enum object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*options : iterable of str
|
||||
The names of the fields for the enum.
|
||||
|
||||
Returns
|
||||
-------
|
||||
enum
|
||||
A new enum collection.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> e = enum('a', 'b', 'c')
|
||||
>>> e
|
||||
<enum: ('a', 'b', 'c')>
|
||||
>>> e.a
|
||||
0
|
||||
>>> e.b
|
||||
1
|
||||
>>> e.a in e
|
||||
True
|
||||
>>> tuple(e)
|
||||
(0, 1, 2)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Identity checking is not guaranteed to work with enum members, instead
|
||||
equality checks should be used. From CPython's documentation:
|
||||
|
||||
"The current implementation keeps an array of integer objects for all
|
||||
integers between -5 and 256, when you create an int in that range you
|
||||
actually just get back a reference to the existing object. So it should be
|
||||
possible to change the value of 1. I suspect the behaviour of Python in
|
||||
this case is undefined. :-)"
|
||||
"""
|
||||
options = (option,) + options
|
||||
rangeob = range(len(options))
|
||||
|
||||
try:
|
||||
inttype = _inttypes[int(np.log2(len(options) - 1)) // 8]
|
||||
except IndexError:
|
||||
raise OverflowError(
|
||||
'Cannot store enums with more than sys.maxsize elements, got %d' %
|
||||
len(options),
|
||||
)
|
||||
|
||||
class _enum(Structure):
|
||||
_fields_ = [(o, inttype) for o in options]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(rangeob)
|
||||
|
||||
def __contains__(self, value):
|
||||
return 0 <= value < len(options)
|
||||
|
||||
def __repr__(self):
|
||||
return '<enum: %s>' % (
|
||||
('%d fields' % len(options))
|
||||
if len(options) > 10 else
|
||||
repr(options)
|
||||
)
|
||||
|
||||
return _enum(*rangeob)
|
||||
@@ -11,8 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from operator import attrgetter
|
||||
|
||||
from six import iteritems, string_types
|
||||
from numpy import dtype
|
||||
from six import iteritems, string_types, PY3
|
||||
from toolz import valmap, complement, compose
|
||||
import toolz.curried.operator as op
|
||||
|
||||
@@ -30,6 +32,98 @@ def ensure_upper_case(func, argname, arg):
|
||||
)
|
||||
|
||||
|
||||
def ensure_dtype(func, argname, arg):
|
||||
"""
|
||||
Argument preprocessor that converts the input into a numpy dtype.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> import numpy as np
|
||||
>>> from zipline.utils.preprocess import preprocess
|
||||
>>> @preprocess(dtype=ensure_dtype)
|
||||
... def foo(dtype):
|
||||
... return dtype
|
||||
...
|
||||
>>> foo(float)
|
||||
dtype('float64')
|
||||
"""
|
||||
try:
|
||||
return dtype(arg)
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
"{func}() couldn't convert argument "
|
||||
"{argname}={arg!r} to a numpy dtype.".format(
|
||||
func=_qualified_name(func),
|
||||
argname=argname,
|
||||
arg=arg,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def expect_dtypes(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected numpy dtypes.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> from numpy import dtype, arange
|
||||
>>> @expect_dtypes(x=dtype(int))
|
||||
... def foo(x, y):
|
||||
... return x, y
|
||||
...
|
||||
>>> foo(arange(3), 'foo')
|
||||
(array([0, 1, 2]), 'foo')
|
||||
>>> foo(arange(3, dtype=float), 'foo')
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
TypeError: foo() expected an argument with dtype 'int64' for argument 'x', but got dtype 'float64' instead. # noqa
|
||||
"""
|
||||
if _pos:
|
||||
raise TypeError("expect_dtypes() only takes keyword arguments.")
|
||||
|
||||
for name, type_ in iteritems(named):
|
||||
if not isinstance(type_, (dtype, tuple)):
|
||||
raise TypeError(
|
||||
"expect_dtypes() expected a numpy dtype or tuple of dtypes"
|
||||
" for argument {name!r}, but got {dtype} instead.".format(
|
||||
name=name, dtype=dtype,
|
||||
)
|
||||
)
|
||||
return preprocess(**valmap(_expect_dtype, named))
|
||||
|
||||
|
||||
def _expect_dtype(_dtype_or_dtype_tuple):
|
||||
"""
|
||||
Factory for dtype-checking functions that work the @preprocess decorator.
|
||||
"""
|
||||
# Slightly different messages for dtype and tuple of dtypes.
|
||||
if isinstance(_dtype_or_dtype_tuple, tuple):
|
||||
allowed_dtypes = _dtype_or_dtype_tuple
|
||||
else:
|
||||
allowed_dtypes = (_dtype_or_dtype_tuple,)
|
||||
template = (
|
||||
"%(funcname)s() expected a value with dtype {dtype_str} "
|
||||
"for argument '%(argname)s', but got %(actual)r instead."
|
||||
).format(dtype_str=' or '.join(repr(d.name) for d in allowed_dtypes))
|
||||
|
||||
def check_dtype(value):
|
||||
return getattr(value, 'dtype', None) not in allowed_dtypes
|
||||
|
||||
def display_bad_value(value):
|
||||
# If the bad value has a dtype, but it's wrong, show the dtype name.
|
||||
try:
|
||||
return value.dtype.name
|
||||
except AttributeError:
|
||||
return value
|
||||
|
||||
return make_check(
|
||||
exc_type=TypeError,
|
||||
template=template,
|
||||
pred=check_dtype,
|
||||
actual=display_bad_value,
|
||||
)
|
||||
|
||||
|
||||
def expect_types(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected types.
|
||||
@@ -62,20 +156,43 @@ def expect_types(*_pos, **named):
|
||||
return preprocess(**valmap(_expect_type, named))
|
||||
|
||||
|
||||
def _qualified_name(obj):
|
||||
"""
|
||||
Return the fully-qualified name (ignoring inner classes) of a type.
|
||||
"""
|
||||
module = obj.__module__
|
||||
if module in ('__builtin__', '__main__', 'builtins'):
|
||||
return obj.__name__
|
||||
return '.'.join([module, obj.__name__])
|
||||
if PY3:
|
||||
_qualified_name = attrgetter('__qualname__')
|
||||
else:
|
||||
def _qualified_name(obj):
|
||||
"""
|
||||
Return the fully-qualified name (ignoring inner classes) of a type.
|
||||
"""
|
||||
module = obj.__module__
|
||||
if module in ('__builtin__', '__main__', 'builtins'):
|
||||
return obj.__name__
|
||||
return '.'.join([module, obj.__name__])
|
||||
|
||||
|
||||
def _mk_check(exc, template, pred, actual):
|
||||
def make_check(exc_type, template, pred, actual):
|
||||
"""
|
||||
Factory for making preprocessing functions that check a predicate on the
|
||||
input value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
exc_type : Exception
|
||||
The exception type to raise if the predicate fails.
|
||||
template : str
|
||||
A template string to use to create error messages.
|
||||
Should have %-style named template parameters for 'funcname',
|
||||
'argname', and 'actual'.
|
||||
pred : function[object -> bool]
|
||||
A function to call on the argument being preprocessed. If the
|
||||
predicate returns `True`, we raise an instance of `exc_type`.
|
||||
actual : function[object -> object]
|
||||
A function to call on bad values to produce the value to display in the
|
||||
error message.
|
||||
"""
|
||||
|
||||
def _check(func, argname, argvalue):
|
||||
if pred(argvalue):
|
||||
raise exc(
|
||||
raise exc_type(
|
||||
template % {
|
||||
'funcname': _qualified_name(func),
|
||||
'argname': argname,
|
||||
@@ -102,7 +219,7 @@ def _expect_type(type_):
|
||||
else:
|
||||
template = _template.format(type_or_types=_qualified_name(type_))
|
||||
|
||||
return _mk_check(
|
||||
return make_check(
|
||||
TypeError,
|
||||
template,
|
||||
lambda v: not isinstance(v, type_),
|
||||
@@ -171,7 +288,7 @@ def _expect_element(collection):
|
||||
"%(funcname)s() expected a value in {collection} "
|
||||
"for argument '%(argname)s', but got %(actual)s instead."
|
||||
).format(collection=collection)
|
||||
return _mk_check(
|
||||
return make_check(
|
||||
ValueError,
|
||||
template,
|
||||
complement(op.contains(collection)),
|
||||
|
||||
@@ -53,6 +53,9 @@ class lazyval(object):
|
||||
def __set__(self, instance, value):
|
||||
raise AttributeError("Can't set read-only attribute.")
|
||||
|
||||
def __delitem__(self, instance):
|
||||
del self._cache[instance]
|
||||
|
||||
|
||||
def remember_last(f):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,40 @@
|
||||
"""
|
||||
Utilities for working with numpy arrays.
|
||||
"""
|
||||
from numpy import (
|
||||
broadcast,
|
||||
busday_count,
|
||||
datetime64,
|
||||
dtype,
|
||||
empty,
|
||||
nan,
|
||||
where
|
||||
)
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
from toolz import flip
|
||||
|
||||
uint8_dtype = dtype('uint8')
|
||||
bool_dtype = dtype('bool')
|
||||
int64_dtype = dtype('int64')
|
||||
float64_dtype = dtype('float64')
|
||||
datetime64D_dtype = dtype('datetime64[D]')
|
||||
datetime64ns_dtype = dtype('datetime64[ns]')
|
||||
|
||||
make_datetime64ns = flip(datetime64, 'ns')
|
||||
make_datetime64D = flip(datetime64, 'D')
|
||||
np_NaT = make_datetime64ns('NaT')
|
||||
|
||||
_FILLVALUE_DEFAULTS = {
|
||||
float64_dtype: nan,
|
||||
datetime64ns_dtype: np_NaT,
|
||||
}
|
||||
|
||||
|
||||
def default_fillvalue_for_dtype(dtype):
|
||||
"""
|
||||
Get the default fill value for `dtype`.
|
||||
"""
|
||||
return _FILLVALUE_DEFAULTS[dtype]
|
||||
|
||||
|
||||
def repeat_first_axis(array, count):
|
||||
@@ -88,3 +121,37 @@ def repeat_last_axis(array, count):
|
||||
repeat_last_axis
|
||||
"""
|
||||
return as_strided(array, array.shape + (count,), array.strides + (0,))
|
||||
|
||||
|
||||
# Sentinel value that isn't NaT.
|
||||
_notNaT = make_datetime64D(0)
|
||||
|
||||
|
||||
def busday_count_mask_NaT(begindates, enddates, out=None):
|
||||
"""
|
||||
Simple of numpy.busday_count that returns `float` arrays rather than int
|
||||
arrays, and handles `NaT`s by returning `NaN`s where the inputs were `NaT`.
|
||||
|
||||
Doesn't support custom weekdays or calendars, but probably should in the
|
||||
future.
|
||||
|
||||
See Also
|
||||
--------
|
||||
np.busday_count
|
||||
"""
|
||||
if out is None:
|
||||
out = empty(broadcast(begindates, enddates).shape, dtype=float)
|
||||
|
||||
beginmask = (begindates == np_NaT)
|
||||
endmask = (enddates == np_NaT)
|
||||
|
||||
out = busday_count(
|
||||
# Temporarily fill in non-NaT values.
|
||||
where(beginmask, _notNaT, begindates),
|
||||
where(endmask, _notNaT, enddates),
|
||||
out=out,
|
||||
)
|
||||
|
||||
# Fill in entries where either comparison was NaT with nan in the output.
|
||||
out[beginmask | endmask] = nan
|
||||
return out
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
#
|
||||
# Copyright 2012 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ctypes import Structure, c_ubyte
|
||||
|
||||
|
||||
def Enum(*options):
|
||||
"""
|
||||
Fast enums are very important when we want really tight
|
||||
loops. These are probably going to evolve into pure C structs
|
||||
anyways so might as well get going on that.
|
||||
"""
|
||||
class cstruct(Structure):
|
||||
_fields_ = [(o, c_ubyte) for o in options]
|
||||
|
||||
def __iter__(s):
|
||||
return iter(range(len(options)))
|
||||
return cstruct(*range(len(options)))
|
||||
@@ -1,6 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from itertools import (
|
||||
combinations,
|
||||
count,
|
||||
product,
|
||||
)
|
||||
@@ -18,6 +19,7 @@ from pandas.tseries.offsets import MonthBegin
|
||||
from six import iteritems, itervalues
|
||||
from six.moves import filter
|
||||
from sqlalchemy import create_engine
|
||||
from toolz import concat
|
||||
|
||||
from zipline.assets import AssetFinder
|
||||
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
|
||||
@@ -603,3 +605,28 @@ def subtest(iterator, *_names):
|
||||
|
||||
return wrapped
|
||||
return dec
|
||||
|
||||
|
||||
def assert_timestamp_equal(left, right, compare_nat_equal=True, msg=""):
|
||||
"""
|
||||
Assert that two pandas Timestamp objects are the same.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left, right : pd.Timestamp
|
||||
The values to compare.
|
||||
compare_nat_equal : bool, optional
|
||||
Whether to consider `NaT` values equal. Defaults to True.
|
||||
msg : str, optional
|
||||
A message to forward to `pd.util.testing.assert_equal`.
|
||||
"""
|
||||
if compare_nat_equal and left is pd.NaT and right is pd.NaT:
|
||||
return
|
||||
return pd.util.testing.assert_equal(left, right, msg=msg)
|
||||
|
||||
|
||||
def powerset(values):
|
||||
"""
|
||||
Return the power set (i.e., the set of all subsets) of entries in `values`.
|
||||
"""
|
||||
return concat(combinations(values, i) for i in range(len(values) + 1))
|
||||
|
||||
Reference in New Issue
Block a user