mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 06:28:42 +08:00
ENH: add Int64Overwrite and dispatching for it
BUG: column value should be float DOC: update docs
This commit is contained in:
@@ -25,6 +25,7 @@ from zipline.lib.adjustment import (
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
Float641DArrayOverwrite,
|
||||
Int64Overwrite,
|
||||
ObjectOverwrite,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
@@ -235,6 +236,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
adjustment_type = {
|
||||
float64_dtype: Float64Overwrite,
|
||||
datetime64ns_dtype: Datetime64Overwrite,
|
||||
int64_dtype: Int64Overwrite,
|
||||
bytes_dtype: ObjectOverwrite,
|
||||
unicode_dtype: ObjectOverwrite,
|
||||
object_dtype: ObjectOverwrite,
|
||||
@@ -585,6 +587,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_overwrite_adjustment_cases(int64_dtype),
|
||||
_gen_overwrite_adjustment_cases(float64_dtype),
|
||||
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(float64_dtype),
|
||||
|
||||
@@ -1534,7 +1534,7 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase):
|
||||
pd.Timestamp('2014-01-04')
|
||||
])
|
||||
baseline = pd.DataFrame({
|
||||
'value': (0, 1),
|
||||
'value': (0., 1.),
|
||||
'asof_date': base_dates,
|
||||
'timestamp': base_dates,
|
||||
})
|
||||
@@ -1545,7 +1545,6 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase):
|
||||
value=deltas.value + 10,
|
||||
timestamp=deltas.timestamp + timedelta(days=1),
|
||||
)
|
||||
|
||||
nassets = len(simple_asset_info)
|
||||
expected_views = keymap(pd.Timestamp, {
|
||||
'2014-01-03': np.array([[10.0],
|
||||
|
||||
@@ -32,6 +32,9 @@ cdef dict _datetime_adjustment_types = {
|
||||
cdef dict _object_adjustment_types = {
|
||||
OVERWRITE: ObjectOverwrite,
|
||||
}
|
||||
cdef dict _int_adjustment_types = {
|
||||
OVERWRITE: Int64Overwrite,
|
||||
}
|
||||
|
||||
cdef _is_float(object value):
|
||||
return isinstance(value, (float, float64))
|
||||
@@ -39,6 +42,8 @@ cdef _is_float(object value):
|
||||
def _is_datetime(object value):
|
||||
return isinstance(value, (datetime64, Timestamp))
|
||||
|
||||
def _is_int(object value):
|
||||
return isinstance(value, (int, int64))
|
||||
|
||||
cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value):
|
||||
"""
|
||||
@@ -70,6 +75,8 @@ cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value):
|
||||
return _float_adjustment_types[adjustment_kind]
|
||||
elif _is_datetime(value):
|
||||
return _datetime_adjustment_types[adjustment_kind]
|
||||
elif _is_int(value):
|
||||
return _int_adjustment_types[adjustment_kind]
|
||||
else:
|
||||
return _object_adjustment_types[adjustment_kind]
|
||||
else:
|
||||
@@ -585,6 +592,45 @@ cdef class _Int64Adjustment(Adjustment):
|
||||
)
|
||||
|
||||
|
||||
cdef class Int64Overwrite(_Int64Adjustment):
|
||||
"""
|
||||
An adjustment that overwrites with an int.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(9, dtype=int).reshape(3, 3)
|
||||
>>> arr
|
||||
array([[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8]])
|
||||
|
||||
>>> adj = Int64Overwrite(
|
||||
... first_row=1,
|
||||
... last_row=2,
|
||||
... first_col=1,
|
||||
... last_col=2,
|
||||
... value=0,
|
||||
... )
|
||||
>>> adj.mutate(arr)
|
||||
>>> arr
|
||||
array([[ 0, 1, 2],
|
||||
[ 3, 0, 0],
|
||||
[ 6, 0, 0]])
|
||||
"""
|
||||
|
||||
cpdef mutate(self, int64_t[:, :] data):
|
||||
cdef Py_ssize_t row, col
|
||||
cdef int64_t value = self.value
|
||||
|
||||
# last_col + 1 because last_col should also be affected.
|
||||
for col in range(self.first_col, self.last_col + 1):
|
||||
# last_row + 1 because last_row should also be affected.
|
||||
for row in range(self.first_row, self.last_row + 1):
|
||||
data[row, col] = value
|
||||
|
||||
|
||||
cdef datetime_to_int(object datetimelike):
|
||||
"""
|
||||
Coerce a datetime-like object to the int format used by AdjustedArrays of
|
||||
|
||||
@@ -760,7 +760,7 @@ def overwrite_novel_deltas(baseline, deltas, dates):
|
||||
|
||||
|
||||
def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
"""Construct a `Float64Overwrite` with the correct
|
||||
"""Construct an Overwrite with the correct
|
||||
start and end date based on the asof date of the delta,
|
||||
the dense_dates, and the dense_dates.
|
||||
|
||||
@@ -775,7 +775,7 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
asset_idx : tuple of int
|
||||
The index of the asset in the block. If this is a tuple, then this
|
||||
is treated as the first and last index to use.
|
||||
value : np.float64
|
||||
value : any
|
||||
The value to overwrite with.
|
||||
|
||||
Returns
|
||||
|
||||
Reference in New Issue
Block a user