Merge pull request #924 from quantopian/dataset-subclassing

ENH: Make datasets have subclass relationships
This commit is contained in:
Joe Jevnik
2015-12-29 11:43:55 -05:00
6 changed files with 125 additions and 19 deletions
+4
View File
@@ -59,6 +59,10 @@ Enhancements
:class:`~zipline.pipeline.factors.ExponentialWeightedMovingStdDev`
factors. (:issue:`910`).
* Allow :class:`~zipline.pipeline.data.DataSet` classes to be subclassed where
subclasses inherit all of the columns from the parent. These columns will be
new sentinels so you can register them a custom loader (:issue:`924`).
Experimental Features
~~~~~~~~~~~~~~~~~~~~~
+1 -1
View File
@@ -90,7 +90,7 @@ class BlazeToPipelineTestCase(TestCase):
self.assertEqual(ds.__name__, name)
self.assertTrue(issubclass(ds, DataSet))
self.assertEqual(
{c.name: c.dtype for c in ds._columns},
{c.name: c.dtype for c in ds.columns},
{'sid': np.int64, 'value': np.float64},
)
+71
View File
@@ -28,6 +28,14 @@ class SomeDataSet(DataSet):
buzz = Column(float64_dtype)
class SubDataSet(SomeDataSet):
pass
class SubDataSetNewCol(SomeDataSet):
qux = Column(float64_dtype)
class SomeFactor(Factor):
dtype = float64_dtype
window_length = 5
@@ -321,3 +329,66 @@ class ObjectIdentityTestCase(TestCase):
with self.assertRaises(InvalidDType):
SomeFactor(dtype=1)
class SubDataSetTestCase(TestCase):
def test_subdataset(self):
some_dataset_map = {
column.name: column for column in SomeDataSet.columns
}
sub_dataset_map = {
column.name: column for column in SubDataSet.columns
}
self.assertEqual(
{column.name for column in SomeDataSet.columns},
{column.name for column in SubDataSet.columns},
)
for k, some_dataset_column in some_dataset_map.items():
sub_dataset_column = sub_dataset_map[k]
self.assertIsNot(
some_dataset_column,
sub_dataset_column,
'subclass column %r should not have the same identity as'
' the parent' % k,
)
self.assertEqual(
some_dataset_column.dtype,
sub_dataset_column.dtype,
'subclass column %r should have the same dtype as the parent' %
k,
)
def test_add_column(self):
some_dataset_map = {
column.name: column for column in SomeDataSet.columns
}
sub_dataset_new_col_map = {
column.name: column for column in SubDataSetNewCol.columns
}
sub_col_names = {column.name for column in SubDataSetNewCol.columns}
# check our extra col
self.assertIn('qux', sub_col_names)
self.assertEqual(
sub_dataset_new_col_map['qux'].dtype,
float64_dtype,
)
self.assertEqual(
{column.name for column in SomeDataSet.columns},
sub_col_names - {'qux'},
)
for k, some_dataset_column in some_dataset_map.items():
sub_dataset_column = sub_dataset_new_col_map[k]
self.assertIsNot(
some_dataset_column,
sub_dataset_column,
'subclass column %r should not have the same identity as'
' the parent' % k,
)
self.assertEqual(
some_dataset_column.dtype,
sub_dataset_column.dtype,
'subclass column %r should have the same dtype as the parent' %
k,
)
+34 -11
View File
@@ -21,11 +21,28 @@ class Column(object):
def __init__(self, dtype):
self.dtype = dtype
def bind(self, dataset, name):
def bind(self, name):
"""
Bind a column to a concrete dataset.
Bind a `Column` object to its name.
"""
return BoundColumn(dtype=self.dtype, dataset=dataset, name=name)
return _BoundColumnDescr(dtype=self.dtype, name=name)
class _BoundColumnDescr(object):
"""
Intermediate class that sits on `DataSet` objects and returns memoized
`BoundColumn` objects when requested.
"""
def __init__(self, dtype, name):
self.dtype = dtype
self.name = name
def __get__(self, instance, owner):
return BoundColumn(
dtype=self.dtype,
dataset=owner,
name=self.name,
)
class BoundColumn(Term):
@@ -97,20 +114,26 @@ class DataSetMeta(type):
"""
def __new__(mcls, name, bases, dict_):
newtype = type.__new__(mcls, name, bases, dict_)
_columns = []
newtype = super(DataSetMeta, mcls).__new__(mcls, name, bases, dict_)
# collect all of the column names that we inherit from our parents
column_names = set().union(
*(getattr(base, '_column_names', ()) for base in bases)
)
for maybe_colname, maybe_column in iteritems(dict_):
if isinstance(maybe_column, Column):
bound_column = maybe_column.bind(newtype, maybe_colname)
setattr(newtype, maybe_colname, bound_column)
_columns.append(bound_column)
# add column names defined on our class
bound_column_descr = maybe_column.bind(maybe_colname)
setattr(newtype, maybe_colname, bound_column_descr)
column_names.add(maybe_colname)
newtype._columns = frozenset(_columns)
newtype._column_names = frozenset(column_names)
return newtype
@property
def columns(self):
return self._columns
return frozenset(
getattr(self, colname) for colname in self._column_names
)
def __lt__(self, other):
return id(self) < id(other)
@@ -119,5 +142,5 @@ class DataSetMeta(type):
return '<DataSet: %r>' % self.__name__
class DataSet(with_metaclass(DataSetMeta)):
class DataSet(with_metaclass(DataSetMeta, object)):
domain = None
+5 -1
View File
@@ -6,6 +6,7 @@ from six import iteritems
from toolz import valmap
from .core import TS_FIELD_NAME, SID_FIELD_NAME
from zipline.pipeline.data import EarningsCalendar
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
@@ -87,7 +88,8 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
expr,
resources=None,
compute_kwargs=None,
odo_kwargs=None):
odo_kwargs=None,
dataset=EarningsCalendar):
dshape = expr.dshape
if not istabular(dshape):
@@ -101,6 +103,7 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
resources,
)
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
self._dataset = dataset
def load_adjusted_array(self, columns, dates, assets, mask):
expr = self._expr
@@ -147,4 +150,5 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
return EarningsCalendarLoader(
dates,
valmap(mkseries, gb.groups),
dataset=self._dataset,
).load_adjusted_array(columns, dates, assets, mask)
+10 -6
View File
@@ -44,9 +44,12 @@ class EarningsCalendarLoader(PipelineLoader):
Whether to allow passing ``DatetimeIndex`` values in
``announcement_dates``.
"""
def __init__(self, all_dates, announcement_dates, infer_timestamps=False):
def __init__(self,
all_dates,
announcement_dates,
infer_timestamps=False,
dataset=EarningsCalendar):
self.all_dates = all_dates
self.announcement_dates = announcement_dates = (
announcement_dates.copy()
)
@@ -64,13 +67,14 @@ class EarningsCalendarLoader(PipelineLoader):
announcement_dates[k] = pd.Series(
v, index=repeat(dates[0], len(v)),
)
self.dataset = dataset
def get_loader(self, column):
"""Dispatch to the loader for ``column``.
"""
if column is EarningsCalendar.next_announcement:
if column is self.dataset.next_announcement:
return self.next_announcement_loader
elif column is EarningsCalendar.previous_announcement:
elif column is self.dataset.previous_announcement:
return self.previous_announcement_loader
else:
raise ValueError("Don't know how to load column '%s'." % column)
@@ -78,7 +82,7 @@ class EarningsCalendarLoader(PipelineLoader):
@lazyval
def next_announcement_loader(self):
return DataFrameLoader(
EarningsCalendar.next_announcement,
self.dataset.next_announcement,
next_earnings_date_frame(
self.all_dates,
self.announcement_dates,
@@ -89,7 +93,7 @@ class EarningsCalendarLoader(PipelineLoader):
@lazyval
def previous_announcement_loader(self):
return DataFrameLoader(
EarningsCalendar.previous_announcement,
self.dataset.previous_announcement,
previous_earnings_date_frame(
self.all_dates,
self.announcement_dates,