diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 7622c665..6e61ccdd 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -43,6 +43,10 @@ from zipline import MESSAGES DEFAULT_CAPITAL_BASE = float("1.0e5") +ANNUALIZER = {'daily': 250, + 'hourly': 250 * 6, + 'minutely': 250 * 6 * 60} + class TradingAlgorithm(object): """Base class for trading algorithms. Inherit and overload @@ -110,17 +114,7 @@ class TradingAlgorithm(object): # this is happening after initialize because granularity # could be set in there. if self.annualizer is None: - if self.granularity == 'daily': - self.annualizer = 250 - elif self.granularity == 'hourly': - # trading days * hours - self.annualizer = 250 * 6 - elif self.granularity == 'minutely': - # trading days * hours * minutes - self.annualizer = 250 * 6 * 60 - else: - raise NotImplementedError('{g} is not implemented.\ - '.format(g=self.granularity)) + self.annualizer = ANNUALIZER[self.granularity] def _create_generator(self, environment): """ @@ -323,3 +317,7 @@ class TradingAlgorithm(object): def set_transforms(self, transforms): assert isinstance(transforms, list) self.transforms = transforms + + def set_granuliarity(self, granularity): + assert granularity in ('daily', 'minute') + self.granularity = granularity