diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 41c86e9a..b275bbaf 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -749,6 +749,8 @@ class tmp_assets_db(object): Parameters ---------- + url : string + The URL for the database connection. **frames The frames to pass to the AssetDBWriter. By default this maps equities: @@ -761,7 +763,11 @@ class tmp_assets_db(object): """ _default_equities = sentinel('_default_equities') - def __init__(self, equities=_default_equities, **frames): + def __init__(self, + url='sqlite:///:memory:', + equities=_default_equities, + **frames): + self._url = url self._eng = None if equities is self._default_equities: equities = make_simple_equity_info( @@ -775,7 +781,7 @@ class tmp_assets_db(object): self._eng = None # set in enter and exit def __enter__(self): - self._eng = eng = create_engine('sqlite://') + self._eng = eng = create_engine(self._url) AssetDBWriter(eng).write(**self._frames) return eng @@ -800,6 +806,8 @@ class tmp_asset_finder(tmp_assets_db): Parameters ---------- + url : string + The URL for the database connection. finder_cls : type, optional The type of asset finder to create from the assets db. **frames @@ -809,9 +817,12 @@ class tmp_asset_finder(tmp_assets_db): -------- tmp_assets_db """ - def __init__(self, finder_cls=AssetFinder, **frames): + def __init__(self, + url='sqlite:///:memory:', + finder_cls=AssetFinder, + **frames): self._finder_cls = finder_cls - super(tmp_asset_finder, self).__init__(**frames) + super(tmp_asset_finder, self).__init__(url=url, **frames) def __enter__(self): return self._finder_cls(super(tmp_asset_finder, self).__enter__()) diff --git a/zipline/testing/fixtures.py b/zipline/testing/fixtures.py index 69c9851f..a7851f5f 100644 --- a/zipline/testing/fixtures.py +++ b/zipline/testing/fixtures.py @@ -304,6 +304,9 @@ class WithAssetFinder(WithDefaultDateBounds): A class method which constructs the dataframe of root symbols information to write to the class's assets db. By default this is empty. + make_asset_finder_db_url() -> string + A class method which returns the URL at which to create the SQLAlchemy + engine. By default provides a URL for an in-memory database. make_asset_finder() -> pd.DataFrame A class method which constructs the actual asset finder object to use for the class. If this method is overridden then the ``make_*_info`` @@ -341,9 +344,14 @@ class WithAssetFinder(WithDefaultDateBounds): cls.ASSET_FINDER_EQUITY_SYMBOLS, ) + @classmethod + def make_asset_finder_db_url(cls): + return 'sqlite:///:memory:' + @classmethod def make_asset_finder(cls): return cls.enter_class_context(tmp_asset_finder( + url=cls.make_asset_finder_db_url(), equities=cls.make_equity_info(), futures=cls.make_futures_info(), exchanges=cls.make_exchanges_info(), @@ -618,6 +626,10 @@ class WithBcolzDailyBarReader(WithTradingEnvironment, WithTmpDir): ``BcolzDailyBarReader`` will read from. By default this creates some simple sythetic data with :func:`~zipline.testing.create_daily_bar_data` + make_bcolz_daily_bar_rootdir_path() -> string + A class method that returns the path for the rootdir of the daily + bars ctable. By default this is a subdirectory BCOLZ_DAILY_BAR_PATH in + the shared temp directory. See Also -------- @@ -693,12 +705,14 @@ class WithBcolzDailyBarReader(WithTradingEnvironment, WithTmpDir): cls.asset_finder.sids, ) + @classmethod + def make_bcolz_daily_bar_rootdir_path(cls): + return cls.tmpdir.makedir(cls.BCOLZ_DAILY_BAR_PATH) + @classmethod def init_class_fixtures(cls): super(WithBcolzDailyBarReader, cls).init_class_fixtures() - cls.bcolz_daily_bar_path = p = cls.tmpdir.makedir( - cls.BCOLZ_DAILY_BAR_PATH, - ) + cls.bcolz_daily_bar_path = p = cls.make_bcolz_daily_bar_rootdir_path() if cls.BCOLZ_DAILY_BAR_USE_FULL_CALENDAR: days = cls.trading_schedule.all_execution_days else: @@ -766,6 +780,10 @@ class WithBcolzMinuteBarReader(WithTradingEnvironment, WithTmpDir): ``BcolzMinuteBarReader`` will read from. By default this creates some simple sythetic data with :func:`~zipline.testing.create_minute_bar_data` + make_bcolz_minute_bar_rootdir_path() -> string + A class method that returns the path for the directory that contains + the minute bar ctables. By default this is a subdirectory + BCOLZ_MINUTE_BAR_PATH in the shared temp directory. See Also -------- @@ -789,12 +807,15 @@ class WithBcolzMinuteBarReader(WithTradingEnvironment, WithTmpDir): cls.asset_finder.sids, ) + @classmethod + def make_bcolz_minute_bar_rootdir_path(cls): + return cls.tmpdir.makedir(cls.BCOLZ_MINUTE_BAR_PATH) + @classmethod def init_class_fixtures(cls): super(WithBcolzMinuteBarReader, cls).init_class_fixtures() - cls.bcolz_minute_bar_path = p = cls.tmpdir.makedir( - cls.BCOLZ_MINUTE_BAR_PATH, - ) + cls.bcolz_minute_bar_path = p = \ + cls.make_bcolz_minute_bar_rootdir_path() if cls.BCOLZ_MINUTE_BAR_USE_FULL_CALENDAR: days = cls.trading_schedule.all_execution_days else: @@ -845,6 +866,10 @@ class WithAdjustmentReader(WithBcolzDailyBarReader): make_stock_dividends_data() -> pd.DataFrame A class method that returns a dataframe of stock dividends data to write to the class's adjustment db. By default this is empty. + make_adjustment_db_conn_str() -> string + A class method that returns the sqlite3 connection string for the + database in to which the adjustments will be written. By default this + is an in-memory database. make_adjustment_writer_daily_bar_reader() -> pd.DataFrame A class method that returns the daily bar reader to use for the class's adjustment writer. By default this is the class's actual @@ -883,10 +908,14 @@ class WithAdjustmentReader(WithBcolzDailyBarReader): def make_adjustment_writer_daily_bar_reader(cls): return cls.bcolz_daily_bar_reader + @classmethod + def make_adjustment_db_conn_str(cls): + return ':memory:' + @classmethod def init_class_fixtures(cls): super(WithAdjustmentReader, cls).init_class_fixtures() - conn = sqlite3.connect(':memory:') + conn = sqlite3.connect(cls.make_adjustment_db_conn_str()) cls.make_adjustment_writer(conn).write( splits=cls.make_splits_data(), mergers=cls.make_mergers_data(),