API: DataFrame/Panel sources expect integer sids, not identifiers

This commit modifies the DataFrameSource and DataPanelSource to accept only Int64Indexes on the incoming data and moves the burden of mapping user identifiers to TradingAlgorithm.run().
This commit is contained in:
jfkirk
2015-06-26 16:41:38 -04:00
parent a4ce9712b8
commit 258b5ea2ca
8 changed files with 83 additions and 33 deletions
+22
View File
@@ -718,3 +718,25 @@ class AssetFinderTestCase(TestCase):
# No contracts exist after 12/14/2015, so we should get none
self.assertIsNone(finder.lookup_future_by_expiration('AD', dt, jan_16))
def test_map_identifier_list_to_sids(self):
# Build an empty finder and some Assets
dt = pd.Timestamp('2014-01-01', tz='UTC')
finder = AssetFinder()
asset1 = Equity(1, symbol="AAPL")
asset2 = Equity(2, symbol="GOOG")
asset200 = Future(200, symbol="CLK15")
asset201 = Future(201, symbol="CLM15")
# Check for correct mapping and types
pre_map = [asset1, asset2, asset200, asset201]
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
self.assertListEqual([1, 2, 200, 201], post_map)
for sid in post_map:
self.assertIsInstance(sid, int)
# Change order and check mapping again
pre_map = [asset201, asset2, asset200, asset1]
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
self.assertListEqual([201, 2, 200, 1], post_map)
+1 -1
View File
@@ -437,7 +437,7 @@ def handle_data(context, data):
_, df = factory.create_test_df_source(sim_params)
df = df.astype(np.float64)
source = DataFrameSource(df, sids=[0])
source = DataFrameSource(df)
test_algo = TradingAlgorithm(
script=algo_text,
+11 -5
View File
@@ -27,6 +27,7 @@ from zipline.sources import (DataFrameSource,
RandomWalkSource)
from zipline.utils import tradingcalendar as calendar_nyse
from zipline.finance.trading import with_environment
from zipline.assets import AssetFinder
class TestDataFrameSource(TestCase):
@@ -43,7 +44,7 @@ class TestDataFrameSource(TestCase):
def test_df_sid_filtering(self):
_, df = factory.create_test_df_source()
source = DataFrameSource(df, sids=[0])
source = DataFrameSource(df)
assert 1 not in [event.sid for event in source], \
"DataFrameSource should only stream selected sid 0, not sid 1."
@@ -63,8 +64,8 @@ class TestDataFrameSource(TestCase):
self.assertTrue(isinstance(event['volume'], int))
self.assertTrue(isinstance(event['arbitrary'], float))
@with_environment()
def test_yahoo_bars_to_panel_source(self, env=None):
def test_yahoo_bars_to_panel_source(self):
finder = AssetFinder()
stocks = ['AAPL', 'GE']
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
@@ -75,10 +76,15 @@ class TestDataFrameSource(TestCase):
check_fields = ['sid', 'open', 'high', 'low', 'close',
'volume', 'price']
source = DataPanelSource(data)
copy_panel = data.copy()
copy_panel.items = finder.map_identifier_list_to_sids(
data.items, data.major_axis[0]
)
source = DataPanelSource(copy_panel)
sids = [
asset.sid for asset in
[env.asset_finder.lookup_symbol(symbol, as_of_date=end)
[finder.lookup_symbol(symbol, as_of_date=end)
for symbol in stocks]
]
stocks_iter = cycle(sids)
+1 -1
View File
@@ -98,7 +98,7 @@ class TestTALIB(TestCase):
zipline_transforms = [ta.MA(timeperiod=10),
ta.MA(timeperiod=25)]
talib_fn = talib.abstract.MA
algo = TALIBAlgorithm(talib=zipline_transforms)
algo = TALIBAlgorithm(talib=zipline_transforms, identifiers=[0])
algo.run(self.source)
# Test if computed values match those computed by pandas rolling mean.
sid = 0
+15 -4
View File
@@ -39,7 +39,6 @@ from zipline.errors import (
UnsupportedCommissionModel,
UnsupportedOrderParameters,
UnsupportedSlippageModel,
SidNotFound,
)
from zipline.finance.trading import TradingEnvironment
@@ -461,10 +460,22 @@ class TradingAlgorithm(object):
__init__().""", UserWarning)
overwrite_sim_params = False
elif isinstance(source, pd.DataFrame):
# if DataFrame provided, wrap in DataFrameSource
source = DataFrameSource(source)
# if DataFrame provided, map columns to sids and wrap
# in DataFrameSource
copy_frame = source.copy()
copy_frame.columns = self.asset_finder.map_identifier_list_to_sids(
source.columns, source.index[0]
)
source = DataFrameSource(copy_frame)
elif isinstance(source, pd.Panel):
source = DataPanelSource(source)
# If Panel provided, map items to sids and wrap
# in DataPanelSource
copy_panel = source.copy()
copy_panel.items = self.asset_finder.map_identifier_list_to_sids(
source.items, source.major_axis[0]
)
source = DataPanelSource(copy_panel)
if isinstance(source, list):
self.set_sources(source)
+25
View File
@@ -18,6 +18,7 @@ from itertools import chain
from numbers import Integral
import numpy as np
import operator
import warnings
from logbook import Logger
import pandas as pd
@@ -531,6 +532,30 @@ class AssetFinder(object):
self._lookup_generic_scalar(obj, as_of_date, matches, missing)
return matches, missing
def map_identifier_list_to_sids(self, index, as_of_date):
"""
This method is for use in sanitizing a user's DataFrame inputs.
Takes the given indices, inserts them in to the AssetFinder as
identifiers, rebuilds the Assets, and returns a new object that
contains the sids of the given identifiers.
:param index: The index to be mapped
:return: The new index of sids
"""
# Populate the caches with the given indices
self.consume_identifiers(index)
self.populate_cache()
# Find all of the newly built assets corresponding to the indices
found, missing = self.lookup_generic(index, as_of_date)
# Handle missing assets
if len(missing) > 0:
warnings.warn("Missing assets for identifiers: " + missing)
# Return a list of the sids of the found assets
return [asset.sid for asset in found]
def insert_metadata(self, identifier, **kwargs):
"""
Inserts the given metadata kwargs to the entry for the given
+1 -3
View File
@@ -23,7 +23,6 @@ import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import batch_transform
from zipline.utils.factory import load_from_yahoo
from zipline.sources.data_frame_source import DataFrameSource
@batch_transform
@@ -121,10 +120,9 @@ if __name__ == '__main__':
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
start=start, end=end)
source = DataFrameSource(data)
pairtrade = Pairtrade()
results = pairtrade.run(source)
results = pairtrade.run(data)
data['spreads'] = np.nan
ax1 = plt.subplot(211)
+7 -19
View File
@@ -18,7 +18,6 @@ Tools to generate data sources.
"""
import numpy as np
import pandas as pd
import datetime
from zipline.gens.utils import hash_args
@@ -41,19 +40,13 @@ class DataFrameSource(DataSource):
@with_environment()
def __init__(self, data, env=None, **kwargs):
assert isinstance(data.index, pd.tseries.index.DatetimeIndex)
assert isinstance(data.columns, pd.Int64Index)
# TODO is ffilling correct/necessary?
# Forward fill prices
self.data = data.fillna(method='ffill')
# Unpack config dictionary with default values.
self.start = kwargs.get('start', self.data.index[0])
self.end = kwargs.get('end', self.data.index[-1])
# Remap sids based on the trading environment
env.update_asset_finder(
identifiers=kwargs.get('sids', self.data.columns)
)
self.data.columns, _ = env.asset_finder.lookup_generic(
self.data.columns, datetime.datetime.now()
)
self.sids = self.data.columns
# Hash_value for downstream sorting.
@@ -118,20 +111,15 @@ class DataPanelSource(DataSource):
@with_environment()
def __init__(self, data, env=None, **kwargs):
assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex)
# Only accept integer SIDs as the items of the Panel
assert isinstance(data.items, pd.Int64Index)
# TODO is ffilling correct/necessary?
# forward fill with volumes of 0
self.data = data.fillna(value={'volume': 0})
self.data = self.data.fillna(method='ffill')
# Unpack config dictionary with default values.
self.start = kwargs.get('start', self.data.major_axis[0])
self.end = kwargs.get('end', self.data.major_axis[-1])
# Remap sids based on the trading environment
env.update_asset_finder(
identifiers=kwargs.get('sids', self.data.items)
)
self.data.items, _ = env.asset_finder.lookup_generic(
self.data.items, datetime.datetime.now()
)
self.sids = self.data.items
# Hash_value for downstream sorting.