diff --git a/tests/test_assets.py b/tests/test_assets.py index 66f8248b..94569080 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -718,3 +718,25 @@ class AssetFinderTestCase(TestCase): # No contracts exist after 12/14/2015, so we should get none self.assertIsNone(finder.lookup_future_by_expiration('AD', dt, jan_16)) + + def test_map_identifier_list_to_sids(self): + + # Build an empty finder and some Assets + dt = pd.Timestamp('2014-01-01', tz='UTC') + finder = AssetFinder() + asset1 = Equity(1, symbol="AAPL") + asset2 = Equity(2, symbol="GOOG") + asset200 = Future(200, symbol="CLK15") + asset201 = Future(201, symbol="CLM15") + + # Check for correct mapping and types + pre_map = [asset1, asset2, asset200, asset201] + post_map = finder.map_identifier_list_to_sids(pre_map, dt) + self.assertListEqual([1, 2, 200, 201], post_map) + for sid in post_map: + self.assertIsInstance(sid, int) + + # Change order and check mapping again + pre_map = [asset201, asset2, asset200, asset1] + post_map = finder.map_identifier_list_to_sids(pre_map, dt) + self.assertListEqual([201, 2, 200, 1], post_map) diff --git a/tests/test_history.py b/tests/test_history.py index a29ef469..5a515e23 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -437,7 +437,7 @@ def handle_data(context, data): _, df = factory.create_test_df_source(sim_params) df = df.astype(np.float64) - source = DataFrameSource(df, sids=[0]) + source = DataFrameSource(df) test_algo = TradingAlgorithm( script=algo_text, diff --git a/tests/test_sources.py b/tests/test_sources.py index 3664acdd..2e2c7c23 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -27,6 +27,7 @@ from zipline.sources import (DataFrameSource, RandomWalkSource) from zipline.utils import tradingcalendar as calendar_nyse from zipline.finance.trading import with_environment +from zipline.assets import AssetFinder class TestDataFrameSource(TestCase): @@ -43,7 +44,7 @@ class TestDataFrameSource(TestCase): def test_df_sid_filtering(self): _, df = factory.create_test_df_source() - source = DataFrameSource(df, sids=[0]) + source = DataFrameSource(df) assert 1 not in [event.sid for event in source], \ "DataFrameSource should only stream selected sid 0, not sid 1." @@ -63,8 +64,8 @@ class TestDataFrameSource(TestCase): self.assertTrue(isinstance(event['volume'], int)) self.assertTrue(isinstance(event['arbitrary'], float)) - @with_environment() - def test_yahoo_bars_to_panel_source(self, env=None): + def test_yahoo_bars_to_panel_source(self): + finder = AssetFinder() stocks = ['AAPL', 'GE'] start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc) end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) @@ -75,10 +76,15 @@ class TestDataFrameSource(TestCase): check_fields = ['sid', 'open', 'high', 'low', 'close', 'volume', 'price'] - source = DataPanelSource(data) + + copy_panel = data.copy() + copy_panel.items = finder.map_identifier_list_to_sids( + data.items, data.major_axis[0] + ) + source = DataPanelSource(copy_panel) sids = [ asset.sid for asset in - [env.asset_finder.lookup_symbol(symbol, as_of_date=end) + [finder.lookup_symbol(symbol, as_of_date=end) for symbol in stocks] ] stocks_iter = cycle(sids) diff --git a/tests/test_transforms_talib.py b/tests/test_transforms_talib.py index 5b9f50a1..203d0190 100644 --- a/tests/test_transforms_talib.py +++ b/tests/test_transforms_talib.py @@ -98,7 +98,7 @@ class TestTALIB(TestCase): zipline_transforms = [ta.MA(timeperiod=10), ta.MA(timeperiod=25)] talib_fn = talib.abstract.MA - algo = TALIBAlgorithm(talib=zipline_transforms) + algo = TALIBAlgorithm(talib=zipline_transforms, identifiers=[0]) algo.run(self.source) # Test if computed values match those computed by pandas rolling mean. sid = 0 diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 375793fc..41a4aa65 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -39,7 +39,6 @@ from zipline.errors import ( UnsupportedCommissionModel, UnsupportedOrderParameters, UnsupportedSlippageModel, - SidNotFound, ) from zipline.finance.trading import TradingEnvironment @@ -461,10 +460,22 @@ class TradingAlgorithm(object): __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): - # if DataFrame provided, wrap in DataFrameSource - source = DataFrameSource(source) + # if DataFrame provided, map columns to sids and wrap + # in DataFrameSource + copy_frame = source.copy() + copy_frame.columns = self.asset_finder.map_identifier_list_to_sids( + source.columns, source.index[0] + ) + source = DataFrameSource(copy_frame) + elif isinstance(source, pd.Panel): - source = DataPanelSource(source) + # If Panel provided, map items to sids and wrap + # in DataPanelSource + copy_panel = source.copy() + copy_panel.items = self.asset_finder.map_identifier_list_to_sids( + source.items, source.major_axis[0] + ) + source = DataPanelSource(copy_panel) if isinstance(source, list): self.set_sources(source) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index c14c93f5..51d592cb 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -18,6 +18,7 @@ from itertools import chain from numbers import Integral import numpy as np import operator +import warnings from logbook import Logger import pandas as pd @@ -531,6 +532,30 @@ class AssetFinder(object): self._lookup_generic_scalar(obj, as_of_date, matches, missing) return matches, missing + def map_identifier_list_to_sids(self, index, as_of_date): + """ + This method is for use in sanitizing a user's DataFrame inputs. + Takes the given indices, inserts them in to the AssetFinder as + identifiers, rebuilds the Assets, and returns a new object that + contains the sids of the given identifiers. + + :param index: The index to be mapped + :return: The new index of sids + """ + # Populate the caches with the given indices + self.consume_identifiers(index) + self.populate_cache() + + # Find all of the newly built assets corresponding to the indices + found, missing = self.lookup_generic(index, as_of_date) + + # Handle missing assets + if len(missing) > 0: + warnings.warn("Missing assets for identifiers: " + missing) + + # Return a list of the sids of the found assets + return [asset.sid for asset in found] + def insert_metadata(self, identifier, **kwargs): """ Inserts the given metadata kwargs to the entry for the given diff --git a/zipline/examples/pairtrade.py b/zipline/examples/pairtrade.py index ca71d8f9..ffb37d53 100755 --- a/zipline/examples/pairtrade.py +++ b/zipline/examples/pairtrade.py @@ -23,7 +23,6 @@ import pytz from zipline.algorithm import TradingAlgorithm from zipline.transforms import batch_transform from zipline.utils.factory import load_from_yahoo -from zipline.sources.data_frame_source import DataFrameSource @batch_transform @@ -121,10 +120,9 @@ if __name__ == '__main__': end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={}, start=start, end=end) - source = DataFrameSource(data) pairtrade = Pairtrade() - results = pairtrade.run(source) + results = pairtrade.run(data) data['spreads'] = np.nan ax1 = plt.subplot(211) diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index d42fdf5a..77241812 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -18,7 +18,6 @@ Tools to generate data sources. """ import numpy as np import pandas as pd -import datetime from zipline.gens.utils import hash_args @@ -41,19 +40,13 @@ class DataFrameSource(DataSource): @with_environment() def __init__(self, data, env=None, **kwargs): assert isinstance(data.index, pd.tseries.index.DatetimeIndex) - + assert isinstance(data.columns, pd.Int64Index) + # TODO is ffilling correct/necessary? + # Forward fill prices self.data = data.fillna(method='ffill') # Unpack config dictionary with default values. self.start = kwargs.get('start', self.data.index[0]) self.end = kwargs.get('end', self.data.index[-1]) - - # Remap sids based on the trading environment - env.update_asset_finder( - identifiers=kwargs.get('sids', self.data.columns) - ) - self.data.columns, _ = env.asset_finder.lookup_generic( - self.data.columns, datetime.datetime.now() - ) self.sids = self.data.columns # Hash_value for downstream sorting. @@ -118,20 +111,15 @@ class DataPanelSource(DataSource): @with_environment() def __init__(self, data, env=None, **kwargs): assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex) - + # Only accept integer SIDs as the items of the Panel + assert isinstance(data.items, pd.Int64Index) + # TODO is ffilling correct/necessary? + # forward fill with volumes of 0 self.data = data.fillna(value={'volume': 0}) self.data = self.data.fillna(method='ffill') # Unpack config dictionary with default values. self.start = kwargs.get('start', self.data.major_axis[0]) self.end = kwargs.get('end', self.data.major_axis[-1]) - - # Remap sids based on the trading environment - env.update_asset_finder( - identifiers=kwargs.get('sids', self.data.items) - ) - self.data.items, _ = env.asset_finder.lookup_generic( - self.data.items, datetime.datetime.now() - ) self.sids = self.data.items # Hash_value for downstream sorting.