Merge pull request #1342 from quantopian/deal-with-bad-getdatetime-params

ENH: Deal with bad parameters to `get_datetime`
This commit is contained in:
Jean Bredeche
2016-07-21 22:14:02 -04:00
committed by GitHub
2 changed files with 25 additions and 9 deletions
+17
View File
@@ -319,6 +319,23 @@ def handle_data(context, data):
algo.namespace['assert_equal'] = self.assertEqual
algo.run(self.data_portal)
def test_datetime_bad_params(self):
algo_text = """
from zipline.api import get_datetime
from pytz import timezone
def initialize(context):
pass
def handle_data(context, data):
get_datetime(timezone)
"""
with self.assertRaises(TypeError):
algo = TradingAlgorithm(script=algo_text,
sim_params=self.sim_params,
env=self.env)
algo.run(self.data_portal)
def test_get_environment(self):
expected_env = {
'arena': 'backtest',
+8 -9
View File
@@ -15,7 +15,7 @@
from copy import copy
import operator as op
import warnings
from datetime import tzinfo
import logbook
import pytz
import pandas as pd
@@ -95,7 +95,8 @@ from zipline.utils.api_support import (
ZiplineAPI,
disallowed_in_before_trading_start)
from zipline.utils.input_validation import ensure_upper_case, error_keywords
from zipline.utils.input_validation import ensure_upper_case, error_keywords, \
expect_types, optional, coerce_string
from zipline.utils.cache import CachedObject, Expired
from zipline.utils.calendars import get_calendar
@@ -1477,8 +1478,11 @@ class TradingAlgorithm(object):
self.performance_needs_update = True
@api_method
@preprocess(tz=coerce_string(pytz.timezone))
@expect_types(tz=optional(tzinfo))
def get_datetime(self, tz=None):
"""Returns the current simulation datetime.
"""
Returns the current simulation datetime.
Parameters
----------
@@ -1492,14 +1496,9 @@ class TradingAlgorithm(object):
"""
dt = self.datetime
assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
if tz is not None:
# Convert to the given timezone passed as a string or tzinfo.
if isinstance(tz, string_types):
tz = pytz.timezone(tz)
dt = dt.astimezone(tz)
return dt # datetime.datetime objects are immutable.
return dt
def update_dividends(self, dividend_frame):
"""