mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 18:13:58 +08:00
Merge pull request #962 from quantopian/pipeline-missing-values
Pipeline missing values
This commit is contained in:
@@ -17,6 +17,11 @@ Highlights
|
||||
* :class:`~zipline.assets.assets.AssetFinder` speedups (:issue:`830` and
|
||||
:issue:`817`).
|
||||
|
||||
* Improved support for non-float dtypes in Pipeline. Most notably, we now
|
||||
support ``datetime64`` and ``int64`` dtypes for ``Factor``, and
|
||||
``BoundColumn.latest`` now returns a proper ``Filter`` object when the column
|
||||
is of dtype ``bool``.
|
||||
|
||||
Enhancements
|
||||
~~~~~~~~~~~~
|
||||
|
||||
@@ -83,6 +88,17 @@ Enhancements
|
||||
data that is timestamped on or after ``8:45`` will not seen on that day in the
|
||||
simulation. The data will be made available on the next day (:issue:`947`).
|
||||
|
||||
* ``BoundColumn.latest`` now returns a
|
||||
:class:`~zipline.pipeline.filters.Filter` for columns of dtype
|
||||
``bool`` (:issue:`962`).
|
||||
|
||||
* Added support for :class:`~zipline.pipeline.factors.Factor` instances with
|
||||
``int64`` dtype. :class:`~zipline.pipeline.data.dataset.Column` now requires
|
||||
a ``missing_value`` when dtype is integral. (:issue:`962`)
|
||||
|
||||
* It is also now possible to specify custom ``missing_value`` values for
|
||||
``float``, ``datetime``, and ``bool`` Pipeline terms. (:issue:`962`)
|
||||
|
||||
Experimental Features
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from numpy import (
|
||||
arange,
|
||||
array,
|
||||
full,
|
||||
where,
|
||||
)
|
||||
from numpy.testing import assert_array_equal
|
||||
from six.moves import zip_longest
|
||||
@@ -23,9 +24,21 @@ from zipline.lib.adjustment import (
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
make_datetime64ns,
|
||||
)
|
||||
from zipline.utils.test_utils import check_arrays, parameter_space
|
||||
|
||||
|
||||
def moving_window(array, nrows):
|
||||
"""
|
||||
Simple moving window generator over a 2D numpy array.
|
||||
"""
|
||||
count = num_windows_of_length_M_on_buffers_of_length_N(nrows, len(array))
|
||||
for i in range(count):
|
||||
yield array[i:i + nrows]
|
||||
|
||||
|
||||
def num_windows_of_length_M_on_buffers_of_length_N(M, N):
|
||||
@@ -66,6 +79,7 @@ def _gen_unadjusted_cases(dtype):
|
||||
nrows = 6
|
||||
ncols = 3
|
||||
data = arange(nrows * ncols).astype(dtype).reshape(nrows, ncols)
|
||||
missing_value = default_missing_value_for_dtype(dtype)
|
||||
|
||||
for windowlen in valid_window_lengths(nrows):
|
||||
|
||||
@@ -78,6 +92,7 @@ def _gen_unadjusted_cases(dtype):
|
||||
data,
|
||||
windowlen,
|
||||
{},
|
||||
missing_value,
|
||||
[
|
||||
data[offset:offset + windowlen]
|
||||
for offset in range(num_legal_windows)
|
||||
@@ -230,6 +245,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
|
||||
def _gen_expectations(baseline, adjustments, buffer_as_of, nrows):
|
||||
|
||||
missing_value = default_missing_value_for_dtype(baseline.dtype)
|
||||
for windowlen in valid_window_lengths(nrows):
|
||||
|
||||
num_legal_windows = num_windows_of_length_M_on_buffers_of_length_N(
|
||||
@@ -241,6 +257,7 @@ def _gen_expectations(baseline, adjustments, buffer_as_of, nrows):
|
||||
baseline,
|
||||
windowlen,
|
||||
adjustments,
|
||||
missing_value,
|
||||
[
|
||||
# This is a nasty expression...
|
||||
#
|
||||
@@ -267,9 +284,10 @@ class AdjustedArrayTestCase(TestCase):
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
expected):
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
@@ -282,9 +300,10 @@ class AdjustedArrayTestCase(TestCase):
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
expected):
|
||||
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
@@ -301,18 +320,43 @@ class AdjustedArrayTestCase(TestCase):
|
||||
data,
|
||||
lookback,
|
||||
adjustments,
|
||||
missing_value,
|
||||
expected):
|
||||
array = AdjustedArray(data, NOMASK, adjustments)
|
||||
array = AdjustedArray(data, NOMASK, adjustments, missing_value)
|
||||
for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable.
|
||||
window_iter = array.traverse(lookback)
|
||||
for yielded, expected_yield in zip_longest(window_iter, expected):
|
||||
self.assertEqual(yielded.dtype, data.dtype)
|
||||
assert_array_equal(yielded, expected_yield)
|
||||
|
||||
@parameter_space(
|
||||
dtype=[float64_dtype, int64_dtype, datetime64ns_dtype],
|
||||
missing_value=[0, 10000],
|
||||
window_length=[2, 3],
|
||||
)
|
||||
def test_masking(self, dtype, missing_value, window_length):
|
||||
missing_value = value_with_dtype(dtype, missing_value)
|
||||
baseline_ints = arange(15).reshape(5, 3)
|
||||
baseline = baseline_ints.astype(dtype)
|
||||
mask = (baseline_ints % 2).astype(bool)
|
||||
masked_baseline = where(mask, baseline, missing_value)
|
||||
|
||||
array = AdjustedArray(
|
||||
baseline,
|
||||
mask,
|
||||
adjustments={},
|
||||
missing_value=missing_value,
|
||||
)
|
||||
|
||||
gen_expected = moving_window(masked_baseline, window_length)
|
||||
gen_actual = array.traverse(window_length)
|
||||
for expected, actual in zip(gen_expected, gen_actual):
|
||||
check_arrays(expected, actual)
|
||||
|
||||
def test_invalid_lookback(self):
|
||||
|
||||
data = arange(30, dtype=float).reshape(6, 5)
|
||||
adj_array = AdjustedArray(data, NOMASK, {})
|
||||
adj_array = AdjustedArray(data, NOMASK, {}, float('nan'))
|
||||
|
||||
with self.assertRaises(WindowLengthTooLong):
|
||||
adj_array.traverse(7)
|
||||
@@ -326,7 +370,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
def test_array_views_arent_writable(self):
|
||||
|
||||
data = arange(30, dtype=float).reshape(6, 5)
|
||||
adj_array = AdjustedArray(data, NOMASK, {})
|
||||
adj_array = AdjustedArray(data, NOMASK, {}, float('nan'))
|
||||
|
||||
for frame in adj_array.traverse(3):
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -338,7 +382,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
bad_mask = array([[0, 1, 1], [0, 0, 1]], dtype=bool)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
AdjustedArray(data, bad_mask, {})
|
||||
AdjustedArray(data, bad_mask, {}, missing_value=-1)
|
||||
|
||||
def test_inspect(self):
|
||||
data = arange(15, dtype=float).reshape(5, 3)
|
||||
@@ -346,6 +390,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
data,
|
||||
NOMASK,
|
||||
{4: [Float64Multiply(2, 3, 0, 0, 4.0)]},
|
||||
float('nan'),
|
||||
)
|
||||
|
||||
expected = dedent(
|
||||
|
||||
@@ -31,7 +31,11 @@ from zipline.pipeline.loaders.blaze.core import (
|
||||
NonPipelineField,
|
||||
no_deltas_rules,
|
||||
)
|
||||
from zipline.utils.numpy_utils import repeat_last_axis
|
||||
from zipline.utils.numpy_utils import (
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
repeat_last_axis,
|
||||
)
|
||||
from zipline.utils.test_utils import tmp_asset_finder, make_simple_equity_info
|
||||
|
||||
|
||||
@@ -73,7 +77,8 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
cls.sids = sids = ord('A'), ord('B'), ord('C')
|
||||
cls.df = df = pd.DataFrame({
|
||||
'sid': sids * 3,
|
||||
'value': (0, 1, 2, 1, 2, 3, 2, 3, 4),
|
||||
'value': (0., 1., 2., 1., 2., 3., 2., 3., 4.),
|
||||
'int_value': (0, 1, 2, 1, 2, 3, 2, 3, 4),
|
||||
'asof_date': dates,
|
||||
'timestamp': dates,
|
||||
})
|
||||
@@ -81,6 +86,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
var * {
|
||||
sid: ?int64,
|
||||
value: ?float64,
|
||||
int_value: ?int64,
|
||||
asof_date: datetime,
|
||||
timestamp: datetime
|
||||
}
|
||||
@@ -91,6 +97,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
cls.macro_dshape = var * Record(dshape_)
|
||||
|
||||
cls.garbage_loader = BlazeLoader()
|
||||
cls.missing_values = {'int_value': 0}
|
||||
|
||||
def test_tabular(self):
|
||||
name = 'expr'
|
||||
@@ -99,15 +106,20 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
self.assertEqual(ds.__name__, name)
|
||||
self.assertTrue(issubclass(ds, DataSet))
|
||||
self.assertEqual(
|
||||
{c.name: c.dtype for c in ds.columns},
|
||||
{'sid': np.int64, 'value': np.float64},
|
||||
)
|
||||
|
||||
for field in ('timestamp', 'asof_date'):
|
||||
self.assertIs(ds.value.dtype, float64_dtype)
|
||||
self.assertIs(ds.int_value.dtype, int64_dtype)
|
||||
|
||||
self.assertTrue(np.isnan(ds.value.missing_value))
|
||||
self.assertEqual(ds.int_value.missing_value, 0)
|
||||
|
||||
invalid_type_fields = ('asof_date',)
|
||||
|
||||
for field in invalid_type_fields:
|
||||
with self.assertRaises(AttributeError) as e:
|
||||
getattr(ds, field)
|
||||
self.assertIn("'%s'" % field, str(e.exception))
|
||||
@@ -119,6 +131,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
),
|
||||
ds,
|
||||
)
|
||||
@@ -130,10 +143,11 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr.value,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
self.assertEqual(value.name, 'value')
|
||||
self.assertIsInstance(value, BoundColumn)
|
||||
self.assertEqual(value.dtype, np.float64)
|
||||
self.assertIs(value.dtype, float64_dtype)
|
||||
|
||||
# test memoization
|
||||
self.assertIs(
|
||||
@@ -141,6 +155,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr.value,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
),
|
||||
value,
|
||||
)
|
||||
@@ -149,6 +164,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
).value,
|
||||
value,
|
||||
)
|
||||
@@ -159,6 +175,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=self.garbage_loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
),
|
||||
value.dataset,
|
||||
)
|
||||
@@ -195,7 +212,11 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
)),
|
||||
)
|
||||
loader = BlazeLoader()
|
||||
ds = from_blaze(expr.ds, loader=loader)
|
||||
ds = from_blaze(
|
||||
expr.ds,
|
||||
loader=loader,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
self.assertEqual(len(loader), 1)
|
||||
exprdata = loader[ds]
|
||||
self.assertTrue(exprdata.expr.isidentical(expr.ds))
|
||||
@@ -210,6 +231,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule=no_deltas_rules.warn,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
self.assertEqual(len(ws), 1)
|
||||
w = ws[0].message
|
||||
@@ -281,6 +303,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr_with_add,
|
||||
deltas=None,
|
||||
loader=self.garbage_loader,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
@@ -288,6 +311,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr.value + 1, # put an Add in the column
|
||||
deltas=None,
|
||||
loader=self.garbage_loader,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
|
||||
deltas = bz.Data(
|
||||
@@ -299,6 +323,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr_with_add,
|
||||
deltas=deltas,
|
||||
loader=self.garbage_loader,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
@@ -306,6 +331,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr.value + 1,
|
||||
deltas=deltas,
|
||||
loader=self.garbage_loader,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
|
||||
def _test_id(self, df, dshape, expected, finder, add):
|
||||
@@ -315,6 +341,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
p = Pipeline()
|
||||
for a in add:
|
||||
@@ -347,9 +374,11 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
p = Pipeline()
|
||||
p.add(ds.value.latest, 'value')
|
||||
p.add(ds.int_value.latest, 'int_value')
|
||||
dates = self.dates
|
||||
|
||||
with tmp_asset_finder() as finder:
|
||||
@@ -405,7 +434,9 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
expected.index.levels[0],
|
||||
finder.retrieve_all(expected.index.levels[1]),
|
||||
))
|
||||
self._test_id(self.df, self.dshape, expected, finder, ('value',))
|
||||
self._test_id(
|
||||
self.df, self.dshape, expected, finder, ('int_value', 'value',)
|
||||
)
|
||||
|
||||
def test_id_ffill_out_of_window(self):
|
||||
"""
|
||||
@@ -512,7 +543,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
var * Record(fields),
|
||||
expected,
|
||||
finder,
|
||||
('value', 'other'),
|
||||
('value', 'int_value', 'other'),
|
||||
)
|
||||
|
||||
def test_id_macro_dataset(self):
|
||||
@@ -782,6 +813,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
deltas,
|
||||
loader=loader,
|
||||
no_deltas_rule=no_deltas_rules.raise_,
|
||||
missing_values=self.missing_values,
|
||||
)
|
||||
p = Pipeline()
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Tests BoundColumn attributes and methods.
|
||||
"""
|
||||
from contextlib2 import ExitStack
|
||||
from unittest import TestCase
|
||||
|
||||
from pandas import date_range, DataFrame
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data.testing import TestingDataSet as TDS
|
||||
from zipline.utils.test_utils import chrange, temp_pipeline_engine
|
||||
|
||||
|
||||
class LatestTestCase(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._stack = stack = ExitStack()
|
||||
cls.calendar = cal = date_range('2014', '2015', freq='D', tz='UTC')
|
||||
cls.sids = list(range(5))
|
||||
cls.engine = stack.enter_context(
|
||||
temp_pipeline_engine(
|
||||
cal,
|
||||
cls.sids,
|
||||
random_seed=100,
|
||||
symbols=chrange('A', 'E'),
|
||||
),
|
||||
)
|
||||
cls.assets = cls.engine._finder.retrieve_all(cls.sids)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._stack.close()
|
||||
|
||||
def expected_latest(self, column, slice_):
|
||||
loader = self.engine.get_loader(column)
|
||||
return DataFrame(
|
||||
loader.values(column.dtype, self.calendar, self.sids)[slice_],
|
||||
index=self.calendar[slice_],
|
||||
columns=self.assets,
|
||||
)
|
||||
|
||||
def test_latest(self):
|
||||
columns = TDS.columns
|
||||
pipe = Pipeline(
|
||||
columns={c.name: c.latest for c in columns},
|
||||
)
|
||||
|
||||
cal_slice = slice(20, 40)
|
||||
dates_to_test = self.calendar[cal_slice]
|
||||
result = self.engine.run_pipeline(
|
||||
pipe,
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1],
|
||||
)
|
||||
for column in columns:
|
||||
float_result = result[column.name].unstack()
|
||||
expected_float_result = self.expected_latest(column, cal_slice)
|
||||
assert_frame_equal(float_result, expected_float_result)
|
||||
@@ -41,7 +41,7 @@ from zipline.data.us_equity_pricing import BcolzDailyBarReader
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
ConstantLoader,
|
||||
PrecomputedLoader,
|
||||
NullAdjustmentReader,
|
||||
SyntheticDailyBarWriter,
|
||||
)
|
||||
@@ -126,16 +126,16 @@ class ColumnArgs(tuple):
|
||||
return hash(frozenset(self))
|
||||
|
||||
|
||||
class RecordingConstantLoader(ConstantLoader):
|
||||
class RecordingPrecomputedLoader(PrecomputedLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RecordingConstantLoader, self).__init__(*args, **kwargs)
|
||||
super(RecordingPrecomputedLoader, self).__init__(*args, **kwargs)
|
||||
|
||||
self.load_calls = []
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
self.load_calls.append(ColumnArgs(*columns))
|
||||
|
||||
return super(RecordingConstantLoader, self).load_adjusted_array(
|
||||
return super(RecordingPrecomputedLoader, self).load_adjusted_array(
|
||||
columns, dates, assets, mask,
|
||||
)
|
||||
|
||||
@@ -159,10 +159,10 @@ class ConstantInputTestCase(TestCase):
|
||||
}
|
||||
self.asset_ids = [1, 2, 3]
|
||||
self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')
|
||||
self.loader = ConstantLoader(
|
||||
self.loader = PrecomputedLoader(
|
||||
constants=self.constants,
|
||||
dates=self.dates,
|
||||
assets=self.asset_ids,
|
||||
sids=self.asset_ids,
|
||||
)
|
||||
|
||||
self.asset_info = make_simple_equity_info(
|
||||
@@ -364,10 +364,10 @@ class ConstantInputTestCase(TestCase):
|
||||
dates_to_test = self.dates[-30:]
|
||||
|
||||
constants = {open_: 1, close: 2, volume: 3}
|
||||
loader = ConstantLoader(
|
||||
loader = PrecomputedLoader(
|
||||
constants=constants,
|
||||
dates=self.dates,
|
||||
assets=self.asset_ids,
|
||||
sids=self.asset_ids,
|
||||
)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
@@ -415,7 +415,7 @@ class ConstantInputTestCase(TestCase):
|
||||
def test_loader_given_multiple_columns(self):
|
||||
|
||||
class Loader1DataSet1(DataSet):
|
||||
col1 = Column(float32)
|
||||
col1 = Column(float)
|
||||
col2 = Column(float32)
|
||||
|
||||
class Loader1DataSet2(DataSet):
|
||||
@@ -430,14 +430,15 @@ class ConstantInputTestCase(TestCase):
|
||||
Loader1DataSet1.col2: 2,
|
||||
Loader1DataSet2.col1: 3,
|
||||
Loader1DataSet2.col2: 4}
|
||||
loader1 = RecordingConstantLoader(constants=constants1,
|
||||
dates=self.dates,
|
||||
assets=self.asset_ids)
|
||||
|
||||
loader1 = RecordingPrecomputedLoader(constants=constants1,
|
||||
dates=self.dates,
|
||||
sids=self.assets)
|
||||
constants2 = {Loader2DataSet.col1: 5,
|
||||
Loader2DataSet.col2: 6}
|
||||
loader2 = RecordingConstantLoader(constants=constants2,
|
||||
dates=self.dates,
|
||||
assets=self.asset_ids)
|
||||
loader2 = RecordingPrecomputedLoader(constants=constants2,
|
||||
dates=self.dates,
|
||||
sids=self.assets)
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column:
|
||||
|
||||
@@ -8,17 +8,23 @@ from unittest import TestCase
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.pipeline import Factor, TermGraph
|
||||
from zipline.pipeline import Factor, Filter, TermGraph
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
from zipline.pipeline.term import AssetExists, NotSpecified
|
||||
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
complex128_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
NoDefaultMissingValue,
|
||||
)
|
||||
|
||||
|
||||
@@ -328,9 +334,53 @@ class ObjectIdentityTestCase(TestCase):
|
||||
with self.assertRaises(DTypeNotSpecified):
|
||||
SomeFactorNoDType()
|
||||
|
||||
with self.assertRaises(InvalidDType):
|
||||
with self.assertRaises(NotDType):
|
||||
SomeFactor(dtype=1)
|
||||
|
||||
with self.assertRaises(NoDefaultMissingValue):
|
||||
SomeFactor(dtype=int64_dtype)
|
||||
|
||||
with self.assertRaises(UnsupportedDType):
|
||||
SomeFactor(dtype=complex128_dtype)
|
||||
|
||||
def test_latest_on_different_dtypes(self):
|
||||
factor_dtypes = (int64_dtype, float64_dtype, datetime64ns_dtype)
|
||||
for column in TestingDataSet.columns:
|
||||
if column.dtype == bool_dtype:
|
||||
self.assertIsInstance(column.latest, Filter)
|
||||
elif column.dtype in factor_dtypes:
|
||||
self.assertIsInstance(column.latest, Factor)
|
||||
else:
|
||||
self.fail(
|
||||
"Unknown dtype %s for column %s" % (column.dtype, column)
|
||||
)
|
||||
# These should be the same value, plus this has the convenient
|
||||
# property of correctly handling `NaN`.
|
||||
self.assertIs(column.missing_value, column.latest.missing_value)
|
||||
|
||||
def test_failure_timing_on_bad_dtypes(self):
|
||||
|
||||
# Just constructing a bad column shouldn't fail.
|
||||
Column(dtype=int64_dtype)
|
||||
with self.assertRaises(NoDefaultMissingValue) as e:
|
||||
class BadDataSet(DataSet):
|
||||
bad_column = Column(dtype=int64_dtype)
|
||||
float_column = Column(dtype=float64_dtype)
|
||||
int_column = Column(dtype=int64_dtype, missing_value=3)
|
||||
|
||||
self.assertTrue(
|
||||
str(e.exception.args[0]).startswith(
|
||||
"Failed to create Column with name 'bad_column'"
|
||||
)
|
||||
)
|
||||
|
||||
Column(dtype=complex128_dtype)
|
||||
with self.assertRaises(UnsupportedDType):
|
||||
class BadDataSetComplex(DataSet):
|
||||
bad_column = Column(dtype=complex128_dtype)
|
||||
float_column = Column(dtype=float64_dtype)
|
||||
int_column = Column(dtype=int64_dtype, missing_value=3)
|
||||
|
||||
|
||||
class SubDataSetTestCase(TestCase):
|
||||
def test_subdataset(self):
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Tests for our testing utilities.
|
||||
"""
|
||||
from itertools import product
|
||||
from unittest import TestCase
|
||||
from zipline.utils.test_utils import parameter_space
|
||||
|
||||
|
||||
class TestParameterSpace(TestCase):
|
||||
|
||||
x_args = [1, 2]
|
||||
y_args = [3, 4]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.xy_invocations = []
|
||||
cls.yx_invocations = []
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# This is the only actual test here.
|
||||
assert cls.xy_invocations == list(product(cls.x_args, cls.y_args))
|
||||
assert cls.yx_invocations == list(product(cls.y_args, cls.x_args))
|
||||
|
||||
@parameter_space(x=x_args, y=y_args)
|
||||
def test_xy(self, x, y):
|
||||
self.xy_invocations.append((x, y))
|
||||
|
||||
@parameter_space(x=x_args, y=y_args)
|
||||
def test_yx(self, y, x):
|
||||
# Ensure that product is called with args in the order that they appear
|
||||
# in the function's parameter list.
|
||||
self.yx_invocations.append((y, x))
|
||||
|
||||
def test_nothing(self):
|
||||
# Ensure that there's at least one "real" test in the class, or else
|
||||
# our {setUp,tearDown}Class won't be called if, for example,
|
||||
# `parameter_space` returns None.
|
||||
pass
|
||||
+12
-1
@@ -396,7 +396,7 @@ class DTypeNotSpecified(ZiplineError):
|
||||
)
|
||||
|
||||
|
||||
class InvalidDType(ZiplineError):
|
||||
class NotDType(ZiplineError):
|
||||
"""
|
||||
Raised when a pipeline Term is constructed with a dtype that isn't a numpy
|
||||
dtype object.
|
||||
@@ -407,6 +407,17 @@ class InvalidDType(ZiplineError):
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedDType(ZiplineError):
|
||||
"""
|
||||
Raised when a pipeline Term is constructed with a dtype that's not
|
||||
supported.
|
||||
"""
|
||||
msg = (
|
||||
"Failed to construct {termname}.\n"
|
||||
"Pipeline terms of dtype {dtype} are not yet supported."
|
||||
)
|
||||
|
||||
|
||||
class BadPercentileBounds(ZiplineError):
|
||||
"""
|
||||
Raised by API functions accepting percentile bounds when the passed bounds
|
||||
|
||||
@@ -7,6 +7,8 @@ from numpy import (
|
||||
float64,
|
||||
int32,
|
||||
int64,
|
||||
int16,
|
||||
uint16,
|
||||
ndarray,
|
||||
uint32,
|
||||
uint8,
|
||||
@@ -17,27 +19,45 @@ from zipline.errors import (
|
||||
)
|
||||
from zipline.utils.numpy_utils import (
|
||||
datetime64ns_dtype,
|
||||
default_fillvalue_for_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
uint8_dtype,
|
||||
)
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
# These class names are all the same because of our bootleg templating system.
|
||||
from ._float64window import AdjustedArrayWindow as Float64Window
|
||||
from ._int64window import AdjustedArrayWindow as Int64Window
|
||||
from ._uint8window import AdjustedArrayWindow as UInt8Window
|
||||
|
||||
Infer = sentinel(
|
||||
'Infer',
|
||||
"Sentinel used to say 'infer missing_value from data type.'"
|
||||
)
|
||||
NOMASK = None
|
||||
SUPPORTED_NUMERIC_DTYPES = frozenset(
|
||||
map(dtype, [float32, float64, int32, int64, uint32])
|
||||
BOOL_DTYPES = frozenset(
|
||||
map(dtype, [bool_]),
|
||||
)
|
||||
FLOAT_DTYPES = frozenset(
|
||||
map(dtype, [float32, float64]),
|
||||
)
|
||||
INT_DTYPES = frozenset(
|
||||
# NOTE: uint64 not supported because it can't be safely cast to int64.
|
||||
map(dtype, [int16, uint16, int32, int64, uint32]),
|
||||
)
|
||||
DATETIME_DTYPES = frozenset(
|
||||
map(dtype, ['datetime64[ns]', 'datetime64[D]']),
|
||||
)
|
||||
REPRESENTABLE_DTYPES = BOOL_DTYPES.union(
|
||||
FLOAT_DTYPES,
|
||||
INT_DTYPES,
|
||||
DATETIME_DTYPES
|
||||
)
|
||||
|
||||
|
||||
def can_represent_dtype(dtype):
|
||||
"""
|
||||
Can we build an AdjustedArray for a baseline of dtype ``dtype``?
|
||||
"""
|
||||
return dtype in REPRESENTABLE_DTYPES
|
||||
|
||||
|
||||
CONCRETE_WINDOW_TYPES = {
|
||||
float64_dtype: Float64Window,
|
||||
int64_dtype: Int64Window,
|
||||
@@ -51,13 +71,10 @@ def _normalize_array(data):
|
||||
representation, returning the coerced array and a numpy dtype object to use
|
||||
as a view type when providing public view into the data.
|
||||
|
||||
Semantically numerical data (float*, int*, uint*) is coerced to float64 and
|
||||
viewed as float64. We coerce integral data to float so that we can use NaN
|
||||
as a missing value.
|
||||
|
||||
datetime[*] data is coerced to int64 with a viewtype of ``datetime64[ns]``.
|
||||
|
||||
``bool_`` data is coerced to uint8 with a viewtype of ``bool_``
|
||||
- float* data is coerced to float64 with viewtype float64.
|
||||
- int32, int64, and uint32 are converted to int64 with viewtype int64.
|
||||
- datetime[*] data is coerced to int64 with a viewtype of datetime64[ns].
|
||||
- bool_ data is coerced to uint8 with a viewtype of bool_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -70,8 +87,10 @@ def _normalize_array(data):
|
||||
data_dtype = data.dtype
|
||||
if data_dtype == bool_:
|
||||
return data.astype(uint8), dtype(bool_)
|
||||
elif data_dtype in SUPPORTED_NUMERIC_DTYPES:
|
||||
elif data_dtype in FLOAT_DTYPES:
|
||||
return data.astype(float64), dtype(float64)
|
||||
elif data_dtype in INT_DTYPES:
|
||||
return data.astype(int64), dtype(int64)
|
||||
elif data_dtype.name.startswith('datetime'):
|
||||
try:
|
||||
outarray = data.astype('datetime64[ns]').view('int64')
|
||||
@@ -105,18 +124,24 @@ class AdjustedArray(object):
|
||||
adjustments : dict[int -> list[Adjustment]]
|
||||
A dict mapping row indices to lists of adjustments to apply when we
|
||||
reach that row.
|
||||
fillvalue : object, optional
|
||||
missing_value : object
|
||||
A value to use to fill missing data in yielded windows.
|
||||
Default behavior is to infer a value based on the dtype of `data`.
|
||||
`NaN` is used for numeric data, and `NaT` is used for datetime data.
|
||||
"""
|
||||
__slots__ = ('_data', '_viewtype', 'adjustments', '__weakref__')
|
||||
Should be a value coercible to `data.dtype`.
|
||||
|
||||
def __init__(self, data, mask, adjustments, fillvalue=Infer):
|
||||
"""
|
||||
__slots__ = (
|
||||
'_data',
|
||||
'_viewtype',
|
||||
'adjustments',
|
||||
'missing_value',
|
||||
'__weakref__',
|
||||
)
|
||||
|
||||
def __init__(self, data, mask, adjustments, missing_value):
|
||||
self._data, self._viewtype = _normalize_array(data)
|
||||
|
||||
self.adjustments = adjustments
|
||||
if fillvalue is Infer:
|
||||
fillvalue = default_fillvalue_for_dtype(self.data.dtype)
|
||||
self.missing_value = missing_value
|
||||
|
||||
if mask is not NOMASK:
|
||||
if mask.dtype != bool_:
|
||||
@@ -126,7 +151,7 @@ class AdjustedArray(object):
|
||||
"Mask shape %s != data shape %s." %
|
||||
(mask.shape, data.shape),
|
||||
)
|
||||
self._data[~mask] = fillvalue
|
||||
self._data[~mask] = self.missing_value
|
||||
|
||||
@lazyval
|
||||
def data(self):
|
||||
|
||||
@@ -58,6 +58,11 @@ def masked_rankdata_2d(ndarray data,
|
||||
# the extra work that apply_along_axis does.
|
||||
result = apply_along_axis(rankdata, 1, data, method=method)
|
||||
|
||||
# On SciPy >= 0.17, rankdata returns integers for any method except
|
||||
# average.
|
||||
if result.dtype.name != 'float64':
|
||||
result = result.astype('float64')
|
||||
|
||||
# rankdata will sort missing values into last place, but we want our nans
|
||||
# to propagate, so explicitly re-apply.
|
||||
result[missing_locations] = nan
|
||||
|
||||
@@ -7,8 +7,16 @@ from six import (
|
||||
with_metaclass,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, AssetExists
|
||||
from zipline.pipeline.term import (
|
||||
Term,
|
||||
AssetExists,
|
||||
NotSpecified,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_dtype
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
NoDefaultMissingValue,
|
||||
)
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
@@ -18,28 +26,62 @@ class Column(object):
|
||||
"""
|
||||
|
||||
@preprocess(dtype=ensure_dtype)
|
||||
def __init__(self, dtype):
|
||||
def __init__(self, dtype, missing_value=NotSpecified):
|
||||
self.dtype = dtype
|
||||
self.missing_value = missing_value
|
||||
|
||||
def bind(self, name):
|
||||
"""
|
||||
Bind a `Column` object to its name.
|
||||
"""
|
||||
return _BoundColumnDescr(dtype=self.dtype, name=name)
|
||||
return _BoundColumnDescr(
|
||||
dtype=self.dtype,
|
||||
missing_value=self.missing_value,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
class _BoundColumnDescr(object):
|
||||
"""
|
||||
Intermediate class that sits on `DataSet` objects and returns memoized
|
||||
`BoundColumn` objects when requested.
|
||||
|
||||
This exists so that subclasses of DataSets don't share columns with their
|
||||
parent classes.
|
||||
"""
|
||||
def __init__(self, dtype, name):
|
||||
self.dtype = dtype
|
||||
def __init__(self, dtype, missing_value, name):
|
||||
# Validating and calculating default missing values here guarantees
|
||||
# that we fail quickly if the user passes an unsupporte dtype or fails
|
||||
# to provide a missing value for a dtype that requires one
|
||||
# (e.g. int64), but still enables us to provide an error message that
|
||||
# points to the name of the failing column.
|
||||
try:
|
||||
self.dtype, self.missing_value = Term.validate_dtype(
|
||||
termname="Column(name={name!r})".format(name=name),
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
)
|
||||
except NoDefaultMissingValue:
|
||||
# Re-raise with a more specific message.
|
||||
raise NoDefaultMissingValue(
|
||||
"Failed to create Column with name {name!r} and"
|
||||
" dtype {dtype} because no missing_value was provided\n\n"
|
||||
"Columns with dtype {dtype} require a missing_value.\n"
|
||||
"Please pass missing_value to Column() or use a different"
|
||||
" dtype.".format(dtype=dtype, name=name)
|
||||
)
|
||||
self.name = name
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
"""
|
||||
Produce a concrete BoundColumn object when accessed.
|
||||
|
||||
We don't bind to datasets at class creation time so that subclasses of
|
||||
DataSets produce different BoundColumns.
|
||||
"""
|
||||
return BoundColumn(
|
||||
dtype=self.dtype,
|
||||
missing_value=self.missing_value,
|
||||
dataset=owner,
|
||||
name=self.name,
|
||||
)
|
||||
@@ -53,11 +95,12 @@ class BoundColumn(Term):
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
|
||||
def __new__(cls, dtype, dataset, name):
|
||||
def __new__(cls, dtype, missing_value, dataset, name):
|
||||
return super(BoundColumn, cls).__new__(
|
||||
cls,
|
||||
domain=dataset.domain,
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
dataset=dataset,
|
||||
name=name,
|
||||
)
|
||||
@@ -92,8 +135,15 @@ class BoundColumn(Term):
|
||||
|
||||
@property
|
||||
def latest(self):
|
||||
from zipline.pipeline.factors import Latest
|
||||
return Latest(inputs=(self,), dtype=self.dtype)
|
||||
if self.dtype == bool_dtype:
|
||||
from zipline.pipeline.filters import Latest
|
||||
else:
|
||||
from zipline.pipeline.factors import Latest
|
||||
return Latest(
|
||||
inputs=(self,),
|
||||
dtype=self.dtype,
|
||||
missing_value=self.missing_value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "{qualname}::{dtype}".format(
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Datasets for testing use.
|
||||
|
||||
Loaders for datasets in this file can be found in
|
||||
zipline.pipeline.data.testing.
|
||||
"""
|
||||
from .dataset import Column, DataSet
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
float64_dtype,
|
||||
datetime64ns_dtype,
|
||||
int64_dtype,
|
||||
)
|
||||
|
||||
|
||||
class TestingDataSet(DataSet):
|
||||
# Tell nose this isn't a test case.
|
||||
__test__ = False
|
||||
|
||||
bool_col = Column(dtype=bool_dtype, missing_value=False)
|
||||
bool_col_default_True = Column(dtype=bool_dtype, missing_value=True)
|
||||
float_col = Column(dtype=float64_dtype)
|
||||
datetime_col = Column(dtype=datetime64ns_dtype)
|
||||
int_col = Column(dtype=int64_dtype, missing_value=0)
|
||||
@@ -92,13 +92,13 @@ class SimplePipelineEngine(object):
|
||||
An AssetFinder instance. We depend on the AssetFinder to determine
|
||||
which assets are in the top-level universe at any point in time.
|
||||
"""
|
||||
__slots__ = [
|
||||
__slots__ = (
|
||||
'_get_loader',
|
||||
'_calendar',
|
||||
'_finder',
|
||||
'_root_mask_term',
|
||||
'__weakref__',
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, get_loader, calendar, asset_finder):
|
||||
self._get_loader = get_loader
|
||||
|
||||
@@ -38,6 +38,7 @@ from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
)
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
@@ -303,7 +304,7 @@ def function_application(func):
|
||||
return mathfunc
|
||||
|
||||
|
||||
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype])
|
||||
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype, int64_dtype])
|
||||
|
||||
|
||||
class Factor(CompositeTerm):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from .filter import Filter, NumExprFilter, PercentileFilter
|
||||
from .latest import Latest
|
||||
|
||||
__all__ = [
|
||||
'Filter',
|
||||
'Latest',
|
||||
'NumExprFilter',
|
||||
'PercentileFilter',
|
||||
]
|
||||
|
||||
@@ -252,6 +252,52 @@ class PercentileFilter(SingleInputMixin, Filter):
|
||||
|
||||
class CustomFilter(PositiveWindowLengthMixin, CustomTermMixin, Filter):
|
||||
"""
|
||||
Filter analog to ``CustomFactor``.
|
||||
Base class for user-defined Filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : iterable, optional
|
||||
An iterable of `BoundColumn` instances (e.g. USEquityPricing.close),
|
||||
describing the data to load and pass to `self.compute`. If this
|
||||
argument is passed to the CustomFilter constructor, we look for a
|
||||
class-level attribute named `inputs`.
|
||||
window_length : int, optional
|
||||
Number of rows to pass for each input. If this argument is not passed
|
||||
to the CustomFilter constructor, we look for a class-level attribute
|
||||
named `window_length`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Users implementing their own Filters should subclass CustomFilter and
|
||||
implement a method named `compute` with the following signature:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def compute(self, today, assets, out, *inputs):
|
||||
...
|
||||
|
||||
On each simulation date, ``compute`` will be called with the current date,
|
||||
an array of sids, an output array, and an input array for each expression
|
||||
passed as inputs to the CustomFilter constructor.
|
||||
|
||||
The specific types of the values passed to `compute` are as follows::
|
||||
|
||||
today : np.datetime64[ns]
|
||||
Row label for the last row of all arrays passed as `inputs`.
|
||||
assets : np.array[int64, ndim=1]
|
||||
Column labels for `out` and`inputs`.
|
||||
out : np.array[bool, ndim=1]
|
||||
Output array of the same shape as `assets`. `compute` should write
|
||||
its desired return values into `out`.
|
||||
*inputs : tuple of np.array
|
||||
Raw data arrays corresponding to the values of `self.inputs`.
|
||||
|
||||
See the documentation for
|
||||
:class:`~zipline.pipeline.factors.factor.CustomFactor` for more details on
|
||||
implementing a custom ``compute`` method.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.pipeline.factors.factor.CustomFactor
|
||||
"""
|
||||
ctx = nullctx()
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Filter that produces the most most recently-known value of a boolean-valued
|
||||
Column.
|
||||
"""
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
|
||||
from .filter import CustomFilter
|
||||
from ..mixins import SingleInputMixin
|
||||
|
||||
|
||||
class Latest(SingleInputMixin, CustomFilter):
|
||||
"""
|
||||
Filter producing the most recently-known value of `inputs[0]` on each day.
|
||||
"""
|
||||
window_length = 1
|
||||
|
||||
def compute(self, today, assets, out, data):
|
||||
out[:] = data[-1]
|
||||
|
||||
def _validate(self):
|
||||
if self.inputs[0].dtype != bool_dtype:
|
||||
raise TypeError(
|
||||
"{name} expected an input of dtype bool, "
|
||||
"but got {not_bool} instead.".format(
|
||||
name=type(self).__name__,
|
||||
not_bool=self.inputs[0].dtype,
|
||||
)
|
||||
)
|
||||
super(Latest, self)._validate()
|
||||
@@ -165,6 +165,7 @@ from zipline.pipeline.loaders.utils import (
|
||||
normalize_data_query_bounds,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.pipeline.term import NotSpecified
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
from zipline.utils.enum import enum
|
||||
@@ -275,15 +276,21 @@ _new_names = ('BlazeDataSet_%d' % n for n in count())
|
||||
|
||||
|
||||
@memoize
|
||||
def new_dataset(expr, deltas):
|
||||
"""Creates or returns a dataset from a pair of blaze expressions.
|
||||
def new_dataset(expr, deltas, missing_values):
|
||||
"""
|
||||
Creates or returns a dataset from a pair of blaze expressions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The blaze expression representing the first known values.
|
||||
The blaze expression representing the first known values.
|
||||
deltas : Expr
|
||||
The blaze expression representing the deltas to the data.
|
||||
The blaze expression representing the deltas to the data.
|
||||
missing_values : frozenset((name, value) pairs
|
||||
Association pairs column name and missing_value for that column.
|
||||
|
||||
This needs to be a frozenset rather than a dict or tuple of tuples
|
||||
because we want a collection that's unordered but still hashable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -295,9 +302,16 @@ def new_dataset(expr, deltas):
|
||||
This function is memoized. repeated calls with the same inputs will return
|
||||
the same type.
|
||||
"""
|
||||
missing_values = dict(missing_values)
|
||||
columns = {}
|
||||
for name, type_ in expr.dshape.measure.fields:
|
||||
# Don't generate a column for sid or timestamp, since they're
|
||||
# implicitly the labels if the arrays that will be passed to pipeline
|
||||
# Terms.
|
||||
if name in (SID_FIELD_NAME, TS_FIELD_NAME):
|
||||
continue
|
||||
try:
|
||||
# TODO: This should support datetime and bool columns.
|
||||
if promote(type_, float64, promote_option=False) != float64:
|
||||
raise NotPipelineCompatible()
|
||||
if isinstance(type_, Option):
|
||||
@@ -307,7 +321,10 @@ def new_dataset(expr, deltas):
|
||||
except TypeError:
|
||||
col = NonNumpyField(name, type_)
|
||||
else:
|
||||
col = Column(type_.to_numpy_dtype())
|
||||
col = Column(
|
||||
type_.to_numpy_dtype(),
|
||||
missing_values.get(name, NotSpecified),
|
||||
)
|
||||
|
||||
columns[name] = col
|
||||
|
||||
@@ -473,6 +490,7 @@ def from_blaze(expr,
|
||||
loader=None,
|
||||
resources=None,
|
||||
odo_kwargs=None,
|
||||
missing_values=None,
|
||||
no_deltas_rule=no_deltas_rules.warn):
|
||||
"""Create a Pipeline API object from a blaze expression.
|
||||
|
||||
@@ -494,6 +512,9 @@ def from_blaze(expr,
|
||||
scope for ``bz.compute``.
|
||||
odo_kwargs : dict, optional
|
||||
The keyword arguments to pass to odo when evaluating the expressions.
|
||||
missing_values : dict[str -> any], optional
|
||||
A dict mapping column names to missing values for those columns.
|
||||
Missing values are required for integral columns.
|
||||
no_deltas_rule : no_deltas_rule
|
||||
What should happen if ``deltas='auto'`` but no deltas can be found.
|
||||
'warn' says to raise a warning but continue.
|
||||
@@ -583,7 +604,10 @@ def from_blaze(expr,
|
||||
_check_resources('deltas', deltas, resources)
|
||||
|
||||
# Create or retrieve the Pipeline API dataset.
|
||||
ds = new_dataset(dataset_expr, deltas)
|
||||
if missing_values is None:
|
||||
missing_values = {}
|
||||
ds = new_dataset(dataset_expr, deltas, frozenset(missing_values.items()))
|
||||
|
||||
# Register our new dataset with the loader.
|
||||
(loader if loader is not None else global_loader)[ds] = ExprData(
|
||||
bind_expression_to_resources(dataset_expr, resources),
|
||||
@@ -1018,7 +1042,8 @@ class BlazeLoader(dict):
|
||||
column_name,
|
||||
asset_idx,
|
||||
sparse_deltas,
|
||||
)
|
||||
),
|
||||
column.missing_value,
|
||||
)
|
||||
|
||||
global_loader = BlazeLoader.global_instance()
|
||||
|
||||
@@ -81,12 +81,16 @@ class USEquityPricingLoader(PipelineLoader):
|
||||
dates,
|
||||
assets,
|
||||
)
|
||||
adjusted_arrays = [
|
||||
AdjustedArray(raw_array, mask, col_adjustments)
|
||||
for raw_array, col_adjustments in zip(raw_arrays, adjustments)
|
||||
]
|
||||
|
||||
return dict(zip(columns, adjusted_arrays))
|
||||
out = {}
|
||||
for c, c_raw, c_adjs in zip(columns, raw_arrays, adjustments):
|
||||
out[c] = AdjustedArray(
|
||||
c_raw.astype(c.dtype),
|
||||
mask,
|
||||
c_adjs,
|
||||
c.missing_value,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _shift_dates(dates, start_date, end_date, shift):
|
||||
|
||||
@@ -60,7 +60,7 @@ class DataFrameLoader(PipelineLoader):
|
||||
|
||||
def __init__(self, column, baseline, adjustments=None):
|
||||
self.column = column
|
||||
self.baseline = baseline.values
|
||||
self.baseline = baseline.values.astype(self.column.dtype)
|
||||
self.dates = baseline.index
|
||||
self.assets = baseline.columns
|
||||
|
||||
@@ -171,5 +171,6 @@ class DataFrameLoader(PipelineLoader):
|
||||
# Mask out requested columns/rows that didnt match.
|
||||
mask=(good_assets & good_dates[:, None]) & mask,
|
||||
adjustments=self.format_adjustments(dates, assets),
|
||||
missing_value=column.missing_value,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@ from bcolz import ctable
|
||||
from numpy import (
|
||||
arange,
|
||||
array,
|
||||
eye,
|
||||
float64,
|
||||
full,
|
||||
iinfo,
|
||||
uint32,
|
||||
)
|
||||
from numpy.random import RandomState
|
||||
from pandas import DataFrame, Timestamp
|
||||
from six import iteritems
|
||||
from sqlite3 import connect as sqlite3_connect
|
||||
@@ -23,6 +25,12 @@ from zipline.data.us_equity_pricing import (
|
||||
SQLiteAdjustmentWriter,
|
||||
US_EQUITY_PRICING_BCOLZ_COLUMNS,
|
||||
)
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
datetime64ns_dtype,
|
||||
float64_dtype,
|
||||
int64_dtype,
|
||||
)
|
||||
|
||||
|
||||
UINT_32_MAX = iinfo(uint32).max
|
||||
@@ -32,31 +40,34 @@ def nanos_to_seconds(nanos):
|
||||
return nanos / (1000 * 1000 * 1000)
|
||||
|
||||
|
||||
class ConstantLoader(PipelineLoader):
|
||||
class PrecomputedLoader(PipelineLoader):
|
||||
"""
|
||||
Synthetic PipelineLoader that returns a constant value for each column.
|
||||
Synthetic PipelineLoader that uses a pre-computed array for each column.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constants : dict
|
||||
Map from column to value(s) to use for that column.
|
||||
values : dict
|
||||
Map from column to values to use for that column.
|
||||
Values can be anything that can be passed as the first positional
|
||||
argument to a DataFrame of the same shape as `mask`.
|
||||
mask : pandas.DataFrame
|
||||
Mask indicating when assets existed.
|
||||
Indices of this frame are used to align input queries.
|
||||
argument to a DataFrame whose indices are ``dates`` and ``sids``
|
||||
dates : iterable[datetime-like]
|
||||
Row labels for input data. Can be anything that pd.DataFrame will
|
||||
coerce to a DatetimeIndex.
|
||||
sids : iterable[int-like]
|
||||
Column labels for input data. Can be anything that pd.DataFrame will
|
||||
coerce to an Int64Index.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Adjustments are unsupported with ConstantLoader.
|
||||
Adjustments are unsupported by this loader.
|
||||
"""
|
||||
def __init__(self, constants, dates, assets):
|
||||
def __init__(self, constants, dates, sids):
|
||||
loaders = {}
|
||||
for column, const in iteritems(constants):
|
||||
frame = DataFrame(
|
||||
const,
|
||||
index=dates,
|
||||
columns=assets,
|
||||
columns=sids,
|
||||
dtype=column.dtype,
|
||||
)
|
||||
loaders[column] = DataFrameLoader(
|
||||
@@ -83,6 +94,106 @@ class ConstantLoader(PipelineLoader):
|
||||
return out
|
||||
|
||||
|
||||
class EyeLoader(PrecomputedLoader):
|
||||
"""
|
||||
A PrecomputedLoader that emits arrays containing 1s on the diagonal and 0s
|
||||
elsewhere.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns : list[BoundColumn]
|
||||
Columns that this loader should know about.
|
||||
dates : iterable[datetime-like]
|
||||
Same as PrecomputedLoader.
|
||||
sids : iterable[int-like]
|
||||
Same as PrecomputedLoader
|
||||
"""
|
||||
def __init__(self, columns, dates, sids):
|
||||
shape = (len(dates), len(sids))
|
||||
super(EyeLoader, self).__init__(
|
||||
{column: eye(shape, dtype=column.dtype) for column in columns},
|
||||
dates,
|
||||
sids,
|
||||
)
|
||||
|
||||
|
||||
class SeededRandomLoader(PrecomputedLoader):
|
||||
"""
|
||||
A PrecomputedLoader that emits arrays randomly-generated with a given seed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seed : int
|
||||
Seed for numpy.random.RandomState.
|
||||
columns : list[BoundColumn]
|
||||
Columns that this loader should know about.
|
||||
dates : iterable[datetime-like]
|
||||
Same as PrecomputedLoader.
|
||||
sids : iterable[int-like]
|
||||
Same as PrecomputedLoader
|
||||
"""
|
||||
|
||||
def __init__(self, seed, columns, dates, sids):
|
||||
self._seed = seed
|
||||
super(SeededRandomLoader, self).__init__(
|
||||
{c: self.values(c.dtype, dates, sids) for c in columns},
|
||||
dates,
|
||||
sids,
|
||||
)
|
||||
|
||||
def values(self, dtype, dates, sids):
|
||||
"""
|
||||
Make a random array of shape (len(dates), len(sids)) with ``dtype``.
|
||||
"""
|
||||
shape = (len(dates), len(sids))
|
||||
return {
|
||||
datetime64ns_dtype: self._datetime_values,
|
||||
float64_dtype: self._float_values,
|
||||
int64_dtype: self._int_values,
|
||||
bool_dtype: self._bool_values,
|
||||
}[dtype](shape)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""
|
||||
Make a new RandomState from our seed.
|
||||
|
||||
This ensures that every call to _*_values produces the same output
|
||||
every time for a given SeededRandomLoader instance.
|
||||
"""
|
||||
return RandomState(self._seed)
|
||||
|
||||
def _float_values(self, shape):
|
||||
"""
|
||||
Return uniformly-distributed floats between -0.0 and 100.0.
|
||||
"""
|
||||
return self.state.uniform(low=0.0, high=100.0, size=shape)
|
||||
|
||||
def _int_values(self, shape):
|
||||
"""
|
||||
Return uniformly-distributed integers between 0 and 100.
|
||||
"""
|
||||
return self.state.random_integers(low=0, high=100, size=shape)
|
||||
|
||||
def _datetime_values(self, shape):
|
||||
"""
|
||||
Return uniformly-distributed dates in 2014.
|
||||
"""
|
||||
start = Timestamp('2014', tz='UTC').asm8
|
||||
offsets = self.state.random_integers(
|
||||
low=0,
|
||||
high=364,
|
||||
size=shape,
|
||||
).astype('timedelta64[D]')
|
||||
return start + offsets
|
||||
|
||||
def _bool_values(self, shape):
|
||||
"""
|
||||
Return uniformly-distributed True/False values.
|
||||
"""
|
||||
return self.state.randn(*shape) < 0
|
||||
|
||||
|
||||
class SyntheticDailyBarWriter(BcolzDailyBarWriter):
|
||||
"""
|
||||
Bcolz writer that creates synthetic data based on asset lifetime metadata.
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Loaders for zipline.pipeline.data.testing datasets.
|
||||
"""
|
||||
from .synthetic import EyeLoader, SeededRandomLoader
|
||||
from ..data.testing import TestingDataSet
|
||||
|
||||
|
||||
def make_eye_loader(dates, sids):
|
||||
"""
|
||||
Make a PipelineLoader that emits np.eye arrays for the columns in
|
||||
``TestingDataSet``.
|
||||
"""
|
||||
return EyeLoader(TestingDataSet.columns, dates, sids)
|
||||
|
||||
|
||||
def make_seeded_random_loader(seed, dates, sids):
|
||||
"""
|
||||
Make a PipelineLoader that emits random arrays seeded with `seed` for the
|
||||
columns in ``TestingDataSet``.
|
||||
"""
|
||||
return SeededRandomLoader(seed, TestingDataSet.columns, dates, sids)
|
||||
@@ -47,6 +47,7 @@ class CustomTermMixin(object):
|
||||
inputs=NotSpecified,
|
||||
window_length=NotSpecified,
|
||||
dtype=NotSpecified,
|
||||
missing_value=NotSpecified,
|
||||
**kwargs):
|
||||
|
||||
unexpected_keys = set(kwargs) - set(cls.params)
|
||||
@@ -64,6 +65,7 @@ class CustomTermMixin(object):
|
||||
inputs=inputs,
|
||||
window_length=window_length,
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
+47
-30
@@ -6,16 +6,20 @@ from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import dtype as dtype_class
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
InvalidDType,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.lib.adjusted_array import can_represent_dtype
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype, default_fillvalue_for_dtype
|
||||
from zipline.utils.numpy_utils import (
|
||||
bool_dtype,
|
||||
default_missing_value_for_dtype,
|
||||
)
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
|
||||
@@ -32,6 +36,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
# These are NotSpecified because a subclass is required to provide them.
|
||||
dtype = NotSpecified
|
||||
domain = NotSpecified
|
||||
missing_value = NotSpecified
|
||||
|
||||
# Subclasses aren't required to provide `params`. The default behavior is
|
||||
# no params.
|
||||
@@ -42,6 +47,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
def __new__(cls,
|
||||
domain=domain,
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
# params is explicitly not allowed to be passed to an instance.
|
||||
*args,
|
||||
**kwargs):
|
||||
@@ -55,18 +61,26 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
Caching previously-constructed Terms is **sane** because terms and
|
||||
their inputs are both conceptually immutable.
|
||||
"""
|
||||
# Class-level attributes can be used to provide defaults for Term
|
||||
# subclasses.
|
||||
|
||||
# Subclasses can set override these class-level attributes to provide
|
||||
# default values.
|
||||
if domain is NotSpecified:
|
||||
domain = cls.domain
|
||||
if dtype is NotSpecified:
|
||||
dtype = cls.dtype
|
||||
if missing_value is NotSpecified:
|
||||
missing_value = cls.missing_value
|
||||
|
||||
dtype = cls._validate_dtype(dtype)
|
||||
dtype, missing_value = cls.validate_dtype(
|
||||
cls.__name__,
|
||||
dtype,
|
||||
missing_value,
|
||||
)
|
||||
params = cls._pop_params(kwargs)
|
||||
|
||||
identity = cls.static_identity(
|
||||
domain=domain,
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
params=params,
|
||||
*args, **kwargs
|
||||
)
|
||||
@@ -78,6 +92,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
super(Term, cls).__new__(cls)._init(
|
||||
domain=domain,
|
||||
dtype=dtype,
|
||||
missing_value=missing_value,
|
||||
params=params,
|
||||
*args, **kwargs
|
||||
)
|
||||
@@ -131,40 +146,45 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
)
|
||||
return tuple(zip(cls.params, param_values))
|
||||
|
||||
@classmethod
|
||||
def _validate_dtype(cls, passed_dtype):
|
||||
@staticmethod
|
||||
def validate_dtype(termname, dtype, missing_value):
|
||||
"""
|
||||
Validate a `dtype` passed to Term.__new__.
|
||||
Validate a `dtype` and `missing_value` passed to Term.__new__.
|
||||
|
||||
If passed_dtype is NotSpecified, then we try to fall back to a
|
||||
class-level attribute. If a value is found at that point, we pass it
|
||||
to np.dtype so that users can pass `float` or `bool` and have them
|
||||
coerce to the appropriate numpy types.
|
||||
Ensures that we know how to represent ``dtype``, and that missing_value
|
||||
is specified for types without default missing values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
validated : np.dtype
|
||||
The dtype to use for the new term.
|
||||
validated_dtype, validated_missing_value : np.dtype, any
|
||||
The dtype and missing_value to use for the new term.
|
||||
|
||||
Raises
|
||||
------
|
||||
DTypeNotSpecified
|
||||
When no dtype was passed to the instance, and the class doesn't
|
||||
provide a default.
|
||||
InvalidDType
|
||||
NotDType
|
||||
When either the class or the instance provides a value not
|
||||
coercible to a numpy dtype.
|
||||
NoDefaultMissingValue
|
||||
When dtype requires an explicit missing_value, but
|
||||
``missing_value`` is NotSpecified.
|
||||
"""
|
||||
dtype = passed_dtype
|
||||
if dtype is NotSpecified:
|
||||
dtype = cls.dtype
|
||||
if dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=cls.__name__)
|
||||
raise DTypeNotSpecified(termname=termname)
|
||||
try:
|
||||
dtype = dtype_class(dtype)
|
||||
except TypeError:
|
||||
raise InvalidDType(dtype=dtype, termname=cls.__name__)
|
||||
return dtype
|
||||
raise NotDType(dtype=dtype, termname=termname)
|
||||
|
||||
if not can_represent_dtype(dtype):
|
||||
raise UnsupportedDType(dtype=dtype, termname=termname)
|
||||
|
||||
if missing_value is NotSpecified:
|
||||
missing_value = default_missing_value_for_dtype(dtype)
|
||||
|
||||
return dtype, missing_value
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -183,7 +203,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, domain, dtype, params):
|
||||
def static_identity(cls, domain, dtype, missing_value, params):
|
||||
"""
|
||||
Return the identity of the Term that would be constructed from the
|
||||
given arguments.
|
||||
@@ -195,9 +215,9 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
This is a classmethod so that it can be called from Term.__new__ to
|
||||
determine whether to produce a new instance.
|
||||
"""
|
||||
return (cls, domain, dtype, params)
|
||||
return (cls, domain, dtype, missing_value, params)
|
||||
|
||||
def _init(self, domain, dtype, params):
|
||||
def _init(self, domain, dtype, missing_value, params):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -210,6 +230,7 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
self.domain = domain
|
||||
self.dtype = dtype
|
||||
self.missing_value = missing_value
|
||||
|
||||
for name, value in params:
|
||||
if hasattr(self, name):
|
||||
@@ -268,10 +289,6 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
return not any(dep for dep in self.dependencies
|
||||
if dep is not AssetExists())
|
||||
|
||||
@lazyval
|
||||
def missing_value(self):
|
||||
return default_fillvalue_for_dtype(self.dtype)
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
|
||||
@@ -15,8 +15,14 @@ from toolz import flip
|
||||
|
||||
uint8_dtype = dtype('uint8')
|
||||
bool_dtype = dtype('bool')
|
||||
|
||||
int64_dtype = dtype('int64')
|
||||
|
||||
float32_dtype = dtype('float32')
|
||||
float64_dtype = dtype('float64')
|
||||
|
||||
complex128_dtype = dtype('complex128')
|
||||
|
||||
datetime64D_dtype = dtype('datetime64[D]')
|
||||
datetime64ns_dtype = dtype('datetime64[ns]')
|
||||
|
||||
@@ -33,16 +39,27 @@ NaTD = NaT_for_dtype(datetime64D_dtype)
|
||||
|
||||
|
||||
_FILLVALUE_DEFAULTS = {
|
||||
bool_dtype: False,
|
||||
float32_dtype: nan,
|
||||
float64_dtype: nan,
|
||||
datetime64ns_dtype: NaTns,
|
||||
}
|
||||
|
||||
|
||||
def default_fillvalue_for_dtype(dtype):
|
||||
class NoDefaultMissingValue(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def default_missing_value_for_dtype(dtype):
|
||||
"""
|
||||
Get the default fill value for `dtype`.
|
||||
"""
|
||||
return _FILLVALUE_DEFAULTS[dtype]
|
||||
try:
|
||||
return _FILLVALUE_DEFAULTS[dtype]
|
||||
except KeyError:
|
||||
raise NoDefaultMissingValue(
|
||||
"No default value registered for dtype %s." % dtype
|
||||
)
|
||||
|
||||
|
||||
def repeat_first_axis(array, count):
|
||||
|
||||
+130
-7
@@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from inspect import getargspec
|
||||
from itertools import (
|
||||
combinations,
|
||||
count,
|
||||
@@ -17,7 +18,7 @@ from numpy.testing import assert_allclose, assert_array_equal
|
||||
import pandas as pd
|
||||
from pandas.tseries.offsets import MonthBegin
|
||||
from six import iteritems, itervalues
|
||||
from six.moves import filter
|
||||
from six.moves import filter, map
|
||||
from sqlalchemy import create_engine
|
||||
from toolz import concat
|
||||
|
||||
@@ -25,6 +26,8 @@ from zipline.assets import AssetFinder
|
||||
from zipline.assets.asset_writer import AssetDBWriterFromDataFrame
|
||||
from zipline.assets.futures import CME_CODE_TO_MONTH
|
||||
from zipline.finance.order import ORDER_STATUS
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.tradingcalendar import trading_days
|
||||
|
||||
@@ -238,6 +241,31 @@ def all_subindices(index):
|
||||
)
|
||||
|
||||
|
||||
def chrange(start, stop):
|
||||
"""
|
||||
Construct an iterable of length-1 strings beginning with `start` and ending
|
||||
with `stop`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start : str
|
||||
The first character.
|
||||
stop : str
|
||||
The last character.
|
||||
|
||||
Returns
|
||||
-------
|
||||
chars: iterable[str]
|
||||
Iterable of strings beginning with start and ending with stop.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> chrange('A', 'C')
|
||||
['A', 'B', 'C']
|
||||
"""
|
||||
return list(map(chr, range(ord(start), ord(stop) + 1)))
|
||||
|
||||
|
||||
def make_rotating_equity_info(num_assets,
|
||||
first_start,
|
||||
frequency,
|
||||
@@ -539,13 +567,18 @@ class SubTestFailures(AssertionError):
|
||||
|
||||
|
||||
def subtest(iterator, *_names):
|
||||
"""Construct a subtest in a unittest.
|
||||
"""
|
||||
Construct a subtest in a unittest.
|
||||
|
||||
This works by decorating a function as a subtest. The test will be run
|
||||
by iterating over the ``iterator`` and *unpacking the values into the
|
||||
function. If any of the runs fail, the result will be put into a set and
|
||||
the rest of the tests will be run. Finally, if any failed, all of the
|
||||
results will be dumped as one failure.
|
||||
Consider using ``zipline.utils.test_utils.parameter_space`` when subtests
|
||||
are constructed over a single input or over the cross-product of multiple
|
||||
inputs.
|
||||
|
||||
``subtest`` works by decorating a function as a subtest. The decorated
|
||||
function will be run by iterating over the ``iterator`` and *unpacking the
|
||||
values into the function. If any of the runs fail, the result will be put
|
||||
into a set and the rest of the tests will be run. Finally, if any failed,
|
||||
all of the results will be dumped as one failure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -587,6 +620,10 @@ def subtest(iterator, *_names):
|
||||
|
||||
We cannot use ``unittest2.TestCase.subTest`` because nose, pytest, and
|
||||
nose2 do not support ``addSubTest``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.utils.test_utils.parameter_space
|
||||
"""
|
||||
def dec(f):
|
||||
@wraps(f)
|
||||
@@ -664,3 +701,89 @@ def gen_calendars(start, stop, critical_dates):
|
||||
|
||||
# Also test with the trading calendar.
|
||||
yield (trading_days[trading_days.slice_indexer(start, stop)],)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_pipeline_engine(calendar, sids, random_seed, symbols=None):
|
||||
"""
|
||||
A contextManager that yields a SimplePipelineEngine holding a reference to
|
||||
an AssetFinder generated via tmp_asset_finder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendar : pd.DatetimeIndex
|
||||
Calendar to pass to the constructed PipelineEngine.
|
||||
sids : iterable[int]
|
||||
Sids to use for the temp asset finder.
|
||||
random_seed : int
|
||||
Integer used to seed instances of SeededRandomLoader.
|
||||
symbols : iterable[str], optional
|
||||
Symbols for constructed assets. Forwarded to make_simple_equity_info.
|
||||
"""
|
||||
equity_info = make_simple_equity_info(
|
||||
sids=sids,
|
||||
start_date=calendar[0],
|
||||
end_date=calendar[-1],
|
||||
symbols=symbols,
|
||||
)
|
||||
|
||||
loader = make_seeded_random_loader(random_seed, calendar, sids)
|
||||
get_loader = lambda column: loader
|
||||
|
||||
with tmp_asset_finder(equities=equity_info) as finder:
|
||||
yield SimplePipelineEngine(get_loader, calendar, finder)
|
||||
|
||||
|
||||
def parameter_space(**params):
|
||||
"""
|
||||
Wrapper around subtest that allows passing keywords mapping names to
|
||||
iterables of values.
|
||||
|
||||
The decorated test function will be called with the cross-product of all
|
||||
possible inputs
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> from unittest import TestCase
|
||||
>>> class SomeTestCase(TestCase):
|
||||
... @parameter_space(x=[1, 2], y=[2, 3])
|
||||
... def test_some_func(self, x, y):
|
||||
... # Will be called with every possible combination of x and y.
|
||||
... self.assertEqual(somefunc(x, y), expected_result(x, y))
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.utils.test_utils.subtest
|
||||
"""
|
||||
def decorator(f):
|
||||
|
||||
argspec = getargspec(f)
|
||||
if argspec.varargs:
|
||||
raise AssertionError("parameter_space() doesn't support *args")
|
||||
if argspec.keywords:
|
||||
raise AssertionError("parameter_space() doesn't support **kwargs")
|
||||
if argspec.defaults:
|
||||
raise AssertionError("parameter_space() doesn't support defaults.")
|
||||
|
||||
# Skip over implicit self.
|
||||
argnames = argspec.args
|
||||
if argnames[0] == 'self':
|
||||
argnames = argnames[1:]
|
||||
|
||||
extra = set(params) - set(argnames)
|
||||
if extra:
|
||||
raise AssertionError(
|
||||
"Keywords %s supplied to parameter_space() are "
|
||||
"not in function signature." % extra
|
||||
)
|
||||
|
||||
unspecified = set(argnames) - set(params)
|
||||
if unspecified:
|
||||
raise AssertionError(
|
||||
"Function arguments %s were not "
|
||||
"supplied to parameter_space()." % extra
|
||||
)
|
||||
|
||||
param_sets = product(*(params[name] for name in argnames))
|
||||
return subtest(param_sets, *argnames)(f)
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user