mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 05:43:26 +08:00
Merge pull request #1302 from quantopian/point-in-time-asset-db
Point in time asset db
This commit is contained in:
@@ -74,6 +74,7 @@ EQUITY_INFO = DataFrame(
|
||||
index=arange(1, 7),
|
||||
columns=['start_date', 'end_date'],
|
||||
).astype(datetime64)
|
||||
EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))]
|
||||
|
||||
TEST_QUERY_ASSETS = EQUITY_INFO.index
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ EQUITY_INFO = DataFrame(
|
||||
index=arange(1, 7),
|
||||
columns=['start_date', 'end_date'],
|
||||
).astype(datetime64)
|
||||
EQUITY_INFO['symbol'] = [chr(ord('A') + n) for n in range(len(EQUITY_INFO))]
|
||||
|
||||
TEST_QUERY_ASSETS = EQUITY_INFO.index
|
||||
|
||||
|
||||
Binary file not shown.
@@ -777,7 +777,10 @@ class TestTransformAlgorithm(WithLogger,
|
||||
|
||||
@classmethod
|
||||
def make_futures_info(cls):
|
||||
return pd.DataFrame.from_dict({3: {'multiplier': 10}}, 'index')
|
||||
return pd.DataFrame.from_dict(
|
||||
{3: {'multiplier': 10, 'symbol': 'F'}},
|
||||
orient='index',
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_equity_daily_bar_data(cls):
|
||||
@@ -985,6 +988,7 @@ def before_trading_start(context, data):
|
||||
'start_date': start_session,
|
||||
'end_date': period_end + timedelta(days=1)
|
||||
}] * 2)
|
||||
equities['symbol'] = ['A', 'B']
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
sim_params = SimulationParameters(
|
||||
@@ -2813,6 +2817,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
metadata = pd.DataFrame.from_dict(
|
||||
{
|
||||
1: {
|
||||
'symbol': 'SYM',
|
||||
'start_date': start,
|
||||
'end_date': start + timedelta(days=6)
|
||||
},
|
||||
@@ -2940,6 +2945,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
|
||||
def test_asset_date_bounds(self):
|
||||
metadata = pd.DataFrame([{
|
||||
'symbol': 'SYM',
|
||||
'start_date': self.sim_params.start_session,
|
||||
'end_date': '2020-01-01',
|
||||
}])
|
||||
@@ -2959,6 +2965,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
algo.run(data_portal)
|
||||
|
||||
metadata = pd.DataFrame([{
|
||||
'symbol': 'SYM',
|
||||
'start_date': '1989-01-01',
|
||||
'end_date': '1990-01-01',
|
||||
}])
|
||||
@@ -2979,6 +2986,7 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
algo.run(data_portal)
|
||||
|
||||
metadata = pd.DataFrame([{
|
||||
'symbol': 'SYM',
|
||||
'start_date': '2020-01-01',
|
||||
'end_date': '2021-01-01',
|
||||
}])
|
||||
|
||||
+122
-59
@@ -18,6 +18,7 @@ Tests for the zipline.assets package
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import pickle
|
||||
import sys
|
||||
from types import GetSetDescriptorType
|
||||
@@ -25,7 +26,6 @@ from unittest import TestCase
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from nose.tools import raises
|
||||
from nose_parameterized import parameterized
|
||||
from numpy import full, int32, int64
|
||||
import pandas as pd
|
||||
@@ -39,7 +39,6 @@ from zipline.assets import (
|
||||
Future,
|
||||
AssetDBWriter,
|
||||
AssetFinder,
|
||||
AssetFinderCachedEquities,
|
||||
)
|
||||
from zipline.assets.synthetic import (
|
||||
make_commodity_future_info,
|
||||
@@ -341,7 +340,6 @@ class TestFuture(WithAssetFinder, ZiplineTestCase):
|
||||
self.assertIn("tick_size=0.01", reprd)
|
||||
self.assertIn("multiplier=500", reprd)
|
||||
|
||||
@raises(AssertionError)
|
||||
def test_reduce(self):
|
||||
assert_equal(
|
||||
pickle.loads(pickle.dumps(self.future)).to_dict(),
|
||||
@@ -485,6 +483,97 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
|
||||
self.assertEqual(2, finder.lookup_symbol('BRK_A', None, fuzzy=True))
|
||||
self.assertEqual(2, finder.lookup_symbol('BRK_A', dt, fuzzy=True))
|
||||
|
||||
def test_lookup_symbol_change_ticker(self):
|
||||
T = partial(pd.Timestamp, tz='utc')
|
||||
metadata = pd.DataFrame.from_records(
|
||||
[
|
||||
# sid 0
|
||||
{
|
||||
'symbol': 'A',
|
||||
'start_date': T('2014-01-01'),
|
||||
'end_date': T('2014-01-05'),
|
||||
},
|
||||
{
|
||||
'symbol': 'B',
|
||||
'start_date': T('2014-01-06'),
|
||||
'end_date': T('2014-01-10'),
|
||||
},
|
||||
|
||||
# sid 1
|
||||
{
|
||||
'symbol': 'C',
|
||||
'start_date': T('2014-01-01'),
|
||||
'end_date': T('2014-01-05'),
|
||||
},
|
||||
{
|
||||
'symbol': 'A', # claiming the unused symbol 'A'
|
||||
'start_date': T('2014-01-06'),
|
||||
'end_date': T('2014-01-10'),
|
||||
},
|
||||
],
|
||||
index=[0, 0, 1, 1],
|
||||
)
|
||||
self.write_assets(equities=metadata)
|
||||
finder = self.asset_finder
|
||||
|
||||
# note: these assertions walk forward in time, starting at assertions
|
||||
# about ownership before the start_date and ending with assertions
|
||||
# after the end_date; new assertions should be inserted in the correct
|
||||
# locations
|
||||
|
||||
# no one held 'A' before 01
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('A', T('2013-12-31'))
|
||||
|
||||
# no one held 'C' before 01
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('C', T('2013-12-31'))
|
||||
|
||||
for asof in pd.date_range('2014-01-01', '2014-01-05', tz='utc'):
|
||||
# from 01 through 05 sid 0 held 'A'
|
||||
assert_equal(
|
||||
finder.lookup_symbol('A', asof),
|
||||
finder.retrieve_asset(0),
|
||||
msg=str(asof),
|
||||
)
|
||||
|
||||
# from 01 through 05 sid 1 held 'C'
|
||||
assert_equal(
|
||||
finder.lookup_symbol('C', asof),
|
||||
finder.retrieve_asset(1),
|
||||
msg=str(asof),
|
||||
)
|
||||
|
||||
# no one held 'B' before 06
|
||||
with self.assertRaises(SymbolNotFound):
|
||||
finder.lookup_symbol('B', T('2014-01-05'))
|
||||
|
||||
# no one held 'C' after 06, however, no one has claimed it yet
|
||||
# so it still maps to sid 1
|
||||
assert_equal(
|
||||
finder.lookup_symbol('C', T('2014-01-07')),
|
||||
finder.retrieve_asset(1),
|
||||
)
|
||||
|
||||
for asof in pd.date_range('2014-01-06', '2014-01-11', tz='utc'):
|
||||
# from 06 through 10 sid 0 held 'B'
|
||||
# we test through the 11th because sid 1 is the last to hold 'B'
|
||||
# so it should ffill
|
||||
assert_equal(
|
||||
finder.lookup_symbol('B', asof),
|
||||
finder.retrieve_asset(0),
|
||||
msg=str(asof),
|
||||
)
|
||||
|
||||
# from 06 through 10 sid 1 held 'A'
|
||||
# we test through the 11th because sid 1 is the last to hold 'A'
|
||||
# so it should ffill
|
||||
assert_equal(
|
||||
finder.lookup_symbol('A', asof),
|
||||
finder.retrieve_asset(1),
|
||||
msg=str(asof),
|
||||
)
|
||||
|
||||
def test_lookup_symbol(self):
|
||||
|
||||
# Incrementing by two so that start and end dates for each
|
||||
@@ -519,27 +608,7 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
|
||||
self.assertEqual(result.symbol, 'EXISTING')
|
||||
self.assertEqual(result.sid, i)
|
||||
|
||||
def test_lookup_symbol_from_multiple_valid(self):
|
||||
# This test asserts that we resolve conflicts in accordance with the
|
||||
# following rules when we have multiple assets holding the same symbol
|
||||
# at the same time:
|
||||
|
||||
# If multiple SIDs exist for symbol S at time T, return the candidate
|
||||
# SID whose start_date is highest. (200 cases)
|
||||
|
||||
# If multiple SIDs exist for symbol S at time T, the best candidate
|
||||
# SIDs share the highest start_date, return the SID with the highest
|
||||
# end_date. (34 cases)
|
||||
|
||||
# It is the opinion of the author (ssanderson) that we should consider
|
||||
# this malformed input and fail here. But this is the current indended
|
||||
# behavior of the code, and I accidentally broke it while refactoring.
|
||||
# These will serve as regression tests until the time comes that we
|
||||
# decide to enforce this as an error.
|
||||
|
||||
# See https://github.com/quantopian/zipline/issues/837 for more
|
||||
# details.
|
||||
|
||||
def test_fail_to_write_overlapping_data(self):
|
||||
df = pd.DataFrame.from_records(
|
||||
[
|
||||
{
|
||||
@@ -568,22 +637,16 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
|
||||
]
|
||||
)
|
||||
|
||||
self.write_assets(equities=df)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
self.write_assets(equities=df)
|
||||
|
||||
def check(expected_sid, date):
|
||||
result = self.asset_finder.lookup_symbol(
|
||||
'MULTIPLE', date,
|
||||
)
|
||||
self.assertEqual(result.symbol, 'MULTIPLE')
|
||||
self.assertEqual(result.sid, expected_sid)
|
||||
|
||||
# Sids 1 and 2 are eligible here. We should get asset 2 because it
|
||||
# has the later end_date.
|
||||
check(2, pd.Timestamp('2010-12-31'))
|
||||
|
||||
# Sids 1, 2, and 3 are eligible here. We should get sid 3 because
|
||||
# it has a later start_date
|
||||
check(3, pd.Timestamp('2011-01-01'))
|
||||
self.assertEqual(
|
||||
str(e.exception),
|
||||
"Ambiguous ownership of 'MULTIPLE', multiple companies held this"
|
||||
" ticker over the following ranges:\n"
|
||||
"[('2010-01-01 00:00:00', '2012-01-01 00:00:00'),"
|
||||
" ('2011-01-01 00:00:00', '2012-01-01 00:00:00')]",
|
||||
)
|
||||
|
||||
def test_lookup_generic(self):
|
||||
"""
|
||||
@@ -1000,14 +1063,6 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
|
||||
)
|
||||
|
||||
|
||||
class AssetFinderCachedEquitiesTestCase(AssetFinderTestCase):
|
||||
asset_finder_type = AssetFinderCachedEquities
|
||||
|
||||
def write_assets(self, **kwargs):
|
||||
super(AssetFinderCachedEquitiesTestCase, self).write_assets(**kwargs)
|
||||
self.asset_finder.rehash_equities()
|
||||
|
||||
|
||||
class TestFutureChain(WithAssetFinder, ZiplineTestCase):
|
||||
@classmethod
|
||||
def make_futures_info(cls):
|
||||
@@ -1259,15 +1314,23 @@ class TestAssetDBVersioning(ZiplineTestCase):
|
||||
version_table = self.metadata.tables['version_info']
|
||||
|
||||
# This should not raise an error
|
||||
check_version_info(version_table, ASSET_DB_VERSION)
|
||||
check_version_info(self.engine, version_table, ASSET_DB_VERSION)
|
||||
|
||||
# This should fail because the version is too low
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
check_version_info(version_table, ASSET_DB_VERSION - 1)
|
||||
check_version_info(
|
||||
self.engine,
|
||||
version_table,
|
||||
ASSET_DB_VERSION - 1,
|
||||
)
|
||||
|
||||
# This should fail because the version is too high
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
check_version_info(version_table, ASSET_DB_VERSION + 1)
|
||||
check_version_info(
|
||||
self.engine,
|
||||
version_table,
|
||||
ASSET_DB_VERSION + 1,
|
||||
)
|
||||
|
||||
def test_write_version(self):
|
||||
version_table = self.metadata.tables['version_info']
|
||||
@@ -1279,24 +1342,24 @@ class TestAssetDBVersioning(ZiplineTestCase):
|
||||
# This should fail because the table has no version info and is,
|
||||
# therefore, consdered v0
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
check_version_info(version_table, -2)
|
||||
check_version_info(self.engine, version_table, -2)
|
||||
|
||||
# This should not raise an error because the version has been written
|
||||
write_version_info(version_table, -2)
|
||||
check_version_info(version_table, -2)
|
||||
write_version_info(self.engine, version_table, -2)
|
||||
check_version_info(self.engine, version_table, -2)
|
||||
|
||||
# Assert that the version is in the table and correct
|
||||
self.assertEqual(sa.select((version_table.c.version,)).scalar(), -2)
|
||||
|
||||
# Assert that trying to overwrite the version fails
|
||||
with self.assertRaises(sa.exc.IntegrityError):
|
||||
write_version_info(version_table, -3)
|
||||
write_version_info(self.engine, version_table, -3)
|
||||
|
||||
def test_finder_checks_version(self):
|
||||
version_table = self.metadata.tables['version_info']
|
||||
version_table.delete().execute()
|
||||
write_version_info(version_table, -2)
|
||||
check_version_info(version_table, -2)
|
||||
write_version_info(self.engine, version_table, -2)
|
||||
check_version_info(self.engine, version_table, -2)
|
||||
|
||||
# Assert that trying to build a finder with a bad db raises an error
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
@@ -1304,8 +1367,8 @@ class TestAssetDBVersioning(ZiplineTestCase):
|
||||
|
||||
# Change the version number of the db to the correct version
|
||||
version_table.delete().execute()
|
||||
write_version_info(version_table, ASSET_DB_VERSION)
|
||||
check_version_info(version_table, ASSET_DB_VERSION)
|
||||
write_version_info(self.engine, version_table, ASSET_DB_VERSION)
|
||||
check_version_info(self.engine, version_table, ASSET_DB_VERSION)
|
||||
|
||||
# Now that the versions match, this Finder should succeed
|
||||
AssetFinder(engine=self.engine)
|
||||
@@ -1319,7 +1382,7 @@ class TestAssetDBVersioning(ZiplineTestCase):
|
||||
metadata = sa.MetaData(conn)
|
||||
metadata.reflect(bind=self.engine)
|
||||
version_table = metadata.tables['version_info']
|
||||
check_version_info(version_table, 0)
|
||||
check_version_info(self.engine, version_table, 0)
|
||||
|
||||
# Check some of the v1-to-v0 downgrades
|
||||
self.assertTrue('futures_contracts' in metadata.tables)
|
||||
|
||||
+12
-8
@@ -45,20 +45,24 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendar,
|
||||
return pd.DataFrame.from_dict(
|
||||
{
|
||||
1: {
|
||||
"start_date": cls.START_DATE,
|
||||
"end_date": cls.END_DATE + pd.Timedelta(days=1)
|
||||
'symbol': 'A',
|
||||
'start_date': cls.START_DATE,
|
||||
'end_date': cls.END_DATE + pd.Timedelta(days=1)
|
||||
},
|
||||
2: {
|
||||
"start_date": cls.START_DATE,
|
||||
"end_date": cls.END_DATE + pd.Timedelta(days=1)
|
||||
'symbol': 'B',
|
||||
'start_date': cls.START_DATE,
|
||||
'end_date': cls.END_DATE + pd.Timedelta(days=1)
|
||||
},
|
||||
3: {
|
||||
"start_date": pd.Timestamp('2006-05-26', tz='utc'),
|
||||
"end_date": pd.Timestamp('2006-08-09', tz='utc')
|
||||
'symbol': 'C',
|
||||
'start_date': pd.Timestamp('2006-05-26', tz='utc'),
|
||||
'end_date': pd.Timestamp('2006-08-09', tz='utc')
|
||||
},
|
||||
4: {
|
||||
"start_date": cls.START_DATE,
|
||||
"end_date": cls.END_DATE + pd.Timedelta(days=1)
|
||||
'symbol': 'D',
|
||||
'start_date': cls.START_DATE,
|
||||
'end_date': cls.END_DATE + pd.Timedelta(days=1)
|
||||
},
|
||||
},
|
||||
orient='index',
|
||||
|
||||
@@ -348,6 +348,9 @@ def bundles():
|
||||
"""List all of the available data bundles.
|
||||
"""
|
||||
for bundle in sorted(bundles_module.bundles.keys()):
|
||||
if bundle.startswith('.'):
|
||||
# hide the test data
|
||||
continue
|
||||
try:
|
||||
ingestions = sorted(
|
||||
(str(bundles_module.from_bundle_ingest_dirname(ing))
|
||||
|
||||
@@ -23,7 +23,6 @@ from ._assets import (
|
||||
from .assets import (
|
||||
AssetFinder,
|
||||
AssetConvertible,
|
||||
AssetFinderCachedEquities
|
||||
)
|
||||
from .asset_db_schema import ASSET_DB_VERSION
|
||||
from .asset_writer import AssetDBWriter
|
||||
@@ -35,7 +34,6 @@ __all__ = [
|
||||
'Equity',
|
||||
'Future',
|
||||
'AssetFinder',
|
||||
'AssetFinderCachedEquities',
|
||||
'AssetConvertible',
|
||||
'make_asset_array',
|
||||
'CACHE_FILE_TEMPLATE'
|
||||
|
||||
+74
-51
@@ -58,26 +58,35 @@ cdef class Asset:
|
||||
|
||||
cdef readonly object exchange
|
||||
|
||||
def __cinit__(self,
|
||||
int sid, # sid is required
|
||||
object symbol="",
|
||||
object asset_name="",
|
||||
object start_date=None,
|
||||
object end_date=None,
|
||||
object first_traded=None,
|
||||
object auto_close_date=None,
|
||||
object exchange="",
|
||||
*args,
|
||||
**kwargs):
|
||||
_kwargnames = frozenset({
|
||||
'sid',
|
||||
'symbol',
|
||||
'asset_name',
|
||||
'start_date',
|
||||
'end_date',
|
||||
'first_traded',
|
||||
'auto_close_date',
|
||||
'exchange',
|
||||
})
|
||||
|
||||
self.sid = sid
|
||||
self.sid_hash = hash(sid)
|
||||
self.symbol = symbol
|
||||
self.asset_name = asset_name
|
||||
self.exchange = exchange
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.first_traded = first_traded
|
||||
def __init__(self,
|
||||
int sid, # sid is required
|
||||
object symbol="",
|
||||
object asset_name="",
|
||||
object start_date=None,
|
||||
object end_date=None,
|
||||
object first_traded=None,
|
||||
object auto_close_date=None,
|
||||
object exchange=""):
|
||||
|
||||
self.sid = sid
|
||||
self.sid_hash = hash(sid)
|
||||
self.symbol = symbol
|
||||
self.asset_name = asset_name
|
||||
self.exchange = exchange
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.first_traded = first_traded
|
||||
self.auto_close_date = auto_close_date
|
||||
|
||||
def __int__(self):
|
||||
@@ -127,9 +136,9 @@ cdef class Asset:
|
||||
|
||||
def __str__(self):
|
||||
if self.symbol:
|
||||
return 'Asset(%d [%s])' % (self.sid, self.symbol)
|
||||
return '%s(%d [%s])' % (type(self).__name__, self.sid, self.symbol)
|
||||
else:
|
||||
return 'Asset(%d)' % self.sid
|
||||
return '%s(%d)' % (type(self).__name__, self.sid)
|
||||
|
||||
def __repr__(self):
|
||||
attrs = ('symbol', 'asset_name', 'exchange',
|
||||
@@ -213,12 +222,6 @@ cdef class Asset:
|
||||
|
||||
cdef class Equity(Asset):
|
||||
|
||||
def __str__(self):
|
||||
if self.symbol:
|
||||
return 'Equity(%d [%s])' % (self.sid, self.symbol)
|
||||
else:
|
||||
return 'Equity(%d)' % self.sid
|
||||
|
||||
def __repr__(self):
|
||||
attrs = ('symbol', 'asset_name', 'exchange',
|
||||
'start_date', 'end_date', 'first_traded', 'auto_close_date')
|
||||
@@ -270,26 +273,52 @@ cdef class Future(Asset):
|
||||
cdef readonly object tick_size
|
||||
cdef readonly float multiplier
|
||||
|
||||
def __cinit__(self,
|
||||
int sid, # sid is required
|
||||
object symbol="",
|
||||
object root_symbol="",
|
||||
object asset_name="",
|
||||
object start_date=None,
|
||||
object end_date=None,
|
||||
object notice_date=None,
|
||||
object expiration_date=None,
|
||||
object auto_close_date=None,
|
||||
object first_traded=None,
|
||||
object exchange="",
|
||||
object tick_size="",
|
||||
float multiplier=1):
|
||||
_kwargnames = frozenset({
|
||||
'sid',
|
||||
'symbol',
|
||||
'root_symbol',
|
||||
'asset_name',
|
||||
'start_date',
|
||||
'end_date',
|
||||
'notice_date',
|
||||
'expiration_date',
|
||||
'auto_close_date',
|
||||
'first_traded',
|
||||
'exchange',
|
||||
'tick_size',
|
||||
'multiplier',
|
||||
})
|
||||
|
||||
self.root_symbol = root_symbol
|
||||
self.notice_date = notice_date
|
||||
def __init__(self,
|
||||
int sid, # sid is required
|
||||
object symbol="",
|
||||
object root_symbol="",
|
||||
object asset_name="",
|
||||
object start_date=None,
|
||||
object end_date=None,
|
||||
object notice_date=None,
|
||||
object expiration_date=None,
|
||||
object auto_close_date=None,
|
||||
object first_traded=None,
|
||||
object exchange="",
|
||||
object tick_size="",
|
||||
float multiplier=1.0):
|
||||
|
||||
super().__init__(
|
||||
sid,
|
||||
symbol=symbol,
|
||||
asset_name=asset_name,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
first_traded=first_traded,
|
||||
auto_close_date=auto_close_date,
|
||||
exchange=exchange,
|
||||
)
|
||||
self.root_symbol = root_symbol
|
||||
self.notice_date = notice_date
|
||||
self.expiration_date = expiration_date
|
||||
self.tick_size = tick_size
|
||||
self.multiplier = multiplier
|
||||
self.tick_size = tick_size
|
||||
self.multiplier = multiplier
|
||||
|
||||
if auto_close_date is None:
|
||||
if notice_date is None:
|
||||
@@ -299,12 +328,6 @@ cdef class Future(Asset):
|
||||
else:
|
||||
self.auto_close_date = min(notice_date, expiration_date)
|
||||
|
||||
def __str__(self):
|
||||
if self.symbol:
|
||||
return 'Future(%d [%s])' % (self.sid, self.symbol)
|
||||
else:
|
||||
return 'Future(%d)' % self.sid
|
||||
|
||||
def __repr__(self):
|
||||
attrs = ('symbol', 'root_symbol', 'asset_name', 'exchange',
|
||||
'start_date', 'end_date', 'first_traded', 'notice_date',
|
||||
|
||||
@@ -50,7 +50,7 @@ def downgrade(engine, desired_version):
|
||||
|
||||
# Execute the downgrades in order
|
||||
for downgrade_key in downgrade_keys:
|
||||
_downgrade_methods[downgrade_key](op, version_info_table)
|
||||
_downgrade_methods[downgrade_key](op, engine, version_info_table)
|
||||
|
||||
# Re-enable foreign keys
|
||||
_pragma_foreign_keys(conn, True)
|
||||
@@ -96,10 +96,10 @@ def downgrades(src):
|
||||
|
||||
@do(op.setitem(_downgrade_methods, destination))
|
||||
@wraps(f)
|
||||
def wrapper(op, version_info_table):
|
||||
def wrapper(op, engine, version_info_table):
|
||||
version_info_table.delete().execute() # clear the version
|
||||
f(op)
|
||||
write_version_info(version_info_table, destination)
|
||||
write_version_info(engine, version_info_table, destination)
|
||||
|
||||
return wrapper
|
||||
return _
|
||||
@@ -206,3 +206,72 @@ def _downgrade_v3(op):
|
||||
'equities',
|
||||
['fuzzy_symbol'],
|
||||
)
|
||||
|
||||
|
||||
@downgrades(4)
|
||||
def _downgrade_v4(op):
|
||||
op.create_table(
|
||||
'_new_equities',
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text),
|
||||
sa.Column('company_symbol', sa.Text, index=True),
|
||||
sa.Column('share_class_symbol', sa.Text),
|
||||
sa.Column('fuzzy_symbol', sa.Text, index=True),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column('auto_close_date', sa.Integer),
|
||||
sa.Column('exchange', sa.Text),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
insert into _new_equities
|
||||
select
|
||||
equities.sid as sid,
|
||||
sym.symbol as symbol,
|
||||
sym.company_symbol as company_symbol,
|
||||
sym.share_class_symbol as share_class_symbol,
|
||||
sym.company_symbol || sym.share_class_symbol as fuzzy_symbol,
|
||||
equities.asset_name as asset_name,
|
||||
equities.start_date as start_date,
|
||||
equities.end_date as end_date,
|
||||
equities.first_traded as first_traded,
|
||||
equities.auto_close_date as auto_close_date,
|
||||
equities.exchange as exchange
|
||||
from
|
||||
equities
|
||||
inner join
|
||||
(select
|
||||
*
|
||||
from
|
||||
equity_symbol_mappings
|
||||
group by
|
||||
equity_symbol_mappings.sid
|
||||
order by
|
||||
equity_symbol_mappings.end_date desc) sym
|
||||
on
|
||||
equities.sid == sym.sid
|
||||
""",
|
||||
)
|
||||
op.drop_table('equity_symbol_mappings')
|
||||
op.drop_table('equities')
|
||||
op.rename_table('_new_equities', 'equities')
|
||||
# we need to make sure the indicies have the proper names after the rename
|
||||
op.create_index(
|
||||
'ix_equities_company_symbol',
|
||||
'equities',
|
||||
['company_symbol'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_equities_fuzzy_symbol',
|
||||
'equities',
|
||||
['fuzzy_symbol'],
|
||||
)
|
||||
|
||||
+155
-143
@@ -6,26 +6,14 @@ import sqlalchemy as sa
|
||||
# assets database
|
||||
# NOTE: When upgrading this remember to add a downgrade in:
|
||||
# .asset_db_migrations
|
||||
ASSET_DB_VERSION = 3
|
||||
|
||||
|
||||
def generate_asset_db_metadata(bind=None):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
metadata = sa.MetaData(bind=bind)
|
||||
_version_table_schema(metadata)
|
||||
_equities_table_schema(metadata)
|
||||
_futures_exchanges_schema(metadata)
|
||||
_futures_root_symbols_schema(metadata)
|
||||
_futures_contracts_schema(metadata)
|
||||
_asset_router_schema(metadata)
|
||||
return metadata
|
||||
|
||||
ASSET_DB_VERSION = 4
|
||||
|
||||
# A frozenset of the names of all tables in the assets db
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
asset_db_table_names = frozenset({
|
||||
'asset_router',
|
||||
'equities',
|
||||
'equity_symbol_mappings',
|
||||
'futures_contracts',
|
||||
'futures_exchanges',
|
||||
'futures_root_symbols',
|
||||
@@ -33,139 +21,163 @@ asset_db_table_names = frozenset({
|
||||
})
|
||||
|
||||
|
||||
def _equities_table_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'equities',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text),
|
||||
sa.Column('company_symbol', sa.Text, index=True),
|
||||
sa.Column('share_class_symbol', sa.Text),
|
||||
sa.Column('fuzzy_symbol', sa.Text, index=True),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column('auto_close_date', sa.Integer),
|
||||
sa.Column('exchange', sa.Text),
|
||||
)
|
||||
metadata = sa.MetaData()
|
||||
|
||||
equities = sa.Table(
|
||||
'equities',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column('auto_close_date', sa.Integer),
|
||||
sa.Column('exchange', sa.Text),
|
||||
)
|
||||
|
||||
def _futures_exchanges_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'futures_exchanges',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('timezone', sa.Text),
|
||||
)
|
||||
equity_symbol_mappings = sa.Table(
|
||||
'equity_symbol_mappings',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
sa.ForeignKey(equities.c.sid),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
'symbol',
|
||||
sa.Text,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'company_symbol',
|
||||
sa.Text,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
'share_class_symbol',
|
||||
sa.Text,
|
||||
),
|
||||
sa.Column(
|
||||
'start_date',
|
||||
sa.Integer,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'end_date',
|
||||
sa.Integer,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
futures_exchanges = sa.Table(
|
||||
'futures_exchanges',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('timezone', sa.Text),
|
||||
)
|
||||
|
||||
def _futures_root_symbols_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'futures_root_symbols',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'root_symbol',
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('root_symbol_id', sa.Integer),
|
||||
sa.Column('sector', sa.Text),
|
||||
sa.Column('description', sa.Text),
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_exchanges.exchange'),
|
||||
),
|
||||
)
|
||||
futures_root_symbols = sa.Table(
|
||||
'futures_root_symbols',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'root_symbol',
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('root_symbol_id', sa.Integer),
|
||||
sa.Column('sector', sa.Text),
|
||||
sa.Column('description', sa.Text),
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_exchanges.exchange'),
|
||||
),
|
||||
)
|
||||
|
||||
futures_contracts = sa.Table(
|
||||
'futures_contracts',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text, unique=True, index=True),
|
||||
sa.Column(
|
||||
'root_symbol',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_root_symbols.root_symbol'),
|
||||
index=True
|
||||
),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_exchanges.exchange'),
|
||||
),
|
||||
sa.Column('notice_date', sa.Integer, nullable=False),
|
||||
sa.Column('expiration_date', sa.Integer, nullable=False),
|
||||
sa.Column('auto_close_date', sa.Integer, nullable=False),
|
||||
sa.Column('multiplier', sa.Float),
|
||||
sa.Column('tick_size', sa.Float),
|
||||
)
|
||||
|
||||
def _futures_contracts_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'futures_contracts',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text, unique=True, index=True),
|
||||
sa.Column(
|
||||
'root_symbol',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_root_symbols.root_symbol'),
|
||||
index=True
|
||||
),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
sa.ForeignKey('futures_exchanges.exchange'),
|
||||
),
|
||||
sa.Column('notice_date', sa.Integer, nullable=False),
|
||||
sa.Column('expiration_date', sa.Integer, nullable=False),
|
||||
sa.Column('auto_close_date', sa.Integer, nullable=False),
|
||||
sa.Column('multiplier', sa.Float),
|
||||
sa.Column('tick_size', sa.Float),
|
||||
)
|
||||
asset_router = sa.Table(
|
||||
'asset_router',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True),
|
||||
sa.Column('asset_type', sa.Text),
|
||||
)
|
||||
|
||||
|
||||
def _asset_router_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'asset_router',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'sid',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True),
|
||||
sa.Column('asset_type', sa.Text),
|
||||
)
|
||||
|
||||
|
||||
def _version_table_schema(metadata):
|
||||
# NOTE: When modifying this schema, update the ASSET_DB_VERSION value
|
||||
return sa.Table(
|
||||
'version_info',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
'version',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
),
|
||||
# This constraint ensures a single entry in this table
|
||||
sa.CheckConstraint('id <= 1'),
|
||||
)
|
||||
version_info = sa.Table(
|
||||
'version_info',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
'version',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
),
|
||||
# This constraint ensures a single entry in this table
|
||||
sa.CheckConstraint('id <= 1'),
|
||||
)
|
||||
|
||||
+157
-60
@@ -23,16 +23,40 @@ from toolz import first
|
||||
|
||||
from zipline.errors import AssetDBVersionError
|
||||
from zipline.assets.asset_db_schema import (
|
||||
generate_asset_db_metadata,
|
||||
asset_db_table_names,
|
||||
ASSET_DB_VERSION,
|
||||
asset_db_table_names,
|
||||
asset_router,
|
||||
equities as equities_table,
|
||||
equity_symbol_mappings,
|
||||
futures_contracts as futures_contracts_table,
|
||||
futures_exchanges,
|
||||
futures_root_symbols,
|
||||
metadata,
|
||||
version_info,
|
||||
)
|
||||
|
||||
from zipline.utils.range import from_tuple, intersecting_ranges
|
||||
|
||||
# Define a namedtuple for use with the load_data and _load_data methods
|
||||
AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols')
|
||||
AssetData = namedtuple(
|
||||
'AssetData', (
|
||||
'equities',
|
||||
'equities_mappings',
|
||||
'futures',
|
||||
'exchanges',
|
||||
'root_symbols',
|
||||
),
|
||||
)
|
||||
|
||||
SQLITE_MAX_VARIABLE_NUMBER = 999
|
||||
|
||||
symbol_columns = frozenset({
|
||||
'symbol',
|
||||
'company_symbol',
|
||||
'share_class_symbol',
|
||||
})
|
||||
mapping_columns = symbol_columns | {'start_date', 'end_date'}
|
||||
|
||||
# Default values for the equities DataFrame
|
||||
_equities_defaults = {
|
||||
'symbol': None,
|
||||
@@ -74,7 +98,7 @@ _root_symbols_defaults = {
|
||||
}
|
||||
|
||||
# Fuzzy symbol delimiters that may break up a company symbol and share class
|
||||
_delimited_symbol_delimiter_regex = r'[./\-_]'
|
||||
_delimited_symbol_delimiters_regex = re.compile(r'[./\-_]')
|
||||
_delimited_symbol_default_triggers = frozenset({np.nan, None, ''})
|
||||
|
||||
|
||||
@@ -91,16 +115,22 @@ def split_delimited_symbol(symbol):
|
||||
|
||||
Returns
|
||||
-------
|
||||
( str, str , str )
|
||||
A tuple of ( company_symbol, share_class_symbol, fuzzy_symbol)
|
||||
company_symbol : str
|
||||
The company part of the symbol.
|
||||
share_class_symbol : str
|
||||
The share class part of a symbol.
|
||||
"""
|
||||
# return blank strings for any bad fuzzy symbols, like NaN or None
|
||||
if symbol in _delimited_symbol_default_triggers:
|
||||
return ('', '', '')
|
||||
return '', ''
|
||||
|
||||
split_list = re.split(pattern=_delimited_symbol_delimiter_regex,
|
||||
string=symbol,
|
||||
maxsplit=1)
|
||||
symbol = symbol.upper()
|
||||
|
||||
split_list = re.split(
|
||||
pattern=_delimited_symbol_delimiters_regex,
|
||||
string=symbol,
|
||||
maxsplit=1,
|
||||
)
|
||||
|
||||
# Break the list up in to its two components, the company symbol and the
|
||||
# share class symbol
|
||||
@@ -110,12 +140,7 @@ def split_delimited_symbol(symbol):
|
||||
else:
|
||||
share_class_symbol = ''
|
||||
|
||||
# Strip all fuzzy characters from the symbol to get the fuzzy symbol
|
||||
fuzzy_symbol = re.sub(pattern=_delimited_symbol_delimiter_regex,
|
||||
repl='',
|
||||
string=symbol)
|
||||
|
||||
return (company_symbol, share_class_symbol, fuzzy_symbol)
|
||||
return company_symbol, share_class_symbol
|
||||
|
||||
|
||||
def _generate_output_dataframe(data_subset, defaults):
|
||||
@@ -151,19 +176,70 @@ def _generate_output_dataframe(data_subset, defaults):
|
||||
|
||||
# Get those columns which we need but
|
||||
# for which no data has been supplied.
|
||||
need = desired_cols - cols
|
||||
for col in desired_cols - cols:
|
||||
# write the default value for any missing columns
|
||||
data_subset[col] = defaults[col]
|
||||
|
||||
# Combine the users supplied data with our required columns.
|
||||
output = pd.concat(
|
||||
(data_subset, pd.DataFrame(
|
||||
{k: defaults[k] for k in need},
|
||||
data_subset.index,
|
||||
)),
|
||||
axis=1,
|
||||
copy=False
|
||||
return data_subset
|
||||
|
||||
|
||||
def _check_asset_group(group):
|
||||
for colname in set(group.columns) - mapping_columns:
|
||||
col = group[colname]
|
||||
if len(col.unique()) != 1:
|
||||
raise ValueError(
|
||||
'All values must be the same for the %s column' % colname,
|
||||
)
|
||||
|
||||
row = group.iloc[0]
|
||||
row.start_date = group.start_date.min()
|
||||
row.end_date = group.end_date.max()
|
||||
row.drop(list(symbol_columns), inplace=True)
|
||||
return row
|
||||
|
||||
|
||||
def _format_range(r):
|
||||
return (
|
||||
str(pd.Timestamp(r.start, unit='ns')),
|
||||
str(pd.Timestamp(r.stop, unit='ns')),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _split_symbol_mappings(df):
|
||||
"""Split out the symbol: sid mappings from the raw data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The dataframe with multiple rows for each symbol: sid pair.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asset_info : pd.DataFrame
|
||||
The asset info with one row per asset.
|
||||
symbol_mappings : pd.DataFrame
|
||||
The dataframe of just symbol: sid mappings. The index will be
|
||||
the sid, then there will be three columns: symbol, start_date, and
|
||||
end_date.
|
||||
"""
|
||||
mappings = df[list(mapping_columns)]
|
||||
for symbol in mappings.symbol.unique():
|
||||
persymbol = mappings[mappings.symbol == symbol]
|
||||
intersections = list(intersecting_ranges(
|
||||
map(from_tuple, zip(persymbol.start_date, persymbol.end_date)),
|
||||
))
|
||||
if intersections:
|
||||
raise ValueError(
|
||||
'Ambiguous ownership of %r, multiple companies held this'
|
||||
' ticker over the following ranges:\n%s' % (
|
||||
symbol,
|
||||
list(map(_format_range, intersections)),
|
||||
),
|
||||
)
|
||||
return (
|
||||
df.groupby(level=0).apply(_check_asset_group),
|
||||
df[list(mapping_columns)],
|
||||
)
|
||||
|
||||
|
||||
def _dt_to_epoch_ns(dt_series):
|
||||
@@ -187,12 +263,14 @@ def _dt_to_epoch_ns(dt_series):
|
||||
return index.view(np.int64)
|
||||
|
||||
|
||||
def check_version_info(version_table, expected_version):
|
||||
def check_version_info(conn, version_table, expected_version):
|
||||
"""
|
||||
Checks for a version value in the version table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
conn : sa.Connection
|
||||
The connection to use to perform the check.
|
||||
version_table : sa.Table
|
||||
The version table of the asset database
|
||||
expected_version : int
|
||||
@@ -205,7 +283,9 @@ def check_version_info(version_table, expected_version):
|
||||
"""
|
||||
|
||||
# Read the version out of the table
|
||||
version_from_table = sa.select((version_table.c.version,)).scalar()
|
||||
version_from_table = conn.execute(
|
||||
sa.select((version_table.c.version,)),
|
||||
).scalar()
|
||||
|
||||
# A db without a version is considered v0
|
||||
if version_from_table is None:
|
||||
@@ -217,19 +297,21 @@ def check_version_info(version_table, expected_version):
|
||||
expected_version=expected_version)
|
||||
|
||||
|
||||
def write_version_info(version_table, version_value):
|
||||
def write_version_info(conn, version_table, version_value):
|
||||
"""
|
||||
Inserts the version value in to the version table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
conn : sa.Connection
|
||||
The connection to use to execute the insert.
|
||||
version_table : sa.Table
|
||||
The version table of the asset database
|
||||
version_value : int
|
||||
The version to write in to the database
|
||||
|
||||
"""
|
||||
sa.insert(version_table, values={'version': version_value}).execute()
|
||||
conn.execute(sa.insert(version_table, values={'version': version_value}))
|
||||
|
||||
|
||||
class _empty(object):
|
||||
@@ -266,9 +348,6 @@ class AssetDBWriter(object):
|
||||
|
||||
symbol : str
|
||||
The ticker symbol for this equity.
|
||||
fuzzy_symbol : str, optional
|
||||
The fuzzy symbol for this equity. This is the symbol
|
||||
without any delimiting characters like '.' or '_'.
|
||||
asset_name : str
|
||||
The full name for this asset.
|
||||
start_date : datetime
|
||||
@@ -348,7 +427,7 @@ class AssetDBWriter(object):
|
||||
"""
|
||||
with self.engine.begin() as txn:
|
||||
# Create SQL tables if they do not exist.
|
||||
metadata = self.init_db(txn)
|
||||
self.init_db(txn)
|
||||
|
||||
# Get the data to add to SQL.
|
||||
data = self._load_data(
|
||||
@@ -359,51 +438,74 @@ class AssetDBWriter(object):
|
||||
)
|
||||
# Write the data to SQL.
|
||||
self._write_df_to_table(
|
||||
metadata.tables['futures_exchanges'],
|
||||
futures_exchanges,
|
||||
data.exchanges,
|
||||
txn,
|
||||
chunk_size,
|
||||
)
|
||||
self._write_df_to_table(
|
||||
metadata.tables['futures_root_symbols'],
|
||||
futures_root_symbols,
|
||||
data.root_symbols,
|
||||
txn,
|
||||
chunk_size,
|
||||
)
|
||||
asset_router = metadata.tables['asset_router']
|
||||
self._write_assets(
|
||||
asset_router,
|
||||
metadata.tables['futures_contracts'],
|
||||
'future',
|
||||
data.futures,
|
||||
txn,
|
||||
chunk_size,
|
||||
)
|
||||
self._write_assets(
|
||||
asset_router,
|
||||
metadata.tables['equities'],
|
||||
'equity',
|
||||
data.equities,
|
||||
txn,
|
||||
chunk_size,
|
||||
mapping_data=data.equities_mappings,
|
||||
)
|
||||
|
||||
def _write_df_to_table(self, tbl, df, txn, chunk_size):
|
||||
def _write_df_to_table(self, tbl, df, txn, chunk_size, idx_label=None):
|
||||
df.to_sql(
|
||||
tbl.name,
|
||||
txn.connection,
|
||||
index_label=first(tbl.primary_key.columns).name,
|
||||
index_label=(
|
||||
idx_label
|
||||
if idx_label is not None else
|
||||
first(tbl.primary_key.columns).name
|
||||
),
|
||||
if_exists='append',
|
||||
chunksize=chunk_size,
|
||||
)
|
||||
|
||||
def _write_assets(self,
|
||||
asset_router,
|
||||
tbl,
|
||||
asset_type,
|
||||
assets,
|
||||
txn,
|
||||
chunk_size):
|
||||
chunk_size,
|
||||
mapping_data=None):
|
||||
if asset_type == 'future':
|
||||
tbl = futures_contracts_table
|
||||
if mapping_data is not None:
|
||||
raise TypeError('no mapping data expected for futures')
|
||||
|
||||
elif asset_type == 'equity':
|
||||
tbl = equities_table
|
||||
if mapping_data is None:
|
||||
raise TypeError('mapping data required for equities')
|
||||
# write the symbol mapping data.
|
||||
self._write_df_to_table(
|
||||
equity_symbol_mappings,
|
||||
mapping_data,
|
||||
txn,
|
||||
chunk_size,
|
||||
idx_label='sid',
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"asset_type must be in {'future', 'equity'}, got: %s" %
|
||||
asset_type,
|
||||
)
|
||||
|
||||
self._write_df_to_table(tbl, assets, txn, chunk_size)
|
||||
|
||||
pd.DataFrame({
|
||||
@@ -456,17 +558,14 @@ class AssetDBWriter(object):
|
||||
txn = stack.enter_context(self.engine.begin())
|
||||
|
||||
tables_already_exist = self._all_tables_present(txn)
|
||||
metadata = generate_asset_db_metadata(bind=txn)
|
||||
|
||||
# Create the SQL tables if they do not already exist.
|
||||
metadata.create_all(checkfirst=True)
|
||||
metadata.create_all(txn, checkfirst=True)
|
||||
|
||||
version_info = metadata.tables['version_info']
|
||||
if tables_already_exist:
|
||||
check_version_info(version_info, ASSET_DB_VERSION)
|
||||
check_version_info(txn, version_info, ASSET_DB_VERSION)
|
||||
else:
|
||||
write_version_info(version_info, ASSET_DB_VERSION)
|
||||
return metadata
|
||||
write_version_info(txn, version_info, ASSET_DB_VERSION)
|
||||
|
||||
def _normalize_equities(self, equities):
|
||||
# HACK: If 'company_name' is provided, map it to asset_name
|
||||
@@ -487,16 +586,13 @@ class AssetDBWriter(object):
|
||||
tuple_series = equities_output['symbol'].apply(split_delimited_symbol)
|
||||
split_symbols = pd.DataFrame(
|
||||
tuple_series.tolist(),
|
||||
columns=['company_symbol', 'share_class_symbol', 'fuzzy_symbol'],
|
||||
columns=['company_symbol', 'share_class_symbol'],
|
||||
index=tuple_series.index
|
||||
)
|
||||
equities_output = equities_output.join(split_symbols)
|
||||
equities_output = pd.concat((equities_output, split_symbols), axis=1)
|
||||
|
||||
# Upper-case all symbol data
|
||||
for col in ('symbol',
|
||||
'company_symbol',
|
||||
'share_class_symbol',
|
||||
'fuzzy_symbol'):
|
||||
for col in symbol_columns:
|
||||
equities_output[col] = equities_output[col].str.upper()
|
||||
|
||||
# Convert date columns to UNIX Epoch integers (nanoseconds)
|
||||
@@ -506,7 +602,7 @@ class AssetDBWriter(object):
|
||||
'auto_close_date'):
|
||||
equities_output[col] = _dt_to_epoch_ns(equities_output[col])
|
||||
|
||||
return equities_output
|
||||
return _split_symbol_mappings(equities_output)
|
||||
|
||||
def _normalize_futures(self, futures):
|
||||
futures_output = _generate_output_dataframe(
|
||||
@@ -541,7 +637,7 @@ class AssetDBWriter(object):
|
||||
if id_col in df.columns:
|
||||
df.set_index(id_col, inplace=True)
|
||||
|
||||
equities_output = self._normalize_equities(equities)
|
||||
equities_output, equities_mappings = self._normalize_equities(equities)
|
||||
futures_output = self._normalize_futures(futures)
|
||||
|
||||
exchanges_output = _generate_output_dataframe(
|
||||
@@ -556,6 +652,7 @@ class AssetDBWriter(object):
|
||||
|
||||
return AssetData(
|
||||
equities=equities_output,
|
||||
equities_mappings=equities_mappings,
|
||||
futures=futures_output,
|
||||
exchanges=exchanges_output,
|
||||
root_symbols=root_symbols_output,
|
||||
|
||||
+305
-261
@@ -13,16 +13,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABCMeta
|
||||
from collections import namedtuple
|
||||
from numbers import Integral
|
||||
from operator import itemgetter
|
||||
from operator import itemgetter, attrgetter
|
||||
|
||||
from logbook import Logger
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas import isnull
|
||||
from six import with_metaclass, string_types, viewkeys
|
||||
from six.moves import map as imap
|
||||
from six import with_metaclass, string_types, viewkeys, iteritems
|
||||
import sqlalchemy as sa
|
||||
from toolz import merge, compose, valmap, sliding_window, concatv, curry
|
||||
from toolz.curried import operator as op
|
||||
|
||||
from zipline.errors import (
|
||||
EquitiesNotFound,
|
||||
@@ -33,18 +35,20 @@ from zipline.errors import (
|
||||
SidsNotFound,
|
||||
SymbolNotFound,
|
||||
)
|
||||
from zipline.assets import (
|
||||
from . import (
|
||||
Asset, Equity, Future,
|
||||
)
|
||||
from zipline.assets.asset_writer import (
|
||||
from .asset_writer import (
|
||||
check_version_info,
|
||||
split_delimited_symbol,
|
||||
asset_db_table_names,
|
||||
symbol_columns,
|
||||
)
|
||||
from zipline.assets.asset_db_schema import (
|
||||
from .asset_db_schema import (
|
||||
ASSET_DB_VERSION
|
||||
)
|
||||
from zipline.utils.control_flow import invert
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.sqlite_utils import group_into_chunks
|
||||
|
||||
log = Logger('assets.py')
|
||||
@@ -67,12 +71,38 @@ _asset_timestamp_fields = frozenset({
|
||||
'auto_close_date',
|
||||
})
|
||||
|
||||
SymbolOwnership = namedtuple('SymbolOwnership', 'start end sid symbol')
|
||||
|
||||
|
||||
@curry
|
||||
def _filter_kwargs(names, dict_):
|
||||
"""Filter out kwargs from a dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
names : set[str]
|
||||
The names to select from ``dict_``.
|
||||
dict_ : dict[str, any]
|
||||
The dictionary to select from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
kwargs : dict[str, any]
|
||||
``dict_`` where the keys intersect with ``names`` and the values are
|
||||
not None.
|
||||
"""
|
||||
return {k: v for k, v in dict_.items() if k in names and v is not None}
|
||||
|
||||
|
||||
_filter_future_kwargs = _filter_kwargs(Future._kwargnames)
|
||||
_filter_equity_kwargs = _filter_kwargs(Equity._kwargnames)
|
||||
|
||||
|
||||
def _convert_asset_timestamp_fields(dict_):
|
||||
"""
|
||||
Takes in a dict of Asset init args and converts dates to pd.Timestamps
|
||||
"""
|
||||
for key in (_asset_timestamp_fields & viewkeys(dict_)):
|
||||
for key in _asset_timestamp_fields & viewkeys(dict_):
|
||||
value = pd.Timestamp(dict_[key], tz='UTC')
|
||||
dict_[key] = None if isnull(value) else value
|
||||
return dict_
|
||||
@@ -101,17 +131,18 @@ class AssetFinder(object):
|
||||
PERSISTENT_TOKEN = "<AssetFinder>"
|
||||
|
||||
def __init__(self, engine):
|
||||
if isinstance(engine, string_types):
|
||||
engine = sa.create_engine('sqlite:///' + engine)
|
||||
|
||||
self.engine = engine
|
||||
self.engine = engine = (
|
||||
sa.create_engine('sqlite:///' + engine)
|
||||
if isinstance(engine, string_types) else
|
||||
engine
|
||||
)
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
metadata.reflect(only=asset_db_table_names)
|
||||
for table_name in asset_db_table_names:
|
||||
setattr(self, table_name, metadata.tables[table_name])
|
||||
|
||||
# Check the version info of the db for compatibility
|
||||
check_version_info(self.version_info, ASSET_DB_VERSION)
|
||||
check_version_info(engine, self.version_info, ASSET_DB_VERSION)
|
||||
|
||||
# Cache for lookup of assets by sid, the objects in the asset lookup
|
||||
# may be shared with the results from equity and future lookup caches.
|
||||
@@ -137,6 +168,79 @@ class AssetFinder(object):
|
||||
# should be calling this.
|
||||
for cache in self._caches:
|
||||
cache.clear()
|
||||
self.reload_symbol_maps()
|
||||
|
||||
def reload_symbol_maps(self):
|
||||
"""Clear the in memory symbol lookup maps.
|
||||
|
||||
This will make any changes to the underlying db available to the
|
||||
symbol maps.
|
||||
"""
|
||||
# clear the lazyval caches, the next access will requery
|
||||
try:
|
||||
del type(self).symbol_ownership_map[self]
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del type(self).fuzzy_symbol_ownership_map[self]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@lazyval
|
||||
def symbol_ownership_map(self):
|
||||
rows = sa.select(self.equity_symbol_mappings.c).execute().fetchall()
|
||||
|
||||
mappings = {}
|
||||
for row in rows:
|
||||
mappings.setdefault(
|
||||
(row.company_symbol, row.share_class_symbol),
|
||||
[],
|
||||
).append(
|
||||
SymbolOwnership(
|
||||
pd.Timestamp(row.start_date, unit='ns', tz='utc'),
|
||||
pd.Timestamp(row.end_date, unit='ns', tz='utc'),
|
||||
row.sid,
|
||||
row.symbol,
|
||||
),
|
||||
)
|
||||
|
||||
return valmap(
|
||||
lambda v: tuple(
|
||||
SymbolOwnership(
|
||||
a.start,
|
||||
b.start,
|
||||
a.sid,
|
||||
a.symbol,
|
||||
) for a, b in sliding_window(
|
||||
2,
|
||||
concatv(
|
||||
sorted(v),
|
||||
# concat with a fake ownership object to make the last
|
||||
# end date be max timestamp
|
||||
[SymbolOwnership(
|
||||
pd.Timestamp.max.tz_localize('utc'),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)],
|
||||
),
|
||||
)
|
||||
),
|
||||
mappings,
|
||||
factory=lambda: mappings,
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def fuzzy_symbol_ownership_map(self):
|
||||
fuzzy_mappings = {}
|
||||
for (cs, scs), owners in iteritems(self.symbol_ownership_map):
|
||||
fuzzy_owners = fuzzy_mappings.setdefault(
|
||||
cs + scs,
|
||||
[],
|
||||
)
|
||||
fuzzy_owners.extend(owners)
|
||||
fuzzy_owners.sort()
|
||||
return fuzzy_mappings
|
||||
|
||||
def lookup_asset_types(self, sids):
|
||||
"""
|
||||
@@ -326,6 +430,50 @@ class AssetFinder(object):
|
||||
def _select_asset_by_symbol(asset_tbl, symbol):
|
||||
return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol)
|
||||
|
||||
def _lookup_most_recent_symbols(self, sids):
|
||||
symbol_cols = self.equity_symbol_mappings.c
|
||||
|
||||
symbols = {
|
||||
row.sid: {c: row[c] for c in symbol_columns}
|
||||
for row in self.engine.execute(
|
||||
sa.select(
|
||||
(symbol_cols.sid,) +
|
||||
tuple(map(op.getitem(symbol_cols), symbol_columns)),
|
||||
).where(
|
||||
symbol_cols.sid.in_(map(int, sids)),
|
||||
).order_by(
|
||||
symbol_cols.end_date.desc(),
|
||||
).group_by(
|
||||
symbol_cols.sid,
|
||||
)
|
||||
).fetchall()
|
||||
}
|
||||
|
||||
if len(symbols) != len(sids):
|
||||
raise EquitiesNotFound(
|
||||
sids=set(sids) - set(symbols),
|
||||
plural=True,
|
||||
)
|
||||
return symbols
|
||||
|
||||
def _retrieve_asset_dicts(self, sids, asset_tbl, querying_equities):
|
||||
if not sids:
|
||||
return
|
||||
|
||||
if querying_equities:
|
||||
def mkdict(row,
|
||||
symbols=self._lookup_most_recent_symbols(sids)):
|
||||
return merge(row, symbols[row['sid']])
|
||||
else:
|
||||
mkdict = dict
|
||||
|
||||
for assets in group_into_chunks(sids):
|
||||
# Load misses from the db.
|
||||
query = self._select_assets_by_sid(asset_tbl, assets)
|
||||
|
||||
for row in query.execute().fetchall():
|
||||
yield _convert_asset_timestamp_fields(mkdict(row))
|
||||
|
||||
def _retrieve_assets(self, sids, asset_tbl, asset_type):
|
||||
"""
|
||||
Internal function for loading assets from a table.
|
||||
@@ -354,14 +502,18 @@ class AssetFinder(object):
|
||||
cache = self._asset_cache
|
||||
hits = {}
|
||||
|
||||
for assets in group_into_chunks(sids):
|
||||
# Load misses from the db.
|
||||
query = self._select_assets_by_sid(asset_tbl, assets)
|
||||
querying_equities = issubclass(asset_type, Equity)
|
||||
filter_kwargs = (
|
||||
_filter_equity_kwargs
|
||||
if querying_equities else
|
||||
_filter_future_kwargs
|
||||
)
|
||||
|
||||
for row in imap(dict, query.execute().fetchall()):
|
||||
asset = asset_type(**_convert_asset_timestamp_fields(row))
|
||||
sid = asset.sid
|
||||
hits[sid] = cache[sid] = asset
|
||||
rows = self._retrieve_asset_dicts(sids, asset_tbl, querying_equities)
|
||||
for row in rows:
|
||||
sid = row['sid']
|
||||
asset = asset_type(**filter_kwargs(row))
|
||||
hits[sid] = cache[sid] = asset
|
||||
|
||||
# If we get here, it means something in our code thought that a
|
||||
# particular sid was an equity/future and called this function with a
|
||||
@@ -369,166 +521,152 @@ class AssetFinder(object):
|
||||
# an error in our code, not a user-input error.
|
||||
misses = tuple(set(sids) - viewkeys(hits))
|
||||
if misses:
|
||||
if asset_type == Equity:
|
||||
if querying_equities:
|
||||
raise EquitiesNotFound(sids=misses)
|
||||
else:
|
||||
raise FutureContractsNotFound(sids=misses)
|
||||
return hits
|
||||
|
||||
def _get_fuzzy_candidates(self, fuzzy_symbol):
|
||||
candidates = sa.select(
|
||||
(self.equities.c.sid,)
|
||||
).where(self.equities.c.fuzzy_symbol == fuzzy_symbol).order_by(
|
||||
self.equities.c.start_date.desc(),
|
||||
self.equities.c.end_date.desc()
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
|
||||
candidates = sa.select(
|
||||
(self.equities.c.sid,)
|
||||
).where(
|
||||
sa.and_(
|
||||
self.equities.c.fuzzy_symbol == fuzzy_symbol,
|
||||
self.equities.c.start_date <= ad_value,
|
||||
self.equities.c.end_date >= ad_value
|
||||
)
|
||||
).order_by(
|
||||
self.equities.c.start_date.desc(),
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _get_split_candidates_in_range(self,
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value):
|
||||
candidates = sa.select(
|
||||
(self.equities.c.sid,)
|
||||
).where(
|
||||
sa.and_(
|
||||
self.equities.c.company_symbol == company_symbol,
|
||||
self.equities.c.share_class_symbol == share_class_symbol,
|
||||
self.equities.c.start_date <= ad_value,
|
||||
self.equities.c.end_date >= ad_value
|
||||
)
|
||||
).order_by(
|
||||
self.equities.c.start_date.desc(),
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _get_split_candidates(self, company_symbol, share_class_symbol):
|
||||
candidates = sa.select(
|
||||
(self.equities.c.sid,)
|
||||
).where(
|
||||
sa.and_(
|
||||
self.equities.c.company_symbol == company_symbol,
|
||||
self.equities.c.share_class_symbol == share_class_symbol
|
||||
)
|
||||
).order_by(
|
||||
self.equities.c.start_date.desc(),
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _resolve_no_matching_candidates(self,
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value):
|
||||
candidates = sa.select((self.equities.c.sid,)).where(
|
||||
sa.and_(
|
||||
self.equities.c.company_symbol == company_symbol,
|
||||
self.equities.c.share_class_symbol ==
|
||||
share_class_symbol,
|
||||
self.equities.c.start_date <= ad_value),
|
||||
).order_by(
|
||||
self.equities.c.end_date.desc(),
|
||||
).execute().fetchall()
|
||||
return candidates
|
||||
|
||||
def _get_best_candidate(self, candidates):
|
||||
return self._retrieve_equity(candidates[0]['sid'])
|
||||
|
||||
def _get_equities_from_candidates(self, candidates):
|
||||
sids = map(itemgetter('sid'), candidates)
|
||||
results = self.retrieve_equities(sids)
|
||||
return [results[sid] for sid in sids]
|
||||
|
||||
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
|
||||
"""
|
||||
Return matching Equity of name symbol in database.
|
||||
|
||||
If multiple Equities are found and as_of_date is not set,
|
||||
raises MultipleSymbolsFound.
|
||||
|
||||
If no Equity was active at as_of_date raises SymbolNotFound.
|
||||
"""
|
||||
company_symbol, share_class_symbol, fuzzy_symbol = \
|
||||
split_delimited_symbol(symbol)
|
||||
if as_of_date:
|
||||
# Format inputs
|
||||
as_of_date = pd.Timestamp(as_of_date).normalize()
|
||||
ad_value = as_of_date.value
|
||||
|
||||
if fuzzy:
|
||||
# Search for a single exact match on the fuzzy column
|
||||
candidates = self._get_fuzzy_candidates_in_range(fuzzy_symbol,
|
||||
ad_value)
|
||||
|
||||
# If exactly one SID exists for fuzzy_symbol, return that sid
|
||||
if len(candidates) == 1:
|
||||
return self._get_best_candidate(candidates)
|
||||
|
||||
# Search for exact matches of the split-up company_symbol and
|
||||
# share_class_symbol
|
||||
candidates = self._get_split_candidates_in_range(
|
||||
def _lookup_symbol_strict(self, symbol, as_of_date):
|
||||
# split the symbol into the components, if there are no
|
||||
# company/share class parts then share_class_symbol will be empty
|
||||
company_symbol, share_class_symbol = split_delimited_symbol(symbol)
|
||||
try:
|
||||
owners = self.symbol_ownership_map[
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value
|
||||
)
|
||||
|
||||
# If exactly one SID exists for symbol, return that symbol
|
||||
# If multiple SIDs exist for symbol, return latest start_date with
|
||||
# end_date as a tie-breaker
|
||||
if candidates:
|
||||
return self._get_best_candidate(candidates)
|
||||
|
||||
# If no SID exists for symbol, return SID with the
|
||||
# highest-but-not-over end_date
|
||||
elif not candidates:
|
||||
candidates = self._resolve_no_matching_candidates(
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value
|
||||
)
|
||||
if candidates:
|
||||
return self._get_best_candidate(candidates)
|
||||
|
||||
]
|
||||
assert owners, 'empty owners list for %r' % symbol
|
||||
except KeyError:
|
||||
# no equity has ever held this symbol
|
||||
raise SymbolNotFound(symbol=symbol)
|
||||
|
||||
else:
|
||||
# If this is a fuzzy look-up, check if there is exactly one match
|
||||
# for the fuzzy symbol
|
||||
if fuzzy:
|
||||
candidates = self._get_fuzzy_candidates(fuzzy_symbol)
|
||||
if len(candidates) == 1:
|
||||
return self._get_best_candidate(candidates)
|
||||
|
||||
candidates = self._get_split_candidates(company_symbol,
|
||||
share_class_symbol)
|
||||
if len(candidates) == 1:
|
||||
return self._get_best_candidate(candidates)
|
||||
elif not candidates:
|
||||
raise SymbolNotFound(symbol=symbol)
|
||||
else:
|
||||
if not as_of_date:
|
||||
if len(owners) > 1:
|
||||
# more than one equity has held this ticker, this is ambigious
|
||||
# without the date
|
||||
raise MultipleSymbolsFound(
|
||||
symbol=symbol,
|
||||
options=self._get_equities_from_candidates(candidates)
|
||||
options=set(map(
|
||||
compose(self.retrieve_asset, attrgetter('sid')),
|
||||
owners,
|
||||
)),
|
||||
)
|
||||
|
||||
# exactly one equity has ever held this symbol, we may resolve
|
||||
# without the date
|
||||
return self.retrieve_asset(owners[0].sid)
|
||||
|
||||
for start, end, sid, _ in owners:
|
||||
if start <= as_of_date < end:
|
||||
# find the equity that owned it on the given asof date
|
||||
return self.retrieve_asset(sid)
|
||||
|
||||
# no equity held the ticker on the given asof date
|
||||
raise SymbolNotFound(symbol=symbol)
|
||||
|
||||
def _lookup_symbol_fuzzy(self, symbol, as_of_date):
|
||||
symbol = symbol.upper()
|
||||
company_symbol, share_class_symbol = split_delimited_symbol(symbol)
|
||||
try:
|
||||
owners = self.fuzzy_symbol_ownership_map[
|
||||
company_symbol + share_class_symbol
|
||||
]
|
||||
assert owners, 'empty owners list for %r' % symbol
|
||||
except KeyError:
|
||||
# no equity has ever held a symbol matching the fuzzy symbol
|
||||
raise SymbolNotFound(symbol=symbol)
|
||||
|
||||
if not as_of_date:
|
||||
if len(owners) == 1:
|
||||
# only one valid match
|
||||
return self.retrieve_asset(owners[0].sid)
|
||||
|
||||
options = []
|
||||
for _, _, sid, sym in owners:
|
||||
if sym == symbol:
|
||||
# there are multiple options, look for exact matches
|
||||
options.append(self.retrieve_asset(sid))
|
||||
|
||||
if len(options) == 1:
|
||||
# there was only one exact match
|
||||
return options[0]
|
||||
|
||||
# there are more than one exact match for this fuzzy symbol
|
||||
raise MultipleSymbolsFound(
|
||||
symbol=symbol,
|
||||
options=set(options),
|
||||
)
|
||||
|
||||
options = []
|
||||
for start, end, sid, sym in owners:
|
||||
if start <= as_of_date < end:
|
||||
# see which fuzzy symbols were owned on the asof date.
|
||||
options.append((sid, sym))
|
||||
|
||||
if not options:
|
||||
# no equity owned the fuzzy symbol on the date requested
|
||||
SymbolNotFound(symbol=symbol)
|
||||
|
||||
if len(options) == 1:
|
||||
# there was only one owner, return it
|
||||
return self.retrieve_asset(options[0][0])
|
||||
|
||||
for sid, sym in options:
|
||||
if sym == symbol:
|
||||
# look for an exact match on the asof date
|
||||
return self.retrieve_asset(sid)
|
||||
|
||||
# multiple equities held tickers matching the fuzzy ticker but
|
||||
# there are no exact matches
|
||||
raise MultipleSymbolsFound(
|
||||
symbol=symbol,
|
||||
options=set(map(
|
||||
compose(self.retrieve_asset, itemgetter(0)),
|
||||
options,
|
||||
)),
|
||||
)
|
||||
|
||||
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
|
||||
"""Lookup an equity by symbol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol : str
|
||||
The ticker symbol to resolve.
|
||||
as_of_date : datetime or None
|
||||
Look up the last owner of this symbol as of this datetime.
|
||||
If ``as_of_date`` is None, then this can only resolve the equity
|
||||
if exactly one equity has ever owned the ticker.
|
||||
fuzzy : bool, optional
|
||||
Should fuzzy symbol matching be used? Fuzzy symbol matching
|
||||
attempts to resolve differences in representations for
|
||||
shareclasses. For example, some people may represent the ``A``
|
||||
shareclass of ``BRK`` as ``BRK.A``, where others could write
|
||||
``BRK_A``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
equity : Equity
|
||||
The equity that held ``symbol`` on the given ``as_of_date``, or the
|
||||
only equity to hold ``symbol`` if ``as_of_date`` is None.
|
||||
|
||||
Raises
|
||||
------
|
||||
SymbolNotFound
|
||||
Raised when no equity has ever held the given symbol.
|
||||
MultipleSymbolsFound
|
||||
Raised when no ``as_of_date`` is given and more than one equity
|
||||
has held ``symbol``. This is also raised when ``fuzzy=True`` and
|
||||
there are multiple candidates for the given ``symbol`` on the
|
||||
``as_of_date``.
|
||||
"""
|
||||
if fuzzy:
|
||||
return self._lookup_symbol_fuzzy(symbol, as_of_date)
|
||||
return self._lookup_symbol_strict(symbol, as_of_date)
|
||||
|
||||
def lookup_future_symbol(self, symbol):
|
||||
""" Return the Future object for a given symbol.
|
||||
"""Lookup a future contract by symbol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -537,8 +675,8 @@ class AssetFinder(object):
|
||||
|
||||
Returns
|
||||
-------
|
||||
Future
|
||||
A Future object.
|
||||
future : Future
|
||||
The future contract referenced by ``symbol``.
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -946,100 +1084,6 @@ class NotAssetConvertible(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class AssetFinderCachedEquities(AssetFinder):
|
||||
"""
|
||||
An extension to AssetFinder that preloads all equities from equities table
|
||||
into memory and does lookups from there.
|
||||
|
||||
To have any changes in the underlying assets db reflected by this asset
|
||||
finder one must manually call the ``rehash_equities`` method.
|
||||
"""
|
||||
|
||||
def __init__(self, engine):
|
||||
super(AssetFinderCachedEquities, self).__init__(engine)
|
||||
self._fuzzy_symbol_cache = {}
|
||||
self._company_share_class_cache = {}
|
||||
|
||||
self.rehash_equities()
|
||||
|
||||
def rehash_equities(self):
|
||||
"""Reload the underlying assets db into the in memory cache.
|
||||
"""
|
||||
for equity in sa.select(self.equities.c).execute().fetchall():
|
||||
company_symbol = equity['company_symbol']
|
||||
share_class_symbol = equity['share_class_symbol']
|
||||
fuzzy_symbol = equity['fuzzy_symbol']
|
||||
asset = self._convert_row_to_equity(equity)
|
||||
self._company_share_class_cache.setdefault(
|
||||
(company_symbol, share_class_symbol),
|
||||
[]
|
||||
).append(asset)
|
||||
self._fuzzy_symbol_cache.setdefault(
|
||||
fuzzy_symbol,
|
||||
[],
|
||||
).append(asset)
|
||||
|
||||
def _convert_row_to_equity(self, row):
|
||||
"""
|
||||
Converts a SQLAlchemy equity row to an Equity object.
|
||||
"""
|
||||
return Equity(**_convert_asset_timestamp_fields(dict(row)))
|
||||
|
||||
def _get_fuzzy_candidates(self, fuzzy_symbol):
|
||||
return self._fuzzy_symbol_cache.get(fuzzy_symbol, ())
|
||||
|
||||
def _get_fuzzy_candidates_in_range(self, fuzzy_symbol, ad_value):
|
||||
return only_active_assets(
|
||||
ad_value,
|
||||
self._get_fuzzy_candidates(fuzzy_symbol),
|
||||
)
|
||||
|
||||
def _get_split_candidates(self, company_symbol, share_class_symbol):
|
||||
return self._company_share_class_cache.get(
|
||||
(company_symbol, share_class_symbol),
|
||||
(),
|
||||
)
|
||||
|
||||
def _get_split_candidates_in_range(self,
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value):
|
||||
return sorted(
|
||||
only_active_assets(
|
||||
ad_value,
|
||||
self._get_split_candidates(company_symbol, share_class_symbol),
|
||||
),
|
||||
key=lambda x: (x.start_date, x.end_date),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def _resolve_no_matching_candidates(self,
|
||||
company_symbol,
|
||||
share_class_symbol,
|
||||
ad_value):
|
||||
equities = self._get_split_candidates(
|
||||
company_symbol,
|
||||
share_class_symbol
|
||||
)
|
||||
partial_candidates = []
|
||||
for equity in equities:
|
||||
if equity.start_date.value <= ad_value:
|
||||
partial_candidates.append(equity)
|
||||
if partial_candidates:
|
||||
partial_candidates = sorted(
|
||||
partial_candidates,
|
||||
key=lambda x: x.end_date,
|
||||
reverse=True
|
||||
)
|
||||
return partial_candidates
|
||||
|
||||
def _get_best_candidate(self, candidates):
|
||||
return candidates[0]
|
||||
|
||||
def _get_equities_from_candidates(self, candidates):
|
||||
return candidates
|
||||
|
||||
|
||||
def was_active(reference_date_value, asset):
|
||||
"""
|
||||
Whether or not `asset` was active at the time corresponding to
|
||||
|
||||
@@ -6,6 +6,7 @@ from pandas_datareader.data import DataReader
|
||||
import requests
|
||||
|
||||
from zipline.utils.cli import maybe_show_progress
|
||||
from .core import register
|
||||
|
||||
|
||||
def _cachpath(symbol, type_):
|
||||
@@ -169,3 +170,24 @@ def yahoo_equities(symbols, start=None, end=None):
|
||||
adjustment_writer.write(splits=splits, dividends=dividends)
|
||||
|
||||
return ingest
|
||||
|
||||
|
||||
# bundle used when creating test data
|
||||
register(
|
||||
'.test',
|
||||
yahoo_equities(
|
||||
(
|
||||
'AMD',
|
||||
'CERN',
|
||||
'COST',
|
||||
'DELL',
|
||||
'GPS',
|
||||
'INTC',
|
||||
'MMM',
|
||||
'AAPL',
|
||||
'MSFT',
|
||||
),
|
||||
pd.Timestamp('2004-01-02', tz='utc'),
|
||||
pd.Timestamp('2015-01-01', tz='utc'),
|
||||
),
|
||||
)
|
||||
|
||||
+1
-1
@@ -281,7 +281,7 @@ Multiple symbols with the name '{symbol}' found. Use the
|
||||
as_of_date' argument to to specify when the date symbol-lookup
|
||||
should be valid.
|
||||
|
||||
Possible options:{options}
|
||||
Possible options: {options}
|
||||
""".strip()
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,11 @@ from nose.tools import ( # noqa
|
||||
)
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.util.testing import (
|
||||
assert_frame_equal,
|
||||
assert_panel_equal,
|
||||
assert_series_equal,
|
||||
)
|
||||
from six import iteritems, viewkeys, PY2
|
||||
from toolz import dissoc, keyfilter
|
||||
import toolz.curried.operator as op
|
||||
@@ -393,18 +397,49 @@ def assert_array_equal(result,
|
||||
raise AssertionError('\n'.join((str(e), _fmt_path(path))))
|
||||
|
||||
|
||||
@assert_equal.register(pd.DataFrame, pd.DataFrame)
|
||||
def assert_dataframe_equal(result, expected, path=(), msg='', **kwargs):
|
||||
try:
|
||||
assert_frame_equal(
|
||||
result,
|
||||
expected,
|
||||
**filter_kwargs(assert_frame_equal, kwargs)
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(
|
||||
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
|
||||
)
|
||||
def _register_assert_ndframe_equal(type_, assert_eq):
|
||||
"""Register a new check for an ndframe object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type_ : type
|
||||
The class to register an ``assert_equal`` dispatch for.
|
||||
assert_eq : callable[type_, type_]
|
||||
The function which checks that if the two ndframes are equal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
assert_ndframe_equal : callable[type_, type_]
|
||||
The wrapped function registered with ``assert_equal``.
|
||||
"""
|
||||
@assert_equal.register(type_, type_)
|
||||
def assert_ndframe_equal(result, expected, path=(), msg='', **kwargs):
|
||||
try:
|
||||
assert_eq(
|
||||
result,
|
||||
expected,
|
||||
**filter_kwargs(assert_frame_equal, kwargs)
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(
|
||||
_fmt_msg(msg) + '\n'.join((str(e), _fmt_path(path))),
|
||||
)
|
||||
|
||||
return assert_ndframe_equal
|
||||
|
||||
|
||||
assert_frame_equal = _register_assert_ndframe_equal(
|
||||
pd.DataFrame,
|
||||
assert_frame_equal,
|
||||
)
|
||||
assert_panel_equal = _register_assert_ndframe_equal(
|
||||
pd.Panel,
|
||||
assert_panel_equal,
|
||||
)
|
||||
assert_series_equal = _register_assert_ndframe_equal(
|
||||
pd.Series,
|
||||
assert_series_equal,
|
||||
)
|
||||
|
||||
|
||||
@assert_equal.register(Adjustment, Adjustment)
|
||||
|
||||
+54
-10
@@ -1,8 +1,9 @@
|
||||
from functools import reduce
|
||||
from pprint import pformat
|
||||
|
||||
from six import viewkeys
|
||||
from six.moves import map, zip
|
||||
from toolz import curry
|
||||
from toolz import curry, flip
|
||||
|
||||
|
||||
@curry
|
||||
@@ -332,17 +333,60 @@ with_name = set_attribute('__name__')
|
||||
with_doc = set_attribute('__doc__')
|
||||
|
||||
|
||||
def let(a):
|
||||
"""Box a value to be bound in a for binding.
|
||||
def foldr(f, seq, default=_no_default):
|
||||
"""Fold a function over a sequence with right associativity.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
f : callable[any, any]
|
||||
The function to reduce the sequence with.
|
||||
The first argument will be the element of the sequence; the second
|
||||
argument will be the accumulator.
|
||||
seq : iterable[any]
|
||||
The sequence to reduce.
|
||||
default : any, optional
|
||||
The starting value to reduce with. If not provided, the sequence
|
||||
cannot be empty, and the last value of the sequence will be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
folded : any
|
||||
The folded value.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This functions works by reducing the list in a right associative way.
|
||||
|
||||
For example, imagine we are folding with ``operator.add`` or ``+``:
|
||||
|
||||
Examples
|
||||
--------
|
||||
.. code-block:: python
|
||||
|
||||
[f(y, y) for x in xs for y in let(g(x)) if p(y)]
|
||||
foldr(add, seq) -> seq[0] + (seq[1] + (seq[2] + (...seq[-1], default)))
|
||||
|
||||
Here, ``y`` is available in both the predicate and the expression
|
||||
of the comprehension. We can see that this allows us to cache the work
|
||||
of computing ``g(x)`` even within the expression.
|
||||
In the more general case with an arbitrary function, ``foldr`` will expand
|
||||
like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
foldr(f, seq) -> f(seq[0], f(seq[1], f(seq[2], ...f(seq[-1], default))))
|
||||
|
||||
For a more in depth discussion of left and right folds, see:
|
||||
`https://en.wikipedia.org/wiki/Fold_(higher-order_function)`_
|
||||
The images in that page are very good for showing the differences between
|
||||
``foldr`` and ``foldl`` (``reduce``).
|
||||
|
||||
.. note::
|
||||
|
||||
For performance reasons is is best to pass a strict (non-lazy) sequence,
|
||||
for example, a list.
|
||||
|
||||
See Also
|
||||
--------
|
||||
:func:`functools.reduce`
|
||||
:func:`sum`
|
||||
"""
|
||||
return a,
|
||||
return reduce(
|
||||
flip(f),
|
||||
reversed(seq),
|
||||
*(default,) if default is not _no_default else ()
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ def hidden(path):
|
||||
path : str
|
||||
A filepath.
|
||||
"""
|
||||
return path.startswith('.')
|
||||
return os.path.split(path)[1].startswith('.')
|
||||
|
||||
|
||||
def ensure_directory(path):
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
import operator as op
|
||||
|
||||
from six import PY2
|
||||
from toolz import peek
|
||||
|
||||
from zipline.utils.functional import foldr
|
||||
|
||||
|
||||
if PY2:
|
||||
class range(object):
|
||||
"""Lazy range object with constant time containment check.
|
||||
|
||||
The arguments are the same as ``range``.
|
||||
"""
|
||||
__slots__ = 'start', 'stop', 'step'
|
||||
|
||||
def __init__(self, stop, *args):
|
||||
if len(args) > 2:
|
||||
raise TypeError(
|
||||
'range takes at most 3 arguments (%d given)' % len(args)
|
||||
)
|
||||
|
||||
if not args:
|
||||
self.start = 0
|
||||
self.stop = stop
|
||||
self.step = 1
|
||||
else:
|
||||
self.start = stop
|
||||
self.stop = args[0]
|
||||
try:
|
||||
self.step = args[1]
|
||||
except IndexError:
|
||||
self.step = 1
|
||||
|
||||
def __iter__(self):
|
||||
n = self.start
|
||||
stop = self.stop
|
||||
step = self.step
|
||||
while n < stop:
|
||||
yield n
|
||||
n += step
|
||||
|
||||
_ops = (
|
||||
(op.gt, op.ge),
|
||||
(op.le, op.lt),
|
||||
)
|
||||
|
||||
def __contains__(self, other, _ops=_ops):
|
||||
start = self.start
|
||||
step = self.step
|
||||
cmp_start, cmp_stop = _ops[step > 0]
|
||||
return (
|
||||
cmp_start(start, other) and
|
||||
cmp_stop(other, self.stop) and
|
||||
(other - start) % step == 0
|
||||
)
|
||||
|
||||
del _ops
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s%s)' % (
|
||||
type(self).__name__,
|
||||
self.start,
|
||||
self.stop,
|
||||
(', ' + str(self.step)) if self.step != 1 else '',
|
||||
)
|
||||
else:
|
||||
range = range
|
||||
|
||||
|
||||
def from_tuple(tup):
|
||||
"""Convert a tuple into a range with error handling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tup : tuple (len 2 or 3)
|
||||
The tuple to turn into a range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
range : range
|
||||
The range from the tuple.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raised when the tuple length is not 2 or 3.
|
||||
"""
|
||||
if len(tup) not in (2, 3):
|
||||
raise ValueError(
|
||||
'tuple must contain 2 or 3 elements, not: %d (%r' % (
|
||||
len(tup),
|
||||
tup,
|
||||
),
|
||||
)
|
||||
return range(*tup)
|
||||
|
||||
|
||||
def maybe_from_tuple(tup_or_range):
|
||||
"""Convert a tuple into a range but pass ranges through silently.
|
||||
|
||||
This is useful to ensure that input is a range so that attributes may
|
||||
be accessed with `.start`, `.stop` or so that containment checks are
|
||||
constant time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tup_or_range : tuple or range
|
||||
A tuple to pass to from_tuple or a range to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
range : range
|
||||
The input to convert to a range.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raised when the input is not a tuple or a range. ValueError is also
|
||||
raised if the input is a tuple whose length is not 2 or 3.
|
||||
"""
|
||||
if isinstance(tup_or_range, tuple):
|
||||
return from_tuple(tup_or_range)
|
||||
elif isinstance(tup_or_range, range):
|
||||
return tup_or_range
|
||||
|
||||
raise ValueError(
|
||||
'maybe_from_tuple expects a tuple or range, got %r: %r' % (
|
||||
type(tup_or_range).__name__,
|
||||
tup_or_range,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _check_steps(a, b):
|
||||
"""Check that the steps of ``a`` and ``b`` are both 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : range
|
||||
The first range to check.
|
||||
b : range
|
||||
The second range to check.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raised when either step is not 1.
|
||||
"""
|
||||
if a.step != 1:
|
||||
raise ValueError('a.step must be equal to 1, got: %s' % a.step)
|
||||
if b.step != 1:
|
||||
raise ValueError('b.step must be equal to 1, got: %s' % b.step)
|
||||
|
||||
|
||||
def overlap(a, b):
|
||||
"""Check if two ranges overlap.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : range
|
||||
The first range.
|
||||
b : range
|
||||
The second range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
overlaps : bool
|
||||
Do these ranges overlap.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function does not support ranges with step != 1.
|
||||
"""
|
||||
_check_steps(a, b)
|
||||
return a.stop >= b.start and b.stop >= a.start
|
||||
|
||||
|
||||
def merge(a, b):
|
||||
"""Merge two ranges with step == 1.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : range
|
||||
The first range.
|
||||
b : range
|
||||
The second range.
|
||||
"""
|
||||
_check_steps(a, b)
|
||||
return range(min(a.start, b.start), max(a.stop, b.stop))
|
||||
|
||||
|
||||
def _combine(n, rs):
|
||||
"""helper for ``_group_ranges``
|
||||
"""
|
||||
try:
|
||||
r, rs = peek(rs)
|
||||
except StopIteration:
|
||||
yield n
|
||||
return
|
||||
|
||||
if overlap(n, r):
|
||||
yield merge(n, r)
|
||||
next(rs)
|
||||
for r in rs:
|
||||
yield r
|
||||
else:
|
||||
yield n
|
||||
for r in rs:
|
||||
yield r
|
||||
|
||||
|
||||
def group_ranges(ranges):
|
||||
"""Group any overlapping ranges into a single range.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ranges : iterable[ranges]
|
||||
A sorted sequence of ranges to group.
|
||||
|
||||
Returns
|
||||
-------
|
||||
grouped : iterable[ranges]
|
||||
A sorted sequence of ranges with overlapping ranges merged together.
|
||||
"""
|
||||
return foldr(_combine, ranges, ())
|
||||
|
||||
|
||||
def sorted_diff(rs, ss):
|
||||
try:
|
||||
r, rs = peek(rs)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
try:
|
||||
s, ss = peek(ss)
|
||||
except StopIteration:
|
||||
for r in rs:
|
||||
yield r
|
||||
return
|
||||
|
||||
rtup = (r.start, r.stop)
|
||||
stup = (s.start, s.stop)
|
||||
if rtup == stup:
|
||||
next(rs)
|
||||
next(ss)
|
||||
elif rtup < stup:
|
||||
yield next(rs)
|
||||
else:
|
||||
next(ss)
|
||||
|
||||
for t in sorted_diff(rs, ss):
|
||||
yield t
|
||||
|
||||
|
||||
def intersecting_ranges(ranges):
|
||||
"""Return any ranges that intersect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ranges : iterable[ranges]
|
||||
A sequence of ranges to check for intersections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
intersections : iterable[ranges]
|
||||
A sequence of all of the ranges that intersected in ``ranges``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> ranges = [range(0, 1), range(2, 5), range(4, 7)]
|
||||
>>> list(intersecting_ranges(ranges))
|
||||
[range(2, 5), range(4, 7)]
|
||||
|
||||
>>> ranges = [range(0, 1), range(2, 3)]
|
||||
>>> list(intersecting_ranges(ranges))
|
||||
[]
|
||||
|
||||
>>> ranges = [range(0, 1), range(1, 2)]
|
||||
>>> list(intersecting_ranges(ranges))
|
||||
[range(0, 1), range(1, 2)]
|
||||
"""
|
||||
ranges = sorted(ranges, key=op.attrgetter('start'))
|
||||
return sorted_diff(ranges, group_ranges(ranges))
|
||||
Reference in New Issue
Block a user