From 258b5ea2ca0ab2cecaa3b5b5e0ff92f7a61a6066 Mon Sep 17 00:00:00 2001 From: jfkirk Date: Fri, 26 Jun 2015 16:41:38 -0400 Subject: [PATCH] API: DataFrame/Panel sources expect integer sids, not identifiers This commit modifies the DataFrameSource and DataPanelSource to accept only Int64Indexes on the incoming data and moves the burden of mapping user identifiers to TradingAlgorithm.run(). --- tests/test_assets.py | 22 ++++++++++++++++++++++ tests/test_history.py | 2 +- tests/test_sources.py | 16 +++++++++++----- tests/test_transforms_talib.py | 2 +- zipline/algorithm.py | 19 +++++++++++++++---- zipline/assets/assets.py | 25 +++++++++++++++++++++++++ zipline/examples/pairtrade.py | 4 +--- zipline/sources/data_frame_source.py | 26 +++++++------------------- 8 files changed, 83 insertions(+), 33 deletions(-) diff --git a/tests/test_assets.py b/tests/test_assets.py index 66f8248b..94569080 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -718,3 +718,25 @@ class AssetFinderTestCase(TestCase): # No contracts exist after 12/14/2015, so we should get none self.assertIsNone(finder.lookup_future_by_expiration('AD', dt, jan_16)) + + def test_map_identifier_list_to_sids(self): + + # Build an empty finder and some Assets + dt = pd.Timestamp('2014-01-01', tz='UTC') + finder = AssetFinder() + asset1 = Equity(1, symbol="AAPL") + asset2 = Equity(2, symbol="GOOG") + asset200 = Future(200, symbol="CLK15") + asset201 = Future(201, symbol="CLM15") + + # Check for correct mapping and types + pre_map = [asset1, asset2, asset200, asset201] + post_map = finder.map_identifier_list_to_sids(pre_map, dt) + self.assertListEqual([1, 2, 200, 201], post_map) + for sid in post_map: + self.assertIsInstance(sid, int) + + # Change order and check mapping again + pre_map = [asset201, asset2, asset200, asset1] + post_map = finder.map_identifier_list_to_sids(pre_map, dt) + self.assertListEqual([201, 2, 200, 1], post_map) diff --git a/tests/test_history.py b/tests/test_history.py index a29ef469..5a515e23 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -437,7 +437,7 @@ def handle_data(context, data): _, df = factory.create_test_df_source(sim_params) df = df.astype(np.float64) - source = DataFrameSource(df, sids=[0]) + source = DataFrameSource(df) test_algo = TradingAlgorithm( script=algo_text, diff --git a/tests/test_sources.py b/tests/test_sources.py index 3664acdd..2e2c7c23 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -27,6 +27,7 @@ from zipline.sources import (DataFrameSource, RandomWalkSource) from zipline.utils import tradingcalendar as calendar_nyse from zipline.finance.trading import with_environment +from zipline.assets import AssetFinder class TestDataFrameSource(TestCase): @@ -43,7 +44,7 @@ class TestDataFrameSource(TestCase): def test_df_sid_filtering(self): _, df = factory.create_test_df_source() - source = DataFrameSource(df, sids=[0]) + source = DataFrameSource(df) assert 1 not in [event.sid for event in source], \ "DataFrameSource should only stream selected sid 0, not sid 1." @@ -63,8 +64,8 @@ class TestDataFrameSource(TestCase): self.assertTrue(isinstance(event['volume'], int)) self.assertTrue(isinstance(event['arbitrary'], float)) - @with_environment() - def test_yahoo_bars_to_panel_source(self, env=None): + def test_yahoo_bars_to_panel_source(self): + finder = AssetFinder() stocks = ['AAPL', 'GE'] start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc) end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) @@ -75,10 +76,15 @@ class TestDataFrameSource(TestCase): check_fields = ['sid', 'open', 'high', 'low', 'close', 'volume', 'price'] - source = DataPanelSource(data) + + copy_panel = data.copy() + copy_panel.items = finder.map_identifier_list_to_sids( + data.items, data.major_axis[0] + ) + source = DataPanelSource(copy_panel) sids = [ asset.sid for asset in - [env.asset_finder.lookup_symbol(symbol, as_of_date=end) + [finder.lookup_symbol(symbol, as_of_date=end) for symbol in stocks] ] stocks_iter = cycle(sids) diff --git a/tests/test_transforms_talib.py b/tests/test_transforms_talib.py index 5b9f50a1..203d0190 100644 --- a/tests/test_transforms_talib.py +++ b/tests/test_transforms_talib.py @@ -98,7 +98,7 @@ class TestTALIB(TestCase): zipline_transforms = [ta.MA(timeperiod=10), ta.MA(timeperiod=25)] talib_fn = talib.abstract.MA - algo = TALIBAlgorithm(talib=zipline_transforms) + algo = TALIBAlgorithm(talib=zipline_transforms, identifiers=[0]) algo.run(self.source) # Test if computed values match those computed by pandas rolling mean. sid = 0 diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 375793fc..41a4aa65 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -39,7 +39,6 @@ from zipline.errors import ( UnsupportedCommissionModel, UnsupportedOrderParameters, UnsupportedSlippageModel, - SidNotFound, ) from zipline.finance.trading import TradingEnvironment @@ -461,10 +460,22 @@ class TradingAlgorithm(object): __init__().""", UserWarning) overwrite_sim_params = False elif isinstance(source, pd.DataFrame): - # if DataFrame provided, wrap in DataFrameSource - source = DataFrameSource(source) + # if DataFrame provided, map columns to sids and wrap + # in DataFrameSource + copy_frame = source.copy() + copy_frame.columns = self.asset_finder.map_identifier_list_to_sids( + source.columns, source.index[0] + ) + source = DataFrameSource(copy_frame) + elif isinstance(source, pd.Panel): - source = DataPanelSource(source) + # If Panel provided, map items to sids and wrap + # in DataPanelSource + copy_panel = source.copy() + copy_panel.items = self.asset_finder.map_identifier_list_to_sids( + source.items, source.major_axis[0] + ) + source = DataPanelSource(copy_panel) if isinstance(source, list): self.set_sources(source) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index c14c93f5..51d592cb 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -18,6 +18,7 @@ from itertools import chain from numbers import Integral import numpy as np import operator +import warnings from logbook import Logger import pandas as pd @@ -531,6 +532,30 @@ class AssetFinder(object): self._lookup_generic_scalar(obj, as_of_date, matches, missing) return matches, missing + def map_identifier_list_to_sids(self, index, as_of_date): + """ + This method is for use in sanitizing a user's DataFrame inputs. + Takes the given indices, inserts them in to the AssetFinder as + identifiers, rebuilds the Assets, and returns a new object that + contains the sids of the given identifiers. + + :param index: The index to be mapped + :return: The new index of sids + """ + # Populate the caches with the given indices + self.consume_identifiers(index) + self.populate_cache() + + # Find all of the newly built assets corresponding to the indices + found, missing = self.lookup_generic(index, as_of_date) + + # Handle missing assets + if len(missing) > 0: + warnings.warn("Missing assets for identifiers: " + missing) + + # Return a list of the sids of the found assets + return [asset.sid for asset in found] + def insert_metadata(self, identifier, **kwargs): """ Inserts the given metadata kwargs to the entry for the given diff --git a/zipline/examples/pairtrade.py b/zipline/examples/pairtrade.py index ca71d8f9..ffb37d53 100755 --- a/zipline/examples/pairtrade.py +++ b/zipline/examples/pairtrade.py @@ -23,7 +23,6 @@ import pytz from zipline.algorithm import TradingAlgorithm from zipline.transforms import batch_transform from zipline.utils.factory import load_from_yahoo -from zipline.sources.data_frame_source import DataFrameSource @batch_transform @@ -121,10 +120,9 @@ if __name__ == '__main__': end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={}, start=start, end=end) - source = DataFrameSource(data) pairtrade = Pairtrade() - results = pairtrade.run(source) + results = pairtrade.run(data) data['spreads'] = np.nan ax1 = plt.subplot(211) diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index d42fdf5a..77241812 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -18,7 +18,6 @@ Tools to generate data sources. """ import numpy as np import pandas as pd -import datetime from zipline.gens.utils import hash_args @@ -41,19 +40,13 @@ class DataFrameSource(DataSource): @with_environment() def __init__(self, data, env=None, **kwargs): assert isinstance(data.index, pd.tseries.index.DatetimeIndex) - + assert isinstance(data.columns, pd.Int64Index) + # TODO is ffilling correct/necessary? + # Forward fill prices self.data = data.fillna(method='ffill') # Unpack config dictionary with default values. self.start = kwargs.get('start', self.data.index[0]) self.end = kwargs.get('end', self.data.index[-1]) - - # Remap sids based on the trading environment - env.update_asset_finder( - identifiers=kwargs.get('sids', self.data.columns) - ) - self.data.columns, _ = env.asset_finder.lookup_generic( - self.data.columns, datetime.datetime.now() - ) self.sids = self.data.columns # Hash_value for downstream sorting. @@ -118,20 +111,15 @@ class DataPanelSource(DataSource): @with_environment() def __init__(self, data, env=None, **kwargs): assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex) - + # Only accept integer SIDs as the items of the Panel + assert isinstance(data.items, pd.Int64Index) + # TODO is ffilling correct/necessary? + # forward fill with volumes of 0 self.data = data.fillna(value={'volume': 0}) self.data = self.data.fillna(method='ffill') # Unpack config dictionary with default values. self.start = kwargs.get('start', self.data.major_axis[0]) self.end = kwargs.get('end', self.data.major_axis[-1]) - - # Remap sids based on the trading environment - env.update_asset_finder( - identifiers=kwargs.get('sids', self.data.items) - ) - self.data.items, _ = env.asset_finder.lookup_generic( - self.data.items, datetime.datetime.now() - ) self.sids = self.data.items # Hash_value for downstream sorting.