diff --git a/zipline/assets/asset_db_migrations.py b/zipline/assets/asset_db_migrations.py index 550e096c..82f9d147 100644 --- a/zipline/assets/asset_db_migrations.py +++ b/zipline/assets/asset_db_migrations.py @@ -7,8 +7,11 @@ from toolz.curried import do, operator as op from zipline.assets.asset_writer import write_version_info from zipline.errors import AssetDBImpossibleDowngrade +from zipline.utils.preprocess import preprocess +from zipline.utils.sqlite_utils import coerce_string_to_eng +@preprocess(engine=coerce_string_to_eng) def downgrade(engine, desired_version): """Downgrades the assets db at the given engine to the desired version. diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index 17d61b2f..50e9070d 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -35,7 +35,9 @@ from zipline.assets.asset_db_schema import ( version_info, ) +from zipline.utils.preprocess import preprocess from zipline.utils.range import from_tuple, intersecting_ranges +from zipline.utils.sqlite_utils import coerce_string_to_eng # Define a namedtuple for use with the load_data and _load_data methods AssetData = namedtuple( @@ -344,9 +346,8 @@ class AssetDBWriter(object): """ DEFAULT_CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER + @preprocess(engine=coerce_string_to_eng) def __init__(self, engine): - if isinstance(engine, str): - engine = sa.create_engine('sqlite:///' + engine) self.engine = engine def write(self, diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 8ca34629..caf6727c 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -60,7 +60,8 @@ from .asset_db_schema import ( from zipline.utils.control_flow import invert from zipline.utils.memoize import lazyval from zipline.utils.numpy_utils import as_column -from zipline.utils.sqlite_utils import group_into_chunks +from zipline.utils.preprocess import preprocess +from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_eng log = Logger('assets.py') @@ -141,12 +142,9 @@ class AssetFinder(object): # reference to an AssetFinder. PERSISTENT_TOKEN = "" + @preprocess(engine=coerce_string_to_eng) def __init__(self, engine): - self.engine = engine = ( - sa.create_engine('sqlite:///' + engine) - if isinstance(engine, string_types) else - engine - ) + self.engine = engine metadata = sa.MetaData(bind=engine) metadata.reflect(only=asset_db_table_names) for table_name in asset_db_table_names: diff --git a/zipline/data/us_equity_pricing.py b/zipline/data/us_equity_pricing.py index 71579ca5..29acdf0d 100644 --- a/zipline/data/us_equity_pricing.py +++ b/zipline/data/us_equity_pricing.py @@ -54,12 +54,11 @@ from zipline.utils.calendars import get_calendar from zipline.utils.functional import apply from zipline.utils.preprocess import call from zipline.utils.input_validation import ( - coerce_string, preprocess, expect_element, verify_indices_all_unique, ) -from zipline.utils.sqlite_utils import group_into_chunks +from zipline.utils.sqlite_utils import group_into_chunks, coerce_string_to_conn from zipline.utils.memoize import lazyval from zipline.utils.cli import maybe_show_progress from ._equities import _compute_row_slices, _read_bcolz_data @@ -1228,7 +1227,7 @@ class SQLiteAdjustmentReader(object): :class:`zipline.data.us_equity_pricing.SQLiteAdjustmentWriter` """ - @preprocess(conn=coerce_string(sqlite3.connect)) + @preprocess(conn=coerce_string_to_conn) def __init__(self, conn): self.conn = conn diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 79fe5b75..4b13b8e1 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -99,7 +99,7 @@ class TradingEnvironment(object): self.exchange_tz = exchange_tz if isinstance(asset_db_path, string_types): - asset_db_path = 'sqlite:///%s' % asset_db_path + asset_db_path = 'sqlite:///' + asset_db_path self.engine = engine = create_engine(asset_db_path) else: self.engine = engine = asset_db_path diff --git a/zipline/pipeline/__init__.py b/zipline/pipeline/__init__.py index 8d2a412f..a169256b 100644 --- a/zipline/pipeline/__init__.py +++ b/zipline/pipeline/__init__.py @@ -35,9 +35,6 @@ def engine_from_files(daily_bar_path, memory consumption. Default is False """ loader = USEquityPricingLoader.from_files(daily_bar_path, adjustments_path) - - if not asset_db_path.startswith("sqlite:"): - asset_db_path = "sqlite:///" + asset_db_path asset_finder = AssetFinder(asset_db_path) if warmup_assets: results = asset_finder.retrieve_all(asset_finder.sids) diff --git a/zipline/utils/sqlite_utils.py b/zipline/utils/sqlite_utils.py index 439db2f2..17bfc9a5 100644 --- a/zipline/utils/sqlite_utils.py +++ b/zipline/utils/sqlite_utils.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sqlite3 +import sqlalchemy as sa from six.moves import range +from .input_validation import coerce_string + SQLITE_MAX_VARIABLE_NUMBER = 998 @@ -22,3 +26,9 @@ def group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER): items = list(items) return [items[x:x+chunk_size] for x in range(0, len(items), chunk_size)] + + +coerce_string_to_conn = coerce_string(sqlite3.connect) +coerce_string_to_eng = coerce_string( + lambda s: sa.create_engine('sqlite:///' + s) +)