From 63ef840363f855bedbcee2f8bcb1e2727d6c1651 Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Thu, 21 Jul 2016 14:15:42 -0400 Subject: [PATCH] ENH: Verify params passed to `get_datetime` --- tests/test_algorithm.py | 17 +++++++++++++++++ zipline/algorithm.py | 17 ++++++++--------- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2ce0cb2d..955a6ce3 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -319,6 +319,23 @@ def handle_data(context, data): algo.namespace['assert_equal'] = self.assertEqual algo.run(self.data_portal) + def test_datetime_bad_params(self): + algo_text = """ +from zipline.api import get_datetime +from pytz import timezone + +def initialize(context): + pass + +def handle_data(context, data): + get_datetime(timezone) +""" + with self.assertRaises(TypeError): + algo = TradingAlgorithm(script=algo_text, + sim_params=self.sim_params, + env=self.env) + algo.run(self.data_portal) + def test_get_environment(self): expected_env = { 'arena': 'backtest', diff --git a/zipline/algorithm.py b/zipline/algorithm.py index bf8c5935..bfbea882 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -15,7 +15,7 @@ from copy import copy import operator as op import warnings - +from datetime import tzinfo import logbook import pytz import pandas as pd @@ -95,7 +95,8 @@ from zipline.utils.api_support import ( ZiplineAPI, disallowed_in_before_trading_start) -from zipline.utils.input_validation import ensure_upper_case, error_keywords +from zipline.utils.input_validation import ensure_upper_case, error_keywords, \ + expect_types, optional, coerce_string from zipline.utils.cache import CachedObject, Expired from zipline.utils.calendars import get_calendar @@ -1477,8 +1478,11 @@ class TradingAlgorithm(object): self.performance_needs_update = True @api_method + @preprocess(tz=coerce_string(pytz.timezone)) + @expect_types(tz=optional(tzinfo)) def get_datetime(self, tz=None): - """Returns the current simulation datetime. + """ + Returns the current simulation datetime. Parameters ---------- @@ -1492,14 +1496,9 @@ class TradingAlgorithm(object): """ dt = self.datetime assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime" - if tz is not None: - # Convert to the given timezone passed as a string or tzinfo. - if isinstance(tz, string_types): - tz = pytz.timezone(tz) dt = dt.astimezone(tz) - - return dt # datetime.datetime objects are immutable. + return dt def update_dividends(self, dividend_frame): """