From 32baac4e4b5e546b6601e1a053dd83049e0ffdf9 Mon Sep 17 00:00:00 2001 From: llllllllll Date: Fri, 18 Dec 2015 14:41:48 -0500 Subject: [PATCH 1/5] ENH: Make datasets have subclass relationships --- tests/pipeline/test_blaze.py | 2 +- tests/pipeline/test_term.py | 32 ++++++++++++++++++++++++ zipline/pipeline/data/dataset.py | 43 ++++++++++++++++++++++++-------- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index cb7ab737..d22aa244 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -90,7 +90,7 @@ class BlazeToPipelineTestCase(TestCase): self.assertEqual(ds.__name__, name) self.assertTrue(issubclass(ds, DataSet)) self.assertEqual( - {c.name: c.dtype for c in ds._columns}, + {c.name: c.dtype for c in ds.columns}, {'sid': np.int64, 'value': np.float64}, ) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index 1a4f94af..94410c5d 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -28,6 +28,10 @@ class SomeDataSet(DataSet): buzz = Column(float64_dtype) +class SubDataSet(SomeDataSet): + pass + + class SomeFactor(Factor): dtype = float64_dtype window_length = 5 @@ -321,3 +325,31 @@ class ObjectIdentityTestCase(TestCase): with self.assertRaises(InvalidDType): SomeFactor(dtype=1) + + +class SubDataSetTestCase(TestCase): + def test_subdataset(self): + some_dataset_map = { + column.name: column for column in SomeDataSet.columns + } + sub_dataset_map = { + column.name: column for column in SubDataSet.columns + } + self.assertEqual( + set(some_dataset_map), + set(sub_dataset_map), + ) + for k, some_dataset_column in some_dataset_map.items(): + sub_dataset_column = sub_dataset_map[k] + self.assertIsNot( + some_dataset_column, + sub_dataset_column, + 'subclass column %r should not have the same identity as' + ' the parent' % k, + ) + self.assertEqual( + some_dataset_column.dtype, + sub_dataset_column.dtype, + 'subclass column %r should have the same dtype as the parent' % + k, + ) diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index 3e2314e7..3de16220 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -21,11 +21,28 @@ class Column(object): def __init__(self, dtype): self.dtype = dtype - def bind(self, dataset, name): + def bind(self, name): """ - Bind a column to a concrete dataset. + Bind a `Column` object to its name. """ - return BoundColumn(dtype=self.dtype, dataset=dataset, name=name) + return _BoundColumnDescr(dtype=self.dtype, name=name) + + +class _BoundColumnDescr(object): + """ + Intermediate class that sits on `DataSet` objects and returns memoized + `BoundColumn` objects when requested. + """ + def __init__(self, dtype, name): + self.dtype = dtype + self.name = name + + def __get__(self, instance, owner): + return BoundColumn( + dtype=self.dtype, + dataset=owner, + name=self.name, + ) class BoundColumn(Term): @@ -97,20 +114,24 @@ class DataSetMeta(type): """ def __new__(mcls, name, bases, dict_): - newtype = type.__new__(mcls, name, bases, dict_) - _columns = [] + newtype = super(DataSetMeta, mcls).__new__(mcls, name, bases, dict_) + column_names = set().union( + *(getattr(base, '_column_names', ()) for base in bases) + ) for maybe_colname, maybe_column in iteritems(dict_): if isinstance(maybe_column, Column): - bound_column = maybe_column.bind(newtype, maybe_colname) - setattr(newtype, maybe_colname, bound_column) - _columns.append(bound_column) + bound_column_descr = maybe_column.bind(maybe_colname) + setattr(newtype, maybe_colname, bound_column_descr) + column_names.add(maybe_colname) - newtype._columns = frozenset(_columns) + newtype._column_names = frozenset(column_names) return newtype @property def columns(self): - return self._columns + return frozenset( + getattr(self, colname) for colname in self._column_names + ) def __lt__(self, other): return id(self) < id(other) @@ -119,5 +140,5 @@ class DataSetMeta(type): return '' % self.__name__ -class DataSet(with_metaclass(DataSetMeta)): +class DataSet(with_metaclass(DataSetMeta, object)): domain = None From f933d6b44e143c2c8a4a87a88655e3eb506e48cc Mon Sep 17 00:00:00 2001 From: llllllllll Date: Fri, 18 Dec 2015 14:54:47 -0500 Subject: [PATCH 2/5] DOC: whatsnew entry --- docs/source/whatsnew/0.8.4.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/whatsnew/0.8.4.txt b/docs/source/whatsnew/0.8.4.txt index 0689149e..0a38a8e2 100644 --- a/docs/source/whatsnew/0.8.4.txt +++ b/docs/source/whatsnew/0.8.4.txt @@ -59,6 +59,10 @@ Enhancements :class:`~zipline.pipeline.factors.ExponentialWeightedMovingStdDev` factors. (:issue:`910`). +* Allow :class:`~zipline.pipeline.data.DataSet` classes to be subclassed where + subclasses inherit all of the columns from the parent. These columns will be + new sentinels so you can register them a custom loader (:issue:`924`). + Experimental Features ~~~~~~~~~~~~~~~~~~~~~ From a3fecd652718827dd0c580f6c199675f93be5af2 Mon Sep 17 00:00:00 2001 From: llllllllll Date: Fri, 18 Dec 2015 21:33:37 -0500 Subject: [PATCH 3/5] ENH: support subclassing in the earningscalendar loader --- zipline/pipeline/loaders/blaze/earnings.py | 6 +++++- zipline/pipeline/loaders/earnings.py | 16 ++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/zipline/pipeline/loaders/blaze/earnings.py b/zipline/pipeline/loaders/blaze/earnings.py index 6c90b16b..6b5c1695 100644 --- a/zipline/pipeline/loaders/blaze/earnings.py +++ b/zipline/pipeline/loaders/blaze/earnings.py @@ -6,6 +6,7 @@ from six import iteritems from toolz import valmap from .core import TS_FIELD_NAME, SID_FIELD_NAME +from zipline.pipeline.data import EarningsCalendar from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.loaders.earnings import EarningsCalendarLoader @@ -87,7 +88,8 @@ class BlazeEarningsCalendarLoader(PipelineLoader): expr, resources=None, compute_kwargs=None, - odo_kwargs=None): + odo_kwargs=None, + dataset=EarningsCalendar): dshape = expr.dshape if not istabular(dshape): @@ -101,6 +103,7 @@ class BlazeEarningsCalendarLoader(PipelineLoader): resources, ) self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {} + self._dataset = dataset def load_adjusted_array(self, columns, dates, assets, mask): expr = self._expr @@ -147,4 +150,5 @@ class BlazeEarningsCalendarLoader(PipelineLoader): return EarningsCalendarLoader( dates, valmap(mkseries, gb.groups), + dataset=self._dataset, ).load_adjusted_array(columns, dates, assets, mask) diff --git a/zipline/pipeline/loaders/earnings.py b/zipline/pipeline/loaders/earnings.py index cb8106b6..b5792a1d 100644 --- a/zipline/pipeline/loaders/earnings.py +++ b/zipline/pipeline/loaders/earnings.py @@ -44,9 +44,12 @@ class EarningsCalendarLoader(PipelineLoader): Whether to allow passing ``DatetimeIndex`` values in ``announcement_dates``. """ - def __init__(self, all_dates, announcement_dates, infer_timestamps=False): + def __init__(self, + all_dates, + announcement_dates, + infer_timestamps=False, + dataset=EarningsCalendar): self.all_dates = all_dates - self.announcement_dates = announcement_dates = ( announcement_dates.copy() ) @@ -64,13 +67,14 @@ class EarningsCalendarLoader(PipelineLoader): announcement_dates[k] = pd.Series( v, index=repeat(dates[0], len(v)), ) + self.dataset = dataset def get_loader(self, column): """Dispatch to the loader for ``column``. """ - if column is EarningsCalendar.next_announcement: + if column is self.dataset.next_announcement: return self.next_announcement_loader - elif column is EarningsCalendar.previous_announcement: + elif column is self.dataset.previous_announcement: return self.previous_announcement_loader else: raise ValueError("Don't know how to load column '%s'." % column) @@ -78,7 +82,7 @@ class EarningsCalendarLoader(PipelineLoader): @lazyval def next_announcement_loader(self): return DataFrameLoader( - EarningsCalendar.next_announcement, + self.dataset.next_announcement, next_earnings_date_frame( self.all_dates, self.announcement_dates, @@ -89,7 +93,7 @@ class EarningsCalendarLoader(PipelineLoader): @lazyval def previous_announcement_loader(self): return DataFrameLoader( - EarningsCalendar.previous_announcement, + self.dataset.previous_announcement, previous_earnings_date_frame( self.all_dates, self.announcement_dates, From 68cf236944eba157827238584dd6e9b5b6f1ec76 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 29 Dec 2015 10:12:39 -0500 Subject: [PATCH 4/5] TST: Add test case for adding columns in subclass --- tests/pipeline/test_term.py | 43 +++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index 94410c5d..f689bc33 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -32,6 +32,10 @@ class SubDataSet(SomeDataSet): pass +class SubDataSetNewCol(SomeDataSet): + qux = Column(float64_dtype) + + class SomeFactor(Factor): dtype = float64_dtype window_length = 5 @@ -336,8 +340,8 @@ class SubDataSetTestCase(TestCase): column.name: column for column in SubDataSet.columns } self.assertEqual( - set(some_dataset_map), - set(sub_dataset_map), + {column.name for column in SomeDataSet.columns}, + {column.name for column in SubDataSet.columns}, ) for k, some_dataset_column in some_dataset_map.items(): sub_dataset_column = sub_dataset_map[k] @@ -353,3 +357,38 @@ class SubDataSetTestCase(TestCase): 'subclass column %r should have the same dtype as the parent' % k, ) + + def test_add_column(self): + some_dataset_map = { + column.name: column for column in SomeDataSet.columns + } + sub_dataset_new_col_map = { + column.name: column for column in SubDataSetNewCol.columns + } + sub_col_names = {column.name for column in SubDataSetNewCol.columns} + + # check our extra col + self.assertIn('qux', sub_col_names) + self.assertEqual( + sub_dataset_new_col_map['qux'].dtype, + float64_dtype, + ) + + self.assertEqual( + {column.name for column in SomeDataSet.columns}, + sub_col_names - {'qux'}, + ) + for k, some_dataset_column in some_dataset_map.items(): + sub_dataset_column = sub_dataset_new_col_map[k] + self.assertIsNot( + some_dataset_column, + sub_dataset_column, + 'subclass column %r should not have the same identity as' + ' the parent' % k, + ) + self.assertEqual( + some_dataset_column.dtype, + sub_dataset_column.dtype, + 'subclass column %r should have the same dtype as the parent' % + k, + ) From 54c58d1205978ad3aebc630d69cdf9ff362c744a Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 29 Dec 2015 10:13:00 -0500 Subject: [PATCH 5/5] DOC: add comments about the column collection in DataSetMeta --- zipline/pipeline/data/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index 3de16220..e65dca39 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -115,11 +115,13 @@ class DataSetMeta(type): def __new__(mcls, name, bases, dict_): newtype = super(DataSetMeta, mcls).__new__(mcls, name, bases, dict_) + # collect all of the column names that we inherit from our parents column_names = set().union( *(getattr(base, '_column_names', ()) for base in bases) ) for maybe_colname, maybe_column in iteritems(dict_): if isinstance(maybe_column, Column): + # add column names defined on our class bound_column_descr = maybe_column.bind(maybe_colname) setattr(newtype, maybe_colname, bound_column_descr) column_names.add(maybe_colname)