mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 15:25:50 +08:00
Merge pull request #1052 from quantopian/empty-pipeline-bug
BUG: Empty pipeline failed to tz_localize
This commit is contained in:
@@ -566,3 +566,43 @@ class PipelineAlgorithmTestCase(TestCase):
|
||||
# TradingAlgorithm.
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
def test_empty_pipeline(self):
|
||||
|
||||
# For ensuring we call before_trading_start.
|
||||
count = [0]
|
||||
|
||||
def initialize(context):
|
||||
pipeline = attach_pipeline(Pipeline(), 'test')
|
||||
|
||||
vwap = VWAP(window_length=10)
|
||||
pipeline.add(vwap, 'vwap')
|
||||
|
||||
# Nothing should have prices less than 0.
|
||||
pipeline.set_screen(vwap < 0)
|
||||
|
||||
def handle_data(context, data):
|
||||
pass
|
||||
|
||||
def before_trading_start(context, data):
|
||||
context.results = pipeline_output('test')
|
||||
self.assertTrue(context.results.empty)
|
||||
count[0] += 1
|
||||
|
||||
algo = TradingAlgorithm(
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.dates[0],
|
||||
end=self.dates[-1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
source=self.make_source(),
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
self.assertTrue(count[0] > 0)
|
||||
|
||||
@@ -384,6 +384,23 @@ class SimplePipelineEngine(object):
|
||||
If mask[date, asset] is True, then result.loc[(date, asset), colname]
|
||||
will contain the value of data[colname][date, asset].
|
||||
"""
|
||||
if not mask.any():
|
||||
# Manually handle the empty DataFrame case. This is a workaround
|
||||
# to pandas failing to tz_localize an empty dataframe with a
|
||||
# MultiIndex. It also saves us the work of applying a known-empty
|
||||
# mask to each array.
|
||||
#
|
||||
# Slicing `dates` here to preserve pandas metadata.
|
||||
empty_dates = dates[:0]
|
||||
empty_assets = array([], dtype=object)
|
||||
return DataFrame(
|
||||
data={
|
||||
name: array([], dtype=arr.dtype)
|
||||
for name, arr in iteritems(data)
|
||||
},
|
||||
index=MultiIndex.from_arrays([empty_dates, empty_assets]),
|
||||
)
|
||||
|
||||
resolved_assets = array(self._finder.retrieve_all(assets))
|
||||
dates_kept = repeat_last_axis(dates.values, len(assets))[mask]
|
||||
assets_kept = repeat_first_axis(resolved_assets, len(dates))[mask]
|
||||
|
||||
Reference in New Issue
Block a user