mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-06 05:14:38 +08:00
BUG: Don't crash on dataframes with assets in index.
This commit is contained in:
committed by
dmichalowicz
parent
73c3bd6955
commit
85ae664d8c
+85
-26
@@ -100,7 +100,6 @@ from zipline.utils.events import (
|
||||
TimeRuleFactory,
|
||||
)
|
||||
from zipline.utils.factory import create_simulation_parameters
|
||||
from zipline.utils.functional import unzip
|
||||
from zipline.utils.math_utils import (
|
||||
tolerant_equals,
|
||||
round_if_near_integer
|
||||
@@ -586,7 +585,13 @@ class TradingAlgorithm(object):
|
||||
env=self.trading_environment
|
||||
)
|
||||
|
||||
copy_panel = data.copy()
|
||||
copy_panel = data.rename(
|
||||
# These were the old names for the close/open columns. We
|
||||
# need to make a copy anyway, so swap these for backwards
|
||||
# compat while we're here.
|
||||
minor_axis={'close_price': 'close', 'open_price': 'open'},
|
||||
copy=True,
|
||||
)
|
||||
copy_panel.items = self._write_and_map_id_index_to_sids(
|
||||
copy_panel.items, copy_panel.major_axis[0],
|
||||
)
|
||||
@@ -632,26 +637,44 @@ class TradingAlgorithm(object):
|
||||
def _write_and_map_id_index_to_sids(self, identifiers, as_of_date):
|
||||
# Build new Assets for identifiers that can't be resolved as
|
||||
# sids/Assets
|
||||
identifiers_to_build = set()
|
||||
next_sid = max(self.asset_finder.sids or (0,)) + 1
|
||||
def is_unknown(asset_or_sid):
|
||||
sid = op.index(asset_or_sid)
|
||||
return self.asset_finder.retrieve_asset(
|
||||
sid=sid,
|
||||
default_none=True
|
||||
) is None
|
||||
|
||||
new_assets = set()
|
||||
new_sids = set()
|
||||
new_symbols = set()
|
||||
for identifier in identifiers:
|
||||
asset = None
|
||||
|
||||
if isinstance(identifier, Asset):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier.sid,
|
||||
default_none=True)
|
||||
elif isinstance(identifier, Integral):
|
||||
asset = self.asset_finder.retrieve_asset(sid=identifier,
|
||||
default_none=True)
|
||||
if asset is None:
|
||||
if isinstance(identifier, Asset) and is_unknown(identifier):
|
||||
new_assets.add(identifier)
|
||||
elif isinstance(identifier, Integral) and is_unknown(identifier):
|
||||
new_sids.add(identifier)
|
||||
elif isinstance(identifier, (bytes, unicode)):
|
||||
new_symbols.add(identifier)
|
||||
else:
|
||||
try:
|
||||
sid = op.index(identifier)
|
||||
new_sids.add(op.index(identifier))
|
||||
except TypeError:
|
||||
sid = next_sid
|
||||
next_sid += 1
|
||||
identifiers_to_build.add((identifier, sid))
|
||||
raise TypeError(
|
||||
"Can't convert %s to an asset." % identifier
|
||||
)
|
||||
|
||||
if identifiers_to_build:
|
||||
new_assets = tuple(new_assets)
|
||||
new_sids = tuple(new_sids)
|
||||
new_symbols = tuple(new_symbols)
|
||||
number_of_kinds_of_new_things = (
|
||||
sum((bool(new_assets), bool(new_sids), bool(new_symbols)))
|
||||
)
|
||||
|
||||
# Nothing to insert, bail early.
|
||||
if not number_of_kinds_of_new_things:
|
||||
return self.asset_finder.map_identifier_index_to_sids(
|
||||
identifiers, as_of_date,
|
||||
)
|
||||
elif number_of_kinds_of_new_things == 1:
|
||||
warnings.warn(
|
||||
'writing unknown identifiers into the assets db of the trading'
|
||||
' environment is deprecated; please write this information'
|
||||
@@ -659,16 +682,52 @@ class TradingAlgorithm(object):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
symbols, sids = unzip(identifiers_to_build, 2)
|
||||
self.trading_environment.write_data(
|
||||
equities=make_simple_equity_info(
|
||||
sids,
|
||||
start_date=self.sim_params.period_start,
|
||||
end_date=self.sim_params.period_end,
|
||||
symbols=symbols,
|
||||
),
|
||||
else:
|
||||
raise ValueError(
|
||||
"Mixed types in DataFrame or Panel index.\n"
|
||||
"Asset Count: %d, Sid Count: %d, Symbol Count: %d.\n"
|
||||
"Choose one type and stick with it." % (
|
||||
len(new_assets),
|
||||
len(new_sids),
|
||||
len(new_symbols),
|
||||
)
|
||||
)
|
||||
|
||||
def map_getattr(iterable, attr):
|
||||
return [getattr(i, attr) for i in iterable]
|
||||
|
||||
if new_assets:
|
||||
frame_to_write = pd.DataFrame(
|
||||
data=dict(
|
||||
symbol=map_getattr(new_assets, 'symbol'),
|
||||
start_date=map_getattr(new_assets, 'start_date'),
|
||||
end_date=map_getattr(new_assets, 'end_date'),
|
||||
exchange=map_getattr(new_assets, 'exchange'),
|
||||
),
|
||||
index=map_getattr(new_assets, 'sid'),
|
||||
)
|
||||
elif new_sids:
|
||||
frame_to_write = make_simple_equity_info(
|
||||
new_sids,
|
||||
start_date=self.sim_params.period_start,
|
||||
end_date=self.sim_params.period_end,
|
||||
symbols=map(str, new_sids),
|
||||
)
|
||||
elif new_symbols:
|
||||
existing_sids = self.asset_finder.sids
|
||||
first_sid = max(existing_sids) + 1 if existing_sids else 0
|
||||
fake_sids = range(first_sid, first_sid + len(new_symbols))
|
||||
frame_to_write = make_simple_equity_info(
|
||||
sids=fake_sids,
|
||||
start_date=self.sim_params.period_start,
|
||||
end_date=self.sim_params.period_end,
|
||||
symbols=new_symbols,
|
||||
)
|
||||
else:
|
||||
raise AssertionError("This should never happen.")
|
||||
|
||||
self.trading_environment.write_data(equities=frame_to_write)
|
||||
|
||||
# We need to clear out any cache misses that were stored while trying
|
||||
# to do lookups. The real fix for this problem is to not construct an
|
||||
# AssetFinder until we `run()` when we actually have all the data we
|
||||
|
||||
@@ -27,7 +27,6 @@ from zipline.api import order_target, record, symbol
|
||||
|
||||
def initialize(context):
|
||||
context.sym = symbol('AAPL')
|
||||
|
||||
context.i = 0
|
||||
|
||||
|
||||
@@ -104,15 +103,19 @@ if __name__ == '__main__':
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
from zipline.utils.factory import load_bars_from_yahoo
|
||||
|
||||
# Set the simulation start and end dates.
|
||||
start = datetime(2011, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(2013, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
|
||||
# Load price data from yahoo.
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
|
||||
end=end)
|
||||
data = load_bars_from_yahoo(
|
||||
stocks=['AAPL'],
|
||||
indexes={},
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
# Create and run the algorithm.
|
||||
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data)
|
||||
|
||||
Reference in New Issue
Block a user