BUG: Don't crash on dataframes with assets in index.

This commit is contained in:
Scott Sanderson
2016-04-28 12:34:30 -04:00
committed by dmichalowicz
parent 73c3bd6955
commit 85ae664d8c
2 changed files with 92 additions and 30 deletions
+85 -26
View File
@@ -100,7 +100,6 @@ from zipline.utils.events import (
TimeRuleFactory,
)
from zipline.utils.factory import create_simulation_parameters
from zipline.utils.functional import unzip
from zipline.utils.math_utils import (
tolerant_equals,
round_if_near_integer
@@ -586,7 +585,13 @@ class TradingAlgorithm(object):
env=self.trading_environment
)
copy_panel = data.copy()
copy_panel = data.rename(
# These were the old names for the close/open columns. We
# need to make a copy anyway, so swap these for backwards
# compat while we're here.
minor_axis={'close_price': 'close', 'open_price': 'open'},
copy=True,
)
copy_panel.items = self._write_and_map_id_index_to_sids(
copy_panel.items, copy_panel.major_axis[0],
)
@@ -632,26 +637,44 @@ class TradingAlgorithm(object):
def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
# Build new Assets for identifiers that can't be resolved as
# sids/Assets
identifiers_to_build = set()
next_sid = max(self.asset_finder.sids or (0,)) + 1
def is_unknown(asset_or_sid):
sid = op.index(asset_or_sid)
return self.asset_finder.retrieve_asset(
sid=sid,
default_none=True
) is None
new_assets = set()
new_sids = set()
new_symbols = set()
for identifier in identifiers:
asset = None
if isinstance(identifier, Asset):
asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
default_none=True)
elif isinstance(identifier, Integral):
asset = self.asset_finder.retrieve_asset(sid=identifier,
default_none=True)
if asset is None:
if isinstance(identifier, Asset) and is_unknown(identifier):
new_assets.add(identifier)
elif isinstance(identifier, Integral) and is_unknown(identifier):
new_sids.add(identifier)
elif isinstance(identifier, (bytes, unicode)):
new_symbols.add(identifier)
else:
try:
sid = op.index(identifier)
new_sids.add(op.index(identifier))
except TypeError:
sid = next_sid
next_sid += 1
identifiers_to_build.add((identifier, sid))
raise TypeError(
"Can't convert %s to an asset." % identifier
)
if identifiers_to_build:
new_assets = tuple(new_assets)
new_sids = tuple(new_sids)
new_symbols = tuple(new_symbols)
number_of_kinds_of_new_things = (
sum((bool(new_assets), bool(new_sids), bool(new_symbols)))
)
# Nothing to insert, bail early.
if not number_of_kinds_of_new_things:
return self.asset_finder.map_identifier_index_to_sids(
identifiers, as_of_date,
)
elif number_of_kinds_of_new_things == 1:
warnings.warn(
'writing unknown identifiers into the assets db of the trading'
' environment is deprecated; please write this information'
@@ -659,16 +682,52 @@ class TradingAlgorithm(object):
DeprecationWarning,
stacklevel=2,
)
symbols, sids = unzip(identifiers_to_build, 2)
self.trading_environment.write_data(
equities=make_simple_equity_info(
sids,
start_date=self.sim_params.period_start,
end_date=self.sim_params.period_end,
symbols=symbols,
),
else:
raise ValueError(
"Mixed types in DataFrame or Panel index.\n"
"Asset Count: %d, Sid Count: %d, Symbol Count: %d.\n"
"Choose one type and stick with it." % (
len(new_assets),
len(new_sids),
len(new_symbols),
)
)
def map_getattr(iterable, attr):
return [getattr(i, attr) for i in iterable]
if new_assets:
frame_to_write = pd.DataFrame(
data=dict(
symbol=map_getattr(new_assets, 'symbol'),
start_date=map_getattr(new_assets, 'start_date'),
end_date=map_getattr(new_assets, 'end_date'),
exchange=map_getattr(new_assets, 'exchange'),
),
index=map_getattr(new_assets, 'sid'),
)
elif new_sids:
frame_to_write = make_simple_equity_info(
new_sids,
start_date=self.sim_params.period_start,
end_date=self.sim_params.period_end,
symbols=map(str, new_sids),
)
elif new_symbols:
existing_sids = self.asset_finder.sids
first_sid = max(existing_sids) + 1 if existing_sids else 0
fake_sids = range(first_sid, first_sid + len(new_symbols))
frame_to_write = make_simple_equity_info(
sids=fake_sids,
start_date=self.sim_params.period_start,
end_date=self.sim_params.period_end,
symbols=new_symbols,
)
else:
raise AssertionError("This should never happen.")
self.trading_environment.write_data(equities=frame_to_write)
# We need to clear out any cache misses that were stored while trying
# to do lookups. The real fix for this problem is to not construct an
# AssetFinder until we `run()` when we actually have all the data we
+7 -4
View File
@@ -27,7 +27,6 @@ from zipline.api import order_target, record, symbol
def initialize(context):
context.sym = symbol('AAPL')
context.i = 0
@@ -104,15 +103,19 @@ if __name__ == '__main__':
from datetime import datetime
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.utils.factory import load_from_yahoo
from zipline.utils.factory import load_bars_from_yahoo
# Set the simulation start and end dates.
start = datetime(2011, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2013, 1, 1, 0, 0, 0, 0, pytz.utc)
# Load price data from yahoo.
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
data = load_bars_from_yahoo(
stocks=['AAPL'],
indexes={},
start=start,
end=end,
)
# Create and run the algorithm.
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data)