diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 6485ccce..b22151ff 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -95,11 +95,11 @@ class TestTransformAlgorithm(TestCase): self.assertEqual(algo.granularity, 'daily') self.assertEqual(algo.annualizer, 250) - algo = TestRegisterTransformAlgorithm(granularity='minutely') - self.assertEqual(algo.granularity, 'minutely') + algo = TestRegisterTransformAlgorithm(granularity='minute') + self.assertEqual(algo.granularity, 'minute') self.assertEqual(algo.annualizer, 250 * 6 * 60) - algo = TestRegisterTransformAlgorithm(granularity='minutely', + algo = TestRegisterTransformAlgorithm(granularity='minute', annualizer=10) - self.assertEqual(algo.granularity, 'minutely') + self.assertEqual(algo.granularity, 'minute') self.assertEqual(algo.annualizer, 10) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 6e61ccdd..f3e8ce1d 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -45,7 +45,7 @@ DEFAULT_CAPITAL_BASE = float("1.0e5") ANNUALIZER = {'daily': 250, 'hourly': 250 * 6, - 'minutely': 250 * 6 * 60} + 'minute': 250 * 6 * 60} class TradingAlgorithm(object): @@ -96,9 +96,14 @@ class TradingAlgorithm(object): self.slippage = VolumeShareSlippage() self.commission = PerShare() - self.granularity = kwargs.get('granularity', 'daily') - # annualizer is used for e.g. risk calculations - self.annualizer = kwargs.get('annualizer', None) + if 'granularity' in kwargs: + self.set_granularity(kwargs.pop('granularity')) + else: + self.granularity = None + + # Override annualizer if set + if 'annualizer' in kwargs: + self.annualizer = kwargs['annualizer'] # set the capital base self.capital_base = kwargs.get('capital_base', DEFAULT_CAPITAL_BASE) @@ -110,12 +115,6 @@ class TradingAlgorithm(object): # call to user-defined constructor method self.initialize(*args, **kwargs) - # set annualizer according to granularity - # this is happening after initialize because granularity - # could be set in there. - if self.annualizer is None: - self.annualizer = ANNUALIZER[self.granularity] - def _create_generator(self, environment): """ Create a basic generator setup using the sources and @@ -318,6 +317,7 @@ class TradingAlgorithm(object): assert isinstance(transforms, list) self.transforms = transforms - def set_granuliarity(self, granularity): + def set_granularity(self, granularity): assert granularity in ('daily', 'minute') self.granularity = granularity + self.annualizer = ANNUALIZER[self.granularity]