mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 12:43:47 +08:00
BUG: Fix crash on .latest for integer-typed columns.
Int columns get coerced to float on load, and we don't currently support non-float columns from CustomFactors.
This commit is contained in:
@@ -504,35 +504,50 @@ class MultiColumnLoaderTestCase(TestCase):
|
||||
self.asset_finder = env.asset_finder
|
||||
|
||||
def test_engine_with_multicolumn_loader(self):
|
||||
open_, close = USEquityPricing.open, USEquityPricing.close
|
||||
open_ = USEquityPricing.open
|
||||
close = USEquityPricing.close
|
||||
volume = USEquityPricing.volume
|
||||
|
||||
# Test for thirty days up to the second to last day that we think all
|
||||
# the assets existed. If we test the last day of our calendar, no
|
||||
# assets will be in our output, because their end dates are all
|
||||
dates_to_test = self.dates[-32:-2]
|
||||
|
||||
loader = MultiColumnLoader({
|
||||
open_: ConstantLoader(dates=self.dates,
|
||||
assets=self.assets,
|
||||
constants={open_: 1}),
|
||||
close: ConstantLoader(dates=self.dates,
|
||||
assets=self.assets,
|
||||
constants={close: 2})
|
||||
})
|
||||
|
||||
constants = {open_: 1, close: 2, volume: 3}
|
||||
loader = ConstantLoader(
|
||||
constants=constants,
|
||||
dates=self.dates,
|
||||
assets=self.assets,
|
||||
)
|
||||
engine = SimpleFFCEngine(loader, self.dates, self.asset_finder)
|
||||
|
||||
factor = RollingSumDifference()
|
||||
sumdiff = RollingSumDifference()
|
||||
|
||||
result = engine.factor_matrix({'f': factor},
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1])
|
||||
result = engine.factor_matrix(
|
||||
{
|
||||
'sumdiff': sumdiff,
|
||||
'open': open_.latest,
|
||||
'close': close.latest,
|
||||
'volume': volume.latest,
|
||||
},
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1]
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual({'f'}, set(result.columns))
|
||||
self.assertEqual(
|
||||
{'sumdiff', 'open', 'close', 'volume'},
|
||||
set(result.columns)
|
||||
)
|
||||
|
||||
result_index = self.assets * len(dates_to_test)
|
||||
result_shape = (len(result_index),)
|
||||
check_arrays(
|
||||
result['f'],
|
||||
result['sumdiff'],
|
||||
Series(index=result_index, data=full(result_shape, -3)),
|
||||
)
|
||||
|
||||
for name, const in [('open', 1), ('close', 2), ('volume', 3)]:
|
||||
check_arrays(
|
||||
result[name],
|
||||
Series(index=result_index, data=full(result_shape, const)),
|
||||
)
|
||||
|
||||
@@ -71,7 +71,10 @@ class BoundColumn(Term):
|
||||
|
||||
@property
|
||||
def latest(self):
|
||||
return Latest(inputs=(self,), dtype=self.dtype)
|
||||
# FIXME: Once we support non-float dtypes, this should pass a dtype
|
||||
# along. Right now we're just assuming that inputs will safely coerce
|
||||
# to float.
|
||||
return Latest(inputs=(self,))
|
||||
|
||||
def __repr__(self):
|
||||
return "{qualname}::{dtype}".format(
|
||||
|
||||
Reference in New Issue
Block a user