PERF: Cache the result of failed lookups.

Otherwise we'll keep trying to look them up.
This commit is contained in:
Scott Sanderson
2015-10-29 12:06:45 -04:00
parent c818463051
commit ff51510d88
4 changed files with 225 additions and 103 deletions
+1 -1
View File
@@ -3,7 +3,7 @@ verbosity=2
detailed-errors=1
with-ignore-docstrings=1
with-timer=1
timer-top-n=15
timer-filter=warning
[metadata]
description-file = README.rst
+195 -95
View File
@@ -20,15 +20,16 @@ import warnings
from logbook import Logger
import numpy as np
import pandas as pd
from pandas import isnull
from pandas.tseries.tools import normalize_date
from six import with_metaclass, string_types, viewkeys
from six.moves import map as imap
import sqlalchemy as sa
from toolz import compose
from zipline.errors import (
MultipleSymbolsFound,
RootSymbolNotFound,
SidNotFound,
SidsNotFound,
SymbolNotFound,
MapAssetIdentifierIndexError,
)
@@ -41,6 +42,7 @@ from zipline.assets.asset_writer import (
ASSET_DB_VERSION,
asset_db_table_names,
)
from zipline.utils.control_flow import invert
log = Logger('assets.py')
@@ -63,13 +65,14 @@ _asset_timestamp_fields = frozenset({
})
def _convert_asset_timestamp_fields(dict):
def _convert_asset_timestamp_fields(dict_):
"""
Takes in a dict of Asset init args and converts dates to pd.Timestamps
"""
for key in (_asset_timestamp_fields & viewkeys(dict)):
value = pd.Timestamp(dict[key], tz='UTC')
dict[key] = None if pd.isnull(value) else value
for key in (_asset_timestamp_fields & viewkeys(dict_)):
value = pd.Timestamp(dict_[key], tz='UTC')
dict_[key] = None if isnull(value) else value
return dict_
class AssetFinder(object):
@@ -96,105 +99,218 @@ class AssetFinder(object):
# routing.
#
# The caches are read through, i.e. accessing an asset through
# retrieve_asset, _retrieve_equity etc. will populate the cache on
# first retrieval.
# retrieve_asset will populate the cache on first retrieval.
self._asset_cache = {}
self._equity_cache = {}
self._future_cache = {}
self._asset_type_cache = {}
# Populated on first call to `lifetimes`.
self._asset_lifetimes = None
def asset_type_by_sid(self, sid):
def lookup_asset_types(self, sids):
"""
Retrieve the asset type of a given sid.
Retrieve asset types for a list of sids.
Parameters
----------
sids : list[int]
Returns
-------
types : dict[sid -> str or None]
Asset types for the provided sids.
"""
try:
return self._asset_type_cache[sid]
except KeyError:
pass
found, missing = {}, set()
for sid in sids:
try:
found[sid] = self._asset_type_cache[sid]
except KeyError:
missing.add(sid)
asset_type = sa.select((self.asset_router.c.asset_type,)).where(
self.asset_router.c.sid == int(sid),
).scalar()
if not missing:
return found
if asset_type is not None:
self._asset_type_cache[sid] = asset_type
return asset_type
router_cols = self.asset_router.c
query = sa.select((router_cols.sid, router_cols.asset_type)).where(
self.asset_router.c.sid.in_(map(int, missing))
)
for sid, type_ in query.execute().fetchall():
missing.remove(sid)
found[sid] = self._asset_type_cache[sid] = type_
for sid in missing:
found[sid] = self._asset_type_cache[sid] = None
return found
def lookup_single_asset_type(self, sid):
"""Retrieve the asset type for a single asset."""
return self.lookup_asset_types([sid])[sid]
def group_by_type(self, sids):
"""
Group a list of sids by asset type.
Parameters
----------
sids : list[int]
Returns
-------
types : defaultdict[str or None -> list[int]]
A dict mapping unique asset types to lists of sids drawn from sids.
If we fail to look up an asset, we assign it a key of None.
"""
return invert(self.lookup_asset_types(sids))
def retrieve_asset(self, sid, default_none=False):
"""
Retrieve the Asset object of a given sid.
Retrieve the Asset for a given sid.
"""
if isinstance(sid, Asset):
return sid
try:
asset = self._asset_cache[sid]
except KeyError:
asset_type = self.asset_type_by_sid(sid)
if asset_type == 'equity':
asset = self._retrieve_equity(sid)
elif asset_type == 'future':
asset = self._retrieve_futures_contract(sid)
else:
asset = None
# Cache the asset if it has been retrieved
if asset is not None:
self._asset_cache[sid] = asset
if (asset is not None) or default_none:
return asset
raise SidNotFound(sid=sid)
return self.retrieve_all((sid,), default_none=default_none)[0]
def retrieve_all(self, sids, default_none=False):
return [self.retrieve_asset(sid, default_none) for sid in sids]
"""
Retrieve all assets in `sids`.
def _retrieve_equity(self, sid):
Parameters
----------
sids : interable of int
Assets to retrieve.
default_none : bool
If True, return None for failed lookups.
If False, raise `SidsNotFound`.
Returns
-------
assets : list[int or None]
A list of the same length as `sids` containing Assets (or Nones)
corresponding to the requested sids.
Raises
------
SidsNotFound
When a requested sid is not found and default_none=False.
"""
hits, missing, failures = {}, set(), []
for sid in sids:
try:
asset = self._asset_cache[sid]
if not default_none and asset is None:
# Bail early if we've already cached that we don't know
# about an asset.
raise SidsNotFound(sids=[sid])
hits[sid] = asset
except KeyError:
missing.add(sid)
# All requests were cache hits. Return requested sids in order.
if not missing:
return [hits[sid] for sid in sids]
update_hits = hits.update
# Look up cache misses by type.
type_to_assets = self.group_by_type(missing)
# Handle failures
failures = {failure: None for failure in type_to_assets.pop(None, ())}
update_hits(failures)
self._asset_cache.update(failures)
if failures and not default_none:
raise SidsNotFound(sids=list(failures))
# We don't update the asset cache here because it should already be
# updated by `self._retrieve_equities`.
update_hits(self._retrieve_equities(type_to_assets.pop('equity', ())))
update_hits(
self._retrieve_futures_contracts(type_to_assets.pop('future', ()))
)
# We shouldn't know about any other asset types.
if type_to_assets:
raise AssertionError(
"Found asset types: %s" % list(type_to_assets.keys())
)
return [hits[sid] for sid in sids]
def _retrieve_equities(self, sids):
"""
Retrieve the Equity object of a given sid.
"""
return self._retrieve_asset(
sid, self._equity_cache, self.equities, Equity,
)
return self._retrieve_assets(sids, self.equities, Equity)
def _retrieve_futures_contract(self, sid):
def _retrieve_equity(self, sid):
return self._retrieve_equities((sid,))[sid]
def _retrieve_futures_contracts(self, sids):
"""
Retrieve the Future object of a given sid.
"""
return self._retrieve_asset(
sid, self._future_cache, self.futures_contracts, Future,
)
return self._retrieve_assets(sids, self.futures_contracts, Future)
def _retrieve_futures_contract(self, sid):
return self._retrieve_futures_contracts((sid,))[sid]
@staticmethod
def _select_asset_by_sid(asset_tbl, sid):
return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid))
def _select_assets_by_sid(asset_tbl, sids):
return sa.select([asset_tbl]).where(
asset_tbl.c.sid.in_(map(int, sids))
)
@staticmethod
def _select_asset_by_symbol(asset_tbl, symbol):
return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
def _retrieve_asset(self, sid, cache, asset_tbl, asset_type):
try:
return cache[sid]
except KeyError:
pass
def _retrieve_assets(self, sids, asset_tbl, asset_type):
"""
Internal function for loading assets from a table.
data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone()
# Convert 'data' from a RowProxy object to a dict, to allow assignment
data = dict(data.items())
if data:
_convert_asset_timestamp_fields(data)
This function does not do any caching. It is assumed that this will be
called at most once with any given sid.
asset = asset_type(**data)
else:
asset = None
Parameters
---------
sids : iterable of int
Asset ids to look up.
asset_tbl : sqlalchemy.Table
Table from which to query assets.
asset_type : type
Type of asset to be constructed.
cache[sid] = asset
return asset
Returns
-------
assets : dict[int -> Asset]
Dict mapping requested sids to the retrieved assets.
"""
# Fastpath for empty request.
if not sids:
return {}
cache = self._asset_cache
hits = {}
# Load misses from the db.
query = self._select_assets_by_sid(asset_tbl, sids)
for row in imap(dict, query.execute().fetchall()):
asset = asset_type(**_convert_asset_timestamp_fields(row))
sid = asset.sid
hits[sid] = cache[sid] = asset
# If we get here, it means something in our code thought that a
# particular sid was an equity/future and called this function with a
# concrete type, but we couldn't actually resolve the asset. This is
# an error in our code, not a user-input error.
misses = tuple(set(sids) - viewkeys(hits))
if misses:
raise AssertionError(
"Couldn't resolve sids {sids} as instances of {type}.".format(
sids=misses,
type=asset_type,
)
)
return hits
def _get_fuzzy_candidates(self, fuzzy_symbol):
candidates = sa.select(
@@ -272,10 +388,9 @@ class AssetFinder(object):
return self._retrieve_equity(candidates[0]['sid'])
def _get_equities_from_candidates(self, candidates):
return list(map(
compose(self._retrieve_equity, itemgetter('sid')),
candidates,
))
sids = map(itemgetter('sid'), candidates)
results = self.retrieve_equities(sids)
return [results[sid] for sid in sids]
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
"""
@@ -286,7 +401,6 @@ class AssetFinder(object):
If no Equity was active at as_of_date raises SymbolNotFound.
"""
company_symbol, share_class_symbol, fuzzy_symbol = \
split_delimited_symbol(symbol)
if as_of_date:
@@ -376,22 +490,7 @@ class AssetFinder(object):
# If no data found, raise an exception
if not data:
raise SymbolNotFound(symbol=symbol)
# If we find a contract, check whether it's been cached
try:
return self._future_cache[data['sid']]
except KeyError:
pass
# Build the Future object from its parameters
data = dict(data.items())
_convert_asset_timestamp_fields(data)
future = Future(**data)
# Cache the Future object.
self._future_cache[data['sid']] = future
return future
return self.retrieve_asset(data['sid'])
def lookup_future_chain(self, root_symbol, as_of_date):
""" Return the futures chain for a given root symbol.
@@ -487,7 +586,8 @@ class AssetFinder(object):
if count == 0:
raise RootSymbolNotFound(root_symbol=root_symbol)
return list(map(self._retrieve_futures_contract, sids))
contracts = self._retrieve_futures_contracts(sids)
return [contracts[sid] for sid in sids]
@property
def sids(self):
@@ -513,7 +613,7 @@ class AssetFinder(object):
elif isinstance(asset_convertible, Integral):
try:
result = self.retrieve_asset(int(asset_convertible))
except SidNotFound:
except SidsNotFound:
missing.append(asset_convertible)
return None
matches.append(result)
@@ -563,7 +663,7 @@ class AssetFinder(object):
return matches[0], missing
except IndexError:
if hasattr(asset_convertible_or_iterable, '__int__'):
raise SidNotFound(sid=asset_convertible_or_iterable)
raise SidsNotFound(sids=[asset_convertible_or_iterable])
else:
raise SymbolNotFound(symbol=asset_convertible_or_iterable)
+12 -7
View File
@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from zipline.utils.memoize import lazyval
class ZiplineError(Exception):
msg = None
def __init__(self, *args, **kwargs):
self.args = args
def __init__(self, **kwargs):
self.kwargs = kwargs
self.message = str(self)
@@ -231,13 +232,17 @@ Root symbol '{root_symbol}' was not found.
""".strip()
class SidNotFound(ZiplineError):
class SidsNotFound(ZiplineError):
"""
Raised when a retrieve_asset() call contains a non-existent sid.
Raised when a retrieve_asset() or retrieve_all() call contains a
non-existent sid.
"""
msg = """
Asset with sid '{sid}' was not found.
""".strip()
@lazyval
def msg(self):
sids = self.kwargs['sids']
if len(sids) == 1:
return "No asset found for sid: {sids[0]}."
return "No assets found for sids: {sids}."
class ConsumeAssetMetaDataError(ZiplineError):
+17
View File
@@ -1,6 +1,7 @@
"""
Control flow utilities.
"""
from six import iteritems
from warnings import (
catch_warnings,
filterwarnings,
@@ -54,3 +55,19 @@ def ignore_nanwarnings():
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
)
)
def invert(d):
"""
Invert a dictionary into a dictionary of lists.
>>> invert({'a': 1, 'b': 2, 'c': 1})
{1: ['a', 'c'], 2: ['b']}
"""
out = {}
for k, v in iteritems(d):
try:
out[v].append(k)
except KeyError:
out[v] = [k]
return out