Merge pull request #1302 from quantopian/point-in-time-asset-db

Point in time asset db
This commit is contained in:
Joe Jevnik
2016-07-26 14:23:13 -04:00
committed by GitHub
19 changed files with 1321 additions and 613 deletions
+1
View File
@@ -74,6 +74,7 @@ EQUITY_INFO = DataFrame(
index=arange(1, 7),
columns=['start_date', 'end_date'],
).astype(datetime64)
EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))]
TEST_QUERY_ASSETS = EQUITY_INFO.index
@@ -91,6 +91,7 @@ EQUITY_INFO = DataFrame(
index=arange(1, 7),
columns=['start_date', 'end_date'],
).astype(datetime64)
EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))]
TEST_QUERY_ASSETS = EQUITY_INFO.index
Binary file not shown.
+9 -1
View File
@@ -777,7 +777,10 @@ class TestTransformAlgorithm(WithLogger,
@classmethod
def make_futures_info(cls):
return pd.DataFrame.from_dict({3: {'multiplier': 10}}, 'index')
return pd.DataFrame.from_dict(
{3: {'multiplier': 10, 'symbol': 'F'}},
orient='index',
)
@classmethod
def make_equity_daily_bar_data(cls):
@@ -985,6 +988,7 @@ def before_trading_start(context, data):
'start_date': start_session,
'end_date': period_end + timedelta(days=1)
}] * 2)
equities['symbol'] = ['A', 'B']
with TempDirectory() as tempdir, \
tmp_trading_env(equities=equities) as env:
sim_params = SimulationParameters(
@@ -2813,6 +2817,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
metadata = pd.DataFrame.from_dict(
{
1: {
'symbol': 'SYM',
'start_date': start,
'end_date': start + timedelta(days=6)
},
@@ -2940,6 +2945,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
def test_asset_date_bounds(self):
metadata = pd.DataFrame([{
'symbol': 'SYM',
'start_date': self.sim_params.start_session,
'end_date': '2020-01-01',
}])
@@ -2959,6 +2965,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
algo.run(data_portal)
metadata = pd.DataFrame([{
'symbol': 'SYM',
'start_date': '1989-01-01',
'end_date': '1990-01-01',
}])
@@ -2979,6 +2986,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
algo.run(data_portal)
metadata = pd.DataFrame([{
'symbol': 'SYM',
'start_date': '2020-01-01',
'end_date': '2021-01-01',
}])
+122 -59
View File
@@ -18,6 +18,7 @@ Tests for the zipline.assets package
"""
from contextlib import contextmanager
from datetime import datetime, timedelta
from functools import partial
import pickle
import sys
from types import GetSetDescriptorType
@@ -25,7 +26,6 @@ from unittest import TestCase
import uuid
import warnings
from nose.tools import raises
from nose_parameterized import parameterized
from numpy import full, int32, int64
import pandas as pd
@@ -39,7 +39,6 @@ from zipline.assets import (
Future,
AssetDBWriter,
AssetFinder,
AssetFinderCachedEquities,
)
from zipline.assets.synthetic import (
make_commodity_future_info,
@@ -341,7 +340,6 @@ class TestFuture(WithAssetFinder, ZiplineTestCase):
self.assertIn("tick_size=0.01", reprd)
self.assertIn("multiplier=500", reprd)
@raises(AssertionError)
def test_reduce(self):
assert_equal(
pickle.loads(pickle.dumps(self.future)).to_dict(),
@@ -485,6 +483,97 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True))
self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True))
def test_lookup_symbol_change_ticker(self):
T = partial(pd.Timestamp, tz='utc')
metadata = pd.DataFrame.from_records(
[
# sid 0
{
'symbol': 'A',
'start_date': T('2014-01-01'),
'end_date': T('2014-01-05'),
},
{
'symbol': 'B',
'start_date': T('2014-01-06'),
'end_date': T('2014-01-10'),
},
# sid 1
{
'symbol': 'C',
'start_date': T('2014-01-01'),
'end_date': T('2014-01-05'),
},
{
'symbol': 'A', # claiming the unused symbol 'A'
'start_date': T('2014-01-06'),
'end_date': T('2014-01-10'),
},
],
index=[0, 0, 1, 1],
)
self.write_assets(equities=metadata)
finder = self.asset_finder
# note: these assertions walk forward in time, starting at assertions
# about ownership before the start_date and ending with assertions
# after the end_date; new assertions should be inserted in the correct
# locations
# no one held 'A' before 01
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('A', T('2013-12-31'))
# no one held 'C' before 01
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('C', T('2013-12-31'))
for asof in pd.date_range('2014-01-01', '2014-01-05', tz='utc'):
# from 01 through 05 sid 0 held 'A'
assert_equal(
finder.lookup_symbol('A', asof),
finder.retrieve_asset(0),
msg=str(asof),
)
# from 01 through 05 sid 1 held 'C'
assert_equal(
finder.lookup_symbol('C', asof),
finder.retrieve_asset(1),
msg=str(asof),
)
# no one held 'B' before 06
with self.assertRaises(SymbolNotFound):
finder.lookup_symbol('B', T('2014-01-05'))
# no one held 'C' after 06, however, no one has claimed it yet
# so it still maps to sid 1
assert_equal(
finder.lookup_symbol('C', T('2014-01-07')),
finder.retrieve_asset(1),
)
for asof in pd.date_range('2014-01-06', '2014-01-11', tz='utc'):
# from 06 through 10 sid 0 held 'B'
# we test through the 11th because sid 1 is the last to hold 'B'
# so it should ffill
assert_equal(
finder.lookup_symbol('B', asof),
finder.retrieve_asset(0),
msg=str(asof),
)
# from 06 through 10 sid 1 held 'A'
# we test through the 11th because sid 1 is the last to hold 'A'
# so it should ffill
assert_equal(
finder.lookup_symbol('A', asof),
finder.retrieve_asset(1),
msg=str(asof),
)
def test_lookup_symbol(self):
# Incrementing by two so that start and end dates for each
@@ -519,27 +608,7 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
self.assertEqual(result.symbol, 'EXISTING')
self.assertEqual(result.sid, i)
def test_lookup_symbol_from_multiple_valid(self):
# This test asserts that we resolve conflicts in accordance with the
# following rules when we have multiple assets holding the same symbol
# at the same time:
# If multiple SIDs exist for symbol S at time T, return the candidate
# SID whose start_date is highest. (200 cases)
# If multiple SIDs exist for symbol S at time T, the best candidate
# SIDs share the highest start_date, return the SID with the highest
# end_date. (34 cases)
# It is the opinion of the author (ssanderson) that we should consider
# this malformed input and fail here. But this is the current indended
# behavior of the code, and I accidentally broke it while refactoring.
# These will serve as regression tests until the time comes that we
# decide to enforce this as an error.
# See https://github.com/quantopian/zipline/issues/837 for more
# details.
def test_fail_to_write_overlapping_data(self):
df = pd.DataFrame.from_records(
[
{
@@ -568,22 +637,16 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
]
)
self.write_assets(equities=df)
with self.assertRaises(ValueError) as e:
self.write_assets(equities=df)
def check(expected_sid, date):
result = self.asset_finder.lookup_symbol(
'MULTIPLE', date,
)
self.assertEqual(result.symbol, 'MULTIPLE')
self.assertEqual(result.sid, expected_sid)
# Sids 1 and 2 are eligible here. We should get asset 2 because it
# has the later end_date.
check(2, pd.Timestamp('2010-12-31'))
# Sids 1, 2, and 3 are eligible here. We should get sid 3 because
# it has a later start_date
check(3, pd.Timestamp('2011-01-01'))
self.assertEqual(
str(e.exception),
"Ambiguous ownership of 'MULTIPLE', multiple companies held this"
" ticker over the following ranges:\n"
"[('2010-01-01 00:00:00', '2012-01-01 00:00:00'),"
" ('2011-01-01 00:00:00', '2012-01-01 00:00:00')]",
)
def test_lookup_generic(self):
"""
@@ -1000,14 +1063,6 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
)
class AssetFinderCachedEquitiesTestCase(AssetFinderTestCase):
asset_finder_type = AssetFinderCachedEquities
def write_assets(self, **kwargs):
super(AssetFinderCachedEquitiesTestCase, self).write_assets(**kwargs)
self.asset_finder.rehash_equities()
class TestFutureChain(WithAssetFinder, ZiplineTestCase):
@classmethod
def make_futures_info(cls):
@@ -1259,15 +1314,23 @@ class TestAssetDBVersioning(ZiplineTestCase):
version_table = self.metadata.tables['version_info']
# This should not raise an error
check_version_info(version_table, ASSET_DB_VERSION)
check_version_info(self.engine, version_table, ASSET_DB_VERSION)
# This should fail because the version is too low
with self.assertRaises(AssetDBVersionError):
check_version_info(version_table, ASSET_DB_VERSION - 1)
check_version_info(
self.engine,
version_table,
ASSET_DB_VERSION - 1,
)
# This should fail because the version is too high
with self.assertRaises(AssetDBVersionError):
check_version_info(version_table, ASSET_DB_VERSION + 1)
check_version_info(
self.engine,
version_table,
ASSET_DB_VERSION + 1,
)
def test_write_version(self):
version_table = self.metadata.tables['version_info']
@@ -1279,24 +1342,24 @@ class TestAssetDBVersioning(ZiplineTestCase):
# This should fail because the table has no version info and is,
# therefore, consdered v0
with self.assertRaises(AssetDBVersionError):
check_version_info(version_table, -2)
check_version_info(self.engine, version_table, -2)
# This should not raise an error because the version has been written
write_version_info(version_table, -2)
check_version_info(version_table, -2)
write_version_info(self.engine, version_table, -2)
check_version_info(self.engine, version_table, -2)
# Assert that the version is in the table and correct
self.assertEqual(sa.select((version_table.c.version,)).scalar(), -2)
# Assert that trying to overwrite the version fails
with self.assertRaises(sa.exc.IntegrityError):
write_version_info(version_table, -3)
write_version_info(self.engine, version_table, -3)
def test_finder_checks_version(self):
version_table = self.metadata.tables['version_info']
version_table.delete().execute()
write_version_info(version_table, -2)
check_version_info(version_table, -2)
write_version_info(self.engine, version_table, -2)
check_version_info(self.engine, version_table, -2)
# Assert that trying to build a finder with a bad db raises an error
with self.assertRaises(AssetDBVersionError):
@@ -1304,8 +1367,8 @@ class TestAssetDBVersioning(ZiplineTestCase):
# Change the version number of the db to the correct version
version_table.delete().execute()
write_version_info(version_table, ASSET_DB_VERSION)
check_version_info(version_table, ASSET_DB_VERSION)
write_version_info(self.engine, version_table, ASSET_DB_VERSION)
check_version_info(self.engine, version_table, ASSET_DB_VERSION)
# Now that the versions match, this Finder should succeed
AssetFinder(engine=self.engine)
@@ -1319,7 +1382,7 @@ class TestAssetDBVersioning(ZiplineTestCase):
metadata = sa.MetaData(conn)
metadata.reflect(bind=self.engine)
version_table = metadata.tables['version_info']
check_version_info(version_table, 0)
check_version_info(self.engine, version_table, 0)
# Check some of the v1-to-v0 downgrades
self.assertTrue('futures_contracts' in metadata.tables)
+12 -8
View File
@@ -45,20 +45,24 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendar,
return pd.DataFrame.from_dict(
{
1: {
"start_date": cls.START_DATE,
"end_date": cls.END_DATE + pd.Timedelta(days=1)
'symbol': 'A',
'start_date': cls.START_DATE,
'end_date': cls.END_DATE + pd.Timedelta(days=1)
},
2: {
"start_date": cls.START_DATE,
"end_date": cls.END_DATE + pd.Timedelta(days=1)
'symbol': 'B',
'start_date': cls.START_DATE,
'end_date': cls.END_DATE + pd.Timedelta(days=1)
},
3: {
"start_date": pd.Timestamp('2006-05-26', tz='utc'),
"end_date": pd.Timestamp('2006-08-09', tz='utc')
'symbol': 'C',
'start_date': pd.Timestamp('2006-05-26', tz='utc'),
'end_date': pd.Timestamp('2006-08-09', tz='utc')
},
4: {
"start_date": cls.START_DATE,
"end_date": cls.END_DATE + pd.Timedelta(days=1)
'symbol': 'D',
'start_date': cls.START_DATE,
'end_date': cls.END_DATE + pd.Timedelta(days=1)
},
},
orient='index',
+3
View File
@@ -348,6 +348,9 @@ def bundles():
"""List all of the available data bundles.
"""
for bundle in sorted(bundles_module.bundles.keys()):
if bundle.startswith('.'):
# hide the test data
continue
try:
ingestions = sorted(
(str(bundles_module.from_bundle_ingest_dirname(ing))
-2
View File
@@ -23,7 +23,6 @@ from ._assets import (
from .assets import (
AssetFinder,
AssetConvertible,
AssetFinderCachedEquities
)
from .asset_db_schema import ASSET_DB_VERSION
from .asset_writer import AssetDBWriter
@@ -35,7 +34,6 @@ __all__ = [
'Equity',
'Future',
'AssetFinder',
'AssetFinderCachedEquities',
'AssetConvertible',
'make_asset_array',
'CACHE_FILE_TEMPLATE'
+74 -51
View File
@@ -58,26 +58,35 @@ cdef class Asset:
cdef readonly object exchange
def __cinit__(self,
int sid, # sid is required
object symbol="",
object asset_name="",
object start_date=None,
object end_date=None,
object first_traded=None,
object auto_close_date=None,
object exchange="",
*args,
**kwargs):
_kwargnames = frozenset({
'sid',
'symbol',
'asset_name',
'start_date',
'end_date',
'first_traded',
'auto_close_date',
'exchange',
})
self.sid = sid
self.sid_hash = hash(sid)
self.symbol = symbol
self.asset_name = asset_name
self.exchange = exchange
self.start_date = start_date
self.end_date = end_date
self.first_traded = first_traded
def __init__(self,
int sid, # sid is required
object symbol="",
object asset_name="",
object start_date=None,
object end_date=None,
object first_traded=None,
object auto_close_date=None,
object exchange=""):
self.sid = sid
self.sid_hash = hash(sid)
self.symbol = symbol
self.asset_name = asset_name
self.exchange = exchange
self.start_date = start_date
self.end_date = end_date
self.first_traded = first_traded
self.auto_close_date = auto_close_date
def __int__(self):
@@ -127,9 +136,9 @@ cdef class Asset:
def __str__(self):
if self.symbol:
return 'Asset(%d [%s])' % (self.sid, self.symbol)
return '%s(%d [%s])' % (type(self).__name__, self.sid, self.symbol)
else:
return 'Asset(%d)' % self.sid
return '%s(%d)' % (type(self).__name__, self.sid)
def __repr__(self):
attrs = ('symbol', 'asset_name', 'exchange',
@@ -213,12 +222,6 @@ cdef class Asset:
cdef class Equity(Asset):
def __str__(self):
if self.symbol:
return 'Equity(%d [%s])' % (self.sid, self.symbol)
else:
return 'Equity(%d)' % self.sid
def __repr__(self):
attrs = ('symbol', 'asset_name', 'exchange',
'start_date', 'end_date', 'first_traded', 'auto_close_date')
@@ -270,26 +273,52 @@ cdef class Future(Asset):
cdef readonly object tick_size
cdef readonly float multiplier
def __cinit__(self,
int sid, # sid is required
object symbol="",
object root_symbol="",
object asset_name="",
object start_date=None,
object end_date=None,
object notice_date=None,
object expiration_date=None,
object auto_close_date=None,
object first_traded=None,
object exchange="",
object tick_size="",
float multiplier=1):
_kwargnames = frozenset({
'sid',
'symbol',
'root_symbol',
'asset_name',
'start_date',
'end_date',
'notice_date',
'expiration_date',
'auto_close_date',
'first_traded',
'exchange',
'tick_size',
'multiplier',
})
self.root_symbol = root_symbol
self.notice_date = notice_date
def __init__(self,
int sid, # sid is required
object symbol="",
object root_symbol="",
object asset_name="",
object start_date=None,
object end_date=None,
object notice_date=None,
object expiration_date=None,
object auto_close_date=None,
object first_traded=None,
object exchange="",
object tick_size="",
float multiplier=1.0):
super().__init__(
sid,
symbol=symbol,
asset_name=asset_name,
start_date=start_date,
end_date=end_date,
first_traded=first_traded,
auto_close_date=auto_close_date,
exchange=exchange,
)
self.root_symbol = root_symbol
self.notice_date = notice_date
self.expiration_date = expiration_date
self.tick_size = tick_size
self.multiplier = multiplier
self.tick_size = tick_size
self.multiplier = multiplier
if auto_close_date is None:
if notice_date is None:
@@ -299,12 +328,6 @@ cdef class Future(Asset):
else:
self.auto_close_date = min(notice_date, expiration_date)
def __str__(self):
if self.symbol:
return 'Future(%d [%s])' % (self.sid, self.symbol)
else:
return 'Future(%d)' % self.sid
def __repr__(self):
attrs = ('symbol', 'root_symbol', 'asset_name', 'exchange',
'start_date', 'end_date', 'first_traded', 'notice_date',
+72 -3
View File
@@ -50,7 +50,7 @@ def downgrade(engine, desired_version):
# Execute the downgrades in order
for downgrade_key in downgrade_keys:
_downgrade_methods[downgrade_key](op, version_info_table)
_downgrade_methods[downgrade_key](op, engine, version_info_table)
# Re-enable foreign keys
_pragma_foreign_keys(conn, True)
@@ -96,10 +96,10 @@ def downgrades(src):
@do(op.setitem(_downgrade_methods, destination))
@wraps(f)
def wrapper(op, version_info_table):
def wrapper(op, engine, version_info_table):
version_info_table.delete().execute() # clear the version
f(op)
write_version_info(version_info_table, destination)
write_version_info(engine, version_info_table, destination)
return wrapper
return _
@@ -206,3 +206,72 @@ def _downgrade_v3(op):
'equities',
['fuzzy_symbol'],
)
@downgrades(4)
def _downgrade_v4(op):
op.create_table(
'_new_equities',
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('symbol', sa.Text),
sa.Column('company_symbol', sa.Text, index=True),
sa.Column('share_class_symbol', sa.Text),
sa.Column('fuzzy_symbol', sa.Text, index=True),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer),
sa.Column('auto_close_date', sa.Integer),
sa.Column('exchange', sa.Text),
)
op.execute(
"""
insert into _new_equities
select
equities.sid as sid,
sym.symbol as symbol,
sym.company_symbol as company_symbol,
sym.share_class_symbol as share_class_symbol,
sym.company_symbol || sym.share_class_symbol as fuzzy_symbol,
equities.asset_name as asset_name,
equities.start_date as start_date,
equities.end_date as end_date,
equities.first_traded as first_traded,
equities.auto_close_date as auto_close_date,
equities.exchange as exchange
from
equities
inner join
(select
*
from
equity_symbol_mappings
group by
equity_symbol_mappings.sid
order by
equity_symbol_mappings.end_date desc) sym
on
equities.sid == sym.sid
""",
)
op.drop_table('equity_symbol_mappings')
op.drop_table('equities')
op.rename_table('_new_equities', 'equities')
# we need to make sure the indicies have the proper names after the rename
op.create_index(
'ix_equities_company_symbol',
'equities',
['company_symbol'],
)
op.create_index(
'ix_equities_fuzzy_symbol',
'equities',
['fuzzy_symbol'],
)
+155 -143
View File
@@ -6,26 +6,14 @@ import sqlalchemy as sa
# assets database
# NOTE: When upgrading this remember to add a downgrade in:
# .asset_db_migrations
ASSET_DB_VERSION = 3
def generate_asset_db_metadata(bind=None):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
metadata = sa.MetaData(bind=bind)
_version_table_schema(metadata)
_equities_table_schema(metadata)
_futures_exchanges_schema(metadata)
_futures_root_symbols_schema(metadata)
_futures_contracts_schema(metadata)
_asset_router_schema(metadata)
return metadata
ASSET_DB_VERSION = 4
# A frozenset of the names of all tables in the assets db
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
asset_db_table_names = frozenset({
'asset_router',
'equities',
'equity_symbol_mappings',
'futures_contracts',
'futures_exchanges',
'futures_root_symbols',
@@ -33,139 +21,163 @@ asset_db_table_names = frozenset({
})
def _equities_table_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'equities',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('symbol', sa.Text),
sa.Column('company_symbol', sa.Text, index=True),
sa.Column('share_class_symbol', sa.Text),
sa.Column('fuzzy_symbol', sa.Text, index=True),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer),
sa.Column('auto_close_date', sa.Integer),
sa.Column('exchange', sa.Text),
)
metadata = sa.MetaData()
equities = sa.Table(
'equities',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer),
sa.Column('auto_close_date', sa.Integer),
sa.Column('exchange', sa.Text),
)
def _futures_exchanges_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'futures_exchanges',
metadata,
sa.Column(
'exchange',
sa.Text,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('timezone', sa.Text),
)
equity_symbol_mappings = sa.Table(
'equity_symbol_mappings',
metadata,
sa.Column(
'id',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column(
'sid',
sa.Integer,
sa.ForeignKey(equities.c.sid),
nullable=False,
index=True,
),
sa.Column(
'symbol',
sa.Text,
nullable=False,
),
sa.Column(
'company_symbol',
sa.Text,
index=True,
),
sa.Column(
'share_class_symbol',
sa.Text,
),
sa.Column(
'start_date',
sa.Integer,
nullable=False,
),
sa.Column(
'end_date',
sa.Integer,
nullable=False,
),
)
futures_exchanges = sa.Table(
'futures_exchanges',
metadata,
sa.Column(
'exchange',
sa.Text,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('timezone', sa.Text),
)
def _futures_root_symbols_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'futures_root_symbols',
metadata,
sa.Column(
'root_symbol',
sa.Text,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('root_symbol_id', sa.Integer),
sa.Column('sector', sa.Text),
sa.Column('description', sa.Text),
sa.Column(
'exchange',
sa.Text,
sa.ForeignKey('futures_exchanges.exchange'),
),
)
futures_root_symbols = sa.Table(
'futures_root_symbols',
metadata,
sa.Column(
'root_symbol',
sa.Text,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('root_symbol_id', sa.Integer),
sa.Column('sector', sa.Text),
sa.Column('description', sa.Text),
sa.Column(
'exchange',
sa.Text,
sa.ForeignKey('futures_exchanges.exchange'),
),
)
futures_contracts = sa.Table(
'futures_contracts',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('symbol', sa.Text, unique=True, index=True),
sa.Column(
'root_symbol',
sa.Text,
sa.ForeignKey('futures_root_symbols.root_symbol'),
index=True
),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer),
sa.Column(
'exchange',
sa.Text,
sa.ForeignKey('futures_exchanges.exchange'),
),
sa.Column('notice_date', sa.Integer, nullable=False),
sa.Column('expiration_date', sa.Integer, nullable=False),
sa.Column('auto_close_date', sa.Integer, nullable=False),
sa.Column('multiplier', sa.Float),
sa.Column('tick_size', sa.Float),
)
def _futures_contracts_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'futures_contracts',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column('symbol', sa.Text, unique=True, index=True),
sa.Column(
'root_symbol',
sa.Text,
sa.ForeignKey('futures_root_symbols.root_symbol'),
index=True
),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer),
sa.Column(
'exchange',
sa.Text,
sa.ForeignKey('futures_exchanges.exchange'),
),
sa.Column('notice_date', sa.Integer, nullable=False),
sa.Column('expiration_date', sa.Integer, nullable=False),
sa.Column('auto_close_date', sa.Integer, nullable=False),
sa.Column('multiplier', sa.Float),
sa.Column('tick_size', sa.Float),
)
asset_router = sa.Table(
'asset_router',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True),
sa.Column('asset_type', sa.Text),
)
def _asset_router_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'asset_router',
metadata,
sa.Column(
'sid',
sa.Integer,
unique=True,
nullable=False,
primary_key=True),
sa.Column('asset_type', sa.Text),
)
def _version_table_schema(metadata):
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
return sa.Table(
'version_info',
metadata,
sa.Column(
'id',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column(
'version',
sa.Integer,
unique=True,
nullable=False,
),
# This constraint ensures a single entry in this table
sa.CheckConstraint('id <= 1'),
)
version_info = sa.Table(
'version_info',
metadata,
sa.Column(
'id',
sa.Integer,
unique=True,
nullable=False,
primary_key=True,
),
sa.Column(
'version',
sa.Integer,
unique=True,
nullable=False,
),
# This constraint ensures a single entry in this table
sa.CheckConstraint('id <= 1'),
)
+157 -60
View File
@@ -23,16 +23,40 @@ from toolz import first
from zipline.errors import AssetDBVersionError
from zipline.assets.asset_db_schema import (
generate_asset_db_metadata,
asset_db_table_names,
ASSET_DB_VERSION,
asset_db_table_names,
asset_router,
equities as equities_table,
equity_symbol_mappings,
futures_contracts as futures_contracts_table,
futures_exchanges,
futures_root_symbols,
metadata,
version_info,
)
from zipline.utils.range import from_tuple, intersecting_ranges
# Define a namedtuple for use with the load_data and _load_data methods
AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols')
AssetData = namedtuple(
'AssetData', (
'equities',
'equities_mappings',
'futures',
'exchanges',
'root_symbols',
),
)
SQLITE_MAX_VARIABLE_NUMBER = 999
symbol_columns = frozenset({
'symbol',
'company_symbol',
'share_class_symbol',
})
mapping_columns = symbol_columns | {'start_date', 'end_date'}
# Default values for the equities DataFrame
_equities_defaults = {
'symbol': None,
@@ -74,7 +98,7 @@ _root_symbols_defaults = {
}
# Fuzzy symbol delimiters that may break up a company symbol and share class
_delimited_symbol_delimiter_regex = r'[./\-_]'
_delimited_symbol_delimiters_regex = re.compile(r'[./\-_]')
_delimited_symbol_default_triggers = frozenset({np.nan, None, ''})
@@ -91,16 +115,22 @@ def split_delimited_symbol(symbol):
Returns
-------
( str, str , str )
A tuple of ( company_symbol, share_class_symbol, fuzzy_symbol)
company_symbol : str
The company part of the symbol.
share_class_symbol : str
The share class part of a symbol.
"""
# return blank strings for any bad fuzzy symbols, like NaN or None
if symbol in _delimited_symbol_default_triggers:
return ('', '', '')
return '', ''
split_list = re.split(pattern=_delimited_symbol_delimiter_regex,
string=symbol,
maxsplit=1)
symbol = symbol.upper()
split_list = re.split(
pattern=_delimited_symbol_delimiters_regex,
string=symbol,
maxsplit=1,
)
# Break the list up in to its two components, the company symbol and the
# share class symbol
@@ -110,12 +140,7 @@ def split_delimited_symbol(symbol):
else:
share_class_symbol = ''
# Strip all fuzzy characters from the symbol to get the fuzzy symbol
fuzzy_symbol = re.sub(pattern=_delimited_symbol_delimiter_regex,
repl='',
string=symbol)
return (company_symbol, share_class_symbol, fuzzy_symbol)
return company_symbol, share_class_symbol
def _generate_output_dataframe(data_subset, defaults):
@@ -151,19 +176,70 @@ def _generate_output_dataframe(data_subset, defaults):
# Get those columns which we need but
# for which no data has been supplied.
need = desired_cols - cols
for col in desired_cols - cols:
# write the default value for any missing columns
data_subset[col] = defaults[col]
# Combine the users supplied data with our required columns.
output = pd.concat(
(data_subset, pd.DataFrame(
{k: defaults[k] for k in need},
data_subset.index,
)),
axis=1,
copy=False
return data_subset
def _check_asset_group(group):
for colname in set(group.columns) - mapping_columns:
col = group[colname]
if len(col.unique()) != 1:
raise ValueError(
'All values must be the same for the %s column' % colname,
)
row = group.iloc[0]
row.start_date = group.start_date.min()
row.end_date = group.end_date.max()
row.drop(list(symbol_columns), inplace=True)
return row
def _format_range(r):
return (
str(pd.Timestamp(r.start, unit='ns')),
str(pd.Timestamp(r.stop, unit='ns')),
)
return output
def _split_symbol_mappings(df):
"""Split out the symbol: sid mappings from the raw data.
Parameters
----------
df : pd.DataFrame
The dataframe with multiple rows for each symbol: sid pair.
Returns
-------
asset_info : pd.DataFrame
The asset info with one row per asset.
symbol_mappings : pd.DataFrame
The dataframe of just symbol: sid mappings. The index will be
the sid, then there will be three columns: symbol, start_date, and
end_date.
"""
mappings = df[list(mapping_columns)]
for symbol in mappings.symbol.unique():
persymbol = mappings[mappings.symbol == symbol]
intersections = list(intersecting_ranges(
map(from_tuple, zip(persymbol.start_date, persymbol.end_date)),
))
if intersections:
raise ValueError(
'Ambiguous ownership of %r, multiple companies held this'
' ticker over the following ranges:\n%s' % (
symbol,
list(map(_format_range, intersections)),
),
)
return (
df.groupby(level=0).apply(_check_asset_group),
df[list(mapping_columns)],
)
def _dt_to_epoch_ns(dt_series):
@@ -187,12 +263,14 @@ def _dt_to_epoch_ns(dt_series):
return index.view(np.int64)
def check_version_info(version_table, expected_version):
def check_version_info(conn, version_table, expected_version):
"""
Checks for a version value in the version table.
Parameters
----------
conn : sa.Connection
The connection to use to perform the check.
version_table : sa.Table
The version table of the asset database
expected_version : int
@@ -205,7 +283,9 @@ def check_version_info(version_table, expected_version):
"""
# Read the version out of the table
version_from_table = sa.select((version_table.c.version,)).scalar()
version_from_table = conn.execute(
sa.select((version_table.c.version,)),
).scalar()
# A db without a version is considered v0
if version_from_table is None:
@@ -217,19 +297,21 @@ def check_version_info(version_table, expected_version):
expected_version=expected_version)
def write_version_info(version_table, version_value):
def write_version_info(conn, version_table, version_value):
"""
Inserts the version value in to the version table.
Parameters
----------
conn : sa.Connection
The connection to use to execute the insert.
version_table : sa.Table
The version table of the asset database
version_value : int
The version to write in to the database
"""
sa.insert(version_table, values={'version': version_value}).execute()
conn.execute(sa.insert(version_table, values={'version': version_value}))
class _empty(object):
@@ -266,9 +348,6 @@ class AssetDBWriter(object):
symbol : str
The ticker symbol for this equity.
fuzzy_symbol : str, optional
The fuzzy symbol for this equity. This is the symbol
without any delimiting characters like '.' or '_'.
asset_name : str
The full name for this asset.
start_date : datetime
@@ -348,7 +427,7 @@ class AssetDBWriter(object):
"""
with self.engine.begin() as txn:
# Create SQL tables if they do not exist.
metadata = self.init_db(txn)
self.init_db(txn)
# Get the data to add to SQL.
data = self._load_data(
@@ -359,51 +438,74 @@ class AssetDBWriter(object):
)
# Write the data to SQL.
self._write_df_to_table(
metadata.tables['futures_exchanges'],
futures_exchanges,
data.exchanges,
txn,
chunk_size,
)
self._write_df_to_table(
metadata.tables['futures_root_symbols'],
futures_root_symbols,
data.root_symbols,
txn,
chunk_size,
)
asset_router = metadata.tables['asset_router']
self._write_assets(
asset_router,
metadata.tables['futures_contracts'],
'future',
data.futures,
txn,
chunk_size,
)
self._write_assets(
asset_router,
metadata.tables['equities'],
'equity',
data.equities,
txn,
chunk_size,
mapping_data=data.equities_mappings,
)
def _write_df_to_table(self, tbl, df, txn, chunk_size):
def _write_df_to_table(self, tbl, df, txn, chunk_size, idx_label=None):
df.to_sql(
tbl.name,
txn.connection,
index_label=first(tbl.primary_key.columns).name,
index_label=(
idx_label
if idx_label is not None else
first(tbl.primary_key.columns).name
),
if_exists='append',
chunksize=chunk_size,
)
def _write_assets(self,
asset_router,
tbl,
asset_type,
assets,
txn,
chunk_size):
chunk_size,
mapping_data=None):
if asset_type == 'future':
tbl = futures_contracts_table
if mapping_data is not None:
raise TypeError('no mapping data expected for futures')
elif asset_type == 'equity':
tbl = equities_table
if mapping_data is None:
raise TypeError('mapping data required for equities')
# write the symbol mapping data.
self._write_df_to_table(
equity_symbol_mappings,
mapping_data,
txn,
chunk_size,
idx_label='sid',
)
else:
raise ValueError(
"asset_type must be in {'future', 'equity'}, got: %s" %
asset_type,
)
self._write_df_to_table(tbl, assets, txn, chunk_size)
pd.DataFrame({
@@ -456,17 +558,14 @@ class AssetDBWriter(object):
txn = stack.enter_context(self.engine.begin())
tables_already_exist = self._all_tables_present(txn)
metadata = generate_asset_db_metadata(bind=txn)
# Create the SQL tables if they do not already exist.
metadata.create_all(checkfirst=True)
metadata.create_all(txn, checkfirst=True)
version_info = metadata.tables['version_info']
if tables_already_exist:
check_version_info(version_info, ASSET_DB_VERSION)
check_version_info(txn, version_info, ASSET_DB_VERSION)
else:
write_version_info(version_info, ASSET_DB_VERSION)
return metadata
write_version_info(txn, version_info, ASSET_DB_VERSION)
def _normalize_equities(self, equities):
# HACK: If 'company_name' is provided, map it to asset_name
@@ -487,16 +586,13 @@ class AssetDBWriter(object):
tuple_series = equities_output['symbol'].apply(split_delimited_symbol)
split_symbols = pd.DataFrame(
tuple_series.tolist(),
columns=['company_symbol', 'share_class_symbol', 'fuzzy_symbol'],
columns=['company_symbol', 'share_class_symbol'],
index=tuple_series.index
)
equities_output = equities_output.join(split_symbols)
equities_output = pd.concat((equities_output, split_symbols), axis=1)
# Upper-case all symbol data
for col in ('symbol',
'company_symbol',
'share_class_symbol',
'fuzzy_symbol'):
for col in symbol_columns:
equities_output[col] = equities_output[col].str.upper()
# Convert date columns to UNIX Epoch integers (nanoseconds)
@@ -506,7 +602,7 @@ class AssetDBWriter(object):
'auto_close_date'):
equities_output[col] = _dt_to_epoch_ns(equities_output[col])
return equities_output
return _split_symbol_mappings(equities_output)
def _normalize_futures(self, futures):
futures_output = _generate_output_dataframe(
@@ -541,7 +637,7 @@ class AssetDBWriter(object):
if id_col in df.columns:
df.set_index(id_col, inplace=True)
equities_output = self._normalize_equities(equities)
equities_output, equities_mappings = self._normalize_equities(equities)
futures_output = self._normalize_futures(futures)
exchanges_output = _generate_output_dataframe(
@@ -556,6 +652,7 @@ class AssetDBWriter(object):
return AssetData(
equities=equities_output,
equities_mappings=equities_mappings,
futures=futures_output,
exchanges=exchanges_output,
root_symbols=root_symbols_output,
+305 -261
View File
@@ -13,16 +13,18 @@
# limitations under the License.
from abc import ABCMeta
from collections import namedtuple
from numbers import Integral
from operator import itemgetter
from operator import itemgetter, attrgetter
from logbook import Logger
import numpy as np
import pandas as pd
from pandas import isnull
from six import with_metaclass, string_types, viewkeys
from six.moves import map as imap
from six import with_metaclass, string_types, viewkeys, iteritems
import sqlalchemy as sa
from toolz import merge, compose, valmap, sliding_window, concatv, curry
from toolz.curried import operator as op
from zipline.errors import (
EquitiesNotFound,
@@ -33,18 +35,20 @@ from zipline.errors import (
SidsNotFound,
SymbolNotFound,
)
from zipline.assets import (
from . import (
Asset, Equity, Future,
)
from zipline.assets.asset_writer import (
from .asset_writer import (
check_version_info,
split_delimited_symbol,
asset_db_table_names,
symbol_columns,
)
from zipline.assets.asset_db_schema import (
from .asset_db_schema import (
ASSET_DB_VERSION
)
from zipline.utils.control_flow import invert
from zipline.utils.memoize import lazyval
from zipline.utils.sqlite_utils import group_into_chunks
log = Logger('assets.py')
@@ -67,12 +71,38 @@ _asset_timestamp_fields = frozenset({
'auto_close_date',
})
SymbolOwnership = namedtuple('SymbolOwnership', 'start end sid symbol')
@curry
def _filter_kwargs(names, dict_):
"""Filter out kwargs from a dictionary.
Parameters
----------
names : set[str]
The names to select from ``dict_``.
dict_ : dict[str, any]
The dictionary to select from.
Returns
-------
kwargs : dict[str, any]
``dict_`` where the keys intersect with ``names`` and the values are
not None.
"""
return {k: v for k, v in dict_.items() if k in names and v is not None}
_filter_future_kwargs = _filter_kwargs(Future._kwargnames)
_filter_equity_kwargs = _filter_kwargs(Equity._kwargnames)
def _convert_asset_timestamp_fields(dict_):
"""
Takes in a dict of Asset init args and converts dates to pd.Timestamps
"""
for key in (_asset_timestamp_fields & viewkeys(dict_)):
for key in _asset_timestamp_fields & viewkeys(dict_):
value = pd.Timestamp(dict_[key], tz='UTC')
dict_[key] = None if isnull(value) else value
return dict_
@@ -101,17 +131,18 @@ class AssetFinder(object):
PERSISTENT_TOKEN = "<AssetFinder>"
def __init__(self, engine):
if isinstance(engine, string_types):
engine = sa.create_engine('sqlite:///' + engine)
self.engine = engine
self.engine = engine = (
sa.create_engine('sqlite:///' + engine)
if isinstance(engine, string_types) else
engine
)
metadata = sa.MetaData(bind=engine)
metadata.reflect(only=asset_db_table_names)
for table_name in asset_db_table_names:
setattr(self, table_name, metadata.tables[table_name])
# Check the version info of the db for compatibility
check_version_info(self.version_info, ASSET_DB_VERSION)
check_version_info(engine, self.version_info, ASSET_DB_VERSION)
# Cache for lookup of assets by sid, the objects in the asset lookup
# may be shared with the results from equity and future lookup caches.
@@ -137,6 +168,79 @@ class AssetFinder(object):
# should be calling this.
for cache in self._caches:
cache.clear()
self.reload_symbol_maps()
def reload_symbol_maps(self):
"""Clear the in memory symbol lookup maps.
This will make any changes to the underlying db available to the
symbol maps.
"""
# clear the lazyval caches, the next access will requery
try:
del type(self).symbol_ownership_map[self]
except KeyError:
pass
try:
del type(self).fuzzy_symbol_ownership_map[self]
except KeyError:
pass
@lazyval
def symbol_ownership_map(self):
rows = sa.select(self.equity_symbol_mappings.c).execute().fetchall()
mappings = {}
for row in rows:
mappings.setdefault(
(row.company_symbol, row.share_class_symbol),
[],
).append(
SymbolOwnership(
pd.Timestamp(row.start_date, unit='ns', tz='utc'),
pd.Timestamp(row.end_date, unit='ns', tz='utc'),
row.sid,
row.symbol,
),
)
return valmap(
lambda v: tuple(
SymbolOwnership(
a.start,
b.start,
a.sid,
a.symbol,
) for a, b in sliding_window(
2,
concatv(
sorted(v),
# concat with a fake ownership object to make the last
# end date be max timestamp
[SymbolOwnership(
pd.Timestamp.max.tz_localize('utc'),
None,
None,
None,
)],
),
)
),
mappings,
factory=lambda: mappings,
)
@lazyval
def fuzzy_symbol_ownership_map(self):
fuzzy_mappings = {}
for (cs, scs), owners in iteritems(self.symbol_ownership_map):
fuzzy_owners = fuzzy_mappings.setdefault(
cs + scs,
[],
)
fuzzy_owners.extend(owners)
fuzzy_owners.sort()
return fuzzy_mappings
def lookup_asset_types(self, sids):
"""
@@ -326,6 +430,50 @@ class AssetFinder(object):
def _select_asset_by_symbol(asset_tbl, symbol):
return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
def _lookup_most_recent_symbols(self, sids):
symbol_cols = self.equity_symbol_mappings.c
symbols = {
row.sid: {c: row[c] for c in symbol_columns}
for row in self.engine.execute(
sa.select(
(symbol_cols.sid,) +
tuple(map(op.getitem(symbol_cols), symbol_columns)),
).where(
symbol_cols.sid.in_(map(int, sids)),
).order_by(
symbol_cols.end_date.desc(),
).group_by(
symbol_cols.sid,
)
).fetchall()
}
if len(symbols) != len(sids):
raise EquitiesNotFound(
sids=set(sids) - set(symbols),
plural=True,
)
return symbols
def _retrieve_asset_dicts(self, sids, asset_tbl, querying_equities):
if not sids:
return
if querying_equities:
def mkdict(row,
symbols=self._lookup_most_recent_symbols(sids)):
return merge(row, symbols[row['sid']])
else:
mkdict = dict
for assets in group_into_chunks(sids):
# Load misses from the db.
query = self._select_assets_by_sid(asset_tbl, assets)
for row in query.execute().fetchall():
yield _convert_asset_timestamp_fields(mkdict(row))
def _retrieve_assets(self, sids, asset_tbl, asset_type):
"""
Internal function for loading assets from a table.
@@ -354,14 +502,18 @@ class AssetFinder(object):
cache = self._asset_cache
hits = {}
for assets in group_into_chunks(sids):
# Load misses from the db.
query = self._select_assets_by_sid(asset_tbl, assets)
querying_equities = issubclass(asset_type, Equity)
filter_kwargs = (
_filter_equity_kwargs
if querying_equities else
_filter_future_kwargs
)
for row in imap(dict, query.execute().fetchall()):
asset = asset_type(**_convert_asset_timestamp_fields(row))
sid = asset.sid
hits[sid] = cache[sid] = asset
rows = self._retrieve_asset_dicts(sids, asset_tbl, querying_equities)
for row in rows:
sid = row['sid']
asset = asset_type(**filter_kwargs(row))
hits[sid] = cache[sid] = asset
# If we get here, it means something in our code thought that a
# particular sid was an equity/future and called this function with a
@@ -369,166 +521,152 @@ class AssetFinder(object):
# an error in our code, not a user-input error.
misses = tuple(set(sids) - viewkeys(hits))
if misses:
if asset_type == Equity:
if querying_equities:
raise EquitiesNotFound(sids=misses)
else:
raise FutureContractsNotFound(sids=misses)
return hits
def _get_fuzzy_candidates(self, fuzzy_symbol):
candidates = sa.select(
(self.equities.c.sid,)
).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc()
).execute().fetchall()
return candidates
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.fuzzy_symbol == fuzzy_symbol,
self.equities.c.start_date <= ad_value,
self.equities.c.end_date >= ad_value
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_split_candidates_in_range(self,
company_symbol,
share_class_symbol,
ad_value):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol == share_class_symbol,
self.equities.c.start_date <= ad_value,
self.equities.c.end_date >= ad_value
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_split_candidates(self, company_symbol, share_class_symbol):
candidates = sa.select(
(self.equities.c.sid,)
).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol == share_class_symbol
)
).order_by(
self.equities.c.start_date.desc(),
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _resolve_no_matching_candidates(self,
company_symbol,
share_class_symbol,
ad_value):
candidates = sa.select((self.equities.c.sid,)).where(
sa.and_(
self.equities.c.company_symbol == company_symbol,
self.equities.c.share_class_symbol ==
share_class_symbol,
self.equities.c.start_date <= ad_value),
).order_by(
self.equities.c.end_date.desc(),
).execute().fetchall()
return candidates
def _get_best_candidate(self, candidates):
return self._retrieve_equity(candidates[0]['sid'])
def _get_equities_from_candidates(self, candidates):
sids = map(itemgetter('sid'), candidates)
results = self.retrieve_equities(sids)
return [results[sid] for sid in sids]
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
"""
Return matching Equity of name symbol in database.
If multiple Equities are found and as_of_date is not set,
raises MultipleSymbolsFound.
If no Equity was active at as_of_date raises SymbolNotFound.
"""
company_symbol, share_class_symbol, fuzzy_symbol = \
split_delimited_symbol(symbol)
if as_of_date:
# Format inputs
as_of_date = pd.Timestamp(as_of_date).normalize()
ad_value = as_of_date.value
if fuzzy:
# Search for a single exact match on the fuzzy column
candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
ad_value)
# If exactly one SID exists for fuzzy_symbol, return that sid
if len(candidates) == 1:
return self._get_best_candidate(candidates)
# Search for exact matches of the split-up company_symbol and
# share_class_symbol
candidates = self._get_split_candidates_in_range(
def _lookup_symbol_strict(self, symbol, as_of_date):
# split the symbol into the components, if there are no
# company/share class parts then share_class_symbol will be empty
company_symbol, share_class_symbol = split_delimited_symbol(symbol)
try:
owners = self.symbol_ownership_map[
company_symbol,
share_class_symbol,
ad_value
)
# If exactly one SID exists for symbol, return that symbol
# If multiple SIDs exist for symbol, return latest start_date with
# end_date as a tie-breaker
if candidates:
return self._get_best_candidate(candidates)
# If no SID exists for symbol, return SID with the
# highest-but-not-over end_date
elif not candidates:
candidates = self._resolve_no_matching_candidates(
company_symbol,
share_class_symbol,
ad_value
)
if candidates:
return self._get_best_candidate(candidates)
]
assert owners, 'empty owners list for %r' % symbol
except KeyError:
# no equity has ever held this symbol
raise SymbolNotFound(symbol=symbol)
else:
# If this is a fuzzy look-up, check if there is exactly one match
# for the fuzzy symbol
if fuzzy:
candidates = self._get_fuzzy_candidates(fuzzy_symbol)
if len(candidates) == 1:
return self._get_best_candidate(candidates)
candidates = self._get_split_candidates(company_symbol,
share_class_symbol)
if len(candidates) == 1:
return self._get_best_candidate(candidates)
elif not candidates:
raise SymbolNotFound(symbol=symbol)
else:
if not as_of_date:
if len(owners) > 1:
# more than one equity has held this ticker, this is ambigious
# without the date
raise MultipleSymbolsFound(
symbol=symbol,
options=self._get_equities_from_candidates(candidates)
options=set(map(
compose(self.retrieve_asset, attrgetter('sid')),
owners,
)),
)
# exactly one equity has ever held this symbol, we may resolve
# without the date
return self.retrieve_asset(owners[0].sid)
for start, end, sid, _ in owners:
if start <= as_of_date < end:
# find the equity that owned it on the given asof date
return self.retrieve_asset(sid)
# no equity held the ticker on the given asof date
raise SymbolNotFound(symbol=symbol)
def _lookup_symbol_fuzzy(self, symbol, as_of_date):
symbol = symbol.upper()
company_symbol, share_class_symbol = split_delimited_symbol(symbol)
try:
owners = self.fuzzy_symbol_ownership_map[
company_symbol + share_class_symbol
]
assert owners, 'empty owners list for %r' % symbol
except KeyError:
# no equity has ever held a symbol matching the fuzzy symbol
raise SymbolNotFound(symbol=symbol)
if not as_of_date:
if len(owners) == 1:
# only one valid match
return self.retrieve_asset(owners[0].sid)
options = []
for _, _, sid, sym in owners:
if sym == symbol:
# there are multiple options, look for exact matches
options.append(self.retrieve_asset(sid))
if len(options) == 1:
# there was only one exact match
return options[0]
# there are more than one exact match for this fuzzy symbol
raise MultipleSymbolsFound(
symbol=symbol,
options=set(options),
)
options = []
for start, end, sid, sym in owners:
if start <= as_of_date < end:
# see which fuzzy symbols were owned on the asof date.
options.append((sid, sym))
if not options:
# no equity owned the fuzzy symbol on the date requested
SymbolNotFound(symbol=symbol)
if len(options) == 1:
# there was only one owner, return it
return self.retrieve_asset(options[0][0])
for sid, sym in options:
if sym == symbol:
# look for an exact match on the asof date
return self.retrieve_asset(sid)
# multiple equities held tickers matching the fuzzy ticker but
# there are no exact matches
raise MultipleSymbolsFound(
symbol=symbol,
options=set(map(
compose(self.retrieve_asset, itemgetter(0)),
options,
)),
)
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
"""Lookup an equity by symbol.
Parameters
----------
symbol : str
The ticker symbol to resolve.
as_of_date : datetime or None
Look up the last owner of this symbol as of this datetime.
If ``as_of_date`` is None, then this can only resolve the equity
if exactly one equity has ever owned the ticker.
fuzzy : bool, optional
Should fuzzy symbol matching be used? Fuzzy symbol matching
attempts to resolve differences in representations for
shareclasses. For example, some people may represent the ``A``
shareclass of ``BRK`` as ``BRK.A``, where others could write
``BRK_A``.
Returns
-------
equity : Equity
The equity that held ``symbol`` on the given ``as_of_date``, or the
only equity to hold ``symbol`` if ``as_of_date`` is None.
Raises
------
SymbolNotFound
Raised when no equity has ever held the given symbol.
MultipleSymbolsFound
Raised when no ``as_of_date`` is given and more than one equity
has held ``symbol``. This is also raised when ``fuzzy=True`` and
there are multiple candidates for the given ``symbol`` on the
``as_of_date``.
"""
if fuzzy:
return self._lookup_symbol_fuzzy(symbol, as_of_date)
return self._lookup_symbol_strict(symbol, as_of_date)
def lookup_future_symbol(self, symbol):
""" Return the Future object for a given symbol.
"""Lookup a future contract by symbol.
Parameters
----------
@@ -537,8 +675,8 @@ class AssetFinder(object):
Returns
-------
Future
A Future object.
future : Future
The future contract referenced by ``symbol``.
Raises
------
@@ -946,100 +1084,6 @@ class NotAssetConvertible(ValueError):
pass
class AssetFinderCachedEquities(AssetFinder):
"""
An extension to AssetFinder that preloads all equities from equities table
into memory and does lookups from there.
To have any changes in the underlying assets db reflected by this asset
finder one must manually call the ``rehash_equities`` method.
"""
def __init__(self, engine):
super(AssetFinderCachedEquities, self).__init__(engine)
self._fuzzy_symbol_cache = {}
self._company_share_class_cache = {}
self.rehash_equities()
def rehash_equities(self):
"""Reload the underlying assets db into the in memory cache.
"""
for equity in sa.select(self.equities.c).execute().fetchall():
company_symbol = equity['company_symbol']
share_class_symbol = equity['share_class_symbol']
fuzzy_symbol = equity['fuzzy_symbol']
asset = self._convert_row_to_equity(equity)
self._company_share_class_cache.setdefault(
(company_symbol, share_class_symbol),
[]
).append(asset)
self._fuzzy_symbol_cache.setdefault(
fuzzy_symbol,
[],
).append(asset)
def _convert_row_to_equity(self, row):
"""
Converts a SQLAlchemy equity row to an Equity object.
"""
return Equity(**_convert_asset_timestamp_fields(dict(row)))
def _get_fuzzy_candidates(self, fuzzy_symbol):
return self._fuzzy_symbol_cache.get(fuzzy_symbol, ())
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
return only_active_assets(
ad_value,
self._get_fuzzy_candidates(fuzzy_symbol),
)
def _get_split_candidates(self, company_symbol, share_class_symbol):
return self._company_share_class_cache.get(
(company_symbol, share_class_symbol),
(),
)
def _get_split_candidates_in_range(self,
company_symbol,
share_class_symbol,
ad_value):
return sorted(
only_active_assets(
ad_value,
self._get_split_candidates(company_symbol, share_class_symbol),
),
key=lambda x: (x.start_date, x.end_date),
reverse=True,
)
def _resolve_no_matching_candidates(self,
company_symbol,
share_class_symbol,
ad_value):
equities = self._get_split_candidates(
company_symbol,
share_class_symbol
)
partial_candidates = []
for equity in equities:
if equity.start_date.value <= ad_value:
partial_candidates.append(equity)
if partial_candidates:
partial_candidates = sorted(
partial_candidates,
key=lambda x: x.end_date,
reverse=True
)
return partial_candidates
def _get_best_candidate(self, candidates):
return candidates[0]
def _get_equities_from_candidates(self, candidates):
return candidates
def was_active(reference_date_value, asset):
"""
Whether or not `asset` was active at the time corresponding to
+22
View File
@@ -6,6 +6,7 @@ from pandas_datareader.data import DataReader
import requests
from zipline.utils.cli import maybe_show_progress
from .core import register
def _cachpath(symbol, type_):
@@ -169,3 +170,24 @@ def yahoo_equities(symbols, start=None, end=None):
adjustment_writer.write(splits=splits, dividends=dividends)
return ingest
# bundle used when creating test data
register(
'.test',
yahoo_equities(
(
'AMD',
'CERN',
'COST',
'DELL',
'GPS',
'INTC',
'MMM',
'AAPL',
'MSFT',
),
pd.Timestamp('2004-01-02', tz='utc'),
pd.Timestamp('2015-01-01', tz='utc'),
),
)
+1 -1
View File
@@ -281,7 +281,7 @@ Multiple symbols with the name '{symbol}' found. Use the
as_of_date' argument to to specify when the date symbol-lookup
should be valid.
Possible options:{options}
Possible options: {options}
""".strip()
+48 -13
View File
@@ -36,7 +36,11 @@ from nose.tools import ( # noqa
)
import numpy as np
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.util.testing import (
assert_frame_equal,
assert_panel_equal,
assert_series_equal,
)
from six import iteritems, viewkeys, PY2
from toolz import dissoc, keyfilter
import toolz.curried.operator as op
@@ -393,18 +397,49 @@ def assert_array_equal(result,
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
@assert_equal.register(pd.DataFrame, pd.DataFrame)
def assert_dataframe_equal(result, expected, path=(), msg='', **kwargs):
try:
assert_frame_equal(
result,
expected,
**filter_kwargs(assert_frame_equal, kwargs)
)
except AssertionError as e:
raise AssertionError(
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
)
def _register_assert_ndframe_equal(type_, assert_eq):
"""Register a new check for an ndframe object.
Parameters
----------
type_ : type
The class to register an ``assert_equal`` dispatch for.
assert_eq : callable[type_, type_]
The function which checks that if the two ndframes are equal.
Returns
-------
assert_ndframe_equal : callable[type_, type_]
The wrapped function registered with ``assert_equal``.
"""
@assert_equal.register(type_, type_)
def assert_ndframe_equal(result, expected, path=(), msg='', **kwargs):
try:
assert_eq(
result,
expected,
**filter_kwargs(assert_frame_equal, kwargs)
)
except AssertionError as e:
raise AssertionError(
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
)
return assert_ndframe_equal
assert_frame_equal = _register_assert_ndframe_equal(
pd.DataFrame,
assert_frame_equal,
)
assert_panel_equal = _register_assert_ndframe_equal(
pd.Panel,
assert_panel_equal,
)
assert_series_equal = _register_assert_ndframe_equal(
pd.Series,
assert_series_equal,
)
@assert_equal.register(Adjustment, Adjustment)
+54 -10
View File
@@ -1,8 +1,9 @@
from functools import reduce
from pprint import pformat
from six import viewkeys
from six.moves import map, zip
from toolz import curry
from toolz import curry, flip
@curry
@@ -332,17 +333,60 @@ with_name = set_attribute('__name__')
with_doc = set_attribute('__doc__')
def let(a):
"""Box a value to be bound in a for binding.
def foldr(f, seq, default=_no_default):
"""Fold a function over a sequence with right associativity.
Parameters
----------
f : callable[any, any]
The function to reduce the sequence with.
The first argument will be the element of the sequence; the second
argument will be the accumulator.
seq : iterable[any]
The sequence to reduce.
default : any, optional
The starting value to reduce with. If not provided, the sequence
cannot be empty, and the last value of the sequence will be used.
Returns
-------
folded : any
The folded value.
Notes
-----
This functions works by reducing the list in a right associative way.
For example, imagine we are folding with ``operator.add`` or ``+``:
Examples
--------
.. code-block:: python
[f(y, y) for x in xs for y in let(g(x)) if p(y)]
foldr(add, seq) -> seq[0] + (seq[1] + (seq[2] + (...seq[-1], default)))
Here, ``y`` is available in both the predicate and the expression
of the comprehension. We can see that this allows us to cache the work
of computing ``g(x)`` even within the expression.
In the more general case with an arbitrary function, ``foldr`` will expand
like so:
.. code-block:: python
foldr(f, seq) -> f(seq[0], f(seq[1], f(seq[2], ...f(seq[-1], default))))
For a more in depth discussion of left and right folds, see:
`https://en.wikipedia.org/wiki/Fold_(higher-order_function)`_
The images in that page are very good for showing the differences between
``foldr`` and ``foldl`` (``reduce``).
.. note::
For performance reasons is is best to pass a strict (non-lazy) sequence,
for example, a list.
See Also
--------
:func:`functools.reduce`
:func:`sum`
"""
return a,
return reduce(
flip(f),
reversed(seq),
*(default,) if default is not _no_default else ()
)
+1 -1
View File
@@ -19,7 +19,7 @@ def hidden(path):
path : str
A filepath.
"""
return path.startswith('.')
return os.path.split(path)[1].startswith('.')
def ensure_directory(path):
+284
View File
@@ -0,0 +1,284 @@
import operator as op
from six import PY2
from toolz import peek
from zipline.utils.functional import foldr
if PY2:
class range(object):
"""Lazy range object with constant time containment check.
The arguments are the same as ``range``.
"""
__slots__ = 'start', 'stop', 'step'
def __init__(self, stop, *args):
if len(args) > 2:
raise TypeError(
'range takes at most 3 arguments (%d given)' % len(args)
)
if not args:
self.start = 0
self.stop = stop
self.step = 1
else:
self.start = stop
self.stop = args[0]
try:
self.step = args[1]
except IndexError:
self.step = 1
def __iter__(self):
n = self.start
stop = self.stop
step = self.step
while n < stop:
yield n
n += step
_ops = (
(op.gt, op.ge),
(op.le, op.lt),
)
def __contains__(self, other, _ops=_ops):
start = self.start
step = self.step
cmp_start, cmp_stop = _ops[step > 0]
return (
cmp_start(start, other) and
cmp_stop(other, self.stop) and
(other - start) % step == 0
)
del _ops
def __repr__(self):
return '%s(%s, %s%s)' % (
type(self).__name__,
self.start,
self.stop,
(', ' + str(self.step)) if self.step != 1 else '',
)
else:
range = range
def from_tuple(tup):
"""Convert a tuple into a range with error handling.
Parameters
----------
tup : tuple (len 2 or 3)
The tuple to turn into a range.
Returns
-------
range : range
The range from the tuple.
Raises
------
ValueError
Raised when the tuple length is not 2 or 3.
"""
if len(tup) not in (2, 3):
raise ValueError(
'tuple must contain 2 or 3 elements, not: %d (%r' % (
len(tup),
tup,
),
)
return range(*tup)
def maybe_from_tuple(tup_or_range):
"""Convert a tuple into a range but pass ranges through silently.
This is useful to ensure that input is a range so that attributes may
be accessed with `.start`, `.stop` or so that containment checks are
constant time.
Parameters
----------
tup_or_range : tuple or range
A tuple to pass to from_tuple or a range to return.
Returns
-------
range : range
The input to convert to a range.
Raises
------
ValueError
Raised when the input is not a tuple or a range. ValueError is also
raised if the input is a tuple whose length is not 2 or 3.
"""
if isinstance(tup_or_range, tuple):
return from_tuple(tup_or_range)
elif isinstance(tup_or_range, range):
return tup_or_range
raise ValueError(
'maybe_from_tuple expects a tuple or range, got %r: %r' % (
type(tup_or_range).__name__,
tup_or_range,
),
)
def _check_steps(a, b):
"""Check that the steps of ``a`` and ``b`` are both 1.
Parameters
----------
a : range
The first range to check.
b : range
The second range to check.
Raises
------
ValueError
Raised when either step is not 1.
"""
if a.step != 1:
raise ValueError('a.step must be equal to 1, got: %s' % a.step)
if b.step != 1:
raise ValueError('b.step must be equal to 1, got: %s' % b.step)
def overlap(a, b):
"""Check if two ranges overlap.
Parameters
----------
a : range
The first range.
b : range
The second range.
Returns
-------
overlaps : bool
Do these ranges overlap.
Notes
-----
This function does not support ranges with step != 1.
"""
_check_steps(a, b)
return a.stop >= b.start and b.stop >= a.start
def merge(a, b):
"""Merge two ranges with step == 1.
Parameters
----------
a : range
The first range.
b : range
The second range.
"""
_check_steps(a, b)
return range(min(a.start, b.start), max(a.stop, b.stop))
def _combine(n, rs):
"""helper for ``_group_ranges``
"""
try:
r, rs = peek(rs)
except StopIteration:
yield n
return
if overlap(n, r):
yield merge(n, r)
next(rs)
for r in rs:
yield r
else:
yield n
for r in rs:
yield r
def group_ranges(ranges):
"""Group any overlapping ranges into a single range.
Parameters
----------
ranges : iterable[ranges]
A sorted sequence of ranges to group.
Returns
-------
grouped : iterable[ranges]
A sorted sequence of ranges with overlapping ranges merged together.
"""
return foldr(_combine, ranges, ())
def sorted_diff(rs, ss):
try:
r, rs = peek(rs)
except StopIteration:
return
try:
s, ss = peek(ss)
except StopIteration:
for r in rs:
yield r
return
rtup = (r.start, r.stop)
stup = (s.start, s.stop)
if rtup == stup:
next(rs)
next(ss)
elif rtup < stup:
yield next(rs)
else:
next(ss)
for t in sorted_diff(rs, ss):
yield t
def intersecting_ranges(ranges):
"""Return any ranges that intersect.
Parameters
----------
ranges : iterable[ranges]
A sequence of ranges to check for intersections.
Returns
-------
intersections : iterable[ranges]
A sequence of all of the ranges that intersected in ``ranges``.
Examples
--------
>>> ranges = [range(0, 1), range(2, 5), range(4, 7)]
>>> list(intersecting_ranges(ranges))
[range(2, 5), range(4, 7)]
>>> ranges = [range(0, 1), range(2, 3)]
>>> list(intersecting_ranges(ranges))
[]
>>> ranges = [range(0, 1), range(1, 2)]
>>> list(intersecting_ranges(ranges))
[range(0, 1), range(1, 2)]
"""
ranges = sorted(ranges, key=op.attrgetter('start'))
return sorted_diff(ranges, group_ranges(ranges))