mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 17:19:16 +08:00
Merge pull request #1665 from quantopian/determine_overwrite_type_dynamically
Determine overwrite type dynamically
This commit is contained in:
@@ -25,6 +25,7 @@ from zipline.lib.adjustment import (
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
Float641DArrayOverwrite,
|
||||
Int64Overwrite,
|
||||
ObjectOverwrite,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
@@ -235,6 +236,7 @@ def _gen_overwrite_adjustment_cases(dtype):
|
||||
adjustment_type = {
|
||||
float64_dtype: Float64Overwrite,
|
||||
datetime64ns_dtype: Datetime64Overwrite,
|
||||
int64_dtype: Int64Overwrite,
|
||||
bytes_dtype: ObjectOverwrite,
|
||||
unicode_dtype: ObjectOverwrite,
|
||||
object_dtype: ObjectOverwrite,
|
||||
@@ -585,6 +587,7 @@ class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_overwrite_adjustment_cases(int64_dtype),
|
||||
_gen_overwrite_adjustment_cases(float64_dtype),
|
||||
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(float64_dtype),
|
||||
|
||||
@@ -35,6 +35,21 @@ class AdjustmentTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_make_int_adjustment(self):
|
||||
result = adj.make_adjustment_from_indices(
|
||||
1, 2, 3, 4,
|
||||
adjustment_kind=adj.OVERWRITE,
|
||||
value=1,
|
||||
)
|
||||
expected = adj.Int64Overwrite(
|
||||
first_row=1,
|
||||
last_row=2,
|
||||
first_col=3,
|
||||
last_col=4,
|
||||
value=1,
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_make_datetime_adjustment(self):
|
||||
overwrite_dt = make_datetime64ns(0)
|
||||
result = adj.make_adjustment_from_indices(
|
||||
@@ -51,6 +66,23 @@ class AdjustmentTestCase(TestCase):
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@parameterized.expand([("some text",), ("some text".encode(),), (None,)])
|
||||
def test_make_object_adjustment(self, value):
|
||||
result = adj.make_adjustment_from_indices(
|
||||
1, 2, 3, 4,
|
||||
adjustment_kind=adj.OVERWRITE,
|
||||
value=value,
|
||||
)
|
||||
|
||||
expected = adj.ObjectOverwrite(
|
||||
first_row=1,
|
||||
last_row=2,
|
||||
first_col=3,
|
||||
last_col=4,
|
||||
value=value,
|
||||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_unsupported_type(self):
|
||||
class SomeClass(object):
|
||||
pass
|
||||
|
||||
@@ -1534,7 +1534,7 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase):
|
||||
pd.Timestamp('2014-01-04')
|
||||
])
|
||||
baseline = pd.DataFrame({
|
||||
'value': (0, 1),
|
||||
'value': (0., 1.),
|
||||
'asof_date': base_dates,
|
||||
'timestamp': base_dates,
|
||||
})
|
||||
@@ -1545,7 +1545,6 @@ class BlazeToPipelineTestCase(WithAssetFinder, ZiplineTestCase):
|
||||
value=deltas.value + 10,
|
||||
timestamp=deltas.timestamp + timedelta(days=1),
|
||||
)
|
||||
|
||||
nassets = len(simple_asset_info)
|
||||
expected_views = keymap(pd.Timestamp, {
|
||||
'2014-01-03': np.array([[10.0],
|
||||
|
||||
@@ -4,6 +4,9 @@ from cpython cimport Py_EQ
|
||||
from pandas import isnull, Timestamp
|
||||
from numpy cimport float64_t, uint8_t, int64_t
|
||||
from numpy import asarray, datetime64, float64, int64
|
||||
|
||||
from zipline.utils.compat import unicode
|
||||
|
||||
# Purely for readability. There aren't C-level declarations for these types.
|
||||
ctypedef object Int64Index_t
|
||||
ctypedef object DatetimeIndex_t
|
||||
@@ -29,6 +32,12 @@ cdef dict _float_adjustment_types = {
|
||||
cdef dict _datetime_adjustment_types = {
|
||||
OVERWRITE: Datetime64Overwrite,
|
||||
}
|
||||
cdef dict _object_adjustment_types = {
|
||||
OVERWRITE: ObjectOverwrite,
|
||||
}
|
||||
cdef dict _int_adjustment_types = {
|
||||
OVERWRITE: Int64Overwrite,
|
||||
}
|
||||
|
||||
cdef _is_float(object value):
|
||||
return isinstance(value, (float, float64))
|
||||
@@ -36,6 +45,11 @@ cdef _is_float(object value):
|
||||
def _is_datetime(object value):
|
||||
return isinstance(value, (datetime64, Timestamp))
|
||||
|
||||
def _is_int(object value):
|
||||
return isinstance(value, (int, int64))
|
||||
|
||||
def _is_obj(object value):
|
||||
return isinstance(value, (bytes, unicode, type(None)))
|
||||
|
||||
cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value):
|
||||
"""
|
||||
@@ -67,11 +81,16 @@ cpdef choose_adjustment_type(AdjustmentKind adjustment_kind, object value):
|
||||
return _float_adjustment_types[adjustment_kind]
|
||||
elif _is_datetime(value):
|
||||
return _datetime_adjustment_types[adjustment_kind]
|
||||
elif _is_int(value):
|
||||
return _int_adjustment_types[adjustment_kind]
|
||||
elif _is_obj(value):
|
||||
return _object_adjustment_types[adjustment_kind]
|
||||
else:
|
||||
raise TypeError(
|
||||
"Don't know how to make overwrite "
|
||||
"adjustments for values of type %r." % type(value),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown adjustment type %d." % adjustment_kind)
|
||||
|
||||
@@ -585,6 +604,45 @@ cdef class _Int64Adjustment(Adjustment):
|
||||
)
|
||||
|
||||
|
||||
cdef class Int64Overwrite(_Int64Adjustment):
|
||||
"""
|
||||
An adjustment that overwrites with an int.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(9, dtype=int).reshape(3, 3)
|
||||
>>> arr
|
||||
array([[ 0, 1, 2],
|
||||
[ 3, 4, 5],
|
||||
[ 6, 7, 8]])
|
||||
|
||||
>>> adj = Int64Overwrite(
|
||||
... first_row=1,
|
||||
... last_row=2,
|
||||
... first_col=1,
|
||||
... last_col=2,
|
||||
... value=0,
|
||||
... )
|
||||
>>> adj.mutate(arr)
|
||||
>>> arr
|
||||
array([[ 0, 1, 2],
|
||||
[ 3, 0, 0],
|
||||
[ 6, 0, 0]])
|
||||
"""
|
||||
|
||||
cpdef mutate(self, int64_t[:, :] data):
|
||||
cdef Py_ssize_t row, col
|
||||
cdef int64_t value = self.value
|
||||
|
||||
# last_col + 1 because last_col should also be affected.
|
||||
for col in range(self.first_col, self.last_col + 1):
|
||||
# last_row + 1 because last_row should also be affected.
|
||||
for row in range(self.first_row, self.last_row + 1):
|
||||
data[row, col] = value
|
||||
|
||||
|
||||
cdef datetime_to_int(object datetimelike):
|
||||
"""
|
||||
Coerce a datetime-like object to the int format used by AdjustedArrays of
|
||||
|
||||
@@ -182,7 +182,7 @@ from zipline.pipeline.loaders.utils import (
|
||||
)
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.lib.adjusted_array import AdjustedArray, can_represent_dtype
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
from zipline.lib.adjustment import make_adjustment_from_indices, OVERWRITE
|
||||
from zipline.utils.input_validation import (
|
||||
expect_element,
|
||||
ensure_timezone,
|
||||
@@ -760,7 +760,7 @@ def overwrite_novel_deltas(baseline, deltas, dates):
|
||||
|
||||
|
||||
def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
"""Construct a `Float64Overwrite` with the correct
|
||||
"""Construct an Overwrite with the correct
|
||||
start and end date based on the asof date of the delta,
|
||||
the dense_dates, and the dense_dates.
|
||||
|
||||
@@ -775,7 +775,7 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
asset_idx : tuple of int
|
||||
The index of the asset in the block. If this is a tuple, then this
|
||||
is treated as the first and last index to use.
|
||||
value : np.float64
|
||||
value : any
|
||||
The value to overwrite with.
|
||||
|
||||
Returns
|
||||
@@ -815,7 +815,9 @@ def overwrite_from_dates(asof, dense_dates, sparse_dates, asset_idx, value):
|
||||
return
|
||||
|
||||
first, last = asset_idx
|
||||
yield Float64Overwrite(first_row, last_row, first, last, value)
|
||||
yield make_adjustment_from_indices(
|
||||
first_row, last_row, first, last, OVERWRITE, value
|
||||
)
|
||||
|
||||
|
||||
def adjustments_from_deltas_no_sids(dense_dates,
|
||||
|
||||
Reference in New Issue
Block a user