DEV: Add tz kwarg to get_datetime.

This commit is contained in:
Scott Sanderson
2014-10-21 22:36:17 -04:00
parent 820115f7be
commit affeb2afbc
2 changed files with 60 additions and 2 deletions
+57 -1
View File
@@ -15,7 +15,9 @@
import datetime
from datetime import timedelta
from mock import MagicMock
from nose_parameterized import parameterized
from six.moves import range
from textwrap import dedent
from unittest import TestCase
import numpy as np
@@ -67,7 +69,11 @@ from zipline.test_algorithms import (
record_variables,
)
from zipline.utils.test_utils import drain_zipline, assert_single_position
from zipline.utils.test_utils import (
assert_single_position,
drain_zipline,
to_utc,
)
from zipline.sources import (SpecificEquityTrades,
DataFrameSource,
@@ -729,6 +735,56 @@ def handle_data(context, data):
self.assertIsNot(output, None)
class TestGetDatetime(TestCase):
@parameterized.expand(
[
('default', None,),
('utc', 'UTC',),
('us_east', 'US/Eastern',),
]
)
def test_get_datetime(self, name, tz):
algo = dedent(
"""
import pandas as pd
from zipline.api import get_datetime
def initialize(context):
context.tz = {tz} or 'UTC'
context.first_bar = True
def handle_data(context, data):
if context.first_bar:
dt = get_datetime({tz})
if dt.tz.zone != context.tz:
raise ValueError("Mismatched Zone")
elif dt.tz_convert("US/Eastern").hour != 9:
raise ValueError("Mismatched Hour")
elif dt.tz_convert("US/Eastern").minute != 31:
raise ValueError("Mismatched Minute")
context.first_bar = False
""".format(tz=repr(tz))
)
start = to_utc('2014-01-02 9:31')
end = to_utc('2014-01-03 9:31')
source = RandomWalkSource(
start=start,
end=end,
)
sim_params = factory.create_simulation_parameters(
data_frequency='minute'
)
algo = TradingAlgorithm(
script=algo,
sim_params=sim_params,
)
algo.run(source)
self.assertFalse(algo.first_bar)
class TestTradingControls(TestCase):
def setUp(self):
+3 -1
View File
@@ -746,13 +746,15 @@ class TradingAlgorithm(object):
self.blotter.set_date(dt)
@api_method
def get_datetime(self):
def get_datetime(self, tz=None):
"""
Returns a copy of the datetime.
"""
date_copy = copy(self.datetime)
assert date_copy.tzinfo == pytz.utc, \
"Algorithm should have a utc datetime"
if tz is not None:
date_copy = date_copy.tz_convert(tz)
return date_copy
def set_transact(self, transact):