diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2a479867..00b3dc9a 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -16,7 +16,7 @@ import datetime from datetime import timedelta from mock import MagicMock from nose_parameterized import parameterized -from six.moves import range +from six.moves import range, map from textwrap import dedent from unittest import TestCase @@ -599,6 +599,26 @@ class TestTransformAlgorithm(TestCase): sim_params=self.sim_params, env=self.env, sids=[0, 1]) + panel = self.panel.copy() + panel.items = pd.Index(map(Equity, panel.items)) + algo.run(panel) + assert isinstance(algo.sources[0], DataPanelSource) + + def test_df_of_assets_as_input(self): + algo = TestRegisterTransformAlgorithm( + sim_params=self.sim_params, + env=TradingEnvironment(), # new env without assets + ) + df = self.df.copy() + df.columns = pd.Index(map(Equity, df.columns)) + algo.run(df) + assert isinstance(algo.sources[0], DataFrameSource) + + def test_panel_of_assets_as_input(self): + algo = TestRegisterTransformAlgorithm( + sim_params=self.sim_params, + env=TradingEnvironment(), # new env without assets + sids=[0, 1]) algo.run(self.panel) assert isinstance(algo.sources[0], DataPanelSource) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 274643d8..e10db5a0 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -513,48 +513,17 @@ class TradingAlgorithm(object): # if DataFrame provided, map columns to sids and wrap # in DataFrameSource copy_frame = source.copy() - - # Build new Assets for identifiers that can't be resolved as - # sids/Assets - identifiers_to_build = [] - for identifier in source.columns: - if hasattr(identifier, '__int__'): - asset = self.asset_finder.retrieve_asset(sid=identifier, - default_none=True) - if asset is None: - identifiers_to_build.append(identifier) - else: - identifiers_to_build.append(identifier) - - self.trading_environment.write_data( - equities_identifiers=identifiers_to_build) - copy_frame.columns = \ - self.asset_finder.map_identifier_index_to_sids( - source.columns, source.index[0] - ) + copy_frame.columns = self._write_and_map_id_index_to_sids( + source.columns, source.index[0], + ) source = DataFrameSource(copy_frame) elif isinstance(source, pd.Panel): # If Panel provided, map items to sids and wrap # in DataPanelSource copy_panel = source.copy() - - # Build new Assets for identifiers that can't be resolved as - # sids/Assets - identifiers_to_build = [] - for identifier in source.items: - if hasattr(identifier, '__int__'): - asset = self.asset_finder.retrieve_asset(sid=identifier, - default_none=True) - if asset is None: - identifiers_to_build.append(identifier) - else: - identifiers_to_build.append(identifier) - - self.trading_environment.write_data( - equities_identifiers=identifiers_to_build) - copy_panel.items = self.asset_finder.map_identifier_index_to_sids( - source.items, source.major_axis[0] + copy_panel.items = self._write_and_map_id_index_to_sids( + source.items, source.major_axis[0], ) source = DataPanelSource(copy_panel) @@ -617,6 +586,30 @@ class TradingAlgorithm(object): return daily_stats + def _write_and_map_id_index_to_sids(self, identifiers, as_of_date): + # Build new Assets for identifiers that can't be resolved as + # sids/Assets + identifiers_to_build = [] + for identifier in identifiers: + asset = None + + if isinstance(identifier, Asset): + asset = self.asset_finder.retrieve_asset(sid=identifier.sid, + default_none=True) + + elif hasattr(identifier, '__int__'): + asset = self.asset_finder.retrieve_asset(sid=identifier, + default_none=True) + if asset is None: + identifiers_to_build.append(identifier) + + self.trading_environment.write_data( + equities_identifiers=identifiers_to_build) + + return self.asset_finder.map_identifier_index_to_sids( + identifiers, as_of_date, + ) + def _create_daily_stats(self, perfs): # create daily and cumulative stats dataframe daily_perfs = []