From affeb2afbcfdbacc19814fc65b7a1dccd45e7597 Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Tue, 21 Oct 2014 22:36:17 -0400 Subject: [PATCH] DEV: Add tz kwarg to get_datetime. --- tests/test_algorithm.py | 58 ++++++++++++++++++++++++++++++++++++++++- zipline/algorithm.py | 4 ++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 6193de5b..9c1075d9 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -15,7 +15,9 @@ import datetime from datetime import timedelta from mock import MagicMock +from nose_parameterized import parameterized from six.moves import range +from textwrap import dedent from unittest import TestCase import numpy as np @@ -67,7 +69,11 @@ from zipline.test_algorithms import ( record_variables, ) -from zipline.utils.test_utils import drain_zipline, assert_single_position +from zipline.utils.test_utils import ( + assert_single_position, + drain_zipline, + to_utc, +) from zipline.sources import (SpecificEquityTrades, DataFrameSource, @@ -729,6 +735,56 @@ def handle_data(context, data): self.assertIsNot(output, None) +class TestGetDatetime(TestCase): + + @parameterized.expand( + [ + ('default', None,), + ('utc', 'UTC',), + ('us_east', 'US/Eastern',), + ] + ) + def test_get_datetime(self, name, tz): + + algo = dedent( + """ + import pandas as pd + from zipline.api import get_datetime + + def initialize(context): + context.tz = {tz} or 'UTC' + context.first_bar = True + + def handle_data(context, data): + if context.first_bar: + dt = get_datetime({tz}) + if dt.tz.zone != context.tz: + raise ValueError("Mismatched Zone") + elif dt.tz_convert("US/Eastern").hour != 9: + raise ValueError("Mismatched Hour") + elif dt.tz_convert("US/Eastern").minute != 31: + raise ValueError("Mismatched Minute") + context.first_bar = False + """.format(tz=repr(tz)) + ) + + start = to_utc('2014-01-02 9:31') + end = to_utc('2014-01-03 9:31') + source = RandomWalkSource( + start=start, + end=end, + ) + sim_params = factory.create_simulation_parameters( + data_frequency='minute' + ) + algo = TradingAlgorithm( + script=algo, + sim_params=sim_params, + ) + algo.run(source) + self.assertFalse(algo.first_bar) + + class TestTradingControls(TestCase): def setUp(self): diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 5cfc307d..95c74a0a 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -746,13 +746,15 @@ class TradingAlgorithm(object): self.blotter.set_date(dt) @api_method - def get_datetime(self): + def get_datetime(self, tz=None): """ Returns a copy of the datetime. """ date_copy = copy(self.datetime) assert date_copy.tzinfo == pytz.utc, \ "Algorithm should have a utc datetime" + if tz is not None: + date_copy = date_copy.tz_convert(tz) return date_copy def set_transact(self, transact):