mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 11:49:18 +08:00
846 lines
24 KiB
Python
846 lines
24 KiB
Python
#
|
|
# Copyright 2014 Quantopian, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from unittest import TestCase
|
|
|
|
from nose_parameterized import parameterized
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pandas.util.testing import assert_frame_equal
|
|
|
|
from zipline.history import history
|
|
from zipline.history.history_container import HistoryContainer
|
|
from zipline.protocol import BarData
|
|
import zipline.utils.factory as factory
|
|
from zipline import TradingAlgorithm
|
|
from zipline.finance.trading import SimulationParameters, TradingEnvironment
|
|
|
|
from zipline.sources import RandomWalkSource, DataFrameSource
|
|
|
|
from .history_cases import (
|
|
HISTORY_CONTAINER_TEST_CASES,
|
|
)
|
|
|
|
# Cases are over the July 4th holiday, to ensure use of trading calendar.
|
|
|
|
# March 2013
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2
|
|
# 3 4 5 6 7 8 9
|
|
# 10 11 12 13 14 15 16
|
|
# 17 18 19 20 21 22 23
|
|
# 24 25 26 27 28 29 30
|
|
# 31
|
|
# April 2013
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6
|
|
# 7 8 9 10 11 12 13
|
|
# 14 15 16 17 18 19 20
|
|
# 21 22 23 24 25 26 27
|
|
# 28 29 30
|
|
#
|
|
# May 2013
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4
|
|
# 5 6 7 8 9 10 11
|
|
# 12 13 14 15 16 17 18
|
|
# 19 20 21 22 23 24 25
|
|
# 26 27 28 29 30 31
|
|
#
|
|
# June 2013
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1
|
|
# 2 3 4 5 6 7 8
|
|
# 9 10 11 12 13 14 15
|
|
# 16 17 18 19 20 21 22
|
|
# 23 24 25 26 27 28 29
|
|
# 30
|
|
# July 2013
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6
|
|
# 7 8 9 10 11 12 13
|
|
# 14 15 16 17 18 19 20
|
|
# 21 22 23 24 25 26 27
|
|
# 28 29 30 31
|
|
#
|
|
# Times to be converted via:
|
|
# pd.Timestamp('2013-07-05 9:31', tz='US/Eastern').tz_convert('UTC')},
|
|
|
|
INDEX_TEST_CASES_RAW = {
|
|
'week of daily data': {
|
|
'input': {'bar_count': 5,
|
|
'frequency': '1d',
|
|
'algo_dt': '2013-07-05 9:31AM'},
|
|
'expected': [
|
|
'2013-06-28 4:00PM',
|
|
'2013-07-01 4:00PM',
|
|
'2013-07-02 4:00PM',
|
|
'2013-07-03 1:00PM',
|
|
'2013-07-05 9:31AM',
|
|
]
|
|
},
|
|
'five minutes on july 5th open': {
|
|
'input': {'bar_count': 5,
|
|
'frequency': '1m',
|
|
'algo_dt': '2013-07-05 9:31AM'},
|
|
'expected': [
|
|
'2013-07-03 12:57PM',
|
|
'2013-07-03 12:58PM',
|
|
'2013-07-03 12:59PM',
|
|
'2013-07-03 1:00PM',
|
|
'2013-07-05 9:31AM',
|
|
]
|
|
},
|
|
}
|
|
|
|
|
|
def to_timestamp(dt_str):
|
|
return pd.Timestamp(dt_str, tz='US/Eastern').tz_convert('UTC')
|
|
|
|
|
|
def convert_cases(cases):
|
|
"""
|
|
Convert raw strings to values comparable with system data.
|
|
"""
|
|
cases = cases.copy()
|
|
for case in cases.values():
|
|
case['input']['algo_dt'] = to_timestamp(case['input']['algo_dt'])
|
|
case['expected'] = pd.DatetimeIndex([to_timestamp(dt_str) for dt_str
|
|
in case['expected']])
|
|
return cases
|
|
|
|
INDEX_TEST_CASES = convert_cases(INDEX_TEST_CASES_RAW)
|
|
|
|
|
|
def get_index_at_dt(case_input):
|
|
history_spec = history.HistorySpec(
|
|
case_input['bar_count'],
|
|
case_input['frequency'],
|
|
None,
|
|
False,
|
|
daily_at_midnight=False
|
|
)
|
|
return history.index_at_dt(history_spec, case_input['algo_dt'])
|
|
|
|
|
|
class TestHistoryIndex(TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.environment = TradingEnvironment.instance()
|
|
|
|
@parameterized.expand(
|
|
[(name, case['input'], case['expected'])
|
|
for name, case in INDEX_TEST_CASES.items()]
|
|
)
|
|
def test_index_at_dt(self, name, case_input, expected):
|
|
history_index = get_index_at_dt(case_input)
|
|
|
|
history_series = pd.Series(index=history_index)
|
|
expected_series = pd.Series(index=expected)
|
|
|
|
pd.util.testing.assert_series_equal(history_series, expected_series)
|
|
|
|
|
|
class TestHistoryContainer(TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.env = TradingEnvironment.instance()
|
|
|
|
def bar_data_dt(self, bar_data, require_unique=True):
|
|
"""
|
|
Get a dt to associate with the given BarData object.
|
|
|
|
If require_unique == True, throw an error if multiple unique dt's are
|
|
encountered. Otherwise, return the earliest dt encountered.
|
|
"""
|
|
dts = {sid_data['dt'] for sid_data in bar_data.values()}
|
|
if require_unique and len(dts) > 1:
|
|
self.fail("Multiple unique dts ({0}) in {1}".format(dts, bar_data))
|
|
|
|
return sorted(dts)[0]
|
|
|
|
@parameterized.expand(
|
|
[(name,
|
|
case['specs'],
|
|
case['sids'],
|
|
case['dt'],
|
|
case['updates'],
|
|
case['expected'])
|
|
for name, case in HISTORY_CONTAINER_TEST_CASES.items()]
|
|
)
|
|
def test_history_container(self,
|
|
name,
|
|
specs,
|
|
sids,
|
|
dt,
|
|
updates,
|
|
expected):
|
|
|
|
for spec in specs:
|
|
# Sanity check on test input.
|
|
self.assertEqual(len(expected[spec.key_str]), len(updates))
|
|
|
|
container = HistoryContainer(
|
|
{spec.key_str: spec for spec in specs}, sids, dt
|
|
)
|
|
|
|
for update_count, update in enumerate(updates):
|
|
|
|
bar_dt = self.bar_data_dt(update)
|
|
container.update(update, bar_dt)
|
|
|
|
for spec in specs:
|
|
pd.util.testing.assert_frame_equal(
|
|
container.get_history(spec, bar_dt),
|
|
expected[spec.key_str][update_count],
|
|
check_dtype=False,
|
|
check_column_type=True,
|
|
check_index_type=True,
|
|
check_frame_type=True,
|
|
)
|
|
|
|
def test_container_nans_and_daily_roll(self):
|
|
|
|
spec = history.HistorySpec(
|
|
bar_count=3,
|
|
frequency='1d',
|
|
field='price',
|
|
ffill=True,
|
|
daily_at_midnight=False
|
|
)
|
|
specs = {spec.key_str: spec}
|
|
initial_sids = [1, ]
|
|
initial_dt = pd.Timestamp(
|
|
'2013-06-28 9:31AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
container = HistoryContainer(
|
|
specs, initial_sids, initial_dt)
|
|
|
|
bar_data = BarData()
|
|
container.update(bar_data, initial_dt)
|
|
# Since there was no backfill because of no db.
|
|
# And no first bar of data, so all values should be nans.
|
|
prices = container.get_history(spec, initial_dt)
|
|
nan_values = np.isnan(prices[1])
|
|
self.assertTrue(all(nan_values), nan_values)
|
|
|
|
# Add data on bar two of first day.
|
|
second_bar_dt = pd.Timestamp(
|
|
'2013-06-28 9:32AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
bar_data[1] = {
|
|
'price': 10,
|
|
'dt': second_bar_dt
|
|
}
|
|
container.update(bar_data, second_bar_dt)
|
|
|
|
prices = container.get_history(spec, second_bar_dt)
|
|
# Prices should be
|
|
# 1
|
|
# 2013-06-26 20:00:00+00:00 NaN
|
|
# 2013-06-27 20:00:00+00:00 NaN
|
|
# 2013-06-28 13:32:00+00:00 10
|
|
|
|
self.assertTrue(np.isnan(prices[1].ix[0]))
|
|
self.assertTrue(np.isnan(prices[1].ix[1]))
|
|
self.assertEqual(prices[1].ix[2], 10)
|
|
|
|
third_bar_dt = pd.Timestamp(
|
|
'2013-06-28 9:33AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
del bar_data[1]
|
|
|
|
container.update(bar_data, third_bar_dt)
|
|
|
|
prices = container.get_history(spec, third_bar_dt)
|
|
# The one should be forward filled
|
|
|
|
# Prices should be
|
|
# 1
|
|
# 2013-06-26 20:00:00+00:00 NaN
|
|
# 2013-06-27 20:00:00+00:00 NaN
|
|
# 2013-06-28 13:33:00+00:00 10
|
|
|
|
self.assertEquals(prices[1][third_bar_dt], 10)
|
|
|
|
# Note that we did not fill in data at the close.
|
|
# There was a bug where a nan was being introduced because of the
|
|
# last value of 'raw' data was used, instead of a ffilled close price.
|
|
|
|
day_two_first_bar_dt = pd.Timestamp(
|
|
'2013-07-01 9:31AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
bar_data[1] = {
|
|
'price': 20,
|
|
'dt': day_two_first_bar_dt
|
|
}
|
|
|
|
container.update(bar_data, day_two_first_bar_dt)
|
|
|
|
prices = container.get_history(spec, day_two_first_bar_dt)
|
|
|
|
# Prices Should Be
|
|
|
|
# 1
|
|
# 2013-06-27 20:00:00+00:00 nan
|
|
# 2013-06-28 20:00:00+00:00 10
|
|
# 2013-07-01 13:31:00+00:00 20
|
|
|
|
self.assertTrue(np.isnan(prices[1].ix[0]))
|
|
self.assertEqual(prices[1].ix[1], 10)
|
|
self.assertEqual(prices[1].ix[2], 20)
|
|
|
|
# Clear out the bar data
|
|
|
|
del bar_data[1]
|
|
|
|
day_three_first_bar_dt = pd.Timestamp(
|
|
'2013-07-02 9:31AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
container.update(bar_data, day_three_first_bar_dt)
|
|
|
|
prices = container.get_history(spec, day_three_first_bar_dt)
|
|
|
|
# 1
|
|
# 2013-06-28 20:00:00+00:00 10
|
|
# 2013-07-01 20:00:00+00:00 20
|
|
# 2013-07-02 13:31:00+00:00 20
|
|
|
|
self.assertTrue(prices[1].ix[0], 10)
|
|
self.assertTrue(prices[1].ix[1], 20)
|
|
self.assertTrue(prices[1].ix[2], 20)
|
|
|
|
day_four_first_bar_dt = pd.Timestamp(
|
|
'2013-07-03 9:31AM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
container.update(bar_data, day_four_first_bar_dt)
|
|
|
|
prices = container.get_history(spec, day_four_first_bar_dt)
|
|
|
|
# 1
|
|
# 2013-07-01 20:00:00+00:00 20
|
|
# 2013-07-02 20:00:00+00:00 20
|
|
# 2013-07-03 13:31:00+00:00 20
|
|
|
|
self.assertEqual(prices[1].ix[0], 20)
|
|
self.assertEqual(prices[1].ix[1], 20)
|
|
self.assertEqual(prices[1].ix[2], 20)
|
|
|
|
|
|
class TestHistoryAlgo(TestCase):
|
|
def setUp(self):
|
|
np.random.seed(123)
|
|
|
|
def test_history_daily(self):
|
|
bar_count = 3
|
|
algo_text = """
|
|
from zipline.api import history, add_history
|
|
from copy import deepcopy
|
|
|
|
def initialize(context):
|
|
add_history(bar_count={bar_count}, frequency='1d', field='price')
|
|
context.history_trace = []
|
|
|
|
def handle_data(context, data):
|
|
prices = history(bar_count={bar_count}, frequency='1d', field='price')
|
|
context.history_trace.append(deepcopy(prices))
|
|
""".format(bar_count=bar_count).strip()
|
|
|
|
# March 2006
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4
|
|
# 5 6 7 8 9 10 11
|
|
# 12 13 14 15 16 17 18
|
|
# 19 20 21 22 23 24 25
|
|
# 26 27 28 29 30 31
|
|
|
|
start = pd.Timestamp('2006-03-20', tz='UTC')
|
|
end = pd.Timestamp('2006-03-30', tz='UTC')
|
|
|
|
sim_params = factory.create_simulation_parameters(
|
|
start=start, end=end)
|
|
|
|
_, df = factory.create_test_df_source(sim_params)
|
|
df = df.astype(np.float64)
|
|
source = DataFrameSource(df, sids=[0])
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='daily',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
output = test_algo.run(source)
|
|
self.assertIsNotNone(output)
|
|
|
|
history_trace = test_algo.history_trace
|
|
|
|
for i, received in enumerate(history_trace[bar_count - 1:]):
|
|
expected = df.iloc[i:i + bar_count]
|
|
assert_frame_equal(expected, received)
|
|
|
|
def test_basic_history(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history
|
|
|
|
def initialize(context):
|
|
add_history(bar_count=2, frequency='1d', field='price')
|
|
|
|
def handle_data(context, data):
|
|
prices = history(bar_count=2, frequency='1d', field='price')
|
|
prices['prices_times_two'] = prices[1] * 2
|
|
context.last_prices = prices
|
|
""".strip()
|
|
|
|
# March 2006
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4
|
|
# 5 6 7 8 9 10 11
|
|
# 12 13 14 15 16 17 18
|
|
# 19 20 21 22 23 24 25
|
|
# 26 27 28 29 30 31
|
|
start = pd.Timestamp('2006-03-20', tz='UTC')
|
|
end = pd.Timestamp('2006-03-21', tz='UTC')
|
|
|
|
sim_params = factory.create_simulation_parameters(
|
|
start=start, end=end)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
self.assertIsNotNone(output)
|
|
|
|
last_prices = test_algo.last_prices[0]
|
|
oldest_dt = pd.Timestamp(
|
|
'2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
newest_dt = pd.Timestamp(
|
|
'2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
self.assertEquals(oldest_dt, last_prices.index[0])
|
|
self.assertEquals(newest_dt, last_prices.index[-1])
|
|
|
|
# Random, depends on seed
|
|
self.assertEquals(139.36946942498648, last_prices[oldest_dt])
|
|
self.assertEquals(180.15661995395106, last_prices[newest_dt])
|
|
|
|
def test_basic_history_one_day(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history
|
|
|
|
def initialize(context):
|
|
add_history(bar_count=1, frequency='1d', field='price')
|
|
|
|
def handle_data(context, data):
|
|
prices = history(bar_count=1, frequency='1d', field='price')
|
|
context.last_prices = prices
|
|
""".strip()
|
|
|
|
# March 2006
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4
|
|
# 5 6 7 8 9 10 11
|
|
# 12 13 14 15 16 17 18
|
|
# 19 20 21 22 23 24 25
|
|
# 26 27 28 29 30 31
|
|
|
|
start = pd.Timestamp('2006-03-20', tz='UTC')
|
|
end = pd.Timestamp('2006-03-21', tz='UTC')
|
|
|
|
sim_params = factory.create_simulation_parameters(
|
|
start=start, end=end)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
self.assertIsNotNone(output)
|
|
|
|
last_prices = test_algo.last_prices[0]
|
|
# oldest and newest should be the same if there is only 1 bar
|
|
oldest_dt = pd.Timestamp(
|
|
'2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
newest_dt = pd.Timestamp(
|
|
'2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
self.assertEquals(oldest_dt, last_prices.index[0])
|
|
self.assertEquals(newest_dt, last_prices.index[-1])
|
|
|
|
# Random, depends on seed
|
|
self.assertEquals(180.15661995395106, last_prices[oldest_dt])
|
|
self.assertEquals(180.15661995395106, last_prices[newest_dt])
|
|
|
|
def test_basic_history_positional_args(self):
|
|
"""
|
|
Ensure that positional args work.
|
|
"""
|
|
algo_text = """
|
|
import copy
|
|
from zipline.api import history, add_history
|
|
|
|
def initialize(context):
|
|
add_history(2, '1d', 'price')
|
|
|
|
def handle_data(context, data):
|
|
|
|
prices = history(2, '1d', 'price')
|
|
context.last_prices = copy.deepcopy(prices)
|
|
""".strip()
|
|
|
|
# March 2006
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4
|
|
# 5 6 7 8 9 10 11
|
|
# 12 13 14 15 16 17 18
|
|
# 19 20 21 22 23 24 25
|
|
# 26 27 28 29 30 31
|
|
|
|
start = pd.Timestamp('2006-03-20', tz='UTC')
|
|
end = pd.Timestamp('2006-03-21', tz='UTC')
|
|
|
|
sim_params = factory.create_simulation_parameters(
|
|
start=start, end=end)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
self.assertIsNotNone(output)
|
|
|
|
last_prices = test_algo.last_prices[0]
|
|
oldest_dt = pd.Timestamp(
|
|
'2006-03-20 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
newest_dt = pd.Timestamp(
|
|
'2006-03-21 4:00 PM', tz='US/Eastern').tz_convert('UTC')
|
|
|
|
self.assertEquals(oldest_dt, last_prices.index[0])
|
|
self.assertEquals(newest_dt, last_prices.index[-1])
|
|
|
|
self.assertEquals(139.36946942498648, last_prices[oldest_dt])
|
|
self.assertEquals(180.15661995395106, last_prices[newest_dt])
|
|
|
|
def test_history_with_volume(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history, record
|
|
|
|
def initialize(context):
|
|
add_history(3, '1d', 'volume')
|
|
|
|
def handle_data(context, data):
|
|
volume = history(3, '1d', 'volume')
|
|
|
|
record(current_volume=volume[0].ix[-1])
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
start = pd.Timestamp('2007-04-10', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='minute'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
np.testing.assert_equal(output.ix[0, 'current_volume'],
|
|
212218404.0)
|
|
|
|
def test_history_with_high(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history, record
|
|
|
|
def initialize(context):
|
|
add_history(3, '1d', 'high')
|
|
|
|
def handle_data(context, data):
|
|
highs = history(3, '1d', 'high')
|
|
|
|
record(current_high=highs[0].ix[-1])
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
start = pd.Timestamp('2007-04-10', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='minute'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
np.testing.assert_equal(output.ix[0, 'current_high'],
|
|
139.5370641791925)
|
|
|
|
def test_history_with_low(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history, record
|
|
|
|
def initialize(context):
|
|
add_history(3, '1d', 'low')
|
|
|
|
def handle_data(context, data):
|
|
lows = history(3, '1d', 'low')
|
|
|
|
record(current_low=lows[0].ix[-1])
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
start = pd.Timestamp('2007-04-10', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='minute'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
np.testing.assert_equal(output.ix[0, 'current_low'],
|
|
99.891436939669944)
|
|
|
|
def test_history_with_open(self):
|
|
algo_text = """
|
|
from zipline.api import history, add_history, record
|
|
|
|
def initialize(context):
|
|
add_history(3, '1d', 'open_price')
|
|
|
|
def handle_data(context, data):
|
|
opens = history(3, '1d', 'open_price')
|
|
|
|
record(current_open=opens[0].ix[-1])
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
start = pd.Timestamp('2007-04-10', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='minute'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
np.testing.assert_equal(output.ix[0, 'current_open'],
|
|
99.991436939669939)
|
|
|
|
def test_history_passed_to_func(self):
|
|
"""
|
|
Had an issue where MagicMock was causing errors during validation
|
|
with rolling mean.
|
|
"""
|
|
algo_text = """
|
|
from zipline.api import history, add_history
|
|
import pandas as pd
|
|
|
|
def initialize(context):
|
|
add_history(2, '1d', 'price')
|
|
|
|
def handle_data(context, data):
|
|
prices = history(2, '1d', 'price')
|
|
|
|
pd.rolling_mean(prices, 2)
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
start = pd.Timestamp('2007-04-10', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='minute'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
|
|
# At this point, just ensure that there is no crash.
|
|
self.assertIsNotNone(output)
|
|
|
|
def test_history_passed_to_talib(self):
|
|
"""
|
|
Had an issue where MagicMock was causing errors during validation
|
|
with talib.
|
|
|
|
We don't officially support a talib integration, yet.
|
|
But using talib directly should work.
|
|
"""
|
|
algo_text = """
|
|
import talib
|
|
import numpy as np
|
|
|
|
from zipline.api import history, add_history, record
|
|
|
|
def initialize(context):
|
|
add_history(2, '1d', 'price')
|
|
|
|
def handle_data(context, data):
|
|
prices = history(2, '1d', 'price')
|
|
|
|
ma_result = talib.MA(np.asarray(prices[0]), timeperiod=2)
|
|
record(ma=ma_result[-1])
|
|
""".strip()
|
|
|
|
# April 2007
|
|
# Su Mo Tu We Th Fr Sa
|
|
# 1 2 3 4 5 6 7
|
|
# 8 9 10 11 12 13 14
|
|
# 15 16 17 18 19 20 21
|
|
# 22 23 24 25 26 27 28
|
|
# 29 30
|
|
|
|
# Eddie: this was set to 04-10 but I don't see how that makes
|
|
# sense as it does not generate enough data to get at -2 index
|
|
# below.
|
|
start = pd.Timestamp('2007-04-05', tz='UTC')
|
|
end = pd.Timestamp('2007-04-10', tz='UTC')
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=start,
|
|
period_end=end,
|
|
capital_base=float("1.0e5"),
|
|
data_frequency='minute',
|
|
emission_rate='daily'
|
|
)
|
|
|
|
test_algo = TradingAlgorithm(
|
|
script=algo_text,
|
|
data_frequency='minute',
|
|
sim_params=sim_params
|
|
)
|
|
|
|
source = RandomWalkSource(start=start,
|
|
end=end)
|
|
output = test_algo.run(source)
|
|
# At this point, just ensure that there is no crash.
|
|
self.assertIsNotNone(output)
|
|
|
|
recorded_ma = output.ix[-2, 'ma']
|
|
|
|
self.assertFalse(pd.isnull(recorded_ma))
|
|
# Depends on seed
|
|
np.testing.assert_almost_equal(recorded_ma,
|
|
159.76304468946876)
|