mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 05:07:29 +08:00
MAINT: Consolidate coercion to sqlite conn/eng
This commit is contained in:
@@ -7,8 +7,11 @@ from toolz.curried import do, operator as op
|
||||
|
||||
from zipline.assets.asset_writer import write_version_info
|
||||
from zipline.errors import AssetDBImpossibleDowngrade
|
||||
from zipline.utils.preprocess import preprocess
|
||||
from zipline.utils.sqlite_utils import coerce_string_to_eng
|
||||
|
||||
|
||||
@preprocess(engine=coerce_string_to_eng)
|
||||
def downgrade(engine, desired_version):
|
||||
"""Downgrades the assets db at the given engine to the desired version.
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ from zipline.assets.asset_db_schema import (
|
||||
version_info,
|
||||
)
|
||||
|
||||
from zipline.utils.preprocess import preprocess
|
||||
from zipline.utils.range import from_tuple, intersecting_ranges
|
||||
from zipline.utils.sqlite_utils import coerce_string_to_eng
|
||||
|
||||
# Define a namedtuple for use with the load_data and _load_data methods
|
||||
AssetData = namedtuple(
|
||||
@@ -344,9 +346,8 @@ class AssetDBWriter(object):
|
||||
"""
|
||||
DEFAULT_CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER
|
||||
|
||||
@preprocess(engine=coerce_string_to_eng)
|
||||
def __init__(self, engine):
|
||||
if isinstance(engine, str):
|
||||
engine = sa.create_engine('sqlite:///' + engine)
|
||||
self.engine = engine
|
||||
|
||||
def write(self,
|
||||
|
||||
@@ -60,7 +60,8 @@ from .asset_db_schema import (
|
||||
from zipline.utils.control_flow import invert
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import as_column
|
||||
from zipline.utils.sqlite_utils import group_into_chunks
|
||||
from zipline.utils.preprocess import preprocess
|
||||
from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_eng
|
||||
|
||||
log = Logger('assets.py')
|
||||
|
||||
@@ -141,12 +142,9 @@ class AssetFinder(object):
|
||||
# reference to an AssetFinder.
|
||||
PERSISTENT_TOKEN = "<AssetFinder>"
|
||||
|
||||
@preprocess(engine=coerce_string_to_eng)
|
||||
def __init__(self, engine):
|
||||
self.engine = engine = (
|
||||
sa.create_engine('sqlite:///' + engine)
|
||||
if isinstance(engine, string_types) else
|
||||
engine
|
||||
)
|
||||
self.engine = engine
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
metadata.reflect(only=asset_db_table_names)
|
||||
for table_name in asset_db_table_names:
|
||||
|
||||
@@ -54,12 +54,11 @@ from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.functional import apply
|
||||
from zipline.utils.preprocess import call
|
||||
from zipline.utils.input_validation import (
|
||||
coerce_string,
|
||||
preprocess,
|
||||
expect_element,
|
||||
verify_indices_all_unique,
|
||||
)
|
||||
from zipline.utils.sqlite_utils import group_into_chunks
|
||||
from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_conn
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.cli import maybe_show_progress
|
||||
from ._equities import _compute_row_slices, _read_bcolz_data
|
||||
@@ -1228,7 +1227,7 @@ class SQLiteAdjustmentReader(object):
|
||||
:class:`zipline.data.us_equity_pricing.SQLiteAdjustmentWriter`
|
||||
"""
|
||||
|
||||
@preprocess(conn=coerce_string(sqlite3.connect))
|
||||
@preprocess(conn=coerce_string_to_conn)
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ class TradingEnvironment(object):
|
||||
self.exchange_tz = exchange_tz
|
||||
|
||||
if isinstance(asset_db_path, string_types):
|
||||
asset_db_path = 'sqlite:///%s' % asset_db_path
|
||||
asset_db_path = 'sqlite:///' + asset_db_path
|
||||
self.engine = engine = create_engine(asset_db_path)
|
||||
else:
|
||||
self.engine = engine = asset_db_path
|
||||
|
||||
@@ -35,9 +35,6 @@ def engine_from_files(daily_bar_path,
|
||||
memory consumption. Default is False
|
||||
"""
|
||||
loader = USEquityPricingLoader.from_files(daily_bar_path, adjustments_path)
|
||||
|
||||
if not asset_db_path.startswith("sqlite:"):
|
||||
asset_db_path = "sqlite:///" + asset_db_path
|
||||
asset_finder = AssetFinder(asset_db_path)
|
||||
if warmup_assets:
|
||||
results = asset_finder.retrieve_all(asset_finder.sids)
|
||||
|
||||
@@ -12,9 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sqlite3
|
||||
|
||||
import sqlalchemy as sa
|
||||
from six.moves import range
|
||||
|
||||
from .input_validation import coerce_string
|
||||
|
||||
SQLITE_MAX_VARIABLE_NUMBER = 998
|
||||
|
||||
|
||||
@@ -22,3 +26,9 @@ def group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
|
||||
items = list(items)
|
||||
return [items[x:x+chunk_size]
|
||||
for x in range(0, len(items), chunk_size)]
|
||||
|
||||
|
||||
coerce_string_to_conn = coerce_string(sqlite3.connect)
|
||||
coerce_string_to_eng = coerce_string(
|
||||
lambda s: sa.create_engine('sqlite:///' + s)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user