mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 00:21:40 +08:00
MAINT: Renamed loader_dispatch to get_loader
Now it raises a KeyError instead of returning None, if loader not found.
This commit is contained in:
@@ -182,7 +182,7 @@ class ClosesOnly(TestCase):
|
||||
initialize=initialize,
|
||||
handle_data=late_attach,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -199,7 +199,7 @@ class ClosesOnly(TestCase):
|
||||
before_trading_start=late_attach,
|
||||
handle_data=barf,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -228,7 +228,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -256,7 +256,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -294,7 +294,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -524,7 +524,7 @@ class PipelineAlgorithmTestCase(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader_dispatch=lambda column: self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.dates[max(window_lengths)],
|
||||
end=self.dates[-1],
|
||||
env=self.env,
|
||||
|
||||
@@ -232,7 +232,7 @@ class TradingAlgorithm(object):
|
||||
self.asset_finder = self.trading_environment.asset_finder
|
||||
|
||||
# Initialize Pipeline API data.
|
||||
self.init_engine(kwargs.pop('pipeline_loader_dispatch', None))
|
||||
self.init_engine(kwargs.pop('get_pipeline_loader', None))
|
||||
self._pipelines = {}
|
||||
# Create an always-expired cache so that we compute the first time data
|
||||
# is requested.
|
||||
@@ -323,15 +323,15 @@ class TradingAlgorithm(object):
|
||||
self.initialize_args = args
|
||||
self.initialize_kwargs = kwargs
|
||||
|
||||
def init_engine(self, loader_dispatch):
|
||||
def init_engine(self, get_loader):
|
||||
"""
|
||||
Construct and store a PipelineEngine from loader.
|
||||
|
||||
If loader is None, constructs a NoOpPipelineEngine.
|
||||
If get_loader is None, constructs a NoOpPipelineEngine.
|
||||
"""
|
||||
if loader_dispatch is not None:
|
||||
if get_loader is not None:
|
||||
self.engine = SimplePipelineEngine(
|
||||
loader_dispatch,
|
||||
get_loader,
|
||||
self.trading_environment.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
|
||||
+11
-13
@@ -83,8 +83,9 @@ class SimplePipelineEngine(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
loader : PipelineLoader
|
||||
A loader to use to retrieve raw data for atomic terms.
|
||||
get_loader : callable
|
||||
A function that is given an atomic term and returns a PipelineLoader
|
||||
to use to retrieve raw data for that term.
|
||||
calendar : DatetimeIndex
|
||||
Array of dates to consider as trading days when computing a range
|
||||
between a fixed start and end.
|
||||
@@ -93,15 +94,15 @@ class SimplePipelineEngine(object):
|
||||
which assets are in the top-level universe at any point in time.
|
||||
"""
|
||||
__slots__ = [
|
||||
'_loader_dispatch',
|
||||
'_get_loader',
|
||||
'_calendar',
|
||||
'_finder',
|
||||
'_root_mask_term',
|
||||
'__weakref__',
|
||||
]
|
||||
|
||||
def __init__(self, loader_dispatch, calendar, asset_finder):
|
||||
self._loader_dispatch = loader_dispatch
|
||||
def __init__(self, get_loader, calendar, asset_finder):
|
||||
self._get_loader = get_loader
|
||||
self._calendar = calendar
|
||||
self._finder = asset_finder
|
||||
self._root_mask_term = AssetExists()
|
||||
@@ -275,17 +276,14 @@ class SimplePipelineEngine(object):
|
||||
out.append(input_data)
|
||||
return out
|
||||
|
||||
def loader_dispatch(self, term):
|
||||
def get_loader(self, term):
|
||||
# AssetExists is one of the atomic terms in the graph, so we look up
|
||||
# a loader here when grouping by loader, but since it's already in the
|
||||
# workspace, we don't actually use that group.
|
||||
if term is AssetExists():
|
||||
return None
|
||||
|
||||
loader = self._loader_dispatch(term)
|
||||
if loader is None:
|
||||
raise ValueError("Couldn't find loader for %s" % term)
|
||||
return loader
|
||||
return self._get_loader(term)
|
||||
|
||||
def compute_chunk(self, graph, dates, assets, initial_workspace):
|
||||
"""
|
||||
@@ -311,14 +309,14 @@ class SimplePipelineEngine(object):
|
||||
Dictionary mapping requested results to outputs.
|
||||
"""
|
||||
self._validate_compute_chunk_params(dates, assets, initial_workspace)
|
||||
loader_dispatch = self.loader_dispatch
|
||||
get_loader = self.get_loader
|
||||
|
||||
# Copy the supplied initial workspace so we don't mutate it in place.
|
||||
workspace = initial_workspace.copy()
|
||||
|
||||
# If atomic terms share the same loader and extra_rows, load them all
|
||||
# together.
|
||||
atomic_group_key = juxt(loader_dispatch, getitem(graph.extra_rows))
|
||||
atomic_group_key = juxt(get_loader, getitem(graph.extra_rows))
|
||||
atomic_groups = groupby(atomic_group_key, graph.atomic_terms)
|
||||
|
||||
for term in graph.ordered():
|
||||
@@ -340,7 +338,7 @@ class SimplePipelineEngine(object):
|
||||
atomic_groups[atomic_group_key(term)],
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = loader_dispatch(term)
|
||||
loader = get_loader(term)
|
||||
loaded = loader.load_adjusted_array(
|
||||
to_load, mask_dates, assets, mask,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user