mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 19:05:11 +08:00
Merge branch 'history_support_daily_input'
This commit is contained in:
@@ -246,7 +246,6 @@ HISTORY_CONTAINER_TEST_CASES = {
|
||||
],
|
||||
},
|
||||
},
|
||||
|
||||
'test illiquid prices': {
|
||||
|
||||
# A list of HistorySpec objects.
|
||||
|
||||
+54
-3
@@ -18,6 +18,7 @@ from unittest import TestCase
|
||||
from nose_parameterized import parameterized
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from zipline.history import history
|
||||
from zipline.history.history_container import HistoryContainer
|
||||
@@ -26,7 +27,7 @@ import zipline.utils.factory as factory
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.finance.trading import SimulationParameters, TradingEnvironment
|
||||
|
||||
from zipline.sources import RandomWalkSource
|
||||
from zipline.sources import RandomWalkSource, DataFrameSource
|
||||
|
||||
from .history_cases import (
|
||||
HISTORY_CONTAINER_TEST_CASES,
|
||||
@@ -128,7 +129,8 @@ def get_index_at_dt(case_input):
|
||||
case_input['bar_count'],
|
||||
case_input['frequency'],
|
||||
None,
|
||||
False
|
||||
False,
|
||||
daily_at_midnight=False
|
||||
)
|
||||
return history.index_at_dt(history_spec, case_input['algo_dt'])
|
||||
|
||||
@@ -217,7 +219,8 @@ class TestHistoryContainer(TestCase):
|
||||
bar_count=3,
|
||||
frequency='1d',
|
||||
field='price',
|
||||
ffill=True
|
||||
ffill=True,
|
||||
daily_at_midnight=False
|
||||
)
|
||||
specs = {spec.key_str: spec}
|
||||
initial_sids = [1, ]
|
||||
@@ -342,6 +345,54 @@ class TestHistoryAlgo(TestCase):
|
||||
def setUp(self):
|
||||
np.random.seed(123)
|
||||
|
||||
def test_history_daily(self):
|
||||
bar_count = 3
|
||||
algo_text = """
|
||||
from zipline.api import history, add_history
|
||||
from copy import deepcopy
|
||||
|
||||
def initialize(context):
|
||||
add_history(bar_count={bar_count}, frequency='1d', field='price')
|
||||
context.history_trace = []
|
||||
|
||||
def handle_data(context, data):
|
||||
prices = history(bar_count={bar_count}, frequency='1d', field='price')
|
||||
context.history_trace.append(deepcopy(prices))
|
||||
""".format(bar_count=bar_count).strip()
|
||||
|
||||
# March 2006
|
||||
# Su Mo Tu We Th Fr Sa
|
||||
# 1 2 3 4
|
||||
# 5 6 7 8 9 10 11
|
||||
# 12 13 14 15 16 17 18
|
||||
# 19 20 21 22 23 24 25
|
||||
# 26 27 28 29 30 31
|
||||
|
||||
start = pd.Timestamp('2006-03-20', tz='UTC')
|
||||
end = pd.Timestamp('2006-03-30', tz='UTC')
|
||||
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=start, end=end)
|
||||
|
||||
_, df = factory.create_test_df_source(sim_params)
|
||||
df = df.astype(np.float64)
|
||||
source = DataFrameSource(df, sids=[0])
|
||||
|
||||
test_algo = TradingAlgorithm(
|
||||
script=algo_text,
|
||||
data_frequency='daily',
|
||||
sim_params=sim_params
|
||||
)
|
||||
|
||||
output = test_algo.run(source)
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
history_trace = test_algo.history_trace
|
||||
|
||||
for i, received in enumerate(history_trace[bar_count - 1:]):
|
||||
expected = df.iloc[i:i + bar_count]
|
||||
assert_frame_equal(expected, received)
|
||||
|
||||
def test_basic_history(self):
|
||||
algo_text = """
|
||||
from zipline.api import history, add_history
|
||||
|
||||
@@ -819,7 +819,10 @@ class TradingAlgorithm(object):
|
||||
@api_method
|
||||
def add_history(self, bar_count, frequency, field,
|
||||
ffill=True):
|
||||
history_spec = HistorySpec(bar_count, frequency, field, ffill)
|
||||
daily_at_midnight = (self.sim_params.data_frequency == 'daily')
|
||||
|
||||
history_spec = HistorySpec(bar_count, frequency, field, ffill,
|
||||
daily_at_midnight=daily_at_midnight)
|
||||
self.history_specs[history_spec.key_str] = history_spec
|
||||
|
||||
@api_method
|
||||
|
||||
+29
-14
@@ -16,6 +16,7 @@
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import re
|
||||
|
||||
from zipline.finance import trading
|
||||
@@ -41,7 +42,7 @@ class Frequency(object):
|
||||
SUPPORTED_FREQUENCIES = frozenset({'1d', '1m'})
|
||||
MAX_MINUTES = {'m': 1, 'd': 390}
|
||||
|
||||
def __init__(self, freq_str):
|
||||
def __init__(self, freq_str, daily_at_midnight=False):
|
||||
|
||||
if freq_str not in self.SUPPORTED_FREQUENCIES:
|
||||
raise ValueError(
|
||||
@@ -56,25 +57,31 @@ class Frequency(object):
|
||||
# unit_str - The unit type, e.g. 'd'
|
||||
self.num, self.unit_str = parse_freq_str(freq_str)
|
||||
|
||||
self.daily_at_midnight = daily_at_midnight
|
||||
|
||||
def next_window_start(self, previous_window_close):
|
||||
"""
|
||||
Get the first minute of the window starting after a window that
|
||||
finished on @previous_window_close.
|
||||
"""
|
||||
if self.unit_str == 'd':
|
||||
return self.next_day_window_start(previous_window_close)
|
||||
return self.next_day_window_start(previous_window_close,
|
||||
self.daily_at_midnight)
|
||||
elif self.unit_str == 'm':
|
||||
return self.next_minute_window_start(previous_window_close)
|
||||
|
||||
@staticmethod
|
||||
def next_day_window_start(previous_window_close):
|
||||
def next_day_window_start(previous_window_close, daily_at_midnight=False):
|
||||
"""
|
||||
Get the next day window start after @previous_window_close. This is
|
||||
defined as the first market open strictly greater than
|
||||
@previous_window_close.
|
||||
"""
|
||||
env = trading.environment
|
||||
next_open, _ = env.next_open_and_close(previous_window_close)
|
||||
if daily_at_midnight:
|
||||
next_open = env.next_trading_day(previous_window_close)
|
||||
else:
|
||||
next_open, _ = env.next_open_and_close(previous_window_close)
|
||||
return next_open
|
||||
|
||||
@staticmethod
|
||||
@@ -107,8 +114,7 @@ class Frequency(object):
|
||||
elif self.unit_str == 'm':
|
||||
return self.minute_window_close(window_start, self.num)
|
||||
|
||||
@staticmethod
|
||||
def day_window_open(window_close, num_days):
|
||||
def day_window_open(self, window_close, num_days):
|
||||
"""
|
||||
Get the first minute for a daily window of length @num_days with last
|
||||
minute @window_close. This is calculated by searching backward until
|
||||
@@ -120,10 +126,13 @@ class Frequency(object):
|
||||
1,
|
||||
offset=-(num_days - 1)
|
||||
).market_open.iloc[0]
|
||||
|
||||
if self.daily_at_midnight:
|
||||
open_ = pd.tslib.normalize_date(open_)
|
||||
|
||||
return open_
|
||||
|
||||
@staticmethod
|
||||
def minute_window_open(window_close, num_minutes):
|
||||
def minute_window_open(self, window_close, num_minutes):
|
||||
"""
|
||||
Get the first minute for a minutely window of length @num_minutes with
|
||||
last minute @window_close.
|
||||
@@ -138,8 +147,7 @@ class Frequency(object):
|
||||
env = trading.environment
|
||||
return env.market_minute_window(window_close, count=-num_minutes)[-1]
|
||||
|
||||
@staticmethod
|
||||
def day_window_close(window_start, num_days):
|
||||
def day_window_close(self, window_start, num_days):
|
||||
"""
|
||||
Get the last minute for a daily window of length @num_days with first
|
||||
minute @window_start. This is calculated by searching forward until
|
||||
@@ -174,10 +182,13 @@ class Frequency(object):
|
||||
1,
|
||||
offset=num_days - 1
|
||||
).market_close.iloc[0]
|
||||
|
||||
if self.daily_at_midnight:
|
||||
close = pd.tslib.normalize_date(close)
|
||||
|
||||
return close
|
||||
|
||||
@staticmethod
|
||||
def minute_window_close(window_start, num_minutes):
|
||||
def minute_window_close(self, window_start, num_minutes):
|
||||
"""
|
||||
Get the last minute for a minutely window of length @num_minutes with
|
||||
first minute @window_start.
|
||||
@@ -229,11 +240,12 @@ class HistorySpec(object):
|
||||
return "{0}:{1}:{2}:{3}".format(
|
||||
bar_count, freq_str, field, ffill)
|
||||
|
||||
def __init__(self, bar_count, frequency, field, ffill):
|
||||
def __init__(self, bar_count, frequency, field, ffill,
|
||||
daily_at_midnight=False):
|
||||
# Number of bars to look back.
|
||||
self.bar_count = bar_count
|
||||
if isinstance(frequency, str):
|
||||
frequency = Frequency(frequency)
|
||||
frequency = Frequency(frequency, daily_at_midnight)
|
||||
# The frequency at which the data is sampled.
|
||||
self.frequency = frequency
|
||||
# The field, e.g. 'price', 'volume', etc.
|
||||
@@ -272,6 +284,9 @@ def days_index_at_dt(history_spec, algo_dt):
|
||||
step=history_spec.frequency.num,
|
||||
).market_close
|
||||
|
||||
if history_spec.frequency.daily_at_midnight:
|
||||
market_closes = market_closes.apply(pd.tslib.normalize_date)
|
||||
|
||||
# Append the current algo_dt as the last index value.
|
||||
# Using the 'rawer' numpy array values here because of a bottleneck
|
||||
# that appeared when using DatetimeIndex
|
||||
|
||||
@@ -246,6 +246,13 @@ class HistoryContainer(object):
|
||||
)
|
||||
return rp
|
||||
|
||||
def convert_columns(self, values):
|
||||
"""
|
||||
If columns have a specific type you want to enforce, overwrite this
|
||||
method and return the transformed values.
|
||||
"""
|
||||
return values
|
||||
|
||||
def create_return_frames(self, algo_dt):
|
||||
"""
|
||||
Populates the return frame cache.
|
||||
@@ -257,7 +264,8 @@ class HistoryContainer(object):
|
||||
index = pd.to_datetime(index_at_dt(history_spec, algo_dt))
|
||||
frame = pd.DataFrame(
|
||||
index=index,
|
||||
columns=map(int, self.buffer_panel.minor_axis.values),
|
||||
columns=self.convert_columns(
|
||||
self.buffer_panel.minor_axis.values),
|
||||
dtype=np.float64)
|
||||
self.return_frames[spec_key] = frame
|
||||
|
||||
|
||||
Reference in New Issue
Block a user