diff --git a/tests/pipeline/test_adjusted_array.py b/tests/pipeline/test_adjusted_array.py index 57afb44b..f7c94a70 100644 --- a/tests/pipeline/test_adjusted_array.py +++ b/tests/pipeline/test_adjusted_array.py @@ -25,6 +25,7 @@ from zipline.lib.adjustment import ( Float64Multiply, Float64Overwrite, Float641DArrayOverwrite, + Int64Overwrite, ObjectOverwrite, ) from zipline.lib.adjusted_array import AdjustedArray, NOMASK @@ -235,6 +236,7 @@ def _gen_overwrite_adjustment_cases(dtype): adjustment_type = { float64_dtype: Float64Overwrite, datetime64ns_dtype: Datetime64Overwrite, + int64_dtype: Int64Overwrite, bytes_dtype: ObjectOverwrite, unicode_dtype: ObjectOverwrite, object_dtype: ObjectOverwrite, @@ -585,6 +587,7 @@ class AdjustedArrayTestCase(TestCase): @parameterized.expand( chain( + _gen_overwrite_adjustment_cases(int64_dtype), _gen_overwrite_adjustment_cases(float64_dtype), _gen_overwrite_adjustment_cases(datetime64ns_dtype), _gen_overwrite_1d_array_adjustment_case(float64_dtype), diff --git a/tests/pipeline/test_adjustment.py b/tests/pipeline/test_adjustment.py index d5c99c0e..9e31041e 100644 --- a/tests/pipeline/test_adjustment.py +++ b/tests/pipeline/test_adjustment.py @@ -35,6 +35,21 @@ class AdjustmentTestCase(TestCase): ) self.assertEqual(result, expected) + def test_make_int_adjustment(self): + result = adj.make_adjustment_from_indices( + 1, 2, 3, 4, + adjustment_kind=adj.OVERWRITE, + value=1, + ) + expected = adj.Int64Overwrite( + first_row=1, + last_row=2, + first_col=3, + last_col=4, + value=1, + ) + self.assertEqual(result, expected) + def test_make_datetime_adjustment(self): overwrite_dt = make_datetime64ns(0) result = adj.make_adjustment_from_indices( @@ -51,6 +66,23 @@ class AdjustmentTestCase(TestCase): ) self.assertEqual(result, expected) + @parameterized.expand([("some text",), ("some text".encode(),), (None,)]) + def test_make_object_adjustment(self, value): + result = adj.make_adjustment_from_indices( + 1, 2, 3, 4, + adjustment_kind=adj.OVERWRITE, + value=value, + ) + + expected = adj.ObjectOverwrite( + first_row=1, + last_row=2, + first_col=3, + last_col=4, + value=value, + ) + self.assertEqual(result, expected) + def test_unsupported_type(self): class SomeClass(object): pass diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index 68843a33..f8e1f750 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -1534,7 +1534,7 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): pd.Timestamp('2014-01-04') ]) baseline = pd.DataFrame({ - 'value': (0, 1), + 'value': (0., 1.), 'asof_date': base_dates, 'timestamp': base_dates, }) @@ -1545,7 +1545,6 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): value=deltas.value + 10, timestamp=deltas.timestamp + timedelta(days=1), ) - nassets = len(simple_asset_info) expected_views = keymap(pd.Timestamp, { '2014-01-03': np.array([[10.0], diff --git a/zipline/lib/adjustment.pyx b/zipline/lib/adjustment.pyx index aea8df5c..13c01280 100644 --- a/zipline/lib/adjustment.pyx +++ b/zipline/lib/adjustment.pyx @@ -4,6 +4,9 @@ from cpython cimport Py_EQ from pandas import isnull, Timestamp from numpy cimport float64_t, uint8_t, int64_t from numpy import asarray, datetime64, float64, int64 + +from zipline.utils.compat import unicode + # Purely for readability. There aren't C-level declarations for these types. ctypedef object Int64Index_t ctypedef object DatetimeIndex_t @@ -29,6 +32,12 @@ cdef dict _float_adjustment_types = { cdef dict _datetime_adjustment_types = { OVERWRITE: Datetime64Overwrite, } +cdef dict _object_adjustment_types = { + OVERWRITE: ObjectOverwrite, +} +cdef dict _int_adjustment_types = { + OVERWRITE: Int64Overwrite, +} cdef _is_float(object value): return isinstance(value, (float, float64)) @@ -36,6 +45,11 @@ cdef _is_float(object value): def _is_datetime(object value): return isinstance(value, (datetime64, Timestamp)) +def _is_int(object value): + return isinstance(value, (int, int64)) + +def _is_obj(object value): + return isinstance(value, (bytes, unicode, type(None))) cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value): """ @@ -67,11 +81,16 @@ cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value): return _float_adjustment_types[adjustment_kind] elif _is_datetime(value): return _datetime_adjustment_types[adjustment_kind] + elif _is_int(value): + return _int_adjustment_types[adjustment_kind] + elif _is_obj(value): + return _object_adjustment_types[adjustment_kind] else: raise TypeError( "Don't know how to make overwrite " "adjustments for values of type %r." % type(value), ) + else: raise ValueError("Unknown adjustment type %d." % adjustment_kind) @@ -585,6 +604,45 @@ cdef class _Int64Adjustment(Adjustment): ) +cdef class Int64Overwrite(_Int64Adjustment): + """ + An adjustment that overwrites with an int. + + Example + ------- + + >>> import numpy as np + >>> arr = np.arange(9, dtype=int).reshape(3, 3) + >>> arr + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8]]) + + >>> adj = Int64Overwrite( + ... first_row=1, + ... last_row=2, + ... first_col=1, + ... last_col=2, + ... value=0, + ... ) + >>> adj.mutate(arr) + >>> arr + array([[ 0, 1, 2], + [ 3, 0, 0], + [ 6, 0, 0]]) + """ + + cpdef mutate(self, int64_t[:, :] data): + cdef Py_ssize_t row, col + cdef int64_t value = self.value + + # last_col + 1 because last_col should also be affected. + for col in range(self.first_col, self.last_col + 1): + # last_row + 1 because last_row should also be affected. + for row in range(self.first_row, self.last_row + 1): + data[row, col] = value + + cdef datetime_to_int(object datetimelike): """ Coerce a datetime-like object to the int format used by AdjustedArrays of diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 00755078..e725e5f8 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -182,7 +182,7 @@ from zipline.pipeline.loaders.utils import ( ) from zipline.pipeline.sentinels import NotSpecified from zipline.lib.adjusted_array import AdjustedArray, can_represent_dtype -from zipline.lib.adjustment import Float64Overwrite +from zipline.lib.adjustment import make_adjustment_from_indices, OVERWRITE from zipline.utils.input_validation import ( expect_element, ensure_timezone, @@ -760,7 +760,7 @@ def overwrite_novel_deltas(baseline, deltas, dates): def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value): - """Construct a `Float64Overwrite` with the correct + """Construct an Overwrite with the correct start and end date based on the asof date of the delta, the dense_dates, and the dense_dates. @@ -775,7 +775,7 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value): asset_idx : tuple of int The index of the asset in the block. If this is a tuple, then this is treated as the first and last index to use. - value : np.float64 + value : any The value to overwrite with. Returns @@ -815,7 +815,9 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value): return first, last = asset_idx - yield Float64Overwrite(first_row, last_row, first, last, value) + yield make_adjustment_from_indices( + first_row, last_row, first, last, OVERWRITE, value + ) def adjustments_from_deltas_no_sids(dense_dates,