mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 21:24:16 +08:00
API: DataFrame/Panel sources expect integer sids, not identifiers
This commit modifies the DataFrameSource and DataPanelSource to accept only Int64Indexes on the incoming data and moves the burden of mapping user identifiers to TradingAlgorithm.run().
This commit is contained in:
@@ -718,3 +718,25 @@ class AssetFinderTestCase(TestCase):
|
||||
|
||||
# No contracts exist after 12/14/2015, so we should get none
|
||||
self.assertIsNone(finder.lookup_future_by_expiration('AD', dt, jan_16))
|
||||
|
||||
def test_map_identifier_list_to_sids(self):
|
||||
|
||||
# Build an empty finder and some Assets
|
||||
dt = pd.Timestamp('2014-01-01', tz='UTC')
|
||||
finder = AssetFinder()
|
||||
asset1 = Equity(1, symbol="AAPL")
|
||||
asset2 = Equity(2, symbol="GOOG")
|
||||
asset200 = Future(200, symbol="CLK15")
|
||||
asset201 = Future(201, symbol="CLM15")
|
||||
|
||||
# Check for correct mapping and types
|
||||
pre_map = [asset1, asset2, asset200, asset201]
|
||||
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
|
||||
self.assertListEqual([1, 2, 200, 201], post_map)
|
||||
for sid in post_map:
|
||||
self.assertIsInstance(sid, int)
|
||||
|
||||
# Change order and check mapping again
|
||||
pre_map = [asset201, asset2, asset200, asset1]
|
||||
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
|
||||
self.assertListEqual([201, 2, 200, 1], post_map)
|
||||
|
||||
@@ -437,7 +437,7 @@ def handle_data(context, data):
|
||||
|
||||
_, df = factory.create_test_df_source(sim_params)
|
||||
df = df.astype(np.float64)
|
||||
source = DataFrameSource(df, sids=[0])
|
||||
source = DataFrameSource(df)
|
||||
|
||||
test_algo = TradingAlgorithm(
|
||||
script=algo_text,
|
||||
|
||||
+11
-5
@@ -27,6 +27,7 @@ from zipline.sources import (DataFrameSource,
|
||||
RandomWalkSource)
|
||||
from zipline.utils import tradingcalendar as calendar_nyse
|
||||
from zipline.finance.trading import with_environment
|
||||
from zipline.assets import AssetFinder
|
||||
|
||||
|
||||
class TestDataFrameSource(TestCase):
|
||||
@@ -43,7 +44,7 @@ class TestDataFrameSource(TestCase):
|
||||
|
||||
def test_df_sid_filtering(self):
|
||||
_, df = factory.create_test_df_source()
|
||||
source = DataFrameSource(df, sids=[0])
|
||||
source = DataFrameSource(df)
|
||||
assert 1 not in [event.sid for event in source], \
|
||||
"DataFrameSource should only stream selected sid 0, not sid 1."
|
||||
|
||||
@@ -63,8 +64,8 @@ class TestDataFrameSource(TestCase):
|
||||
self.assertTrue(isinstance(event['volume'], int))
|
||||
self.assertTrue(isinstance(event['arbitrary'], float))
|
||||
|
||||
@with_environment()
|
||||
def test_yahoo_bars_to_panel_source(self, env=None):
|
||||
def test_yahoo_bars_to_panel_source(self):
|
||||
finder = AssetFinder()
|
||||
stocks = ['AAPL', 'GE']
|
||||
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
@@ -75,10 +76,15 @@ class TestDataFrameSource(TestCase):
|
||||
|
||||
check_fields = ['sid', 'open', 'high', 'low', 'close',
|
||||
'volume', 'price']
|
||||
source = DataPanelSource(data)
|
||||
|
||||
copy_panel = data.copy()
|
||||
copy_panel.items = finder.map_identifier_list_to_sids(
|
||||
data.items, data.major_axis[0]
|
||||
)
|
||||
source = DataPanelSource(copy_panel)
|
||||
sids = [
|
||||
asset.sid for asset in
|
||||
[env.asset_finder.lookup_symbol(symbol, as_of_date=end)
|
||||
[finder.lookup_symbol(symbol, as_of_date=end)
|
||||
for symbol in stocks]
|
||||
]
|
||||
stocks_iter = cycle(sids)
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestTALIB(TestCase):
|
||||
zipline_transforms = [ta.MA(timeperiod=10),
|
||||
ta.MA(timeperiod=25)]
|
||||
talib_fn = talib.abstract.MA
|
||||
algo = TALIBAlgorithm(talib=zipline_transforms)
|
||||
algo = TALIBAlgorithm(talib=zipline_transforms, identifiers=[0])
|
||||
algo.run(self.source)
|
||||
# Test if computed values match those computed by pandas rolling mean.
|
||||
sid = 0
|
||||
|
||||
+15
-4
@@ -39,7 +39,6 @@ from zipline.errors import (
|
||||
UnsupportedCommissionModel,
|
||||
UnsupportedOrderParameters,
|
||||
UnsupportedSlippageModel,
|
||||
SidNotFound,
|
||||
)
|
||||
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
@@ -461,10 +460,22 @@ class TradingAlgorithm(object):
|
||||
__init__().""", UserWarning)
|
||||
overwrite_sim_params = False
|
||||
elif isinstance(source, pd.DataFrame):
|
||||
# if DataFrame provided, wrap in DataFrameSource
|
||||
source = DataFrameSource(source)
|
||||
# if DataFrame provided, map columns to sids and wrap
|
||||
# in DataFrameSource
|
||||
copy_frame = source.copy()
|
||||
copy_frame.columns = self.asset_finder.map_identifier_list_to_sids(
|
||||
source.columns, source.index[0]
|
||||
)
|
||||
source = DataFrameSource(copy_frame)
|
||||
|
||||
elif isinstance(source, pd.Panel):
|
||||
source = DataPanelSource(source)
|
||||
# If Panel provided, map items to sids and wrap
|
||||
# in DataPanelSource
|
||||
copy_panel = source.copy()
|
||||
copy_panel.items = self.asset_finder.map_identifier_list_to_sids(
|
||||
source.items, source.major_axis[0]
|
||||
)
|
||||
source = DataPanelSource(copy_panel)
|
||||
|
||||
if isinstance(source, list):
|
||||
self.set_sources(source)
|
||||
|
||||
@@ -18,6 +18,7 @@ from itertools import chain
|
||||
from numbers import Integral
|
||||
import numpy as np
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
from logbook import Logger
|
||||
import pandas as pd
|
||||
@@ -531,6 +532,30 @@ class AssetFinder(object):
|
||||
self._lookup_generic_scalar(obj, as_of_date, matches, missing)
|
||||
return matches, missing
|
||||
|
||||
def map_identifier_list_to_sids(self, index, as_of_date):
|
||||
"""
|
||||
This method is for use in sanitizing a user's DataFrame inputs.
|
||||
Takes the given indices, inserts them in to the AssetFinder as
|
||||
identifiers, rebuilds the Assets, and returns a new object that
|
||||
contains the sids of the given identifiers.
|
||||
|
||||
:param index: The index to be mapped
|
||||
:return: The new index of sids
|
||||
"""
|
||||
# Populate the caches with the given indices
|
||||
self.consume_identifiers(index)
|
||||
self.populate_cache()
|
||||
|
||||
# Find all of the newly built assets corresponding to the indices
|
||||
found, missing = self.lookup_generic(index, as_of_date)
|
||||
|
||||
# Handle missing assets
|
||||
if len(missing) > 0:
|
||||
warnings.warn("Missing assets for identifiers: " + missing)
|
||||
|
||||
# Return a list of the sids of the found assets
|
||||
return [asset.sid for asset in found]
|
||||
|
||||
def insert_metadata(self, identifier, **kwargs):
|
||||
"""
|
||||
Inserts the given metadata kwargs to the entry for the given
|
||||
|
||||
@@ -23,7 +23,6 @@ import pytz
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import batch_transform
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
from zipline.sources.data_frame_source import DataFrameSource
|
||||
|
||||
|
||||
@batch_transform
|
||||
@@ -121,10 +120,9 @@ if __name__ == '__main__':
|
||||
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
|
||||
start=start, end=end)
|
||||
source = DataFrameSource(data)
|
||||
|
||||
pairtrade = Pairtrade()
|
||||
results = pairtrade.run(source)
|
||||
results = pairtrade.run(data)
|
||||
data['spreads'] = np.nan
|
||||
|
||||
ax1 = plt.subplot(211)
|
||||
|
||||
@@ -18,7 +18,6 @@ Tools to generate data sources.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import datetime
|
||||
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
@@ -41,19 +40,13 @@ class DataFrameSource(DataSource):
|
||||
@with_environment()
|
||||
def __init__(self, data, env=None, **kwargs):
|
||||
assert isinstance(data.index, pd.tseries.index.DatetimeIndex)
|
||||
|
||||
assert isinstance(data.columns, pd.Int64Index)
|
||||
# TODO is ffilling correct/necessary?
|
||||
# Forward fill prices
|
||||
self.data = data.fillna(method='ffill')
|
||||
# Unpack config dictionary with default values.
|
||||
self.start = kwargs.get('start', self.data.index[0])
|
||||
self.end = kwargs.get('end', self.data.index[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
env.update_asset_finder(
|
||||
identifiers=kwargs.get('sids', self.data.columns)
|
||||
)
|
||||
self.data.columns, _ = env.asset_finder.lookup_generic(
|
||||
self.data.columns, datetime.datetime.now()
|
||||
)
|
||||
self.sids = self.data.columns
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
@@ -118,20 +111,15 @@ class DataPanelSource(DataSource):
|
||||
@with_environment()
|
||||
def __init__(self, data, env=None, **kwargs):
|
||||
assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex)
|
||||
|
||||
# Only accept integer SIDs as the items of the Panel
|
||||
assert isinstance(data.items, pd.Int64Index)
|
||||
# TODO is ffilling correct/necessary?
|
||||
# forward fill with volumes of 0
|
||||
self.data = data.fillna(value={'volume': 0})
|
||||
self.data = self.data.fillna(method='ffill')
|
||||
# Unpack config dictionary with default values.
|
||||
self.start = kwargs.get('start', self.data.major_axis[0])
|
||||
self.end = kwargs.get('end', self.data.major_axis[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
env.update_asset_finder(
|
||||
identifiers=kwargs.get('sids', self.data.items)
|
||||
)
|
||||
self.data.items, _ = env.asset_finder.lookup_generic(
|
||||
self.data.items, datetime.datetime.now()
|
||||
)
|
||||
self.sids = self.data.items
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
|
||||
Reference in New Issue
Block a user