mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 03:33:01 +08:00
16fd6681a6
More documentation to follow in release notes. Based on lazy-mainline branch, see for more details. Also-By: Jean Bredeche <jean@quantopian.com> Also-By: Andrew Liang <aliang@quantopian.com> Also-By: Abhijeet Kalyan <akalyan@quantopian.com>
580 lines
20 KiB
Python
580 lines
20 KiB
Python
import warnings
|
|
from unittest import TestCase
|
|
from mock import patch
|
|
import pandas as pd
|
|
import numpy as np
|
|
from testfixtures import TempDirectory
|
|
|
|
from zipline import TradingAlgorithm
|
|
from zipline.data.data_portal import DataPortal
|
|
from zipline.data.minute_bars import BcolzMinuteBarWriter, \
|
|
US_EQUITIES_MINUTES_PER_DAY, BcolzMinuteBarReader
|
|
from zipline.data.us_equity_pricing import BcolzDailyBarReader, \
|
|
SQLiteAdjustmentReader, SQLiteAdjustmentWriter
|
|
from zipline.finance.trading import TradingEnvironment, SimulationParameters
|
|
from zipline.protocol import BarData
|
|
from zipline.testing.core import write_minute_data_for_asset, \
|
|
create_daily_df_for_asset, DailyBarWriterFromDataFrames, MockDailyBarReader
|
|
from zipline.testing import str_to_seconds
|
|
from zipline.zipline_warnings import ZiplineDeprecationWarning
|
|
|
|
simple_algo = """
|
|
from zipline.api import sid, order
|
|
def initialize(context):
|
|
pass
|
|
|
|
def handle_data(context, data):
|
|
assert sid(1) in data
|
|
assert sid(2) in data
|
|
assert len(data) == 3
|
|
for asset in data:
|
|
pass
|
|
"""
|
|
|
|
history_algo = """
|
|
from zipline.api import sid, history
|
|
|
|
def initialize(context):
|
|
context.sid1 = sid(1)
|
|
|
|
def handle_data(context, data):
|
|
context.history_window = history(5, "1m", "volume")
|
|
"""
|
|
|
|
history_bts_algo = """
|
|
from zipline.api import sid, history, record
|
|
|
|
def initialize(context):
|
|
context.sid3 = sid(3)
|
|
context.num_bts = 0
|
|
|
|
def before_trading_start(context, data):
|
|
context.num_bts += 1
|
|
|
|
# Get history at the second BTS (beginning of second day)
|
|
if context.num_bts == 2:
|
|
record(history=history(5, "1m", "volume"))
|
|
|
|
def handle_data(context, data):
|
|
pass
|
|
"""
|
|
|
|
simple_transforms_algo = """
|
|
from zipline.api import sid
|
|
def initialize(context):
|
|
context.count = 0
|
|
|
|
def handle_data(context, data):
|
|
if context.count == 2:
|
|
context.mavg = data[sid(1)].mavg(5)
|
|
context.vwap = data[sid(1)].vwap(5)
|
|
context.stddev = data[sid(1)].stddev(5)
|
|
context.returns = data[sid(1)].returns()
|
|
|
|
context.count += 1
|
|
"""
|
|
|
|
manipulation_algo = """
|
|
def initialize(context):
|
|
context.asset1 = sid(1)
|
|
context.asset2 = sid(2)
|
|
|
|
def handle_data(context, data):
|
|
assert len(data) == 2
|
|
assert len(data.keys()) == 2
|
|
assert context.asset1 in data.keys()
|
|
assert context.asset2 in data.keys()
|
|
"""
|
|
|
|
sid_accessor_algo = """
|
|
from zipline.api import sid
|
|
|
|
def initialize(context):
|
|
context.asset1 = sid(1)
|
|
|
|
def handle_data(context,data):
|
|
assert data[sid(1)].sid == context.asset1
|
|
assert data[sid(1)]["sid"] == context.asset1
|
|
"""
|
|
|
|
data_items_algo = """
|
|
from zipline.api import sid
|
|
|
|
def initialize(context):
|
|
context.asset1 = sid(1)
|
|
context.asset2 = sid(2)
|
|
|
|
def handle_data(context, data):
|
|
iter_list = list(data.iteritems())
|
|
items_list = data.items()
|
|
assert iter_list == items_list
|
|
"""
|
|
|
|
|
|
class TestAPIShim(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.env = TradingEnvironment()
|
|
cls.tempdir = TempDirectory()
|
|
|
|
cls.trading_days = cls.env.days_in_range(
|
|
start=pd.Timestamp("2016-01-05", tz='UTC'),
|
|
end=pd.Timestamp("2016-01-28", tz='UTC')
|
|
)
|
|
|
|
equities_data = {}
|
|
for sid in [1, 2, 3]:
|
|
equities_data[sid] = {
|
|
"start_date": cls.trading_days[0],
|
|
"end_date": cls.env.next_trading_day(cls.trading_days[-1]),
|
|
"symbol": "ASSET{0}".format(sid),
|
|
}
|
|
|
|
cls.env.write_data(equities_data=equities_data)
|
|
|
|
cls.asset1 = cls.env.asset_finder.retrieve_asset(1)
|
|
cls.asset2 = cls.env.asset_finder.retrieve_asset(2)
|
|
cls.asset3 = cls.env.asset_finder.retrieve_asset(3)
|
|
|
|
market_opens = cls.env.open_and_closes.market_open.loc[
|
|
cls.trading_days]
|
|
market_closes = cls.env.open_and_closes.market_close.loc[
|
|
cls.trading_days]
|
|
|
|
minute_writer = BcolzMinuteBarWriter(
|
|
cls.trading_days[0],
|
|
cls.tempdir.path,
|
|
market_opens,
|
|
market_closes,
|
|
US_EQUITIES_MINUTES_PER_DAY
|
|
)
|
|
|
|
for sid in [1, 2, 3]:
|
|
write_minute_data_for_asset(
|
|
cls.env, minute_writer, cls.trading_days[0],
|
|
cls.trading_days[-1], sid
|
|
)
|
|
|
|
cls.adj_reader = cls.create_adjustments_reader()
|
|
|
|
cls.sim_params = SimulationParameters(
|
|
period_start=cls.trading_days[0],
|
|
period_end=cls.trading_days[-1],
|
|
data_frequency="minute",
|
|
env=cls.env
|
|
)
|
|
|
|
@classmethod
|
|
def build_daily_data(cls):
|
|
path = cls.tempdir.getpath("testdaily.bcolz")
|
|
|
|
dfs = {
|
|
1: create_daily_df_for_asset(cls.env, cls.trading_days[0],
|
|
cls.trading_days[-1]),
|
|
2: create_daily_df_for_asset(cls.env, cls.trading_days[0],
|
|
cls.trading_days[-1]),
|
|
3: create_daily_df_for_asset(cls.env, cls.trading_days[0],
|
|
cls.trading_days[-1])
|
|
}
|
|
|
|
daily_writer = DailyBarWriterFromDataFrames(dfs)
|
|
daily_writer.write(path, cls.trading_days, dfs)
|
|
|
|
return BcolzDailyBarReader(path)
|
|
|
|
@classmethod
|
|
def create_adjustments_reader(cls):
|
|
path = cls.tempdir.getpath("test_adjustments.db")
|
|
|
|
adj_writer = SQLiteAdjustmentWriter(
|
|
path,
|
|
cls.env.trading_days,
|
|
MockDailyBarReader()
|
|
)
|
|
|
|
splits = pd.DataFrame([
|
|
{
|
|
'effective_date': str_to_seconds("2016-01-06"),
|
|
'ratio': 0.5,
|
|
'sid': cls.asset3.sid
|
|
}
|
|
])
|
|
|
|
# Mergers and Dividends are not tested, but we need to have these
|
|
# anyway
|
|
mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid'])
|
|
mergers.effective_date = mergers.effective_date.astype(int)
|
|
mergers.ratio = mergers.ratio.astype(float)
|
|
mergers.sid = mergers.sid.astype(int)
|
|
|
|
dividends = pd.DataFrame({}, columns=['ex_date', 'record_date',
|
|
'declared_date', 'pay_date',
|
|
'amount', 'sid'])
|
|
dividends.amount = dividends.amount.astype(float)
|
|
dividends.sid = dividends.sid.astype(int)
|
|
|
|
adj_writer.write(splits, mergers, dividends)
|
|
|
|
return SQLiteAdjustmentReader(path)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.tempdir.cleanup()
|
|
|
|
def setUp(self):
|
|
self.data_portal = DataPortal(
|
|
self.env,
|
|
equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path),
|
|
equity_daily_reader=self.build_daily_data(),
|
|
adjustment_reader=self.adj_reader
|
|
)
|
|
|
|
@classmethod
|
|
def create_algo(cls, code, filename=None, sim_params=None):
|
|
if sim_params is None:
|
|
sim_params = cls.sim_params
|
|
|
|
return TradingAlgorithm(
|
|
script=code,
|
|
sim_params=sim_params,
|
|
env=cls.env,
|
|
algo_filename=filename
|
|
)
|
|
|
|
def test_old_new_data_api_paths(self):
|
|
"""
|
|
Test that the new and old data APIs hit the same code paths.
|
|
|
|
We want to ensure that the old data API(data[sid(N)].field and
|
|
similar) and the new data API(data.current(sid(N), field) and
|
|
similar) hit the same code paths on the DataPortal.
|
|
"""
|
|
test_start_minute = self.env.market_minutes_for_day(
|
|
self.trading_days[0]
|
|
)[1]
|
|
test_end_minute = self.env.market_minutes_for_day(
|
|
self.trading_days[0]
|
|
)[-1]
|
|
bar_data = BarData(
|
|
self.data_portal,
|
|
lambda: test_end_minute, "minute"
|
|
)
|
|
ohlcvp_fields = [
|
|
"open",
|
|
"high",
|
|
"low"
|
|
"close",
|
|
"volume",
|
|
"price",
|
|
]
|
|
spot_value_meth = 'zipline.data.data_portal.DataPortal.get_spot_value'
|
|
|
|
def assert_get_spot_value_called(fun, field):
|
|
"""
|
|
Assert that get_spot_value was called during the execution of fun.
|
|
|
|
Takes in a function fun and a string field.
|
|
"""
|
|
with patch(spot_value_meth) as gsv:
|
|
fun()
|
|
gsv.assert_called_with(
|
|
self.asset1,
|
|
field,
|
|
test_end_minute,
|
|
'minute'
|
|
)
|
|
# Ensure that data.current(sid(n), field) has the same behaviour as
|
|
# data[sid(n)].field.
|
|
for field in ohlcvp_fields:
|
|
assert_get_spot_value_called(
|
|
lambda: getattr(bar_data[self.asset1], field),
|
|
field,
|
|
)
|
|
assert_get_spot_value_called(
|
|
lambda: bar_data.current(self.asset1, field),
|
|
field,
|
|
)
|
|
|
|
history_meth = 'zipline.data.data_portal.DataPortal.get_history_window'
|
|
|
|
def assert_get_history_window_called(fun, is_legacy):
|
|
"""
|
|
Assert that get_history_window was called during fun().
|
|
|
|
Takes in a function fun and a boolean is_legacy.
|
|
"""
|
|
with patch(history_meth) as ghw:
|
|
fun()
|
|
# Slightly hacky, but done to get around the fact that
|
|
# history( explicitly passes an ffill param as the last arg,
|
|
# while data.history doesn't.
|
|
if is_legacy:
|
|
ghw.assert_called_with(
|
|
[self.asset1, self.asset2, self.asset3],
|
|
test_end_minute,
|
|
5,
|
|
"1m",
|
|
"volume",
|
|
True
|
|
)
|
|
else:
|
|
ghw.assert_called_with(
|
|
[self.asset1, self.asset2, self.asset3],
|
|
test_end_minute,
|
|
5,
|
|
"1m",
|
|
"volume",
|
|
)
|
|
|
|
test_sim_params = SimulationParameters(
|
|
period_start=test_start_minute,
|
|
period_end=test_end_minute,
|
|
data_frequency="minute",
|
|
env=self.env
|
|
)
|
|
|
|
history_algorithm = self.create_algo(
|
|
history_algo,
|
|
sim_params=test_sim_params
|
|
)
|
|
assert_get_history_window_called(
|
|
lambda: history_algorithm.run(self.data_portal),
|
|
is_legacy=True
|
|
)
|
|
assert_get_history_window_called(
|
|
lambda: bar_data.history(
|
|
[self.asset1, self.asset2, self.asset3],
|
|
"volume",
|
|
5,
|
|
"1m"
|
|
),
|
|
is_legacy=False
|
|
)
|
|
|
|
def test_sid_accessor(self):
|
|
"""
|
|
Test that we maintain backwards compat for sid access on a data object.
|
|
|
|
We want to support both data[sid(24)].sid, as well as
|
|
data[sid(24)]["sid"]. Since these are deprecated and will eventually
|
|
cease to be supported, we also want to assert that we're seeing a
|
|
deprecation warning.
|
|
"""
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
algo = self.create_algo(sid_accessor_algo)
|
|
algo.run(self.data_portal)
|
|
|
|
# Since we're already raising a warning on doing data[sid(x)],
|
|
# we don't want to raise an extra warning on data[sid(x)].sid.
|
|
self.assertEqual(2, len(w))
|
|
|
|
# Check that both the warnings raised were in fact
|
|
# ZiplineDeprecationWarnings
|
|
for warning in w:
|
|
self.assertEqual(
|
|
ZiplineDeprecationWarning,
|
|
warning.category
|
|
)
|
|
self.assertEqual(
|
|
"`data[sid(N)]` is deprecated. Use `data.current`.",
|
|
str(warning.message)
|
|
)
|
|
|
|
def test_data_items(self):
|
|
"""
|
|
Test that we maintain backwards compat for data.[items | iteritems].
|
|
|
|
We also want to assert that we warn that iterating over the assets
|
|
in `data` is deprecated.
|
|
"""
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
algo = self.create_algo(data_items_algo)
|
|
algo.run(self.data_portal)
|
|
|
|
self.assertEqual(4, len(w))
|
|
|
|
for idx, warning in enumerate(w):
|
|
self.assertEqual(
|
|
ZiplineDeprecationWarning,
|
|
warning.category
|
|
)
|
|
if idx % 2 == 0:
|
|
self.assertEqual(
|
|
"Iterating over the assets in `data` is deprecated.",
|
|
str(warning.message)
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
"`data[sid(N)]` is deprecated. Use `data.current`.",
|
|
str(warning.message)
|
|
)
|
|
|
|
def test_iterate_data(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
|
|
algo = self.create_algo(simple_algo)
|
|
algo.run(self.data_portal)
|
|
|
|
self.assertEqual(4, len(w))
|
|
|
|
line_nos = [warning.lineno for warning in w]
|
|
self.assertEqual(4, len(set(line_nos)))
|
|
|
|
for idx, warning in enumerate(w):
|
|
self.assertEqual(ZiplineDeprecationWarning,
|
|
warning.category)
|
|
|
|
self.assertEqual("<string>", warning.filename)
|
|
self.assertEqual(line_nos[idx], warning.lineno)
|
|
|
|
if idx < 2:
|
|
self.assertEqual(
|
|
"Checking whether an asset is in data is deprecated.",
|
|
str(warning.message)
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
"Iterating over the assets in `data` is deprecated.",
|
|
str(warning.message)
|
|
)
|
|
|
|
def test_history(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=self.trading_days[1],
|
|
period_end=self.sim_params.period_end,
|
|
capital_base=self.sim_params.capital_base,
|
|
data_frequency=self.sim_params.data_frequency,
|
|
emission_rate=self.sim_params.emission_rate,
|
|
env=self.env,
|
|
)
|
|
|
|
algo = self.create_algo(history_algo,
|
|
sim_params=sim_params)
|
|
algo.run(self.data_portal)
|
|
|
|
self.assertEqual(1, len(w))
|
|
self.assertEqual(ZiplineDeprecationWarning, w[0].category)
|
|
self.assertEqual("<string>", w[0].filename)
|
|
self.assertEqual(8, w[0].lineno)
|
|
self.assertEqual("The `history` method is deprecated. Use "
|
|
"`data.history` instead.", str(w[0].message))
|
|
|
|
def test_old_new_history_bts_paths(self):
|
|
"""
|
|
Tests that calling history in before_trading_start gets us the correct
|
|
values, which involves 1) calling data_portal.get_history_window as of
|
|
the previous market minute, 2) getting adjustments between the previous
|
|
market minute and the current time, and 3) applying those adjustments
|
|
"""
|
|
algo = self.create_algo(history_bts_algo)
|
|
algo.run(self.data_portal)
|
|
|
|
expected_vol_without_split = np.arange(386, 391) * 100
|
|
expected_vol_with_split = np.arange(386, 391) * 200
|
|
|
|
window = algo.recorded_vars['history']
|
|
np.testing.assert_array_equal(window[self.asset1].values,
|
|
expected_vol_without_split)
|
|
np.testing.assert_array_equal(window[self.asset2].values,
|
|
expected_vol_without_split)
|
|
np.testing.assert_array_equal(window[self.asset3].values,
|
|
expected_vol_with_split)
|
|
|
|
def test_simple_transforms(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
|
|
sim_params = SimulationParameters(
|
|
period_start=self.trading_days[8],
|
|
period_end=self.trading_days[-1],
|
|
data_frequency="minute",
|
|
env=self.env
|
|
)
|
|
|
|
algo = self.create_algo(simple_transforms_algo,
|
|
sim_params=sim_params)
|
|
algo.run(self.data_portal)
|
|
|
|
self.assertEqual(8, len(w))
|
|
transforms = ["mavg", "vwap", "stddev", "returns"]
|
|
|
|
for idx, line_no in enumerate(range(8, 12)):
|
|
warning1 = w[idx * 2]
|
|
warning2 = w[(idx * 2) + 1]
|
|
|
|
self.assertEqual("<string>", warning1.filename)
|
|
self.assertEqual("<string>", warning2.filename)
|
|
|
|
self.assertEqual(line_no, warning1.lineno)
|
|
self.assertEqual(line_no, warning2.lineno)
|
|
|
|
self.assertEqual("`data[sid(N)]` is deprecated. Use "
|
|
"`data.current`.",
|
|
str(warning1.message))
|
|
self.assertEqual("The `{0}` method is "
|
|
"deprecated.".format(transforms[idx]),
|
|
str(warning2.message))
|
|
|
|
# now verify the transform values
|
|
# minute price
|
|
# 2016-01-11 14:31:00+00:00 1561
|
|
# ...
|
|
# 2016-01-14 20:59:00+00:00 3119
|
|
# 2016-01-14 21:00:00+00:00 3120
|
|
# 2016-01-15 14:31:00+00:00 3121
|
|
# 2016-01-15 14:32:00+00:00 3122
|
|
# 2016-01-15 14:33:00+00:00 3123
|
|
|
|
# volume
|
|
# 2016-01-11 14:31:00+00:00 156100
|
|
# ...
|
|
# 2016-01-14 20:59:00+00:00 311900
|
|
# 2016-01-14 21:00:00+00:00 312000
|
|
# 2016-01-15 14:31:00+00:00 312100
|
|
# 2016-01-15 14:32:00+00:00 312200
|
|
# 2016-01-15 14:33:00+00:00 312300
|
|
|
|
# daily price (last day built with minute data)
|
|
# 2016-01-14 00:00:00+00:00 9
|
|
# 2016-01-15 00:00:00+00:00 3123
|
|
|
|
# mavg = average of all the prices = (1561 + 3123) / 2 = 2342
|
|
# vwap = sum(price * volume) / sum(volumes)
|
|
# = 889119531400.0 / 366054600.0
|
|
# = 2428.9259891830343
|
|
# stddev = stddev(price, ddof=1) = 451.3435498597493
|
|
# returns = (todayprice - yesterdayprice) / yesterdayprice
|
|
# = (3123 - 9) / 9 = 346
|
|
self.assertEqual(2342, algo.mavg)
|
|
self.assertAlmostEqual(2428.92599, algo.vwap, places=5)
|
|
self.assertAlmostEqual(451.34355, algo.stddev, places=5)
|
|
self.assertAlmostEqual(346, algo.returns)
|
|
|
|
def test_manipulation(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
|
|
|
algo = self.create_algo(simple_algo)
|
|
algo.run(self.data_portal)
|
|
|
|
self.assertEqual(4, len(w))
|
|
|
|
for idx, warning in enumerate(w):
|
|
self.assertEqual("<string>", warning.filename)
|
|
self.assertEqual(7 + idx, warning.lineno)
|
|
|
|
if idx < 2:
|
|
self.assertEqual("Checking whether an asset is in data is "
|
|
"deprecated.",
|
|
str(warning.message))
|
|
else:
|
|
self.assertEqual("Iterating over the assets in `data` is "
|
|
"deprecated.",
|
|
str(warning.message))
|