mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 16:11:42 +08:00
Merge pull request #1368 from quantopian/lots-of-symbols
Lots of symbols
This commit is contained in:
+25
-1
@@ -30,7 +30,7 @@ from nose_parameterized import parameterized
|
||||
from numpy import full, int32, int64
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from six import PY2
|
||||
from six import PY2, viewkeys
|
||||
import sqlalchemy as sa
|
||||
|
||||
from zipline.assets import (
|
||||
@@ -57,6 +57,7 @@ from zipline.assets.asset_writer import (
|
||||
check_version_info,
|
||||
write_version_info,
|
||||
_futures_defaults,
|
||||
SQLITE_MAX_VARIABLE_NUMBER,
|
||||
)
|
||||
from zipline.assets.asset_db_schema import ASSET_DB_VERSION
|
||||
from zipline.assets.asset_db_migrations import (
|
||||
@@ -83,6 +84,7 @@ from zipline.testing.fixtures import (
|
||||
ZiplineTestCase,
|
||||
WithTradingCalendar,
|
||||
)
|
||||
from zipline.utils.range import range
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -407,6 +409,28 @@ class AssetFinderTestCase(WithTradingCalendar, ZiplineTestCase):
|
||||
self._asset_writer = AssetDBWriter(conn)
|
||||
self.asset_finder = self.asset_finder_type(conn)
|
||||
|
||||
def test_blocked_lookup_symbol_query(self):
|
||||
# we will try to query for more variables than sqlite supports
|
||||
# to make sure we are properly chunking on the client side
|
||||
as_of = pd.Timestamp('2013-01-01', tz='UTC')
|
||||
# we need more sids than we can query from sqlite
|
||||
nsids = SQLITE_MAX_VARIABLE_NUMBER + 10
|
||||
sids = range(nsids)
|
||||
frame = pd.DataFrame.from_records(
|
||||
[
|
||||
{
|
||||
'sid': sid,
|
||||
'symbol': 'TEST.%d' % sid,
|
||||
'start_date': as_of.value,
|
||||
'end_date': as_of.value,
|
||||
}
|
||||
for sid in sids
|
||||
]
|
||||
)
|
||||
self.write_assets(equities=frame)
|
||||
assets = self.asset_finder.retrieve_equities(sids)
|
||||
assert_equal(viewkeys(assets), set(sids))
|
||||
|
||||
def test_lookup_symbol_delimited(self):
|
||||
as_of = pd.Timestamp('2013-01-01', tz='UTC')
|
||||
frame = pd.DataFrame.from_records(
|
||||
|
||||
+29
-14
@@ -23,7 +23,16 @@ import pandas as pd
|
||||
from pandas import isnull
|
||||
from six import with_metaclass, string_types, viewkeys, iteritems
|
||||
import sqlalchemy as sa
|
||||
from toolz import merge, compose, valmap, sliding_window, concatv, curry
|
||||
from toolz import (
|
||||
compose,
|
||||
concat,
|
||||
concatv,
|
||||
curry,
|
||||
merge,
|
||||
partition_all,
|
||||
sliding_window,
|
||||
valmap,
|
||||
)
|
||||
from toolz.curried import operator as op
|
||||
|
||||
from zipline.errors import (
|
||||
@@ -43,6 +52,7 @@ from .asset_writer import (
|
||||
split_delimited_symbol,
|
||||
asset_db_table_names,
|
||||
symbol_columns,
|
||||
SQLITE_MAX_VARIABLE_NUMBER,
|
||||
)
|
||||
from .asset_db_schema import (
|
||||
ASSET_DB_VERSION
|
||||
@@ -432,21 +442,26 @@ class AssetFinder(object):
|
||||
|
||||
def _lookup_most_recent_symbols(self, sids):
|
||||
symbol_cols = self.equity_symbol_mappings.c
|
||||
|
||||
symbols = {
|
||||
row.sid: {c: row[c] for c in symbol_columns}
|
||||
for row in self.engine.execute(
|
||||
sa.select(
|
||||
(symbol_cols.sid,) +
|
||||
tuple(map(op.getitem(symbol_cols), symbol_columns)),
|
||||
).where(
|
||||
symbol_cols.sid.in_(map(int, sids)),
|
||||
).order_by(
|
||||
symbol_cols.end_date.desc(),
|
||||
).group_by(
|
||||
symbol_cols.sid,
|
||||
)
|
||||
).fetchall()
|
||||
for row in concat(
|
||||
self.engine.execute(
|
||||
sa.select(
|
||||
(symbol_cols.sid,) +
|
||||
tuple(map(op.getitem(symbol_cols), symbol_columns)),
|
||||
).where(
|
||||
symbol_cols.sid.in_(map(int, sid_group)),
|
||||
).order_by(
|
||||
symbol_cols.end_date.desc(),
|
||||
).group_by(
|
||||
symbol_cols.sid,
|
||||
)
|
||||
).fetchall()
|
||||
for sid_group in partition_all(
|
||||
SQLITE_MAX_VARIABLE_NUMBER,
|
||||
sids
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
if len(symbols) != len(sids):
|
||||
|
||||
+57
-1
@@ -32,11 +32,31 @@ if PY2:
|
||||
except IndexError:
|
||||
self.step = 1
|
||||
|
||||
if self.step == 0:
|
||||
raise ValueError('range step must not be zero')
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Examples
|
||||
--------
|
||||
>>> list(range(1))
|
||||
[0]
|
||||
>>> list(range(5))
|
||||
[0, 1, 2, 3, 4]
|
||||
>>> list(range(1, 5))
|
||||
[1, 2, 3, 4]
|
||||
>>> list(range(0, 5, 2))
|
||||
[0, 2, 4]
|
||||
>>> list(range(5, 0, -1))
|
||||
[5, 4, 3, 2, 1]
|
||||
>>> list(range(5, 0, 1))
|
||||
[]
|
||||
"""
|
||||
n = self.start
|
||||
stop = self.stop
|
||||
step = self.step
|
||||
while n < stop:
|
||||
cmp_ = op.lt if step > 0 else op.gt
|
||||
while cmp_(n, stop):
|
||||
yield n
|
||||
n += step
|
||||
|
||||
@@ -46,6 +66,8 @@ if PY2:
|
||||
)
|
||||
|
||||
def __contains__(self, other, _ops=_ops):
|
||||
# Algorithm taken from CPython
|
||||
# Objects/rangeobject.c:range_contains_long
|
||||
start = self.start
|
||||
step = self.step
|
||||
cmp_start, cmp_stop = _ops[step > 0]
|
||||
@@ -57,6 +79,40 @@ if PY2:
|
||||
|
||||
del _ops
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Examples
|
||||
--------
|
||||
>>> len(range(1))
|
||||
1
|
||||
>>> len(range(5))
|
||||
5
|
||||
>>> len(range(1, 5))
|
||||
4
|
||||
>>> len(range(0, 5, 2))
|
||||
3
|
||||
>>> len(range(5, 0, -1))
|
||||
5
|
||||
>>> len(range(5, 0, 1))
|
||||
0
|
||||
"""
|
||||
# Algorithm taken from CPython
|
||||
# rangeobject.c:compute_range_length
|
||||
step = self.step
|
||||
|
||||
if step > 0:
|
||||
low = self.start
|
||||
high = self.stop
|
||||
else:
|
||||
low = self.stop
|
||||
high = self.start
|
||||
step = -step
|
||||
|
||||
if low >= high:
|
||||
return 0
|
||||
|
||||
return (high - low - 1) // step + 1
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s%s)' % (
|
||||
type(self).__name__,
|
||||
|
||||
Reference in New Issue
Block a user