mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 20:20:55 +08:00
Merge pull request #924 from quantopian/dataset-subclassing
ENH: Make datasets have subclass relationships
This commit is contained in:
@@ -59,6 +59,10 @@ Enhancements
|
||||
:class:`~zipline.pipeline.factors.ExponentialWeightedMovingStdDev`
|
||||
factors. (:issue:`910`).
|
||||
|
||||
* Allow :class:`~zipline.pipeline.data.DataSet` classes to be subclassed where
|
||||
subclasses inherit all of the columns from the parent. These columns will be
|
||||
new sentinels so you can register them a custom loader (:issue:`924`).
|
||||
|
||||
Experimental Features
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
self.assertEqual(ds.__name__, name)
|
||||
self.assertTrue(issubclass(ds, DataSet))
|
||||
self.assertEqual(
|
||||
{c.name: c.dtype for c in ds._columns},
|
||||
{c.name: c.dtype for c in ds.columns},
|
||||
{'sid': np.int64, 'value': np.float64},
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +28,14 @@ class SomeDataSet(DataSet):
|
||||
buzz = Column(float64_dtype)
|
||||
|
||||
|
||||
class SubDataSet(SomeDataSet):
|
||||
pass
|
||||
|
||||
|
||||
class SubDataSetNewCol(SomeDataSet):
|
||||
qux = Column(float64_dtype)
|
||||
|
||||
|
||||
class SomeFactor(Factor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
@@ -321,3 +329,66 @@ class ObjectIdentityTestCase(TestCase):
|
||||
|
||||
with self.assertRaises(InvalidDType):
|
||||
SomeFactor(dtype=1)
|
||||
|
||||
|
||||
class SubDataSetTestCase(TestCase):
|
||||
def test_subdataset(self):
|
||||
some_dataset_map = {
|
||||
column.name: column for column in SomeDataSet.columns
|
||||
}
|
||||
sub_dataset_map = {
|
||||
column.name: column for column in SubDataSet.columns
|
||||
}
|
||||
self.assertEqual(
|
||||
{column.name for column in SomeDataSet.columns},
|
||||
{column.name for column in SubDataSet.columns},
|
||||
)
|
||||
for k, some_dataset_column in some_dataset_map.items():
|
||||
sub_dataset_column = sub_dataset_map[k]
|
||||
self.assertIsNot(
|
||||
some_dataset_column,
|
||||
sub_dataset_column,
|
||||
'subclass column %r should not have the same identity as'
|
||||
' the parent' % k,
|
||||
)
|
||||
self.assertEqual(
|
||||
some_dataset_column.dtype,
|
||||
sub_dataset_column.dtype,
|
||||
'subclass column %r should have the same dtype as the parent' %
|
||||
k,
|
||||
)
|
||||
|
||||
def test_add_column(self):
|
||||
some_dataset_map = {
|
||||
column.name: column for column in SomeDataSet.columns
|
||||
}
|
||||
sub_dataset_new_col_map = {
|
||||
column.name: column for column in SubDataSetNewCol.columns
|
||||
}
|
||||
sub_col_names = {column.name for column in SubDataSetNewCol.columns}
|
||||
|
||||
# check our extra col
|
||||
self.assertIn('qux', sub_col_names)
|
||||
self.assertEqual(
|
||||
sub_dataset_new_col_map['qux'].dtype,
|
||||
float64_dtype,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{column.name for column in SomeDataSet.columns},
|
||||
sub_col_names - {'qux'},
|
||||
)
|
||||
for k, some_dataset_column in some_dataset_map.items():
|
||||
sub_dataset_column = sub_dataset_new_col_map[k]
|
||||
self.assertIsNot(
|
||||
some_dataset_column,
|
||||
sub_dataset_column,
|
||||
'subclass column %r should not have the same identity as'
|
||||
' the parent' % k,
|
||||
)
|
||||
self.assertEqual(
|
||||
some_dataset_column.dtype,
|
||||
sub_dataset_column.dtype,
|
||||
'subclass column %r should have the same dtype as the parent' %
|
||||
k,
|
||||
)
|
||||
|
||||
@@ -21,11 +21,28 @@ class Column(object):
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
def bind(self, dataset, name):
|
||||
def bind(self, name):
|
||||
"""
|
||||
Bind a column to a concrete dataset.
|
||||
Bind a `Column` object to its name.
|
||||
"""
|
||||
return BoundColumn(dtype=self.dtype, dataset=dataset, name=name)
|
||||
return _BoundColumnDescr(dtype=self.dtype, name=name)
|
||||
|
||||
|
||||
class _BoundColumnDescr(object):
|
||||
"""
|
||||
Intermediate class that sits on `DataSet` objects and returns memoized
|
||||
`BoundColumn` objects when requested.
|
||||
"""
|
||||
def __init__(self, dtype, name):
|
||||
self.dtype = dtype
|
||||
self.name = name
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return BoundColumn(
|
||||
dtype=self.dtype,
|
||||
dataset=owner,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
|
||||
class BoundColumn(Term):
|
||||
@@ -97,20 +114,26 @@ class DataSetMeta(type):
|
||||
"""
|
||||
|
||||
def __new__(mcls, name, bases, dict_):
|
||||
newtype = type.__new__(mcls, name, bases, dict_)
|
||||
_columns = []
|
||||
newtype = super(DataSetMeta, mcls).__new__(mcls, name, bases, dict_)
|
||||
# collect all of the column names that we inherit from our parents
|
||||
column_names = set().union(
|
||||
*(getattr(base, '_column_names', ()) for base in bases)
|
||||
)
|
||||
for maybe_colname, maybe_column in iteritems(dict_):
|
||||
if isinstance(maybe_column, Column):
|
||||
bound_column = maybe_column.bind(newtype, maybe_colname)
|
||||
setattr(newtype, maybe_colname, bound_column)
|
||||
_columns.append(bound_column)
|
||||
# add column names defined on our class
|
||||
bound_column_descr = maybe_column.bind(maybe_colname)
|
||||
setattr(newtype, maybe_colname, bound_column_descr)
|
||||
column_names.add(maybe_colname)
|
||||
|
||||
newtype._columns = frozenset(_columns)
|
||||
newtype._column_names = frozenset(column_names)
|
||||
return newtype
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self._columns
|
||||
return frozenset(
|
||||
getattr(self, colname) for colname in self._column_names
|
||||
)
|
||||
|
||||
def __lt__(self, other):
|
||||
return id(self) < id(other)
|
||||
@@ -119,5 +142,5 @@ class DataSetMeta(type):
|
||||
return '<DataSet: %r>' % self.__name__
|
||||
|
||||
|
||||
class DataSet(with_metaclass(DataSetMeta)):
|
||||
class DataSet(with_metaclass(DataSetMeta, object)):
|
||||
domain = None
|
||||
|
||||
@@ -6,6 +6,7 @@ from six import iteritems
|
||||
from toolz import valmap
|
||||
|
||||
from .core import TS_FIELD_NAME, SID_FIELD_NAME
|
||||
from zipline.pipeline.data import EarningsCalendar
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
|
||||
|
||||
@@ -87,7 +88,8 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
expr,
|
||||
resources=None,
|
||||
compute_kwargs=None,
|
||||
odo_kwargs=None):
|
||||
odo_kwargs=None,
|
||||
dataset=EarningsCalendar):
|
||||
dshape = expr.dshape
|
||||
|
||||
if not istabular(dshape):
|
||||
@@ -101,6 +103,7 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
resources,
|
||||
)
|
||||
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
|
||||
self._dataset = dataset
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
expr = self._expr
|
||||
@@ -147,4 +150,5 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
return EarningsCalendarLoader(
|
||||
dates,
|
||||
valmap(mkseries, gb.groups),
|
||||
dataset=self._dataset,
|
||||
).load_adjusted_array(columns, dates, assets, mask)
|
||||
|
||||
@@ -44,9 +44,12 @@ class EarningsCalendarLoader(PipelineLoader):
|
||||
Whether to allow passing ``DatetimeIndex`` values in
|
||||
``announcement_dates``.
|
||||
"""
|
||||
def __init__(self, all_dates, announcement_dates, infer_timestamps=False):
|
||||
def __init__(self,
|
||||
all_dates,
|
||||
announcement_dates,
|
||||
infer_timestamps=False,
|
||||
dataset=EarningsCalendar):
|
||||
self.all_dates = all_dates
|
||||
|
||||
self.announcement_dates = announcement_dates = (
|
||||
announcement_dates.copy()
|
||||
)
|
||||
@@ -64,13 +67,14 @@ class EarningsCalendarLoader(PipelineLoader):
|
||||
announcement_dates[k] = pd.Series(
|
||||
v, index=repeat(dates[0], len(v)),
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
def get_loader(self, column):
|
||||
"""Dispatch to the loader for ``column``.
|
||||
"""
|
||||
if column is EarningsCalendar.next_announcement:
|
||||
if column is self.dataset.next_announcement:
|
||||
return self.next_announcement_loader
|
||||
elif column is EarningsCalendar.previous_announcement:
|
||||
elif column is self.dataset.previous_announcement:
|
||||
return self.previous_announcement_loader
|
||||
else:
|
||||
raise ValueError("Don't know how to load column '%s'." % column)
|
||||
@@ -78,7 +82,7 @@ class EarningsCalendarLoader(PipelineLoader):
|
||||
@lazyval
|
||||
def next_announcement_loader(self):
|
||||
return DataFrameLoader(
|
||||
EarningsCalendar.next_announcement,
|
||||
self.dataset.next_announcement,
|
||||
next_earnings_date_frame(
|
||||
self.all_dates,
|
||||
self.announcement_dates,
|
||||
@@ -89,7 +93,7 @@ class EarningsCalendarLoader(PipelineLoader):
|
||||
@lazyval
|
||||
def previous_announcement_loader(self):
|
||||
return DataFrameLoader(
|
||||
EarningsCalendar.previous_announcement,
|
||||
self.dataset.previous_announcement,
|
||||
previous_earnings_date_frame(
|
||||
self.all_dates,
|
||||
self.announcement_dates,
|
||||
|
||||
Reference in New Issue
Block a user