mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 02:54:16 +08:00
BUG: Running an algo with a df/panel of Assets was raising SidNotFound
This commit is contained in:
+21
-1
@@ -16,7 +16,7 @@ import datetime
|
||||
from datetime import timedelta
|
||||
from mock import MagicMock
|
||||
from nose_parameterized import parameterized
|
||||
from six.moves import range
|
||||
from six.moves import range, map
|
||||
from textwrap import dedent
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -599,6 +599,26 @@ class TestTransformAlgorithm(TestCase):
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
sids=[0, 1])
|
||||
panel = self.panel.copy()
|
||||
panel.items = pd.Index(map(Equity, panel.items))
|
||||
algo.run(panel)
|
||||
assert isinstance(algo.sources[0], DataPanelSource)
|
||||
|
||||
def test_df_of_assets_as_input(self):
|
||||
algo = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
env=TradingEnvironment(), # new env without assets
|
||||
)
|
||||
df = self.df.copy()
|
||||
df.columns = pd.Index(map(Equity, df.columns))
|
||||
algo.run(df)
|
||||
assert isinstance(algo.sources[0], DataFrameSource)
|
||||
|
||||
def test_panel_of_assets_as_input(self):
|
||||
algo = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
env=TradingEnvironment(), # new env without assets
|
||||
sids=[0, 1])
|
||||
algo.run(self.panel)
|
||||
assert isinstance(algo.sources[0], DataPanelSource)
|
||||
|
||||
|
||||
+29
-36
@@ -513,48 +513,17 @@ class TradingAlgorithm(object):
|
||||
# if DataFrame provided, map columns to sids and wrap
|
||||
# in DataFrameSource
|
||||
copy_frame = source.copy()
|
||||
|
||||
# Build new Assets for identifiers that can't be resolved as
|
||||
# sids/Assets
|
||||
identifiers_to_build = []
|
||||
for identifier in source.columns:
|
||||
if hasattr(identifier, '__int__'):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier,
|
||||
default_none=True)
|
||||
if asset is None:
|
||||
identifiers_to_build.append(identifier)
|
||||
else:
|
||||
identifiers_to_build.append(identifier)
|
||||
|
||||
self.trading_environment.write_data(
|
||||
equities_identifiers=identifiers_to_build)
|
||||
copy_frame.columns = \
|
||||
self.asset_finder.map_identifier_index_to_sids(
|
||||
source.columns, source.index[0]
|
||||
)
|
||||
copy_frame.columns = self._write_and_map_id_index_to_sids(
|
||||
source.columns, source.index[0],
|
||||
)
|
||||
source = DataFrameSource(copy_frame)
|
||||
|
||||
elif isinstance(source, pd.Panel):
|
||||
# If Panel provided, map items to sids and wrap
|
||||
# in DataPanelSource
|
||||
copy_panel = source.copy()
|
||||
|
||||
# Build new Assets for identifiers that can't be resolved as
|
||||
# sids/Assets
|
||||
identifiers_to_build = []
|
||||
for identifier in source.items:
|
||||
if hasattr(identifier, '__int__'):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier,
|
||||
default_none=True)
|
||||
if asset is None:
|
||||
identifiers_to_build.append(identifier)
|
||||
else:
|
||||
identifiers_to_build.append(identifier)
|
||||
|
||||
self.trading_environment.write_data(
|
||||
equities_identifiers=identifiers_to_build)
|
||||
copy_panel.items = self.asset_finder.map_identifier_index_to_sids(
|
||||
source.items, source.major_axis[0]
|
||||
copy_panel.items = self._write_and_map_id_index_to_sids(
|
||||
source.items, source.major_axis[0],
|
||||
)
|
||||
source = DataPanelSource(copy_panel)
|
||||
|
||||
@@ -617,6 +586,30 @@ class TradingAlgorithm(object):
|
||||
|
||||
return daily_stats
|
||||
|
||||
def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
|
||||
# Build new Assets for identifiers that can't be resolved as
|
||||
# sids/Assets
|
||||
identifiers_to_build = []
|
||||
for identifier in identifiers:
|
||||
asset = None
|
||||
|
||||
if isinstance(identifier, Asset):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
|
||||
default_none=True)
|
||||
|
||||
elif hasattr(identifier, '__int__'):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier,
|
||||
default_none=True)
|
||||
if asset is None:
|
||||
identifiers_to_build.append(identifier)
|
||||
|
||||
self.trading_environment.write_data(
|
||||
equities_identifiers=identifiers_to_build)
|
||||
|
||||
return self.asset_finder.map_identifier_index_to_sids(
|
||||
identifiers, as_of_date,
|
||||
)
|
||||
|
||||
def _create_daily_stats(self, perfs):
|
||||
# create daily and cumulative stats dataframe
|
||||
daily_perfs = []
|
||||
|
||||
Reference in New Issue
Block a user