diff --git a/tests/pipeline/test_adjusted_array.py b/tests/pipeline/test_adjusted_array.py index 57afb44b..f7c94a70 100644 --- a/tests/pipeline/test_adjusted_array.py +++ b/tests/pipeline/test_adjusted_array.py @@ -25,6 +25,7 @@ from zipline.lib.adjustment import ( Float64Multiply, Float64Overwrite, Float641DArrayOverwrite, + Int64Overwrite, ObjectOverwrite, ) from zipline.lib.adjusted_array import AdjustedArray, NOMASK @@ -235,6 +236,7 @@ def _gen_overwrite_adjustment_cases(dtype): adjustment_type = { float64_dtype: Float64Overwrite, datetime64ns_dtype: Datetime64Overwrite, + int64_dtype: Int64Overwrite, bytes_dtype: ObjectOverwrite, unicode_dtype: ObjectOverwrite, object_dtype: ObjectOverwrite, @@ -585,6 +587,7 @@ class AdjustedArrayTestCase(TestCase): @parameterized.expand( chain( + _gen_overwrite_adjustment_cases(int64_dtype), _gen_overwrite_adjustment_cases(float64_dtype), _gen_overwrite_adjustment_cases(datetime64ns_dtype), _gen_overwrite_1d_array_adjustment_case(float64_dtype), diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index 68843a33..f8e1f750 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -1534,7 +1534,7 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): pd.Timestamp('2014-01-04') ]) baseline = pd.DataFrame({ - 'value': (0, 1), + 'value': (0., 1.), 'asof_date': base_dates, 'timestamp': base_dates, }) @@ -1545,7 +1545,6 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase): value=deltas.value + 10, timestamp=deltas.timestamp + timedelta(days=1), ) - nassets = len(simple_asset_info) expected_views = keymap(pd.Timestamp, { '2014-01-03': np.array([[10.0], diff --git a/zipline/lib/adjustment.pyx b/zipline/lib/adjustment.pyx index ebf0d4f2..6df8c770 100644 --- a/zipline/lib/adjustment.pyx +++ b/zipline/lib/adjustment.pyx @@ -32,6 +32,9 @@ cdef dict _datetime_adjustment_types = { cdef dict _object_adjustment_types = { OVERWRITE: ObjectOverwrite, } +cdef dict _int_adjustment_types = { + OVERWRITE: Int64Overwrite, +} cdef _is_float(object value): return isinstance(value, (float, float64)) @@ -39,6 +42,8 @@ cdef _is_float(object value): def _is_datetime(object value): return isinstance(value, (datetime64, Timestamp)) +def _is_int(object value): + return isinstance(value, (int, int64)) cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value): """ @@ -70,6 +75,8 @@ cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value): return _float_adjustment_types[adjustment_kind] elif _is_datetime(value): return _datetime_adjustment_types[adjustment_kind] + elif _is_int(value): + return _int_adjustment_types[adjustment_kind] else: return _object_adjustment_types[adjustment_kind] else: @@ -585,6 +592,45 @@ cdef class _Int64Adjustment(Adjustment): ) +cdef class Int64Overwrite(_Int64Adjustment): + """ + An adjustment that overwrites with an int. + + Example + ------- + + >>> import numpy as np + >>> arr = np.arange(9, dtype=int).reshape(3, 3) + >>> arr + array([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8]]) + + >>> adj = Int64Overwrite( + ... first_row=1, + ... last_row=2, + ... first_col=1, + ... last_col=2, + ... value=0, + ... ) + >>> adj.mutate(arr) + >>> arr + array([[ 0, 1, 2], + [ 3, 0, 0], + [ 6, 0, 0]]) + """ + + cpdef mutate(self, int64_t[:, :] data): + cdef Py_ssize_t row, col + cdef int64_t value = self.value + + # last_col + 1 because last_col should also be affected. + for col in range(self.first_col, self.last_col + 1): + # last_row + 1 because last_row should also be affected. + for row in range(self.first_row, self.last_row + 1): + data[row, col] = value + + cdef datetime_to_int(object datetimelike): """ Coerce a datetime-like object to the int format used by AdjustedArrays of diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 7c4dcf80..e725e5f8 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -760,7 +760,7 @@ def overwrite_novel_deltas(baseline, deltas, dates): def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value): - """Construct a `Float64Overwrite` with the correct + """Construct an Overwrite with the correct start and end date based on the asof date of the delta, the dense_dates, and the dense_dates. @@ -775,7 +775,7 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value): asset_idx : tuple of int The index of the asset in the block. If this is a tuple, then this is treated as the first and last index to use. - value : np.float64 + value : any The value to overwrite with. Returns