mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 05:23:38 +08:00
Merge pull request #1342 from quantopian/deal-with-bad-getdatetime-params
ENH: Deal with bad parameters to `get_datetime`
This commit is contained in:
@@ -319,6 +319,23 @@ def handle_data(context, data):
|
||||
algo.namespace['assert_equal'] = self.assertEqual
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_datetime_bad_params(self):
|
||||
algo_text = """
|
||||
from zipline.api import get_datetime
|
||||
from pytz import timezone
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
get_datetime(timezone)
|
||||
"""
|
||||
with self.assertRaises(TypeError):
|
||||
algo = TradingAlgorithm(script=algo_text,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_get_environment(self):
|
||||
expected_env = {
|
||||
'arena': 'backtest',
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from copy import copy
|
||||
import operator as op
|
||||
import warnings
|
||||
|
||||
from datetime import tzinfo
|
||||
import logbook
|
||||
import pytz
|
||||
import pandas as pd
|
||||
@@ -95,7 +95,8 @@ from zipline.utils.api_support import (
|
||||
ZiplineAPI,
|
||||
disallowed_in_before_trading_start)
|
||||
|
||||
from zipline.utils.input_validation import ensure_upper_case, error_keywords
|
||||
from zipline.utils.input_validation import ensure_upper_case, error_keywords, \
|
||||
expect_types, optional, coerce_string
|
||||
from zipline.utils.cache import CachedObject, Expired
|
||||
from zipline.utils.calendars import get_calendar
|
||||
|
||||
@@ -1477,8 +1478,11 @@ class TradingAlgorithm(object):
|
||||
self.performance_needs_update = True
|
||||
|
||||
@api_method
|
||||
@preprocess(tz=coerce_string(pytz.timezone))
|
||||
@expect_types(tz=optional(tzinfo))
|
||||
def get_datetime(self, tz=None):
|
||||
"""Returns the current simulation datetime.
|
||||
"""
|
||||
Returns the current simulation datetime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1492,14 +1496,9 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
dt = self.datetime
|
||||
assert dt.tzinfo == pytz.utc, "Algorithm should have a utc datetime"
|
||||
|
||||
if tz is not None:
|
||||
# Convert to the given timezone passed as a string or tzinfo.
|
||||
if isinstance(tz, string_types):
|
||||
tz = pytz.timezone(tz)
|
||||
dt = dt.astimezone(tz)
|
||||
|
||||
return dt # datetime.datetime objects are immutable.
|
||||
return dt
|
||||
|
||||
def update_dividends(self, dividend_frame):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user