mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 12:39:51 +08:00
TST: Add unittest for daily history with midnight dt.
This commit is contained in:
+54
-3
@@ -18,6 +18,7 @@ from unittest import TestCase
|
||||
from nose_parameterized import parameterized
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from zipline.history import history
|
||||
from zipline.history.history_container import HistoryContainer
|
||||
@@ -26,7 +27,7 @@ import zipline.utils.factory as factory
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.finance.trading import SimulationParameters, TradingEnvironment
|
||||
|
||||
from zipline.sources import RandomWalkSource
|
||||
from zipline.sources import RandomWalkSource, DataFrameSource
|
||||
|
||||
from .history_cases import (
|
||||
HISTORY_CONTAINER_TEST_CASES,
|
||||
@@ -128,7 +129,8 @@ def get_index_at_dt(case_input):
|
||||
case_input['bar_count'],
|
||||
case_input['frequency'],
|
||||
None,
|
||||
False
|
||||
False,
|
||||
daily_at_midnight=False
|
||||
)
|
||||
return history.index_at_dt(history_spec, case_input['algo_dt'])
|
||||
|
||||
@@ -217,7 +219,8 @@ class TestHistoryContainer(TestCase):
|
||||
bar_count=3,
|
||||
frequency='1d',
|
||||
field='price',
|
||||
ffill=True
|
||||
ffill=True,
|
||||
daily_at_midnight=False
|
||||
)
|
||||
specs = {spec.key_str: spec}
|
||||
initial_sids = [1, ]
|
||||
@@ -342,6 +345,54 @@ class TestHistoryAlgo(TestCase):
|
||||
def setUp(self):
|
||||
np.random.seed(123)
|
||||
|
||||
def test_history_daily(self):
|
||||
bar_count = 3
|
||||
algo_text = """
|
||||
from zipline.api import history, add_history
|
||||
from copy import deepcopy
|
||||
|
||||
def initialize(context):
|
||||
add_history(bar_count={bar_count}, frequency='1d', field='price')
|
||||
context.history_trace = []
|
||||
|
||||
def handle_data(context, data):
|
||||
prices = history(bar_count={bar_count}, frequency='1d', field='price')
|
||||
context.history_trace.append(deepcopy(prices))
|
||||
""".format(bar_count=bar_count).strip()
|
||||
|
||||
# March 2006
|
||||
# Su Mo Tu We Th Fr Sa
|
||||
# 1 2 3 4
|
||||
# 5 6 7 8 9 10 11
|
||||
# 12 13 14 15 16 17 18
|
||||
# 19 20 21 22 23 24 25
|
||||
# 26 27 28 29 30 31
|
||||
|
||||
start = pd.Timestamp('2006-03-20', tz='UTC')
|
||||
end = pd.Timestamp('2006-03-30', tz='UTC')
|
||||
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=start, end=end)
|
||||
|
||||
_, df = factory.create_test_df_source(sim_params)
|
||||
df = df.astype(np.float64)
|
||||
source = DataFrameSource(df, sids=[0])
|
||||
|
||||
test_algo = TradingAlgorithm(
|
||||
script=algo_text,
|
||||
data_frequency='daily',
|
||||
sim_params=sim_params
|
||||
)
|
||||
|
||||
output = test_algo.run(source)
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
history_trace = test_algo.history_trace
|
||||
|
||||
for i, received in enumerate(history_trace[bar_count - 1:]):
|
||||
expected = df.iloc[i:i + bar_count]
|
||||
assert_frame_equal(expected, received)
|
||||
|
||||
def test_basic_history(self):
|
||||
algo_text = """
|
||||
from zipline.api import history, add_history
|
||||
|
||||
Reference in New Issue
Block a user