mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 01:56:36 +08:00
Merge pull request #756 from quantopian/batch_load_ffc
Batch load Pipeline columns
This commit is contained in:
@@ -47,7 +47,7 @@ bcolz==0.10.0
|
||||
click==4.0.0
|
||||
|
||||
# FUNctional programming utilities
|
||||
toolz==0.7.2
|
||||
toolz==0.7.4
|
||||
|
||||
# Asset writer and finder
|
||||
sqlalchemy==1.0.8
|
||||
|
||||
@@ -93,7 +93,7 @@ class BasePipelineTestCase(TestCase):
|
||||
Mapping from termname -> computed result.
|
||||
"""
|
||||
engine = SimplePipelineEngine(
|
||||
ExplodingObject(),
|
||||
lambda column: ExplodingObject(),
|
||||
self.__calendar,
|
||||
self.__finder,
|
||||
)
|
||||
|
||||
+213
-78
@@ -2,6 +2,7 @@
|
||||
Tests for SimplePipelineEngine
|
||||
"""
|
||||
from __future__ import division
|
||||
from collections import OrderedDict
|
||||
from unittest import TestCase
|
||||
from itertools import product
|
||||
|
||||
@@ -11,6 +12,8 @@ from numpy import (
|
||||
nan,
|
||||
tile,
|
||||
zeros,
|
||||
float32,
|
||||
concatenate,
|
||||
)
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
@@ -21,19 +24,20 @@ from pandas import (
|
||||
Series,
|
||||
Timestamp,
|
||||
)
|
||||
from pandas.compat.chainmap import ChainMap
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from six import iteritems, itervalues
|
||||
from testfixtures import TempDirectory
|
||||
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
ConstantLoader,
|
||||
MultiColumnLoader,
|
||||
NullAdjustmentReader,
|
||||
SyntheticDailyBarWriter,
|
||||
)
|
||||
from zipline.data.us_equity_pricing import BcolzDailyBarReader
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.data import USEquityPricing, DataSet, Column
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader, MULTIPLY
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
@@ -85,6 +89,49 @@ def assert_multi_index_is_product(testcase, index, *levels):
|
||||
testcase.assertEqual(set(index), set(product(*levels)))
|
||||
|
||||
|
||||
class ColumnArgs(tuple):
|
||||
"""A tuple of Columns that defines equivalence based on the order of the
|
||||
columns' DataSets, instead of the columns themselves. This is used when
|
||||
comparing the columns passed to a loader's load_adjusted_array method,
|
||||
since we want to assert that they are ordered by DataSet.
|
||||
"""
|
||||
def __new__(cls, *cols):
|
||||
return super(ColumnArgs, cls).__new__(cls, cols)
|
||||
|
||||
@classmethod
|
||||
def sorted_by_ds(cls, *cols):
|
||||
return cls(*sorted(cols, key=lambda col: col.dataset))
|
||||
|
||||
def by_ds(self):
|
||||
return tuple(col.dataset for col in self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return set(self) == set(other) and self.by_ds() == other.by_ds()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(frozenset(self))
|
||||
|
||||
|
||||
class RecordingConstantLoader(ConstantLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RecordingConstantLoader, self).__init__(*args, **kwargs)
|
||||
|
||||
self.load_calls = []
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
self.load_calls.append(ColumnArgs(*columns))
|
||||
|
||||
return super(RecordingConstantLoader, self).load_adjusted_array(
|
||||
columns, dates, assets, mask,
|
||||
)
|
||||
|
||||
|
||||
class RollingSumSum(CustomFactor):
|
||||
def compute(self, today, assets, out, *inputs):
|
||||
assert len(self.inputs) == len(inputs)
|
||||
out[:] = sum(inputs).sum(axis=0)
|
||||
|
||||
|
||||
class ConstantInputTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -97,7 +144,7 @@ class ConstantInputTestCase(TestCase):
|
||||
USEquityPricing.high: 4,
|
||||
}
|
||||
self.assets = [1, 2, 3]
|
||||
self.dates = date_range('2014-01-01', '2014-02-01', freq='D', tz='UTC')
|
||||
self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')
|
||||
self.loader = ConstantLoader(
|
||||
constants=self.constants,
|
||||
dates=self.dates,
|
||||
@@ -115,7 +162,9 @@ class ConstantInputTestCase(TestCase):
|
||||
|
||||
def test_bad_dates(self):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
p = Pipeline()
|
||||
|
||||
@@ -129,7 +178,9 @@ class ConstantInputTestCase(TestCase):
|
||||
loader = self.loader
|
||||
finder = self.asset_finder
|
||||
assets = array(self.assets)
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
num_dates = 5
|
||||
dates = self.dates[10:10 + num_dates]
|
||||
|
||||
@@ -152,7 +203,9 @@ class ConstantInputTestCase(TestCase):
|
||||
loader = self.loader
|
||||
finder = self.asset_finder
|
||||
assets = self.assets
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
result_shape = (num_dates, num_assets) = (5, len(assets))
|
||||
dates = self.dates[10:10 + num_dates]
|
||||
|
||||
@@ -185,7 +238,9 @@ class ConstantInputTestCase(TestCase):
|
||||
loader = self.loader
|
||||
finder = self.asset_finder
|
||||
assets = self.assets
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
shape = num_dates, num_assets = (5, len(assets))
|
||||
dates = self.dates[10:10 + num_dates]
|
||||
|
||||
@@ -228,7 +283,9 @@ class ConstantInputTestCase(TestCase):
|
||||
def test_numeric_factor(self):
|
||||
constants = self.constants
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
num_dates = 5
|
||||
dates = self.dates[10:10 + num_dates]
|
||||
high, low = USEquityPricing.high, USEquityPricing.low
|
||||
@@ -271,6 +328,147 @@ class ConstantInputTestCase(TestCase):
|
||||
DataFrame(expected_avg, index=dates, columns=self.assets),
|
||||
)
|
||||
|
||||
def test_rolling_and_nonrolling(self):
|
||||
open_ = USEquityPricing.open
|
||||
close = USEquityPricing.close
|
||||
volume = USEquityPricing.volume
|
||||
|
||||
# Test for thirty days up to the last day that we think all
|
||||
# the assets existed.
|
||||
dates_to_test = self.dates[-30:]
|
||||
|
||||
constants = {open_: 1, close: 2, volume: 3}
|
||||
loader = ConstantLoader(
|
||||
constants=constants,
|
||||
dates=self.dates,
|
||||
assets=self.assets,
|
||||
)
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
sumdiff = RollingSumDifference()
|
||||
|
||||
result = engine.run_pipeline(
|
||||
Pipeline(
|
||||
columns={
|
||||
'sumdiff': sumdiff,
|
||||
'open': open_.latest,
|
||||
'close': close.latest,
|
||||
'volume': volume.latest,
|
||||
},
|
||||
),
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1]
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(
|
||||
{'sumdiff', 'open', 'close', 'volume'},
|
||||
set(result.columns)
|
||||
)
|
||||
|
||||
result_index = self.assets * len(dates_to_test)
|
||||
result_shape = (len(result_index),)
|
||||
check_arrays(
|
||||
result['sumdiff'],
|
||||
Series(index=result_index, data=full(result_shape, -3)),
|
||||
)
|
||||
|
||||
for name, const in [('open', 1), ('close', 2), ('volume', 3)]:
|
||||
check_arrays(
|
||||
result[name],
|
||||
Series(index=result_index, data=full(result_shape, const)),
|
||||
)
|
||||
|
||||
def test_loader_given_multiple_columns(self):
|
||||
|
||||
class Loader1DataSet1(DataSet):
|
||||
col1 = Column(float32)
|
||||
col2 = Column(float32)
|
||||
|
||||
class Loader1DataSet2(DataSet):
|
||||
col1 = Column(float32)
|
||||
col2 = Column(float32)
|
||||
|
||||
class Loader2DataSet(DataSet):
|
||||
col1 = Column(float32)
|
||||
col2 = Column(float32)
|
||||
|
||||
constants1 = {Loader1DataSet1.col1: 1,
|
||||
Loader1DataSet1.col2: 2,
|
||||
Loader1DataSet2.col1: 3,
|
||||
Loader1DataSet2.col2: 4}
|
||||
loader1 = RecordingConstantLoader(constants=constants1,
|
||||
dates=self.dates,
|
||||
assets=self.assets)
|
||||
constants2 = {Loader2DataSet.col1: 5,
|
||||
Loader2DataSet.col2: 6}
|
||||
loader2 = RecordingConstantLoader(constants=constants2,
|
||||
dates=self.dates,
|
||||
assets=self.assets)
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column:
|
||||
loader2 if column.dataset == Loader2DataSet else loader1,
|
||||
self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
pipe_col1 = RollingSumSum(inputs=[Loader1DataSet1.col1,
|
||||
Loader1DataSet2.col1,
|
||||
Loader2DataSet.col1],
|
||||
window_length=2)
|
||||
|
||||
pipe_col2 = RollingSumSum(inputs=[Loader1DataSet1.col2,
|
||||
Loader1DataSet2.col2,
|
||||
Loader2DataSet.col2],
|
||||
window_length=3)
|
||||
|
||||
pipe_col3 = RollingSumSum(inputs=[Loader2DataSet.col1],
|
||||
window_length=3)
|
||||
|
||||
columns = OrderedDict([
|
||||
('pipe_col1', pipe_col1),
|
||||
('pipe_col2', pipe_col2),
|
||||
('pipe_col3', pipe_col3),
|
||||
])
|
||||
result = engine.run_pipeline(
|
||||
Pipeline(columns=columns),
|
||||
self.dates[2], # index is >= the largest window length - 1
|
||||
self.dates[-1]
|
||||
)
|
||||
min_window = min(pip_col.window_length
|
||||
for pip_col in itervalues(columns))
|
||||
col_to_val = ChainMap(constants1, constants2)
|
||||
vals = {name: (sum(col_to_val[col] for col in pipe_col.inputs)
|
||||
* pipe_col.window_length)
|
||||
for name, pipe_col in iteritems(columns)}
|
||||
|
||||
index = MultiIndex.from_product([self.dates[2:], self.assets])
|
||||
expected = DataFrame(
|
||||
data={col:
|
||||
concatenate((
|
||||
full((columns[col].window_length - min_window)
|
||||
* index.levshape[1],
|
||||
nan),
|
||||
full((index.levshape[0]
|
||||
- (columns[col].window_length - min_window))
|
||||
* index.levshape[1],
|
||||
val)))
|
||||
for col, val in iteritems(vals)},
|
||||
index=index,
|
||||
columns=columns)
|
||||
|
||||
assert_frame_equal(result, expected)
|
||||
|
||||
self.assertEqual(set(loader1.load_calls),
|
||||
{ColumnArgs.sorted_by_ds(Loader1DataSet1.col1,
|
||||
Loader1DataSet2.col1),
|
||||
ColumnArgs.sorted_by_ds(Loader1DataSet1.col2,
|
||||
Loader1DataSet2.col2)})
|
||||
self.assertEqual(set(loader2.load_calls),
|
||||
{ColumnArgs.sorted_by_ds(Loader2DataSet.col1,
|
||||
Loader2DataSet.col2)})
|
||||
|
||||
|
||||
class FrameInputTestCase(TestCase):
|
||||
|
||||
@@ -353,9 +551,12 @@ class FrameInputTestCase(TestCase):
|
||||
high_base.iloc[:apply_idxs[2], 1] /= 5.0
|
||||
|
||||
high_loader = DataFrameLoader(high, high_base, adjustments)
|
||||
loader = MultiColumnLoader({low: low_loader, high: high_loader})
|
||||
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
engine = SimplePipelineEngine(
|
||||
{low: low_loader, high: high_loader}.__getitem__,
|
||||
self.dates,
|
||||
self.asset_finder,
|
||||
)
|
||||
|
||||
for window_length in range(1, 4):
|
||||
low_mavg = SimpleMovingAverage(
|
||||
@@ -465,7 +666,7 @@ class SyntheticBcolzTestCase(TestCase):
|
||||
|
||||
def test_SMA(self):
|
||||
engine = SimplePipelineEngine(
|
||||
self.pipeline_loader,
|
||||
lambda column: self.pipeline_loader,
|
||||
self.env.trading_days,
|
||||
self.finder,
|
||||
)
|
||||
@@ -517,7 +718,7 @@ class SyntheticBcolzTestCase(TestCase):
|
||||
# or zero, but verifying we correctly handle those corner cases is
|
||||
# valuable.
|
||||
engine = SimplePipelineEngine(
|
||||
self.pipeline_loader,
|
||||
lambda column: self.pipeline_loader,
|
||||
self.env.trading_days,
|
||||
self.finder,
|
||||
)
|
||||
@@ -552,69 +753,3 @@ class SyntheticBcolzTestCase(TestCase):
|
||||
result = results['drawdown'].unstack()
|
||||
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
class MultiColumnLoaderTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.assets = [1, 2, 3]
|
||||
self.dates = date_range('2014-01', '2014-03', freq='D', tz='UTC')
|
||||
|
||||
asset_info = make_simple_asset_info(
|
||||
self.assets,
|
||||
start_date=self.dates[0],
|
||||
end_date=self.dates[-1],
|
||||
)
|
||||
env = TradingEnvironment()
|
||||
env.write_data(equities_df=asset_info)
|
||||
self.asset_finder = env.asset_finder
|
||||
|
||||
def test_engine_with_multicolumn_loader(self):
|
||||
open_ = USEquityPricing.open
|
||||
close = USEquityPricing.close
|
||||
volume = USEquityPricing.volume
|
||||
|
||||
# Test for thirty days up to the second to last day that we think all
|
||||
# the assets existed. If we test the last day of our calendar, no
|
||||
# assets will be in our output, because their end dates are all
|
||||
dates_to_test = self.dates[-32:-2]
|
||||
|
||||
constants = {open_: 1, close: 2, volume: 3}
|
||||
loader = ConstantLoader(
|
||||
constants=constants,
|
||||
dates=self.dates,
|
||||
assets=self.assets,
|
||||
)
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
|
||||
sumdiff = RollingSumDifference()
|
||||
|
||||
result = engine.run_pipeline(
|
||||
Pipeline(
|
||||
columns={
|
||||
'sumdiff': sumdiff,
|
||||
'open': open_.latest,
|
||||
'close': close.latest,
|
||||
'volume': volume.latest,
|
||||
},
|
||||
),
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1]
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(
|
||||
{'sumdiff', 'open', 'close', 'volume'},
|
||||
set(result.columns)
|
||||
)
|
||||
|
||||
result_index = self.assets * len(dates_to_test)
|
||||
result_shape = (len(result_index),)
|
||||
check_arrays(
|
||||
result['sumdiff'],
|
||||
Series(index=result_index, data=full(result_shape, -3)),
|
||||
)
|
||||
|
||||
for name, const in [('open', 1), ('close', 2), ('volume', 3)]:
|
||||
check_arrays(
|
||||
result[name],
|
||||
Series(index=result_index, data=full(result_shape, const)),
|
||||
)
|
||||
|
||||
@@ -182,7 +182,7 @@ class ClosesOnly(TestCase):
|
||||
initialize=initialize,
|
||||
handle_data=late_attach,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -199,7 +199,7 @@ class ClosesOnly(TestCase):
|
||||
before_trading_start=late_attach,
|
||||
handle_data=barf,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -228,7 +228,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -256,7 +256,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -294,7 +294,7 @@ class ClosesOnly(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.first_asset_start - trading_day,
|
||||
end=self.last_asset_end + trading_day,
|
||||
env=self.env,
|
||||
@@ -524,7 +524,7 @@ class PipelineAlgorithmTestCase(TestCase):
|
||||
handle_data=handle_data,
|
||||
before_trading_start=before_trading_start,
|
||||
data_frequency='daily',
|
||||
pipeline_loader=self.pipeline_loader,
|
||||
get_pipeline_loader=lambda column: self.pipeline_loader,
|
||||
start=self.dates[max(window_lengths)],
|
||||
end=self.dates[-1],
|
||||
env=self.env,
|
||||
|
||||
+21
-34
@@ -82,11 +82,14 @@ def to_dict(l):
|
||||
|
||||
class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
def check_dependency_order(self, ordered_terms):
|
||||
seen = set()
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
for term in ordered_terms:
|
||||
for dep in term.dependencies:
|
||||
self.assertIn(dep, seen)
|
||||
|
||||
seen.add(term)
|
||||
|
||||
def test_single_factor(self):
|
||||
"""
|
||||
@@ -97,12 +100,12 @@ class DependencyResolutionTestCase(TestCase):
|
||||
resolution_order = list(graph.ordered())
|
||||
|
||||
self.assertEqual(len(resolution_order), 4)
|
||||
self.assertIs(resolution_order[0], AssetExists())
|
||||
self.assertEqual(
|
||||
set([resolution_order[1], resolution_order[2]]),
|
||||
set([SomeDataSet.foo, SomeDataSet.bar]),
|
||||
)
|
||||
self.assertEqual(resolution_order[-1], SomeFactor())
|
||||
self.check_dependency_order(resolution_order)
|
||||
self.assertIn(AssetExists(), resolution_order)
|
||||
self.assertIn(SomeDataSet.foo, resolution_order)
|
||||
self.assertIn(SomeDataSet.bar, resolution_order)
|
||||
self.assertIn(SomeFactor(), resolution_order)
|
||||
|
||||
self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4)
|
||||
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
|
||||
|
||||
@@ -121,18 +124,14 @@ class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
# SomeFactor, its inputs, and AssetExists()
|
||||
self.assertEqual(len(resolution_order), 4)
|
||||
|
||||
self.assertIs(resolution_order[0], AssetExists())
|
||||
self.check_dependency_order(resolution_order)
|
||||
self.assertIn(AssetExists(), resolution_order)
|
||||
self.assertEqual(graph.extra_rows[AssetExists()], 4)
|
||||
|
||||
self.assertEqual(
|
||||
set([resolution_order[1], resolution_order[2]]),
|
||||
set([bar, buzz]),
|
||||
)
|
||||
self.assertEqual(
|
||||
resolution_order[-1],
|
||||
SomeFactor([bar, buzz], window_length=5),
|
||||
)
|
||||
self.assertIn(bar, resolution_order)
|
||||
self.assertIn(buzz, resolution_order)
|
||||
self.assertIn(SomeFactor([bar, buzz], window_length=5),
|
||||
resolution_order)
|
||||
self.assertEqual(graph.extra_rows[bar], 4)
|
||||
self.assertEqual(graph.extra_rows[buzz], 4)
|
||||
|
||||
@@ -148,20 +147,8 @@ class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
# bar should only appear once.
|
||||
self.assertEqual(len(resolution_order), 6)
|
||||
indices = {
|
||||
term: resolution_order.index(term)
|
||||
for term in resolution_order
|
||||
}
|
||||
|
||||
self.assertEqual(indices[AssetExists()], 0)
|
||||
|
||||
# Verify that f1's dependencies will be computed before f1.
|
||||
self.assertLess(indices[SomeDataSet.foo], indices[f1])
|
||||
self.assertLess(indices[SomeDataSet.bar], indices[f1])
|
||||
|
||||
# Verify that f2's dependencies will be computed before f2.
|
||||
self.assertLess(indices[SomeDataSet.bar], indices[f2])
|
||||
self.assertLess(indices[SomeDataSet.buzz], indices[f2])
|
||||
self.assertEqual(len(set(resolution_order)), 6)
|
||||
self.check_dependency_order(resolution_order)
|
||||
|
||||
def test_disallow_recursive_lookback(self):
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ class TradingAlgorithm(object):
|
||||
self.asset_finder = self.trading_environment.asset_finder
|
||||
|
||||
# Initialize Pipeline API data.
|
||||
self.init_engine(kwargs.pop('pipeline_loader', None))
|
||||
self.init_engine(kwargs.pop('get_pipeline_loader', None))
|
||||
self._pipelines = {}
|
||||
# Create an always-expired cache so that we compute the first time data
|
||||
# is requested.
|
||||
@@ -323,15 +323,15 @@ class TradingAlgorithm(object):
|
||||
self.initialize_args = args
|
||||
self.initialize_kwargs = kwargs
|
||||
|
||||
def init_engine(self, loader):
|
||||
def init_engine(self, get_loader):
|
||||
"""
|
||||
Construct and store a PipelineEngine from loader.
|
||||
|
||||
If loader is None, constructs a NoOpPipelineEngine.
|
||||
If get_loader is None, constructs a NoOpPipelineEngine.
|
||||
"""
|
||||
if loader is not None:
|
||||
if get_loader is not None:
|
||||
self.engine = SimplePipelineEngine(
|
||||
loader,
|
||||
get_loader,
|
||||
self.trading_environment.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
classifier.py
|
||||
"""
|
||||
|
||||
from zipline.pipeline.term import Term
|
||||
from zipline.pipeline.term import CompositeTerm
|
||||
|
||||
|
||||
class Classifier(Term):
|
||||
class Classifier(CompositeTerm):
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""
|
||||
dataset.py
|
||||
"""
|
||||
from functools import total_ordering
|
||||
from six import (
|
||||
iteritems,
|
||||
with_metaclass,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term
|
||||
from zipline.pipeline.term import Term, AssetExists
|
||||
from zipline.pipeline.factors import Latest
|
||||
|
||||
|
||||
@@ -29,12 +30,13 @@ class BoundColumn(Term):
|
||||
"""
|
||||
A Column of data that's been concretely bound to a particular dataset.
|
||||
"""
|
||||
mask = AssetExists()
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
|
||||
def __new__(cls, dtype, dataset, name):
|
||||
return super(BoundColumn, cls).__new__(
|
||||
cls,
|
||||
inputs=(),
|
||||
window_length=0,
|
||||
domain=dataset.domain,
|
||||
dtype=dtype,
|
||||
dataset=dataset,
|
||||
@@ -86,6 +88,7 @@ class BoundColumn(Term):
|
||||
return self.qualname
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DataSetMeta(type):
|
||||
"""
|
||||
Metaclass for DataSets
|
||||
@@ -109,6 +112,9 @@ class DataSetMeta(type):
|
||||
def columns(self):
|
||||
return self._columns
|
||||
|
||||
def __lt__(self, other):
|
||||
return id(self) < id(other)
|
||||
|
||||
|
||||
class DataSet(with_metaclass(DataSetMeta)):
|
||||
domain = None
|
||||
|
||||
+31
-11
@@ -13,12 +13,13 @@ from six import (
|
||||
)
|
||||
from six.moves import zip_longest
|
||||
from numpy import array
|
||||
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
date_range,
|
||||
MultiIndex,
|
||||
)
|
||||
from toolz import groupby, juxt
|
||||
from toolz.curried.operator import getitem
|
||||
|
||||
from zipline.lib.adjusted_array import ensure_ndarray
|
||||
from zipline.errors import NoFurtherDataError
|
||||
@@ -82,8 +83,9 @@ class SimplePipelineEngine(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
loader : PipelineLoader
|
||||
A loader to use to retrieve raw data for atomic terms.
|
||||
get_loader : callable
|
||||
A function that is given an atomic term and returns a PipelineLoader
|
||||
to use to retrieve raw data for that term.
|
||||
calendar : DatetimeIndex
|
||||
Array of dates to consider as trading days when computing a range
|
||||
between a fixed start and end.
|
||||
@@ -92,15 +94,15 @@ class SimplePipelineEngine(object):
|
||||
which assets are in the top-level universe at any point in time.
|
||||
"""
|
||||
__slots__ = [
|
||||
'_loader',
|
||||
'_get_loader',
|
||||
'_calendar',
|
||||
'_finder',
|
||||
'_root_mask_term',
|
||||
'__weakref__',
|
||||
]
|
||||
|
||||
def __init__(self, loader, calendar, asset_finder):
|
||||
self._loader = loader
|
||||
def __init__(self, get_loader, calendar, asset_finder):
|
||||
self._get_loader = get_loader
|
||||
self._calendar = calendar
|
||||
self._finder = asset_finder
|
||||
self._root_mask_term = AssetExists()
|
||||
@@ -240,7 +242,8 @@ class SimplePipelineEngine(object):
|
||||
offset = graph.extra_rows[mask] - graph.extra_rows[term]
|
||||
return workspace[mask][offset:], dates[offset:]
|
||||
|
||||
def _inputs_for_term(self, term, workspace, graph):
|
||||
@staticmethod
|
||||
def _inputs_for_term(term, workspace, graph):
|
||||
"""
|
||||
Compute inputs for the given term.
|
||||
|
||||
@@ -273,6 +276,15 @@ class SimplePipelineEngine(object):
|
||||
out.append(input_data)
|
||||
return out
|
||||
|
||||
def get_loader(self, term):
|
||||
# AssetExists is one of the atomic terms in the graph, so we look up
|
||||
# a loader here when grouping by loader, but since it's already in the
|
||||
# workspace, we don't actually use that group.
|
||||
if term is AssetExists():
|
||||
return None
|
||||
|
||||
return self._get_loader(term)
|
||||
|
||||
def compute_chunk(self, graph, dates, assets, initial_workspace):
|
||||
"""
|
||||
Compute the Pipeline terms in the graph for the requested start and end
|
||||
@@ -297,11 +309,16 @@ class SimplePipelineEngine(object):
|
||||
Dictionary mapping requested results to outputs.
|
||||
"""
|
||||
self._validate_compute_chunk_params(dates, assets, initial_workspace)
|
||||
loader = self._loader
|
||||
get_loader = self.get_loader
|
||||
|
||||
# Copy the supplied initial workspace so we don't mutate it in place.
|
||||
workspace = initial_workspace.copy()
|
||||
|
||||
# If atomic terms share the same loader and extra_rows, load them all
|
||||
# together.
|
||||
atomic_group_key = juxt(get_loader, getitem(graph.extra_rows))
|
||||
atomic_groups = groupby(atomic_group_key, graph.atomic_terms)
|
||||
|
||||
for term in graph.ordered():
|
||||
# `term` may have been supplied in `initial_workspace`, and in the
|
||||
# future we may pre-compute atomic terms coming from the same
|
||||
@@ -315,10 +332,13 @@ class SimplePipelineEngine(object):
|
||||
mask, mask_dates = self._mask_and_dates_for_term(
|
||||
term, workspace, graph, dates
|
||||
)
|
||||
|
||||
if term.atomic:
|
||||
# FUTURE OPTIMIZATION: Scan the resolution order for terms in
|
||||
# the same dataset and load them here as well.
|
||||
to_load = [term]
|
||||
to_load = sorted(
|
||||
atomic_groups[atomic_group_key(term)],
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = get_loader(term)
|
||||
loaded = loader.load_adjusted_array(
|
||||
to_load, mask_dates, assets, mask,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from numpy import (
|
||||
find_common_type,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, NotSpecified
|
||||
from zipline.pipeline.term import Term, NotSpecified, CompositeTerm
|
||||
|
||||
_VARIABLE_NAME_RE = re.compile("^(x_)([0-9]+)$")
|
||||
|
||||
@@ -154,7 +154,7 @@ def is_comparison(op):
|
||||
return op in COMPARISONS
|
||||
|
||||
|
||||
class NumericalExpression(Term):
|
||||
class NumericalExpression(CompositeTerm):
|
||||
"""
|
||||
Term binding to a numexpr expression.
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from zipline.pipeline.term import (
|
||||
NotSpecified,
|
||||
RequiredWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
Term,
|
||||
CompositeTerm,
|
||||
)
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
@@ -184,7 +184,7 @@ def function_application(func):
|
||||
return mathfunc
|
||||
|
||||
|
||||
class Factor(Term):
|
||||
class Factor(CompositeTerm):
|
||||
"""
|
||||
Pipeline API expression producing numerically-valued outputs.
|
||||
"""
|
||||
@@ -450,7 +450,7 @@ class CustomFactor(RequiredWindowLengthMixin, CustomTermMixin, Factor):
|
||||
class-level attribute named `inputs`.
|
||||
window_length : int, optional
|
||||
Number of rows of rows to pass for each input. If this
|
||||
argument is passed to the CustomFactor constructor, we look for a
|
||||
argument is not passed to the CustomFactor constructor, we look for a
|
||||
class-level attribute named `window_length`.
|
||||
|
||||
Notes
|
||||
|
||||
@@ -15,7 +15,7 @@ from zipline.errors import (
|
||||
)
|
||||
from zipline.pipeline.term import (
|
||||
SingleInputMixin,
|
||||
Term,
|
||||
CompositeTerm,
|
||||
)
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
@@ -83,7 +83,7 @@ def binary_operator(op):
|
||||
return binary_operator
|
||||
|
||||
|
||||
class Filter(Term):
|
||||
class Filter(CompositeTerm):
|
||||
"""
|
||||
Pipeline API expression producing boolean-valued outputs.
|
||||
"""
|
||||
|
||||
@@ -102,15 +102,9 @@ class TermGraph(DiGraph):
|
||||
zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term
|
||||
zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term
|
||||
"""
|
||||
out = {}
|
||||
for term in self:
|
||||
extra_input_rows = term.extra_input_rows
|
||||
for input_ in term.inputs:
|
||||
out[term, input_] = self.extra_rows[input_] - extra_input_rows
|
||||
mask = term.mask
|
||||
if term.mask is not None:
|
||||
out[term, mask] = self.extra_rows[mask] - extra_input_rows
|
||||
return out
|
||||
return {(term, dep): self.extra_rows[dep] - term.extra_input_rows
|
||||
for term in self
|
||||
for dep in term.dependencies}
|
||||
|
||||
@lazyval
|
||||
def extra_rows(self):
|
||||
@@ -168,6 +162,10 @@ class TermGraph(DiGraph):
|
||||
"""
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def atomic_terms(self):
|
||||
return tuple(term for term in self if term.atomic)
|
||||
|
||||
def _add_to_graph(self, term, parents, extra_rows):
|
||||
"""
|
||||
Add `term` and all its inputs to the graph.
|
||||
@@ -191,7 +189,7 @@ class TermGraph(DiGraph):
|
||||
dependency_extra_rows = extra_rows + term.extra_input_rows
|
||||
|
||||
# Recursively add dependencies.
|
||||
for dependency in term.inputs:
|
||||
for dependency in term.dependencies:
|
||||
self._add_to_graph(
|
||||
dependency,
|
||||
parents,
|
||||
@@ -199,12 +197,6 @@ class TermGraph(DiGraph):
|
||||
)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
# Add term's mask, which is really just a specially-enumerated input.
|
||||
mask = term.mask
|
||||
if mask is not None:
|
||||
self._add_to_graph(mask, parents, extra_rows=dependency_extra_rows)
|
||||
self.add_edge(mask, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
def _ensure_extra_rows(self, term, N):
|
||||
|
||||
@@ -17,5 +17,5 @@ class PipelineLoader(with_metaclass(ABCMeta)):
|
||||
TODO: DOCUMENT THIS MORE!
|
||||
"""
|
||||
@abstractmethod
|
||||
def load_adjusted_array(self, columns, mask):
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
pass
|
||||
|
||||
@@ -32,33 +32,7 @@ def nanos_to_seconds(nanos):
|
||||
return nanos / (1000 * 1000 * 1000)
|
||||
|
||||
|
||||
class MultiColumnLoader(PipelineLoader):
|
||||
"""
|
||||
PipelineLoader that can delegate to sub-loaders.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
loaders : dict
|
||||
Dictionary mapping columns -> loader
|
||||
"""
|
||||
def __init__(self, loaders):
|
||||
self._loaders = loaders
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
"""
|
||||
Load by delegating to sub-loaders.
|
||||
"""
|
||||
out = []
|
||||
for col in columns:
|
||||
try:
|
||||
loader = self._loaders[col]
|
||||
except KeyError:
|
||||
raise ValueError("Couldn't find loader for %s" % col)
|
||||
out.extend(loader.load_adjusted_array([col], dates, assets, mask))
|
||||
return out
|
||||
|
||||
|
||||
class ConstantLoader(MultiColumnLoader):
|
||||
class ConstantLoader(PipelineLoader):
|
||||
"""
|
||||
Synthetic PipelineLoader that returns a constant value for each column.
|
||||
|
||||
@@ -91,7 +65,20 @@ class ConstantLoader(MultiColumnLoader):
|
||||
adjustments=None,
|
||||
)
|
||||
|
||||
super(ConstantLoader, self).__init__(loaders)
|
||||
self._loaders = loaders
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
"""
|
||||
Load by delegating to sub-loaders.
|
||||
"""
|
||||
out = []
|
||||
for col in columns:
|
||||
try:
|
||||
loader = self._loaders[col]
|
||||
except KeyError:
|
||||
raise ValueError("Couldn't find loader for %s" % col)
|
||||
out.extend(loader.load_adjusted_array([col], dates, assets, mask))
|
||||
return out
|
||||
|
||||
|
||||
class SyntheticDailyBarWriter(BcolzDailyBarWriter):
|
||||
|
||||
+149
-115
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
Base class for Filters, Factors and Classifiers
|
||||
"""
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import bool_, full, nan
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
@@ -38,23 +40,17 @@ class NotSpecified(object):
|
||||
return self
|
||||
|
||||
|
||||
class Term(object):
|
||||
class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
Base class for terms in a Pipeline API compute graph.
|
||||
"""
|
||||
# These are NotSpecified because a subclass is required to provide them.
|
||||
inputs = NotSpecified
|
||||
window_length = NotSpecified
|
||||
dtype = NotSpecified
|
||||
mask = NotSpecified
|
||||
domain = NotSpecified
|
||||
|
||||
_term_cache = WeakValueDictionary()
|
||||
|
||||
def __new__(cls,
|
||||
inputs=NotSpecified,
|
||||
mask=NotSpecified,
|
||||
window_length=NotSpecified,
|
||||
domain=NotSpecified,
|
||||
dtype=NotSpecified,
|
||||
*args,
|
||||
@@ -72,23 +68,6 @@ class Term(object):
|
||||
# Class-level attributes can be used to provide defaults for Term
|
||||
# subclasses.
|
||||
|
||||
if inputs is NotSpecified:
|
||||
inputs = cls.inputs
|
||||
# Having inputs = NotSpecified is an error, but we handle it later
|
||||
# in self._validate rather than here.
|
||||
if inputs is not NotSpecified:
|
||||
# Allow users to specify lists as class-level defaults, but
|
||||
# normalize to a tuple so that inputs is hashable.
|
||||
inputs = tuple(inputs)
|
||||
|
||||
if mask is NotSpecified:
|
||||
mask = cls.mask
|
||||
if mask is NotSpecified:
|
||||
mask = AssetExists()
|
||||
|
||||
if window_length is NotSpecified:
|
||||
window_length = cls.window_length
|
||||
|
||||
if domain is NotSpecified:
|
||||
domain = cls.domain
|
||||
|
||||
@@ -96,9 +75,6 @@ class Term(object):
|
||||
dtype = cls.dtype
|
||||
|
||||
identity = cls.static_identity(
|
||||
inputs=inputs,
|
||||
mask=mask,
|
||||
window_length=window_length,
|
||||
domain=domain,
|
||||
dtype=dtype,
|
||||
*args, **kwargs
|
||||
@@ -109,9 +85,6 @@ class Term(object):
|
||||
except KeyError:
|
||||
new_instance = cls._term_cache[identity] = \
|
||||
super(Term, cls).__new__(cls)._init(
|
||||
inputs=inputs,
|
||||
mask=mask,
|
||||
window_length=window_length,
|
||||
domain=domain,
|
||||
dtype=dtype,
|
||||
*args, **kwargs
|
||||
@@ -134,10 +107,7 @@ class Term(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _init(self, inputs, mask, window_length, domain, dtype):
|
||||
self.inputs = inputs
|
||||
self.mask = mask
|
||||
self.window_length = window_length
|
||||
def _init(self, domain, dtype):
|
||||
self.domain = domain
|
||||
self.dtype = dtype
|
||||
|
||||
@@ -145,7 +115,7 @@ class Term(object):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, inputs, mask, window_length, domain, dtype):
|
||||
def static_identity(cls, domain, dtype):
|
||||
"""
|
||||
Return the identity of the Term that would be constructed from the
|
||||
given arguments.
|
||||
@@ -157,79 +127,64 @@ class Term(object):
|
||||
This is a classmethod so that it can be called from Term.__new__ to
|
||||
determine whether to produce a new instance.
|
||||
"""
|
||||
return (cls, inputs, mask, window_length, domain, dtype)
|
||||
return (cls, domain, dtype)
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Assert that this term is well-formed. This should be called exactly
|
||||
once, at the end of Term._init().
|
||||
"""
|
||||
if self.inputs is NotSpecified:
|
||||
raise TermInputsNotSpecified(termname=type(self).__name__)
|
||||
if self.window_length is NotSpecified:
|
||||
raise WindowLengthNotSpecified(termname=type(self).__name__)
|
||||
if self.dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=type(self).__name__)
|
||||
if self.mask is NotSpecified and not self.atomic:
|
||||
# This isn't user error, this is a bug in our code.
|
||||
raise AssertionError("{term} has no mask".format(term=self))
|
||||
|
||||
if self.window_length:
|
||||
for child in self.inputs:
|
||||
if not child.atomic:
|
||||
raise InputTermNotAtomic(parent=self, child=child)
|
||||
|
||||
@lazyval
|
||||
def atomic(self):
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
"""
|
||||
Whether or not this term has dependencies.
|
||||
|
||||
If term.atomic is truthy, it should have dataset and dtype attributes.
|
||||
"""
|
||||
return len(self.inputs) == 0
|
||||
|
||||
@lazyval
|
||||
def windowed(self):
|
||||
"""
|
||||
Whether or not this term represents a trailing window computation.
|
||||
|
||||
If term.windowed is truthy, its compute_from_windows method will be
|
||||
called with instances of AdjustedArray as inputs.
|
||||
|
||||
If term.windowed is falsey, its compute_from_baseline will be called
|
||||
with instances of np.ndarray as inputs.
|
||||
"""
|
||||
return (
|
||||
self.window_length is not NotSpecified
|
||||
and self.window_length > 0
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def extra_input_rows(self):
|
||||
"""
|
||||
The number of extra rows needed for each of our inputs to compute this
|
||||
term.
|
||||
"""
|
||||
return max(0, self.window_length - 1)
|
||||
|
||||
def _compute(self, inputs, dates, assets, mask):
|
||||
"""
|
||||
Subclasses should implement this to perform actual computation.
|
||||
|
||||
This is `_compute` rather than just `compute` because `compute` is
|
||||
reserved for user-supplied functions in CustomFactor.
|
||||
A tuple of other Terms that this Term requires for computation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractproperty
|
||||
def mask(self):
|
||||
"""
|
||||
A 2D Filter representing asset/date pairs to include while
|
||||
computing this Term. (True means include; False means exclude.)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@lazyval
|
||||
def dependencies(self):
|
||||
return self.inputs + (self.mask,)
|
||||
|
||||
@lazyval
|
||||
def atomic(self):
|
||||
return not any(dep for dep in self.dependencies
|
||||
if dep is not AssetExists())
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
Pseudo-filter describing whether or not an asset existed on a given day.
|
||||
This is the default mask for all terms that haven't been passed a mask
|
||||
explicitly.
|
||||
|
||||
This is morally a Filter, in the sense that it produces a boolean value for
|
||||
every asset on every date. We don't subclass Filter, however, because
|
||||
`AssetExists` is computed directly by the PipelineEngine.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
"""
|
||||
dtype = bool_
|
||||
dataset = None
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
dependencies = ()
|
||||
mask = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{type}({inputs}, window_length={window_length})"
|
||||
).format(
|
||||
type=type(self).__name__,
|
||||
inputs=self.inputs,
|
||||
window_length=self.window_length,
|
||||
mask=self.mask,
|
||||
)
|
||||
return "AssetExists()"
|
||||
|
||||
|
||||
# TODO: Move mixins to a separate file?
|
||||
@@ -307,28 +262,107 @@ class CustomTermMixin(object):
|
||||
return out
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
Pseudo-filter describing whether or not an asset existed on a given day.
|
||||
This is the default mask for all terms that haven't been passed a mask
|
||||
explicitly.
|
||||
class CompositeTerm(Term):
|
||||
inputs = NotSpecified
|
||||
window_length = NotSpecified
|
||||
mask = NotSpecified
|
||||
|
||||
This is morally a Filter, in the sense that it produces a boolean value for
|
||||
every asset on every date. We don't subclass Filter, however, because
|
||||
`AssetExists` is computed directly by the PipelineEngine.
|
||||
def __new__(cls, inputs=NotSpecified, window_length=NotSpecified,
|
||||
mask=NotSpecified, *args, **kwargs):
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
"""
|
||||
inputs = ()
|
||||
dtype = bool_
|
||||
window_length = 0
|
||||
mask = None
|
||||
if inputs is NotSpecified:
|
||||
inputs = cls.inputs
|
||||
# Having inputs = NotSpecified is an error, but we handle it later
|
||||
# in self._validate rather than here.
|
||||
if inputs is not NotSpecified:
|
||||
# Allow users to specify lists as class-level defaults, but
|
||||
# normalize to a tuple so that inputs is hashable.
|
||||
inputs = tuple(inputs)
|
||||
|
||||
def _compute(self, *args, **kwargs):
|
||||
# TODO: Consider moving the bulk of the logic from
|
||||
# SimplePipelineEngine._compute_root_mask here.
|
||||
raise NotImplementedError(
|
||||
"Direct computation of AssetExists is not supported!"
|
||||
if mask is NotSpecified:
|
||||
mask = cls.mask
|
||||
if mask is NotSpecified:
|
||||
mask = AssetExists()
|
||||
|
||||
if window_length is NotSpecified:
|
||||
window_length = cls.window_length
|
||||
|
||||
return super(CompositeTerm, cls).__new__(cls, inputs=inputs, mask=mask,
|
||||
window_length=window_length,
|
||||
*args, **kwargs)
|
||||
|
||||
def _init(self, inputs, window_length, mask, *args, **kwargs):
|
||||
self.inputs = inputs
|
||||
self.window_length = window_length
|
||||
self.mask = mask
|
||||
return super(CompositeTerm, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
|
||||
return (
|
||||
super(CompositeTerm, cls).static_identity(*args, **kwargs),
|
||||
inputs,
|
||||
window_length,
|
||||
mask,
|
||||
)
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Assert that this term is well-formed. This should be called exactly
|
||||
once, at the end of Term._init().
|
||||
"""
|
||||
if self.inputs is NotSpecified:
|
||||
raise TermInputsNotSpecified(termname=type(self).__name__)
|
||||
if self.window_length is NotSpecified:
|
||||
raise WindowLengthNotSpecified(termname=type(self).__name__)
|
||||
if self.mask is NotSpecified:
|
||||
# This isn't user error, this is a bug in our code.
|
||||
raise AssertionError("{term} has no mask".format(term=self))
|
||||
|
||||
if self.window_length:
|
||||
for child in self.inputs:
|
||||
if not child.atomic:
|
||||
raise InputTermNotAtomic(parent=self, child=child)
|
||||
|
||||
return super(CompositeTerm, self)._validate()
|
||||
|
||||
def _compute(self, inputs, dates, assets, mask):
|
||||
"""
|
||||
Subclasses should implement this to perform actual computation.
|
||||
This is `_compute` rather than just `compute` because `compute` is
|
||||
reserved for user-supplied functions in CustomFactor.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@lazyval
|
||||
def windowed(self):
|
||||
"""
|
||||
Whether or not this term represents a trailing window computation.
|
||||
|
||||
If term.windowed is truthy, its compute_from_windows method will be
|
||||
called with instances of AdjustedArray as inputs.
|
||||
|
||||
If term.windowed is falsey, its compute_from_baseline will be called
|
||||
with instances of np.ndarray as inputs.
|
||||
"""
|
||||
return (
|
||||
self.window_length is not NotSpecified
|
||||
and self.window_length > 0
|
||||
)
|
||||
|
||||
@lazyval
|
||||
def extra_input_rows(self):
|
||||
"""
|
||||
The number of extra rows needed for each of our inputs to compute this
|
||||
term.
|
||||
"""
|
||||
return max(0, self.window_length - 1)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{type}({inputs}, window_length={window_length})"
|
||||
).format(
|
||||
type=type(self).__name__,
|
||||
inputs=self.inputs,
|
||||
window_length=self.window_length,
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ def _render(g, out, format_, include_asset_exists=False):
|
||||
graph_attrs = {'rankdir': 'TB', 'splines': 'ortho'}
|
||||
cluster_attrs = {'style': 'filled', 'color': 'lightgoldenrod1'}
|
||||
|
||||
in_nodes = list(node for node in g if node.atomic)
|
||||
in_nodes = g.atomic_terms
|
||||
out_nodes = list(g.outputs.values())
|
||||
|
||||
f = BytesIO()
|
||||
|
||||
Reference in New Issue
Block a user