diff --git a/zipline/lib/labelarray.py b/zipline/lib/labelarray.py index b92faf06..a3e8e199 100644 --- a/zipline/lib/labelarray.py +++ b/zipline/lib/labelarray.py @@ -197,6 +197,29 @@ class LabelArray(ndarray): ret._missing_value = missing_value return ret + @classmethod + def from_categorical(cls, categorical, missing_value=None): + """ + Create a LabelArray from a pandas categorical. + + Parameters + ---------- + categorical : pd.Categorical + The categorical object to convert. + missing_value : bytes, unicode, or None, optional + The missing value to use for this LabelArray. + + Returns + ------- + la : LabelArray + The LabelArray representation of this categorical. + """ + return LabelArray( + categorical, + missing_value, + categorical.categories, + ) + @property def categories(self): # This is a property because it should be immutable. diff --git a/zipline/pipeline/classifiers/classifier.py b/zipline/pipeline/classifiers/classifier.py index 73bf5a31..36b0a18a 100644 --- a/zipline/pipeline/classifiers/classifier.py +++ b/zipline/pipeline/classifiers/classifier.py @@ -6,6 +6,7 @@ import operator import re from numpy import where, isnan, nan, zeros +import pandas as pd from zipline.lib.labelarray import LabelArray from zipline.lib.quantiles import quantiles @@ -303,6 +304,21 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): raise AssertionError("Expected a LabelArray, got %s." % type(data)) return data.as_categorical() + def to_workspace_value(self, result, assets): + """ + Called with the result of a pipeline. This needs to return an object + which can be put into the workspace to continue doing computations. + + This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`. + """ + data = super(Classifier, self).unprocess(result, assets) + if self.dtype == int64_dtype: + return data + assert isinstance(data, pd.Categorical), ( + 'Expected a Categorical, got %r.' % type(data).__name__ + ) + return LabelArray.from_categorical(data, self.missing_value) + @classlazyval def _downsampled_type(self): return DownsampledMixin.make_downsampled_type(Classifier) diff --git a/zipline/pipeline/term.py b/zipline/pipeline/term.py index 6cecd147..37ac214c 100644 --- a/zipline/pipeline/term.py +++ b/zipline/pipeline/term.py @@ -590,6 +590,18 @@ class ComputableTerm(Term): """ return data + def to_workspace_value(self, result, assets): + """ + Called with the result of a pipeline. This needs to return an object + which can be put into the workspace to continue doing computations. + + This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`. + """ + return result.unstack().fillna(self.missing_value).reindex( + columns=assets, + fill_value=self.missing_value, + ).values + def _downsampled_type(self): """ The expression type to return from self.downsample().