diff --git a/zipline/testing/fixtures.py b/zipline/testing/fixtures.py index e08ee7b7..78e7e707 100644 --- a/zipline/testing/fixtures.py +++ b/zipline/testing/fixtures.py @@ -1,3 +1,4 @@ +from itertools import repeat import os import sqlite3 from unittest import TestCase @@ -50,11 +51,14 @@ import zipline from zipline.assets import Equity, Future from zipline.finance.asset_restrictions import NoRestrictions from zipline.pipeline import SimplePipelineEngine +from zipline.pipeline.data import USEquityPricing +from zipline.pipeline.loaders import USEquityPricingLoader from zipline.pipeline.loaders.testing import make_seeded_random_loader from zipline.protocol import BarData from zipline.utils.calendars import ( get_calendar, register_calendar) +from zipline.utils.paths import ensure_directory zipline_dir = os.path.dirname(zipline.__file__) @@ -1307,6 +1311,47 @@ class WithAdjustmentReader(WithBcolzEquityDailyBarReader): cls.adjustment_reader = SQLiteAdjustmentReader(conn) +class WithEquityPricingPipelineEngine(WithAdjustmentReader, + WithTradingSessions): + """ + Mixin providing the following as a class-level fixtures. + - cls.data_root_dir + - cls.findata_dir + - cls.pipeline_engine + - cls.adjustments_db_path + + """ + @classmethod + def init_class_fixtures(cls): + cls.data_root_dir = cls.enter_class_context(tmp_dir()) + cls.findata_dir = cls.data_root_dir.makedir('findata') + super(WithEquityPricingPipelineEngine, cls).init_class_fixtures() + + loader = USEquityPricingLoader( + cls.bcolz_equity_daily_bar_reader, + SQLiteAdjustmentReader(cls.adjustments_db_path), + ) + dispatcher = dict( + zip(USEquityPricing.columns, repeat(loader)) + ).__getitem__ + + cls.pipeline_engine = SimplePipelineEngine( + get_loader=dispatcher, + calendar=cls.nyse_sessions, + asset_finder=cls.asset_finder, + ) + + @classmethod + def make_adjustment_db_conn_str(cls): + cls.adjustments_db_path = os.path.join( + cls.findata_dir, + 'adjustments', + cls.END_DATE.strftime("%Y-%m-%d-adjustments.db") + ) + ensure_directory(os.path.dirname(cls.adjustments_db_path)) + return cls.adjustments_db_path + + class WithSeededRandomPipelineEngine(WithTradingSessions, WithAssetFinder): """ ZiplineTestCase mixin providing class-level fixtures for running pipelines