diff --git a/tests/pipeline/test_pipeline_algo.py b/tests/pipeline/test_pipeline_algo.py index 09ac4aed..f947d25b 100644 --- a/tests/pipeline/test_pipeline_algo.py +++ b/tests/pipeline/test_pipeline_algo.py @@ -182,7 +182,7 @@ class ClosesOnly(TestCase): initialize=initialize, handle_data=late_attach, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.first_asset_start - trading_day, end=self.last_asset_end + trading_day, env=self.env, @@ -199,7 +199,7 @@ class ClosesOnly(TestCase): before_trading_start=late_attach, handle_data=barf, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.first_asset_start - trading_day, end=self.last_asset_end + trading_day, env=self.env, @@ -228,7 +228,7 @@ class ClosesOnly(TestCase): handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.first_asset_start - trading_day, end=self.last_asset_end + trading_day, env=self.env, @@ -256,7 +256,7 @@ class ClosesOnly(TestCase): handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.first_asset_start - trading_day, end=self.last_asset_end + trading_day, env=self.env, @@ -294,7 +294,7 @@ class ClosesOnly(TestCase): handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.first_asset_start - trading_day, end=self.last_asset_end + trading_day, env=self.env, @@ -524,7 +524,7 @@ class PipelineAlgorithmTestCase(TestCase): handle_data=handle_data, before_trading_start=before_trading_start, data_frequency='daily', - pipeline_loader_dispatch=lambda column: self.pipeline_loader, + get_pipeline_loader=lambda column: self.pipeline_loader, start=self.dates[max(window_lengths)], end=self.dates[-1], env=self.env, diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 83685601..9c43d4df 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -232,7 +232,7 @@ class TradingAlgorithm(object): self.asset_finder = self.trading_environment.asset_finder # Initialize Pipeline API data. - self.init_engine(kwargs.pop('pipeline_loader_dispatch', None)) + self.init_engine(kwargs.pop('get_pipeline_loader', None)) self._pipelines = {} # Create an always-expired cache so that we compute the first time data # is requested. @@ -323,15 +323,15 @@ class TradingAlgorithm(object): self.initialize_args = args self.initialize_kwargs = kwargs - def init_engine(self, loader_dispatch): + def init_engine(self, get_loader): """ Construct and store a PipelineEngine from loader. - If loader is None, constructs a NoOpPipelineEngine. + If get_loader is None, constructs a NoOpPipelineEngine. """ - if loader_dispatch is not None: + if get_loader is not None: self.engine = SimplePipelineEngine( - loader_dispatch, + get_loader, self.trading_environment.trading_days, self.asset_finder, ) diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index 64cc1306..2e4f2e0f 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -83,8 +83,9 @@ class SimplePipelineEngine(object): Parameters ---------- - loader : PipelineLoader - A loader to use to retrieve raw data for atomic terms. + get_loader : callable + A function that is given an atomic term and returns a PipelineLoader + to use to retrieve raw data for that term. calendar : DatetimeIndex Array of dates to consider as trading days when computing a range between a fixed start and end. @@ -93,15 +94,15 @@ class SimplePipelineEngine(object): which assets are in the top-level universe at any point in time. """ __slots__ = [ - '_loader_dispatch', + '_get_loader', '_calendar', '_finder', '_root_mask_term', '__weakref__', ] - def __init__(self, loader_dispatch, calendar, asset_finder): - self._loader_dispatch = loader_dispatch + def __init__(self, get_loader, calendar, asset_finder): + self._get_loader = get_loader self._calendar = calendar self._finder = asset_finder self._root_mask_term = AssetExists() @@ -275,17 +276,14 @@ class SimplePipelineEngine(object): out.append(input_data) return out - def loader_dispatch(self, term): + def get_loader(self, term): # AssetExists is one of the atomic terms in the graph, so we look up # a loader here when grouping by loader, but since it's already in the # workspace, we don't actually use that group. if term is AssetExists(): return None - loader = self._loader_dispatch(term) - if loader is None: - raise ValueError("Couldn't find loader for %s" % term) - return loader + return self._get_loader(term) def compute_chunk(self, graph, dates, assets, initial_workspace): """ @@ -311,14 +309,14 @@ class SimplePipelineEngine(object): Dictionary mapping requested results to outputs. """ self._validate_compute_chunk_params(dates, assets, initial_workspace) - loader_dispatch = self.loader_dispatch + get_loader = self.get_loader # Copy the supplied initial workspace so we don't mutate it in place. workspace = initial_workspace.copy() # If atomic terms share the same loader and extra_rows, load them all # together. - atomic_group_key = juxt(loader_dispatch, getitem(graph.extra_rows)) + atomic_group_key = juxt(get_loader, getitem(graph.extra_rows)) atomic_groups = groupby(atomic_group_key, graph.atomic_terms) for term in graph.ordered(): @@ -340,7 +338,7 @@ class SimplePipelineEngine(object): atomic_groups[atomic_group_key(term)], key=lambda t: t.dataset ) - loader = loader_dispatch(term) + loader = get_loader(term) loaded = loader.load_adjusted_array( to_load, mask_dates, assets, mask, )