mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 03:02:36 +08:00
DEV: Add tz kwarg to get_datetime.
This commit is contained in:
+57
-1
@@ -15,7 +15,9 @@
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
from mock import MagicMock
|
||||
from nose_parameterized import parameterized
|
||||
from six.moves import range
|
||||
from textwrap import dedent
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
@@ -67,7 +69,11 @@ from zipline.test_algorithms import (
|
||||
record_variables,
|
||||
)
|
||||
|
||||
from zipline.utils.test_utils import drain_zipline, assert_single_position
|
||||
from zipline.utils.test_utils import (
|
||||
assert_single_position,
|
||||
drain_zipline,
|
||||
to_utc,
|
||||
)
|
||||
|
||||
from zipline.sources import (SpecificEquityTrades,
|
||||
DataFrameSource,
|
||||
@@ -729,6 +735,56 @@ def handle_data(context, data):
|
||||
self.assertIsNot(output, None)
|
||||
|
||||
|
||||
class TestGetDatetime(TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
('default', None,),
|
||||
('utc', 'UTC',),
|
||||
('us_east', 'US/Eastern',),
|
||||
]
|
||||
)
|
||||
def test_get_datetime(self, name, tz):
|
||||
|
||||
algo = dedent(
|
||||
"""
|
||||
import pandas as pd
|
||||
from zipline.api import get_datetime
|
||||
|
||||
def initialize(context):
|
||||
context.tz = {tz} or 'UTC'
|
||||
context.first_bar = True
|
||||
|
||||
def handle_data(context, data):
|
||||
if context.first_bar:
|
||||
dt = get_datetime({tz})
|
||||
if dt.tz.zone != context.tz:
|
||||
raise ValueError("Mismatched Zone")
|
||||
elif dt.tz_convert("US/Eastern").hour != 9:
|
||||
raise ValueError("Mismatched Hour")
|
||||
elif dt.tz_convert("US/Eastern").minute != 31:
|
||||
raise ValueError("Mismatched Minute")
|
||||
context.first_bar = False
|
||||
""".format(tz=repr(tz))
|
||||
)
|
||||
|
||||
start = to_utc('2014-01-02 9:31')
|
||||
end = to_utc('2014-01-03 9:31')
|
||||
source = RandomWalkSource(
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
data_frequency='minute'
|
||||
)
|
||||
algo = TradingAlgorithm(
|
||||
script=algo,
|
||||
sim_params=sim_params,
|
||||
)
|
||||
algo.run(source)
|
||||
self.assertFalse(algo.first_bar)
|
||||
|
||||
|
||||
class TestTradingControls(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user