MAINT: Consolidate coercion to sqlite conn/eng

This commit is contained in:
Richard Frank
2016-08-23 10:17:33 -04:00
parent 6d473ce23f
commit bc87ea4efb
7 changed files with 23 additions and 15 deletions
+3
View File
@@ -7,8 +7,11 @@ from toolz.curried import do, operator as op
from zipline.assets.asset_writer import write_version_info
from zipline.errors import AssetDBImpossibleDowngrade
from zipline.utils.preprocess import preprocess
from zipline.utils.sqlite_utils import coerce_string_to_eng
@preprocess(engine=coerce_string_to_eng)
def downgrade(engine, desired_version):
"""Downgrades the assets db at the given engine to the desired version.
+3 -2
View File
@@ -35,7 +35,9 @@ from zipline.assets.asset_db_schema import (
version_info,
)
from zipline.utils.preprocess import preprocess
from zipline.utils.range import from_tuple, intersecting_ranges
from zipline.utils.sqlite_utils import coerce_string_to_eng
# Define a namedtuple for use with the load_data and _load_data methods
AssetData = namedtuple(
@@ -344,9 +346,8 @@ class AssetDBWriter(object):
"""
DEFAULT_CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER
@preprocess(engine=coerce_string_to_eng)
def __init__(self, engine):
if isinstance(engine, str):
engine = sa.create_engine('sqlite:///' + engine)
self.engine = engine
def write(self,
+4 -6
View File
@@ -60,7 +60,8 @@ from .asset_db_schema import (
from zipline.utils.control_flow import invert
from zipline.utils.memoize import lazyval
from zipline.utils.numpy_utils import as_column
from zipline.utils.sqlite_utils import group_into_chunks
from zipline.utils.preprocess import preprocess
from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_eng
log = Logger('assets.py')
@@ -141,12 +142,9 @@ class AssetFinder(object):
# reference to an AssetFinder.
PERSISTENT_TOKEN = "<AssetFinder>"
@preprocess(engine=coerce_string_to_eng)
def __init__(self, engine):
self.engine = engine = (
sa.create_engine('sqlite:///' + engine)
if isinstance(engine, string_types) else
engine
)
self.engine = engine
metadata = sa.MetaData(bind=engine)
metadata.reflect(only=asset_db_table_names)
for table_name in asset_db_table_names:
+2 -3
View File
@@ -54,12 +54,11 @@ from zipline.utils.calendars import get_calendar
from zipline.utils.functional import apply
from zipline.utils.preprocess import call
from zipline.utils.input_validation import (
coerce_string,
preprocess,
expect_element,
verify_indices_all_unique,
)
from zipline.utils.sqlite_utils import group_into_chunks
from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_conn
from zipline.utils.memoize import lazyval
from zipline.utils.cli import maybe_show_progress
from ._equities import _compute_row_slices, _read_bcolz_data
@@ -1228,7 +1227,7 @@ class SQLiteAdjustmentReader(object):
:class:`zipline.data.us_equity_pricing.SQLiteAdjustmentWriter`
"""
@preprocess(conn=coerce_string(sqlite3.connect))
@preprocess(conn=coerce_string_to_conn)
def __init__(self, conn):
self.conn = conn
+1 -1
View File
@@ -99,7 +99,7 @@ class TradingEnvironment(object):
self.exchange_tz = exchange_tz
if isinstance(asset_db_path, string_types):
asset_db_path = 'sqlite:///%s' % asset_db_path
asset_db_path = 'sqlite:///' + asset_db_path
self.engine = engine = create_engine(asset_db_path)
else:
self.engine = engine = asset_db_path
-3
View File
@@ -35,9 +35,6 @@ def engine_from_files(daily_bar_path,
memory consumption. Default is False
"""
loader = USEquityPricingLoader.from_files(daily_bar_path, adjustments_path)
if not asset_db_path.startswith("sqlite:"):
asset_db_path = "sqlite:///" + asset_db_path
asset_finder = AssetFinder(asset_db_path)
if warmup_assets:
results = asset_finder.retrieve_all(asset_finder.sids)
+10
View File
@@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import sqlalchemy as sa
from six.moves import range
from .input_validation import coerce_string
SQLITE_MAX_VARIABLE_NUMBER = 998
@@ -22,3 +26,9 @@ def group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
items = list(items)
return [items[x:x+chunk_size]
for x in range(0, len(items), chunk_size)]
coerce_string_to_conn = coerce_string(sqlite3.connect)
coerce_string_to_eng = coerce_string(
lambda s: sa.create_engine('sqlite:///' + s)
)