diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index cb7ab737..d22aa244 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -90,7 +90,7 @@ class BlazeToPipelineTestCase(TestCase): self.assertEqual(ds.__name__, name) self.assertTrue(issubclass(ds, DataSet)) self.assertEqual( - {c.name: c.dtype for c in ds._columns}, + {c.name: c.dtype for c in ds.columns}, {'sid': np.int64, 'value': np.float64}, ) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index 1a4f94af..94410c5d 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -28,6 +28,10 @@ class SomeDataSet(DataSet): buzz = Column(float64_dtype) +class SubDataSet(SomeDataSet): + pass + + class SomeFactor(Factor): dtype = float64_dtype window_length = 5 @@ -321,3 +325,31 @@ class ObjectIdentityTestCase(TestCase): with self.assertRaises(InvalidDType): SomeFactor(dtype=1) + + +class SubDataSetTestCase(TestCase): + def test_subdataset(self): + some_dataset_map = { + column.name: column for column in SomeDataSet.columns + } + sub_dataset_map = { + column.name: column for column in SubDataSet.columns + } + self.assertEqual( + set(some_dataset_map), + set(sub_dataset_map), + ) + for k, some_dataset_column in some_dataset_map.items(): + sub_dataset_column = sub_dataset_map[k] + self.assertIsNot( + some_dataset_column, + sub_dataset_column, + 'subclass column %r should not have the same identity as' + ' the parent' % k, + ) + self.assertEqual( + some_dataset_column.dtype, + sub_dataset_column.dtype, + 'subclass column %r should have the same dtype as the parent' % + k, + ) diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index 3e2314e7..3de16220 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -21,11 +21,28 @@ class Column(object): def __init__(self, dtype): self.dtype = dtype - def bind(self, dataset, name): + def bind(self, name): """ - Bind a column to a concrete dataset. + Bind a `Column` object to its name. """ - return BoundColumn(dtype=self.dtype, dataset=dataset, name=name) + return _BoundColumnDescr(dtype=self.dtype, name=name) + + +class _BoundColumnDescr(object): + """ + Intermediate class that sits on `DataSet` objects and returns memoized + `BoundColumn` objects when requested. + """ + def __init__(self, dtype, name): + self.dtype = dtype + self.name = name + + def __get__(self, instance, owner): + return BoundColumn( + dtype=self.dtype, + dataset=owner, + name=self.name, + ) class BoundColumn(Term): @@ -97,20 +114,24 @@ class DataSetMeta(type): """ def __new__(mcls, name, bases, dict_): - newtype = type.__new__(mcls, name, bases, dict_) - _columns = [] + newtype = super(DataSetMeta, mcls).__new__(mcls, name, bases, dict_) + column_names = set().union( + *(getattr(base, '_column_names', ()) for base in bases) + ) for maybe_colname, maybe_column in iteritems(dict_): if isinstance(maybe_column, Column): - bound_column = maybe_column.bind(newtype, maybe_colname) - setattr(newtype, maybe_colname, bound_column) - _columns.append(bound_column) + bound_column_descr = maybe_column.bind(maybe_colname) + setattr(newtype, maybe_colname, bound_column_descr) + column_names.add(maybe_colname) - newtype._columns = frozenset(_columns) + newtype._column_names = frozenset(column_names) return newtype @property def columns(self): - return self._columns + return frozenset( + getattr(self, colname) for colname in self._column_names + ) def __lt__(self, other): return id(self) < id(other) @@ -119,5 +140,5 @@ class DataSetMeta(type): return '' % self.__name__ -class DataSet(with_metaclass(DataSetMeta)): +class DataSet(with_metaclass(DataSetMeta, object)): domain = None