mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 19:47:43 +08:00
PERF: Cache the result of failed lookups.
Otherwise we'll keep trying to look them up.
This commit is contained in:
@@ -3,7 +3,7 @@ verbosity=2
|
||||
detailed-errors=1
|
||||
with-ignore-docstrings=1
|
||||
with-timer=1
|
||||
timer-top-n=15
|
||||
timer-filter=warning
|
||||
|
||||
[metadata]
|
||||
description-file = README.rst
|
||||
|
||||
+195
-95
@@ -20,15 +20,16 @@ import warnings
|
||||
from logbook import Logger
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas import isnull
|
||||
from pandas.tseries.tools import normalize_date
|
||||
from six import with_metaclass, string_types, viewkeys
|
||||
from six.moves import map as imap
|
||||
import sqlalchemy as sa
|
||||
from toolz import compose
|
||||
|
||||
from zipline.errors import (
|
||||
MultipleSymbolsFound,
|
||||
RootSymbolNotFound,
|
||||
SidNotFound,
|
||||
SidsNotFound,
|
||||
SymbolNotFound,
|
||||
MapAssetIdentifierIndexError,
|
||||
)
|
||||
@@ -41,6 +42,7 @@ from zipline.assets.asset_writer import (
|
||||
ASSET_DB_VERSION,
|
||||
asset_db_table_names,
|
||||
)
|
||||
from zipline.utils.control_flow import invert
|
||||
|
||||
log = Logger('assets.py')
|
||||
|
||||
@@ -63,13 +65,14 @@ _asset_timestamp_fields = frozenset({
|
||||
})
|
||||
|
||||
|
||||
def _convert_asset_timestamp_fields(dict):
|
||||
def _convert_asset_timestamp_fields(dict_):
|
||||
"""
|
||||
Takes in a dict of Asset init args and converts dates to pd.Timestamps
|
||||
"""
|
||||
for key in (_asset_timestamp_fields & viewkeys(dict)):
|
||||
value = pd.Timestamp(dict[key], tz='UTC')
|
||||
dict[key] = None if pd.isnull(value) else value
|
||||
for key in (_asset_timestamp_fields & viewkeys(dict_)):
|
||||
value = pd.Timestamp(dict_[key], tz='UTC')
|
||||
dict_[key] = None if isnull(value) else value
|
||||
return dict_
|
||||
|
||||
|
||||
class AssetFinder(object):
|
||||
@@ -96,105 +99,218 @@ class AssetFinder(object):
|
||||
# routing.
|
||||
#
|
||||
# The caches are read through, i.e. accessing an asset through
|
||||
# retrieve_asset, _retrieve_equity etc. will populate the cache on
|
||||
# first retrieval.
|
||||
# retrieve_asset will populate the cache on first retrieval.
|
||||
self._asset_cache = {}
|
||||
self._equity_cache = {}
|
||||
self._future_cache = {}
|
||||
|
||||
self._asset_type_cache = {}
|
||||
|
||||
# Populated on first call to `lifetimes`.
|
||||
self._asset_lifetimes = None
|
||||
|
||||
def asset_type_by_sid(self, sid):
|
||||
def lookup_asset_types(self, sids):
|
||||
"""
|
||||
Retrieve the asset type of a given sid.
|
||||
Retrieve asset types for a list of sids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sids : list[int]
|
||||
|
||||
Returns
|
||||
-------
|
||||
types : dict[sid -> str or None]
|
||||
Asset types for the provided sids.
|
||||
"""
|
||||
try:
|
||||
return self._asset_type_cache[sid]
|
||||
except KeyError:
|
||||
pass
|
||||
found, missing = {}, set()
|
||||
for sid in sids:
|
||||
try:
|
||||
found[sid] = self._asset_type_cache[sid]
|
||||
except KeyError:
|
||||
missing.add(sid)
|
||||
|
||||
asset_type = sa.select((self.asset_router.c.asset_type,)).where(
|
||||
self.asset_router.c.sid == int(sid),
|
||||
).scalar()
|
||||
if not missing:
|
||||
return found
|
||||
|
||||
if asset_type is not None:
|
||||
self._asset_type_cache[sid] = asset_type
|
||||
return asset_type
|
||||
router_cols = self.asset_router.c
|
||||
query = sa.select((router_cols.sid, router_cols.asset_type)).where(
|
||||
self.asset_router.c.sid.in_(map(int, missing))
|
||||
)
|
||||
for sid, type_ in query.execute().fetchall():
|
||||
missing.remove(sid)
|
||||
found[sid] = self._asset_type_cache[sid] = type_
|
||||
|
||||
for sid in missing:
|
||||
found[sid] = self._asset_type_cache[sid] = None
|
||||
|
||||
return found
|
||||
|
||||
def lookup_single_asset_type(self, sid):
|
||||
"""Retrieve the asset type for a single asset."""
|
||||
return self.lookup_asset_types([sid])[sid]
|
||||
|
||||
def group_by_type(self, sids):
|
||||
"""
|
||||
Group a list of sids by asset type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sids : list[int]
|
||||
|
||||
Returns
|
||||
-------
|
||||
types : defaultdict[str or None -> list[int]]
|
||||
A dict mapping unique asset types to lists of sids drawn from sids.
|
||||
If we fail to look up an asset, we assign it a key of None.
|
||||
"""
|
||||
return invert(self.lookup_asset_types(sids))
|
||||
|
||||
def retrieve_asset(self, sid, default_none=False):
|
||||
"""
|
||||
Retrieve the Asset object of a given sid.
|
||||
Retrieve the Asset for a given sid.
|
||||
"""
|
||||
if isinstance(sid, Asset):
|
||||
return sid
|
||||
|
||||
try:
|
||||
asset = self._asset_cache[sid]
|
||||
except KeyError:
|
||||
asset_type = self.asset_type_by_sid(sid)
|
||||
if asset_type == 'equity':
|
||||
asset = self._retrieve_equity(sid)
|
||||
elif asset_type == 'future':
|
||||
asset = self._retrieve_futures_contract(sid)
|
||||
else:
|
||||
asset = None
|
||||
|
||||
# Cache the asset if it has been retrieved
|
||||
if asset is not None:
|
||||
self._asset_cache[sid] = asset
|
||||
|
||||
if (asset is not None) or default_none:
|
||||
return asset
|
||||
raise SidNotFound(sid=sid)
|
||||
return self.retrieve_all((sid,), default_none=default_none)[0]
|
||||
|
||||
def retrieve_all(self, sids, default_none=False):
|
||||
return [self.retrieve_asset(sid, default_none) for sid in sids]
|
||||
"""
|
||||
Retrieve all assets in `sids`.
|
||||
|
||||
def _retrieve_equity(self, sid):
|
||||
Parameters
|
||||
----------
|
||||
sids : interable of int
|
||||
Assets to retrieve.
|
||||
default_none : bool
|
||||
If True, return None for failed lookups.
|
||||
If False, raise `SidsNotFound`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
assets : list[int or None]
|
||||
A list of the same length as `sids` containing Assets (or Nones)
|
||||
corresponding to the requested sids.
|
||||
|
||||
Raises
|
||||
------
|
||||
SidsNotFound
|
||||
When a requested sid is not found and default_none=False.
|
||||
"""
|
||||
hits, missing, failures = {}, set(), []
|
||||
for sid in sids:
|
||||
try:
|
||||
asset = self._asset_cache[sid]
|
||||
if not default_none and asset is None:
|
||||
# Bail early if we've already cached that we don't know
|
||||
# about an asset.
|
||||
raise SidsNotFound(sids=[sid])
|
||||
hits[sid] = asset
|
||||
except KeyError:
|
||||
missing.add(sid)
|
||||
|
||||
# All requests were cache hits. Return requested sids in order.
|
||||
if not missing:
|
||||
return [hits[sid] for sid in sids]
|
||||
|
||||
update_hits = hits.update
|
||||
|
||||
# Look up cache misses by type.
|
||||
type_to_assets = self.group_by_type(missing)
|
||||
|
||||
# Handle failures
|
||||
failures = {failure: None for failure in type_to_assets.pop(None, ())}
|
||||
update_hits(failures)
|
||||
self._asset_cache.update(failures)
|
||||
|
||||
if failures and not default_none:
|
||||
raise SidsNotFound(sids=list(failures))
|
||||
|
||||
# We don't update the asset cache here because it should already be
|
||||
# updated by `self._retrieve_equities`.
|
||||
update_hits(self._retrieve_equities(type_to_assets.pop('equity', ())))
|
||||
update_hits(
|
||||
self._retrieve_futures_contracts(type_to_assets.pop('future', ()))
|
||||
)
|
||||
|
||||
# We shouldn't know about any other asset types.
|
||||
if type_to_assets:
|
||||
raise AssertionError(
|
||||
"Found asset types: %s" % list(type_to_assets.keys())
|
||||
)
|
||||
|
||||
return [hits[sid] for sid in sids]
|
||||
|
||||
def _retrieve_equities(self, sids):
|
||||
"""
|
||||
Retrieve the Equity object of a given sid.
|
||||
"""
|
||||
return self._retrieve_asset(
|
||||
sid, self._equity_cache, self.equities, Equity,
|
||||
)
|
||||
return self._retrieve_assets(sids, self.equities, Equity)
|
||||
|
||||
def _retrieve_futures_contract(self, sid):
|
||||
def _retrieve_equity(self, sid):
|
||||
return self._retrieve_equities((sid,))[sid]
|
||||
|
||||
def _retrieve_futures_contracts(self, sids):
|
||||
"""
|
||||
Retrieve the Future object of a given sid.
|
||||
"""
|
||||
return self._retrieve_asset(
|
||||
sid, self._future_cache, self.futures_contracts, Future,
|
||||
)
|
||||
return self._retrieve_assets(sids, self.futures_contracts, Future)
|
||||
|
||||
def _retrieve_futures_contract(self, sid):
|
||||
return self._retrieve_futures_contracts((sid,))[sid]
|
||||
|
||||
@staticmethod
|
||||
def _select_asset_by_sid(asset_tbl, sid):
|
||||
return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid))
|
||||
def _select_assets_by_sid(asset_tbl, sids):
|
||||
return sa.select([asset_tbl]).where(
|
||||
asset_tbl.c.sid.in_(map(int, sids))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_asset_by_symbol(asset_tbl, symbol):
|
||||
return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
|
||||
|
||||
def _retrieve_asset(self, sid, cache, asset_tbl, asset_type):
|
||||
try:
|
||||
return cache[sid]
|
||||
except KeyError:
|
||||
pass
|
||||
def _retrieve_assets(self, sids, asset_tbl, asset_type):
|
||||
"""
|
||||
Internal function for loading assets from a table.
|
||||
|
||||
data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone()
|
||||
# Convert 'data' from a RowProxy object to a dict, to allow assignment
|
||||
data = dict(data.items())
|
||||
if data:
|
||||
_convert_asset_timestamp_fields(data)
|
||||
This function does not do any caching. It is assumed that this will be
|
||||
called at most once with any given sid.
|
||||
|
||||
asset = asset_type(**data)
|
||||
else:
|
||||
asset = None
|
||||
Parameters
|
||||
---------
|
||||
sids : iterable of int
|
||||
Asset ids to look up.
|
||||
asset_tbl : sqlalchemy.Table
|
||||
Table from which to query assets.
|
||||
asset_type : type
|
||||
Type of asset to be constructed.
|
||||
|
||||
cache[sid] = asset
|
||||
return asset
|
||||
Returns
|
||||
-------
|
||||
assets : dict[int -> Asset]
|
||||
Dict mapping requested sids to the retrieved assets.
|
||||
"""
|
||||
# Fastpath for empty request.
|
||||
if not sids:
|
||||
return {}
|
||||
|
||||
cache = self._asset_cache
|
||||
|
||||
hits = {}
|
||||
# Load misses from the db.
|
||||
query = self._select_assets_by_sid(asset_tbl, sids)
|
||||
for row in imap(dict, query.execute().fetchall()):
|
||||
asset = asset_type(**_convert_asset_timestamp_fields(row))
|
||||
sid = asset.sid
|
||||
hits[sid] = cache[sid] = asset
|
||||
|
||||
# If we get here, it means something in our code thought that a
|
||||
# particular sid was an equity/future and called this function with a
|
||||
# concrete type, but we couldn't actually resolve the asset. This is
|
||||
# an error in our code, not a user-input error.
|
||||
misses = tuple(set(sids) - viewkeys(hits))
|
||||
if misses:
|
||||
raise AssertionError(
|
||||
"Couldn't resolve sids {sids} as instances of {type}.".format(
|
||||
sids=misses,
|
||||
type=asset_type,
|
||||
)
|
||||
)
|
||||
return hits
|
||||
|
||||
def _get_fuzzy_candidates(self, fuzzy_symbol):
|
||||
candidates = sa.select(
|
||||
@@ -272,10 +388,9 @@ class AssetFinder(object):
|
||||
return self._retrieve_equity(candidates[0]['sid'])
|
||||
|
||||
def _get_equities_from_candidates(self, candidates):
|
||||
return list(map(
|
||||
compose(self._retrieve_equity, itemgetter('sid')),
|
||||
candidates,
|
||||
))
|
||||
sids = map(itemgetter('sid'), candidates)
|
||||
results = self.retrieve_equities(sids)
|
||||
return [results[sid] for sid in sids]
|
||||
|
||||
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
|
||||
"""
|
||||
@@ -286,7 +401,6 @@ class AssetFinder(object):
|
||||
|
||||
If no Equity was active at as_of_date raises SymbolNotFound.
|
||||
"""
|
||||
|
||||
company_symbol, share_class_symbol, fuzzy_symbol = \
|
||||
split_delimited_symbol(symbol)
|
||||
if as_of_date:
|
||||
@@ -376,22 +490,7 @@ class AssetFinder(object):
|
||||
# If no data found, raise an exception
|
||||
if not data:
|
||||
raise SymbolNotFound(symbol=symbol)
|
||||
|
||||
# If we find a contract, check whether it's been cached
|
||||
try:
|
||||
return self._future_cache[data['sid']]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Build the Future object from its parameters
|
||||
data = dict(data.items())
|
||||
_convert_asset_timestamp_fields(data)
|
||||
future = Future(**data)
|
||||
|
||||
# Cache the Future object.
|
||||
self._future_cache[data['sid']] = future
|
||||
|
||||
return future
|
||||
return self.retrieve_asset(data['sid'])
|
||||
|
||||
def lookup_future_chain(self, root_symbol, as_of_date):
|
||||
""" Return the futures chain for a given root symbol.
|
||||
@@ -487,7 +586,8 @@ class AssetFinder(object):
|
||||
if count == 0:
|
||||
raise RootSymbolNotFound(root_symbol=root_symbol)
|
||||
|
||||
return list(map(self._retrieve_futures_contract, sids))
|
||||
contracts = self._retrieve_futures_contracts(sids)
|
||||
return [contracts[sid] for sid in sids]
|
||||
|
||||
@property
|
||||
def sids(self):
|
||||
@@ -513,7 +613,7 @@ class AssetFinder(object):
|
||||
elif isinstance(asset_convertible, Integral):
|
||||
try:
|
||||
result = self.retrieve_asset(int(asset_convertible))
|
||||
except SidNotFound:
|
||||
except SidsNotFound:
|
||||
missing.append(asset_convertible)
|
||||
return None
|
||||
matches.append(result)
|
||||
@@ -563,7 +663,7 @@ class AssetFinder(object):
|
||||
return matches[0], missing
|
||||
except IndexError:
|
||||
if hasattr(asset_convertible_or_iterable, '__int__'):
|
||||
raise SidNotFound(sid=asset_convertible_or_iterable)
|
||||
raise SidsNotFound(sids=[asset_convertible_or_iterable])
|
||||
else:
|
||||
raise SymbolNotFound(symbol=asset_convertible_or_iterable)
|
||||
|
||||
|
||||
+12
-7
@@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from zipline.utils.memoize import lazyval
|
||||
|
||||
|
||||
class ZiplineError(Exception):
|
||||
msg = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.message = str(self)
|
||||
|
||||
@@ -231,13 +232,17 @@ Root symbol '{root_symbol}' was not found.
|
||||
""".strip()
|
||||
|
||||
|
||||
class SidNotFound(ZiplineError):
|
||||
class SidsNotFound(ZiplineError):
|
||||
"""
|
||||
Raised when a retrieve_asset() call contains a non-existent sid.
|
||||
Raised when a retrieve_asset() or retrieve_all() call contains a
|
||||
non-existent sid.
|
||||
"""
|
||||
msg = """
|
||||
Asset with sid '{sid}' was not found.
|
||||
""".strip()
|
||||
@lazyval
|
||||
def msg(self):
|
||||
sids = self.kwargs['sids']
|
||||
if len(sids) == 1:
|
||||
return "No asset found for sid: {sids[0]}."
|
||||
return "No assets found for sids: {sids}."
|
||||
|
||||
|
||||
class ConsumeAssetMetaDataError(ZiplineError):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Control flow utilities.
|
||||
"""
|
||||
from six import iteritems
|
||||
from warnings import (
|
||||
catch_warnings,
|
||||
filterwarnings,
|
||||
@@ -54,3 +55,19 @@ def ignore_nanwarnings():
|
||||
{'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invert(d):
|
||||
"""
|
||||
Invert a dictionary into a dictionary of lists.
|
||||
|
||||
>>> invert({'a': 1, 'b': 2, 'c': 1})
|
||||
{1: ['a', 'c'], 2: ['b']}
|
||||
"""
|
||||
out = {}
|
||||
for k, v in iteritems(d):
|
||||
try:
|
||||
out[v].append(k)
|
||||
except KeyError:
|
||||
out[v] = [k]
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user