Merge pull request #830 from quantopian/equity_caching

Equity caching
This commit is contained in:
Maya Tydykov
2015-11-12 14:01:29 -05:00
4 changed files with 338 additions and 132 deletions
+4 -1
View File
@@ -38,7 +38,10 @@ Bug Fixes
Performance
~~~~~~~~~~~
None
* Speeds up `AssetFinder.lookup_symbol` by adding an extension,
`AssetFinderCachedEquities`, that loads equities into dictionaries and
then directs `lookup_symbol` to these dictionaries to find matching equities
(:issue:`830`).
Maintenance and Refactorings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+109 -70
View File
@@ -16,12 +16,11 @@
"""
Tests for the zipline.assets package
"""
import sys
from unittest import TestCase
from contextlib import contextmanager
from datetime import datetime, timedelta
import pickle
import sys
from unittest import TestCase
import uuid
import warnings
@@ -29,14 +28,20 @@ import pandas as pd
from pandas.tseries.tools import normalize_date
from pandas.util.testing import assert_frame_equal
from nose_parameterized import parameterized
from numpy import full
from zipline.assets import Asset, Equity, Future, AssetFinder
from zipline.assets import (
Asset,
Equity,
Future,
AssetFinder,
AssetFinderCachedEquities
)
from zipline.assets.futures import (
cme_code_to_month,
FutureChain,
month_to_cme_code,
month_to_cme_code
)
from zipline.errors import (
SymbolNotFound,
@@ -48,12 +53,15 @@ from zipline.finance.trading import TradingEnvironment, noop_load
from zipline.utils.test_utils import (
all_subindices,
make_rotating_asset_info,
tmp_assets_db
)
def build_lookup_generic_cases():
@contextmanager
def build_lookup_generic_cases(asset_finder_type):
"""
Generate test cases for AssetFinder test_lookup_generic.
Generate test cases for the type of asset finder specific by
asset_finder_type for test_lookup_generic.
"""
unique_start = pd.Timestamp('2013-01-01', tz='UTC')
@@ -90,54 +98,52 @@ def build_lookup_generic_cases():
},
],
index='sid')
env = TradingEnvironment()
env.write_data(equities_df=frame)
finder = env.asset_finder
dupe_0, dupe_1, unique = assets = [
finder.retrieve_asset(i)
for i in range(3)
]
with tmp_assets_db(frame) as assets_db:
finder = asset_finder_type(assets_db)
dupe_0, dupe_1, unique = assets = [
finder.retrieve_asset(i)
for i in range(3)
]
dupe_0_start = dupe_0.start_date
dupe_1_start = dupe_1.start_date
cases = [
##
# Scalars
dupe_0_start = dupe_0.start_date
dupe_1_start = dupe_1.start_date
yield (
##
# Scalars
# Asset object
(finder, assets[0], None, assets[0]),
(finder, assets[1], None, assets[1]),
(finder, assets[2], None, assets[2]),
# int
(finder, 0, None, assets[0]),
(finder, 1, None, assets[1]),
(finder, 2, None, assets[2]),
# Duplicated symbol with resolution date
(finder, 'DUPLICATED', dupe_0_start, dupe_0),
(finder, 'DUPLICATED', dupe_1_start, dupe_1),
# Unique symbol, with or without resolution date.
(finder, 'UNIQUE', unique_start, unique),
(finder, 'UNIQUE', None, unique),
# Asset object
(finder, assets[0], None, assets[0]),
(finder, assets[1], None, assets[1]),
(finder, assets[2], None, assets[2]),
# int
(finder, 0, None, assets[0]),
(finder, 1, None, assets[1]),
(finder, 2, None, assets[2]),
# Duplicated symbol with resolution date
(finder, 'DUPLICATED', dupe_0_start, dupe_0),
(finder, 'DUPLICATED', dupe_1_start, dupe_1),
# Unique symbol, with or without resolution date.
(finder, 'UNIQUE', unique_start, unique),
(finder, 'UNIQUE', None, unique),
##
# Iterables
##
# Iterables
# Iterables of Asset objects.
(finder, assets, None, assets),
(finder, iter(assets), None, assets),
# Iterables of ints
(finder, (0, 1), None, assets[:-1]),
(finder, iter((0, 1)), None, assets[:-1]),
# Iterables of symbols.
(finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]),
(finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]),
# Mixed types
(finder,
('DUPLICATED', 2, 'UNIQUE', 1, dupe_1),
dupe_0_start,
[dupe_0, assets[2], unique, assets[1], dupe_1]),
]
return cases
# Iterables of Asset objects.
(finder, assets, None, assets),
(finder, iter(assets), None, assets),
# Iterables of ints
(finder, (0, 1), None, assets[:-1]),
(finder, iter((0, 1)), None, assets[:-1]),
# Iterables of symbols.
(finder, ('DUPLICATED', 'UNIQUE'), dupe_0_start, [dupe_0, unique]),
(finder, ('DUPLICATED', 'UNIQUE'), dupe_1_start, [dupe_1, unique]),
# Mixed types
(finder,
('DUPLICATED', 2, 'UNIQUE', 1, dupe_1),
dupe_0_start,
[dupe_0, assets[2], unique, assets[1], dupe_1]),
)
class AssetTestCase(TestCase):
@@ -339,6 +345,7 @@ class AssetFinderTestCase(TestCase):
def setUp(self):
self.env = TradingEnvironment(load=noop_load)
self.asset_finder_type = AssetFinder
def test_lookup_symbol_delimited(self):
as_of = pd.Timestamp('2013-01-01', tz='UTC')
@@ -356,7 +363,7 @@ class AssetFinderTestCase(TestCase):
]
)
self.env.write_data(equities_df=frame)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
asset_0, asset_1, asset_2 = (
finder.retrieve_asset(i) for i in range(3)
)
@@ -435,7 +442,7 @@ class AssetFinderTestCase(TestCase):
]
)
self.env.write_data(equities_df=df)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
for _ in range(2): # Run checks twice to test for caching bugs.
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('NON_EXISTING', dates[0])
@@ -450,16 +457,41 @@ class AssetFinderTestCase(TestCase):
self.assertEqual(result.symbol, 'EXISTING')
self.assertEqual(result.sid, i)
@parameterized.expand(
build_lookup_generic_cases()
)
def test_lookup_generic(self, finder, symbols, reference_date, expected):
def test_lookup_symbol_from_multiple_valid(self):
df = pd.DataFrame.from_records(
[
{
'sid': 1,
'symbol': 'multiple',
'start_date': pd.Timestamp('2010-01-01'),
'end_date': pd.Timestamp('2013-01-01'),
'exchange': 'NYSE'
},
{
'sid': 2,
'symbol': 'multiple',
'start_date': pd.Timestamp('2012-01-01'),
'end_date': pd.Timestamp('2014-01-01'),
'exchange': 'NYSE'
}
]
)
self.env.write_data(equities_df=df)
finder = self.asset_finder_type(self.env.engine)
result = finder.lookup_symbol('MULTIPLE', pd.Timestamp('2012-05-05'))
self.assertEqual(result.symbol, 'MULTIPLE')
self.assertEqual(result.sid, 2)
def test_lookup_generic(self):
"""
Ensure that lookup_generic works with various permutations of inputs.
"""
results, missing = finder.lookup_generic(symbols, reference_date)
self.assertEqual(results, expected)
self.assertEqual(missing, [])
with build_lookup_generic_cases(self.asset_finder_type) as cases:
for finder, symbols, reference_date, expected in cases:
results, missing = finder.lookup_generic(symbols,
reference_date)
self.assertEqual(results, expected)
self.assertEqual(missing, [])
def test_lookup_generic_handle_missing(self):
data = pd.DataFrame.from_records(
@@ -499,7 +531,7 @@ class AssetFinderTestCase(TestCase):
]
)
self.env.write_data(equities_df=data)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
results, missing = finder.lookup_generic(
['REAL', 1, 'FAKE', 'REAL_BUT_OLD', 'REAL_BUT_IN_THE_FUTURE'],
pd.Timestamp('2013-02-01', tz='UTC'),
@@ -523,7 +555,7 @@ class AssetFinderTestCase(TestCase):
'symbol': "PLAY",
'foo_data': "FOO"}}
self.env.write_data(equities_data=data)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
# Test proper insertion
equity = finder.retrieve_asset(0)
self.assertIsInstance(equity, Equity)
@@ -541,7 +573,7 @@ class AssetFinderTestCase(TestCase):
dict_to_consume = {0: {'symbol': 'PLAY'},
1: {'symbol': 'MSFT'}}
self.env.write_data(equities_data=dict_to_consume)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
equity = finder.retrieve_asset(0)
self.assertIsInstance(equity, Equity)
@@ -555,7 +587,7 @@ class AssetFinderTestCase(TestCase):
df['exchange'][1] = "NYSE"
self.env = TradingEnvironment(load=noop_load)
self.env.write_data(equities_df=df)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
self.assertEqual('NASDAQ', finder.retrieve_asset(0).exchange)
self.assertEqual('Microsoft', finder.retrieve_asset(1).asset_name)
@@ -571,7 +603,7 @@ class AssetFinderTestCase(TestCase):
# Consume the Assets
self.env.write_data(equities_identifiers=[equity_asset],
futures_identifiers=[future_asset])
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
# Test equality with newly built Assets
self.assertEqual(equity_asset, finder.retrieve_asset(1))
@@ -591,7 +623,7 @@ class AssetFinderTestCase(TestCase):
allow_sid_assignment=True)
# Verify that Assets were built and different sids were assigned
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
play = finder.lookup_symbol('PLAY', today)
msft = finder.lookup_symbol('MSFT', today)
self.assertEqual('PLAY', play.symbol)
@@ -678,7 +710,7 @@ class AssetFinderTestCase(TestCase):
},
}
self.env.write_data(futures_data=metadata)
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
dt = pd.Timestamp('2015-05-14', tz='UTC')
dt_2 = pd.Timestamp('2015-10-14', tz='UTC')
dt_3 = pd.Timestamp('2016-11-17', tz='UTC')
@@ -712,7 +744,7 @@ class AssetFinderTestCase(TestCase):
def test_map_identifier_index_to_sids(self):
# Build an empty finder and some Assets
dt = pd.Timestamp('2014-01-01', tz='UTC')
finder = AssetFinder(self.env.engine)
finder = self.asset_finder_type(self.env.engine)
asset1 = Equity(1, symbol="AAPL")
asset2 = Equity(2, symbol="GOOG")
asset200 = Future(200, symbol="CLK15")
@@ -800,6 +832,13 @@ class AssetFinderTestCase(TestCase):
self.assertTrue(3 in sids)
class AssetFinderCachedEquitiesTestCase(AssetFinderTestCase):
def setUp(self):
self.env = TradingEnvironment(load=noop_load)
self.asset_finder_type = AssetFinderCachedEquities
class TestFutureChain(TestCase):
@classmethod
+3 -1
View File
@@ -22,7 +22,8 @@ from ._assets import (
)
from .assets import (
AssetFinder,
AssetConvertible
AssetConvertible,
AssetFinderCachedEquities
)
__all__ = [
@@ -30,6 +31,7 @@ __all__ = [
'Equity',
'Future',
'AssetFinder',
'AssetFinderCachedEquities',
'AssetConvertible',
'make_asset_array',
'CACHE_FILE_TEMPLATE'
+222 -60
View File
@@ -196,6 +196,87 @@ class AssetFinder(object):
cache[sid] = asset
return asset
def _get_fuzzy_candidates(self, fuzzy_symbol):
candidates = sa.select(
(self.equities.c.sid,)
).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc()
).execute().fetchall()
return candidates
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.fuzzy_symbol == fuzzy_symbol,
self.equities.c.start_date <= ad_value,
self.equities.c.end_date >= ad_value
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_split_candidates_in_range(self,
company_symbol,
share_class_symbol,
ad_value):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol == share_class_symbol,
self.equities.c.start_date <= ad_value,
self.equities.c.end_date >= ad_value
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_split_candidates(self, company_symbol, share_class_symbol):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol == share_class_symbol
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _resolve_no_matching_candidates(self,
company_symbol,
share_class_symbol,
ad_value):
candidates = sa.select((self.equities.c.sid,)).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol ==
share_class_symbol,
self.equities.c.start_date <= ad_value),
).order_by(
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_best_candidate(self, candidates):
return self._retrieve_equity(candidates[0]['sid'])
def _get_equities_from_candidates(self, candidates):
return list(map(
compose(self._retrieve_equity, itemgetter('sid')),
candidates,
))
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
"""
Return matching Equity of name symbol in database.
@@ -206,68 +287,46 @@ class AssetFinder(object):
If no Equity was active at as_of_date raises SymbolNotFound.
"""
# Format inputs
if as_of_date is not None:
as_of_date = pd.Timestamp(normalize_date(as_of_date))
company_symbol, share_class_symbol, fuzzy_symbol = \
split_delimited_symbol(symbol)
equities_cols = self.equities.c
if as_of_date:
# Format inputs
as_of_date = pd.Timestamp(normalize_date(as_of_date))
ad_value = as_of_date.value
if fuzzy:
# Search for a single exact match on the fuzzy column
fuzzy_candidates = sa.select((equities_cols.sid,)).where(
(equities_cols.fuzzy_symbol == fuzzy_symbol) &
(equities_cols.start_date <= ad_value) &
(equities_cols.end_date >= ad_value),
).execute().fetchall()
candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
ad_value)
# If exactly one SID exists for fuzzy_symbol, return that sid
if len(fuzzy_candidates) == 1:
return self._retrieve_equity(fuzzy_candidates[0]['sid'])
if len(candidates) == 1:
return self._get_best_candidate(candidates)
# Search for exact matches of the split-up company_symbol and
# share_class_symbol
candidates = sa.select((equities_cols.sid,)).where(
(equities_cols.company_symbol == company_symbol) &
(equities_cols.share_class_symbol == share_class_symbol) &
(equities_cols.start_date <= ad_value) &
(equities_cols.end_date >= ad_value),
).execute().fetchall()
candidates = self._get_split_candidates_in_range(
company_symbol,
share_class_symbol,
ad_value
)
# If exactly one SID exists for symbol, return that symbol
if len(candidates) == 1:
return self._retrieve_equity(candidates[0]['sid'])
# If multiple SIDs exist for symbol, return latest start_date with
# end_date as a tie-breaker
if candidates:
return self._get_best_candidate(candidates)
# If no SID exists for symbol, return SID with the
# highest-but-not-over end_date
elif not candidates:
sid = sa.select((equities_cols.sid,)).where(
(equities_cols.company_symbol == company_symbol) &
(equities_cols.share_class_symbol == share_class_symbol) &
(equities_cols.start_date <= ad_value),
).order_by(
equities_cols.end_date.desc(),
).scalar()
if sid is not None:
return self._retrieve_equity(sid)
# If multiple SIDs exist for symbol, return latest start_date with
# end_date as a tie-breaker
elif len(candidates) > 1:
sid = sa.select((equities_cols.sid,)).where(
(equities_cols.company_symbol == company_symbol) &
(equities_cols.share_class_symbol == share_class_symbol) &
(equities_cols.start_date <= ad_value),
).order_by(
equities_cols.start_date.desc(),
equities_cols.end_date.desc(),
).scalar()
if sid is not None:
return self._retrieve_equity(sid)
candidates = self._resolve_no_matching_candidates(
company_symbol,
share_class_symbol,
ad_value
)
if candidates:
return self._get_best_candidate(candidates)
raise SymbolNotFound(symbol=symbol)
@@ -275,27 +334,20 @@ class AssetFinder(object):
# If this is a fuzzy look-up, check if there is exactly one match
# for the fuzzy symbol
if fuzzy:
fuzzy_sids = sa.select((equities_cols.sid,)).where(
(equities_cols.fuzzy_symbol == fuzzy_symbol)
).execute().fetchall()
if len(fuzzy_sids) == 1:
return self._retrieve_equity(fuzzy_sids[0]['sid'])
candidates = self._get_fuzzy_candidates(fuzzy_symbol)
if len(candidates) == 1:
return self._get_best_candidate(candidates)
sids = sa.select((equities_cols.sid,)).where(
(equities_cols.company_symbol == company_symbol) &
(equities_cols.share_class_symbol == share_class_symbol)
).execute().fetchall()
if len(sids) == 1:
return self._retrieve_equity(sids[0]['sid'])
elif not sids:
candidates = self._get_split_candidates(company_symbol,
share_class_symbol)
if len(candidates) == 1:
return self._get_best_candidate(candidates)
elif not candidates:
raise SymbolNotFound(symbol=symbol)
else:
raise MultipleSymbolsFound(
symbol=symbol,
options=list(map(
compose(self._retrieve_equity, itemgetter('sid')),
sids,
))
options=self._get_equities_from_candidates(candidates)
)
def lookup_future_symbol(self, symbol):
@@ -678,3 +730,113 @@ for _type in string_types:
class NotAssetConvertible(ValueError):
pass
class AssetFinderCachedEquities(AssetFinder):
"""
An extension to AssetFinder that loads all equities from equities table
into memory and overrides the methods that lookup_symbol uses to look up
those equities.
"""
def __init__(self, engine):
super(AssetFinderCachedEquities, self).__init__(engine)
self.fuzzy_symbol_hashed_equities = {}
self.company_share_class_hashed_equities = {}
self.hashed_equities = sa.select(self.equities.c).execute().fetchall()
self._load_hashed_equities()
def _load_hashed_equities(self):
"""
Populates two maps - fuzzy symbol to list of equities having that
fuzzy symbol and company symbol/share class symbol to list of
equities having that combination of company symbol/share class symbol.
"""
for equity in self.hashed_equities:
company_symbol = equity['company_symbol']
share_class_symbol = equity['share_class_symbol']
fuzzy_symbol = equity['fuzzy_symbol']
asset = self._convert_row_to_equity(equity)
self.company_share_class_hashed_equities.setdefault(
(company_symbol, share_class_symbol),
[]
).append(asset)
self.fuzzy_symbol_hashed_equities.setdefault(
fuzzy_symbol, []
).append(asset)
def _convert_row_to_equity(self, equity):
"""
Converts a SQLAlchemy equity row to an Equity object.
"""
data = dict(equity.items())
_convert_asset_timestamp_fields(data)
asset = Equity(**data)
return asset
def _get_fuzzy_candidates(self, fuzzy_symbol):
if fuzzy_symbol in self.fuzzy_symbol_hashed_equities:
return self.fuzzy_symbol_hashed_equities[fuzzy_symbol]
return []
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
equities = self._get_fuzzy_candidates(fuzzy_symbol)
fuzzy_candidates = []
for equity in equities:
if (equity.start_date.value <=
ad_value <=
equity.end_date.value):
fuzzy_candidates.append(equity)
return fuzzy_candidates
def _get_split_candidates(self, company_symbol, share_class_symbol):
if (company_symbol, share_class_symbol) in \
self.company_share_class_hashed_equities:
return self.company_share_class_hashed_equities[(
company_symbol, share_class_symbol)]
return []
def _get_split_candidates_in_range(self,
company_symbol,
share_class_symbol,
ad_value):
equities = self._get_split_candidates(
company_symbol, share_class_symbol
)
best_candidates = []
for equity in equities:
if (equity.start_date.value <=
ad_value <=
equity.end_date.value):
best_candidates.append(equity)
if best_candidates:
best_candidates = sorted(
best_candidates,
key=lambda x: (x.start_date, x.end_date),
reverse=True
)
return best_candidates
def _resolve_no_matching_candidates(self,
company_symbol,
share_class_symbol,
ad_value):
equities = self._get_split_candidates(
company_symbol, share_class_symbol
)
partial_candidates = []
for equity in equities:
if equity.start_date.value <= ad_value:
partial_candidates.append(equity)
if partial_candidates:
partial_candidates = sorted(
partial_candidates,
key=lambda x: x.end_date,
reverse=True
)
return partial_candidates
def _get_best_candidate(self, candidates):
return candidates[0]
def _get_equities_from_candidates(self, candidates):
return candidates