Merge pull request #1105 from quantopian/q2

ENH: Rewrite of Zipline to use lazy access pattern
This commit is contained in:
Jean Bredeche
2016-04-04 16:13:14 -04:00
123 changed files with 15208 additions and 325116 deletions
+1
View File
@@ -70,3 +70,4 @@ after_success:
branches:
only:
- master
- lazy-mainline
+8
View File
@@ -0,0 +1,8 @@
"%PYTHON%" setup.py install
if errorlevel 1 exit 1
:: Add more build steps here, if they are necessary.
:: See
:: http://docs.continuum.io/conda/build.html
:: for a list of environment variables that are set during the build process.
+9
View File
@@ -0,0 +1,9 @@
#!/bin/bash
$PYTHON setup.py install
# Add more build steps here, if they are necessary.
# See
# http://docs.continuum.io/conda/build.html
# for a list of environment variables that are set during the build process.
+62
View File
@@ -0,0 +1,62 @@
package:
name: sortedcontainers
version: "1.4.4"
source:
fn: sortedcontainers-1.4.4.tar.gz
url: https://pypi.python.org/packages/source/s/sortedcontainers/sortedcontainers-1.4.4.tar.gz
md5: 456638aea9f8f705bd2e3c891c402023
# patches:
# List any patch files here
# - fix.patch
# build:
# noarch_python: True
# preserve_egg_dir: True
# entry_points:
# Put any entry points (scripts to be generated automatically) here. The
# syntax is module:function. For example
#
# - sortedcontainers = sortedcontainers:main
#
# Would create an entry point called sortedcontainers that calls sortedcontainers.main()
# If this is a new build for the same version, increment the build
# number. If you do not include this key, it defaults to 0.
# number: 1
requirements:
build:
- python
- setuptools
run:
- python
test:
# Python imports
imports:
- sortedcontainers
# commands:
# You can put test commands to be run here. Use this to test that the
# entry points work.
# You can also put a file called run_test.py in the recipe that will be run
# at test time.
requires:
# - tox
# Put any additional test requirements here. For example
# - nose
about:
home: http://www.grantjenks.com/docs/sortedcontainers/
license: Apache Software License
summary: 'Python Sorted Container Types: SortedList, SortedDict, and SortedSet'
# See
# http://docs.continuum.io/conda/build.html for
# more information about meta.yaml
+8
View File
@@ -0,0 +1,8 @@
"%PYTHON%" setup.py install
if errorlevel 1 exit 1
:: Add more build steps here, if they are necessary.
:: See
:: http://docs.continuum.io/conda/build.html
:: for a list of environment variables that are set during the build process.
+9
View File
@@ -0,0 +1,9 @@
#!/bin/bash
$PYTHON setup.py install
# Add more build steps here, if they are necessary.
# See
# http://docs.continuum.io/conda/build.html
# for a list of environment variables that are set during the build process.
+64
View File
@@ -0,0 +1,64 @@
package:
name: intervaltree
version: "2.1.0"
source:
fn: intervaltree-2.1.0.tar.gz
url: https://pypi.python.org/packages/source/i/intervaltree/intervaltree-2.1.0.tar.gz
md5: 33bef3448aaf30b78aa093dc7c315c2c
# patches:
# List any patch files here
# - fix.patch
# build:
# noarch_python: True
# preserve_egg_dir: True
# entry_points:
# Put any entry points (scripts to be generated automatically) here. The
# syntax is module:function. For example
#
# - intervaltree = intervaltree:main
#
# Would create an entry point called intervaltree that calls intervaltree.main()
# If this is a new build for the same version, increment the build
# number. If you do not include this key, it defaults to 0.
# number: 1
requirements:
build:
- python
- setuptools
- sortedcontainers
run:
- python
- sortedcontainers
test:
# Python imports
imports:
- intervaltree
# commands:
# You can put test commands to be run here. Use this to test that the
# entry points work.
# You can also put a file called run_test.py in the recipe that will be run
# at test time.
requires:
- pytest
# Put any additional test requirements here. For example
# - nose
about:
home: https://github.com/chaimleib/intervaltree
license: Apache Software License
summary: 'Editable interval tree data structure for Python 2 and 3'
# See
# http://docs.continuum.io/conda/build.html for
# more information about meta.yaml
+4
View File
@@ -54,3 +54,7 @@ toolz==0.7.4
# Asset writer and finder
sqlalchemy==1.0.8
# for intervaltree
sortedcontainers==1.4.4
intervaltree==2.1.0
+3
View File
@@ -59,5 +59,8 @@ futures==3.0.5
requests-futures==0.9.7
piprot==0.9.6
# For mocking out requests fetches
responses==0.4.0
# For asset db management
alembic==0.7.7
+2
View File
@@ -0,0 +1,2 @@
# Caching and other utilities
functools32==3.2.3.post2
-143
View File
@@ -1,143 +0,0 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logbook
import os
import pickle
import pytz
import sys
sys.path.insert(0, '.') # noqa
from zipline.finance.blotter import Blotter, Order
from zipline.finance.commission import PerShare, PerTrade, PerDollar
from zipline.finance.performance.period import PerformancePeriod
from zipline.finance.performance.position import Position
from zipline.finance.performance.position_tracker import PositionTracker
from zipline.finance.performance.tracker import PerformanceTracker
from zipline.finance.risk.cumulative import RiskMetricsCumulative
from zipline.finance.risk.period import RiskMetricsPeriod
from zipline.finance.risk.report import RiskReport
from zipline.finance.slippage import (
FixedSlippage,
Transaction,
VolumeShareSlippage
)
from zipline.protocol import Account
from zipline.protocol import Portfolio
from zipline.protocol import Position as ProtocolPosition
from zipline.finance.trading import SimulationParameters
from zipline.utils import factory
from zipline.utils.serialization_utils import VERSION_LABEL
base_state_dir = 'tests/resources/saved_state_archive'
if not os.path.exists(base_state_dir):
os.makedirs(base_state_dir)
sim_params_daily = SimulationParameters(
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
10000,
emission_rate='daily')
sim_params_minute = SimulationParameters(
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
10000,
emission_rate='minute')
returns = factory.create_returns_from_list(
[1.0], sim_params_daily)
argument_list = [
(Blotter, ()),
(Order, (datetime.datetime(2013, 6, 19), 8554, 100)),
(PerShare, ()),
(PerTrade, ()),
(PerDollar, ()),
(PerformancePeriod, (10000,)),
(Position, (8554,)),
(PositionTracker, ()),
(PerformanceTracker, (sim_params_minute,)),
(RiskMetricsCumulative, (sim_params_minute,)),
(RiskMetricsPeriod, (returns.index[0], returns.index[0], returns)),
(RiskReport, (returns, sim_params_minute)),
(FixedSlippage, ()),
(Transaction, (8554, 10, datetime.datetime(2013, 6, 19), 100, "0000")),
(VolumeShareSlippage, ()),
(Account, ()),
(Portfolio, ()),
(ProtocolPosition, (8554,))
]
def write_state_to_disk(cls, state, emission_rate=None):
state_dir = cls.__module__ + '.' + cls.__name__
full_dir = base_state_dir + '/' + state_dir
if not os.path.exists(full_dir):
os.makedirs(full_dir)
if emission_rate is not None:
name = 'State_Version_' + emission_rate + \
str(state['obj_state'][VERSION_LABEL])
else:
name = 'State_Version_' + str(state['obj_state'][VERSION_LABEL])
full_path = full_dir + '/' + name
f = open(full_path, 'w')
pickle.dump(state, f)
f.close()
def generate_object_state(cls, initargs):
obj = cls(*initargs)
state = obj.__getstate__()
if hasattr(obj, '__getinitargs__'):
initargs = obj.__getinitargs__()
else:
initargs = None
if hasattr(obj, '__getnewargs__'):
newargs = obj.__getnewargs__()
else:
newargs = None
on_disk_state = {
'obj_state': state,
'initargs': initargs,
'newargs': newargs
}
write_state_to_disk(cls, on_disk_state)
if __name__ == "__main__":
logbook.StderrHandler().push_application()
for args in argument_list:
generate_object_state(*args)
+32 -10
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from functools import partial
import os
import re
import sys
@@ -89,6 +90,12 @@ ext_modules = [
Extension('zipline.lib.rank', ['zipline/lib/rank.pyx']),
Extension('zipline.data._equities', ['zipline/data/_equities.pyx']),
Extension('zipline.data._adjustments', ['zipline/data/_adjustments.pyx']),
Extension('zipline._protocol', ['zipline/_protocol.pyx']),
Extension('zipline.gens.sim_engine', ['zipline/gens/sim_engine.pyx']),
Extension(
'zipline.data._minute_bar_internal',
['zipline/data/_minute_bar_internal.pyx']
)
]
@@ -156,19 +163,27 @@ def _with_bounds(req):
REQ_PATTERN = re.compile("([^=<>]+)([<=>]{1,2})(.*)")
def _conda_format(req):
def _conda_format(req, selector=None):
match = REQ_PATTERN.match(req)
if match and match.group(1).lower() == 'numpy':
return 'numpy x.x'
line = 'numpy x.x'
else:
line = REQ_PATTERN.sub(
lambda m: '%s %s%s' % (m.group(1).lower(), m.group(2), m.group(3)),
req,
1,
)
return REQ_PATTERN.sub(
lambda m: '%s %s%s' % (m.group(1).lower(), m.group(2), m.group(3)),
req,
1,
)
if selector is not None:
line += ' # [%s]' % selector
return line
def read_requirements(path, strict_bounds, conda_format=False):
def read_requirements(path,
strict_bounds,
conda_format=False,
conda_selector=None):
"""
Read a requirements.txt file, expressed as a path relative to Zipline root.
@@ -183,15 +198,22 @@ def read_requirements(path, strict_bounds, conda_format=False):
reqs = map(_with_bounds, reqs)
if conda_format:
reqs = map(_conda_format, reqs)
reqs = map(partial(_conda_format, selector=conda_selector), reqs)
return list(reqs)
def install_requires(strict_bounds=False, conda_format=False):
return read_requirements('etc/requirements.txt',
reqs = read_requirements('etc/requirements.txt',
strict_bounds=strict_bounds,
conda_format=conda_format)
if sys.version_info.major == 2 or conda_format:
reqs += read_requirements('etc/requirements_py2.txt',
strict_bounds=strict_bounds,
conda_format=conda_format,
conda_selector='py2k')
return reqs
def extras_requires(conda_format=False):
+147 -1
View File
@@ -47,7 +47,7 @@ from zipline.finance.trading import TradingEnvironment
TEST_CALENDAR_START = Timestamp('2015-06-01', tz='UTC')
TEST_CALENDAR_STOP = Timestamp('2015-06-30', tz='UTC')
TEST_CALENDAR_STOP = Timestamp('2015-12-31', tz='UTC')
class BcolzMinuteBarTestCase(TestCase):
@@ -637,3 +637,149 @@ class BcolzMinuteBarTestCase(TestCase):
for i, col in enumerate(columns):
for j, sid in enumerate(sids):
assert_almost_equal(data[sid][col], arrays[i][j])
def test_unadjusted_minutes_early_close(self):
"""
Test unadjusted minute window, ensuring that early closes are filtered
out.
"""
day_before_thanksgiving = Timestamp('2015-11-25', tz='UTC')
xmas_eve = Timestamp('2015-12-24', tz='UTC')
market_day_after_xmas = Timestamp('2015-12-28', tz='UTC')
minutes = [self.market_closes[day_before_thanksgiving] -
Timedelta('2 min'),
self.market_closes[xmas_eve] - Timedelta('1 min'),
self.market_opens[market_day_after_xmas] +
Timedelta('1 min')]
sids = [1, 2]
data_1 = DataFrame(
data={
'open': [
15.0, 15.1, 15.2],
'high': [17.0, 17.1, 17.2],
'low': [11.0, 11.1, 11.3],
'close': [14.0, 14.1, 14.2],
'volume': [1000, 1001, 1002],
},
index=minutes)
self.writer.write(sids[0], data_1)
data_2 = DataFrame(
data={
'open': [25.0, 25.1, 25.2],
'high': [27.0, 27.1, 27.2],
'low': [21.0, 21.1, 21.2],
'close': [24.0, 24.1, 24.2],
'volume': [2000, 2001, 2002],
},
index=minutes)
self.writer.write(sids[1], data_2)
reader = BcolzMinuteBarReader(self.dest)
columns = ['open', 'high', 'low', 'close', 'volume']
sids = [sids[0], sids[1]]
arrays = reader.unadjusted_window(
columns, minutes[0], minutes[-1], sids)
data = {sids[0]: data_1, sids[1]: data_2}
start_minute_loc = self.env.market_minutes.get_loc(minutes[0])
minute_locs = [self.env.market_minutes.get_loc(minute) -
start_minute_loc
for minute in minutes]
for i, col in enumerate(columns):
for j, sid in enumerate(sids):
assert_almost_equal(data[sid].loc[minutes, col],
arrays[i][j][minute_locs])
def test_adjust_non_trading_minutes(self):
start_day = Timestamp('2015-06-01', tz='UTC')
end_day = Timestamp('2015-06-02', tz='UTC')
sid = 1
cols = {
'open': arange(1, 781),
'high': arange(1, 781),
'low': arange(1, 781),
'close': arange(1, 781),
'volume': arange(1, 781)
}
dts = array(self.env.minutes_for_days_in_range(start_day, end_day))
self.writer.write_cols(sid, dts, cols)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-06-01 20:00:00', tz='UTC'),
'open'),
390)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-06-02 20:00:00', tz='UTC'),
'open'),
780)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-06-02', tz='UTC'),
'open'),
390)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-06-02 20:01:00', tz='UTC'),
'open'),
780)
def test_adjust_non_trading_minutes_half_days(self):
# half day
start_day = Timestamp('2015-11-27', tz='UTC')
end_day = Timestamp('2015-11-30', tz='UTC')
sid = 1
cols = {
'open': arange(1, 601),
'high': arange(1, 601),
'low': arange(1, 601),
'close': arange(1, 601),
'volume': arange(1, 601)
}
dts = array(self.env.minutes_for_days_in_range(start_day, end_day))
self.writer.write_cols(sid, dts, cols)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-11-27 18:00:00', tz='UTC'),
'open'),
210)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-11-30 21:00:00', tz='UTC'),
'open'),
600)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-11-27 18:01:00', tz='UTC'),
'open'),
210)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-11-30', tz='UTC'),
'open'),
210)
self.assertEqual(
self.reader.get_value(
sid,
Timestamp('2015-11-30 21:01:00', tz='UTC'),
'open'),
600)
+34
View File
@@ -0,0 +1,34 @@
#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from zipline.finance.cancel_policy import NeverCancel, EODCancel
from zipline.gens.sim_engine import (
BAR,
DAY_END
)
class CancelPolicyTestCase(TestCase):
def test_eod_cancel(self):
cancel_policy = EODCancel()
self.assertTrue(cancel_policy.should_cancel(DAY_END))
self.assertFalse(cancel_policy.should_cancel(BAR))
def test_never_cancel(self):
cancel_policy = NeverCancel()
self.assertFalse(cancel_policy.should_cancel(DAY_END))
self.assertFalse(cancel_policy.should_cancel(BAR))
+375 -218
View File
@@ -24,132 +24,264 @@ from unittest import TestCase
from nose_parameterized import parameterized
import numpy as np
import pandas as pd
from pandas.tslib import normalize_date
from testfixtures import TempDirectory
from zipline.finance.slippage import VolumeShareSlippage
from zipline.finance.trading import TradingEnvironment, SimulationParameters
from zipline.protocol import Event, DATASOURCE_TYPE
from zipline.protocol import DATASOURCE_TYPE
from zipline.finance.blotter import Order
from zipline.data.minute_bars import BcolzMinuteBarReader
from zipline.data.data_portal import DataPortal
from zipline.protocol import BarData
from zipline.testing.core import write_bcolz_minute_data
class SlippageTestCase(TestCase):
def test_volume_share_slippage(self):
event = Event(
{'volume': 200,
'type': 4,
'price': 3.0,
'datetime': datetime.datetime(
2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 3.0,
'dt':
datetime.datetime(2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'open': 3.0}
@classmethod
def setUpClass(cls):
cls.tempdir = TempDirectory()
cls.env = TradingEnvironment()
cls.sim_params = SimulationParameters(
period_start=pd.Timestamp("2006-01-05 14:31", tz="utc"),
period_end=pd.Timestamp("2006-01-05 14:36", tz="utc"),
capital_base=1.0e5,
data_frequency="minute",
emission_rate='daily',
env=cls.env
)
slippage_model = VolumeShareSlippage()
cls.sids = [133]
open_orders = [
Order(dt=datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
amount=100,
filled=0,
sid=133)
]
cls.minutes = pd.DatetimeIndex(
start=pd.Timestamp("2006-01-05 14:31", tz="utc"),
end=pd.Timestamp("2006-01-05 14:35", tz="utc"),
freq="1min"
)
orders_txns = list(slippage_model.simulate(
event,
open_orders
))
self.assertEquals(len(orders_txns), 1)
_, txn = orders_txns[0]
expected_txn = {
'price': float(3.01875),
'dt': datetime.datetime(
2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'amount': int(50),
'sid': int(133),
'commission': None,
'type': DATASOURCE_TYPE.TRANSACTION,
'order_id': open_orders[0].id
assets = {
133: pd.DataFrame({
"open": np.array([3.0, 3.0, 3.5, 4.0, 3.5]),
"high": np.array([3.15, 3.15, 3.15, 3.15, 3.15]),
"low": np.array([2.85, 2.85, 2.85, 2.85, 2.85]),
"close": np.array([3.0, 3.5, 4.0, 3.5, 3.0]),
"volume": [2000, 2000, 2000, 2000, 2000],
"dt": cls.minutes
}).set_index("dt")
}
self.assertIsNotNone(txn)
write_bcolz_minute_data(
cls.env,
pd.date_range(
start=normalize_date(cls.minutes[0]),
end=normalize_date(cls.minutes[-1])
),
cls.tempdir.path,
assets
)
# TODO: Make expected_txn an Transaction object and ensure there
# is a __eq__ for that class.
self.assertEquals(expected_txn, txn.__dict__)
cls.env.write_data(equities_data={
133: {
"start_date": pd.Timestamp("2006-01-05", tz='utc'),
"end_date": pd.Timestamp("2006-01-07", tz='utc')
}
})
cls.ASSET133 = cls.env.asset_finder.retrieve_asset(133)
cls.data_portal = DataPortal(
cls.env,
equity_minute_reader=BcolzMinuteBarReader(cls.tempdir.path),
)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
del cls.env
def test_volume_share_slippage(self):
tempdir = TempDirectory()
try:
assets = {
133: pd.DataFrame({
"open": [3.00],
"high": [3.15],
"low": [2.85],
"close": [3.00],
"volume": [200],
"dt": [self.minutes[0]]
}).set_index("dt")
}
write_bcolz_minute_data(
self.env,
pd.date_range(
start=normalize_date(self.minutes[0]),
end=normalize_date(self.minutes[-1])
),
tempdir.path,
assets
)
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
data_portal = DataPortal(
self.env,
equity_minute_reader=equity_minute_reader,
)
slippage_model = VolumeShareSlippage()
open_orders = [
Order(
dt=datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
amount=100,
filled=0,
sid=self.ASSET133
)
]
bar_data = BarData(data_portal,
lambda: self.minutes[0],
'minute')
orders_txns = list(slippage_model.simulate(
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 1)
_, txn = orders_txns[0]
expected_txn = {
'price': float(3.0001875),
'dt': datetime.datetime(
2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'amount': int(5),
'sid': int(133),
'commission': None,
'type': DATASOURCE_TYPE.TRANSACTION,
'order_id': open_orders[0].id
}
self.assertIsNotNone(txn)
# TODO: Make expected_txn an Transaction object and ensure there
# is a __eq__ for that class.
self.assertEquals(expected_txn, txn.__dict__)
open_orders = [
Order(
dt=datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
amount=100,
filled=0,
sid=self.ASSET133
)
]
# Set bar_data to be a minute ahead of last trade.
# Volume share slippage should not execute when there is no trade.
bar_data = BarData(data_portal,
lambda: self.minutes[1],
'minute')
orders_txns = list(slippage_model.simulate(
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
finally:
tempdir.cleanup()
def test_orders_limit(self):
events = self.gen_trades()
slippage_model = VolumeShareSlippage()
slippage_model.data_portal = self.data_portal
# long, does not trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.5})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# long, does not trade - impacted price worse than limit price
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.5})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# long, does trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.6})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 1)
txn = orders_txns[0][1]
expected_txn = {
'price': float(3.500875),
'price': float(3.50021875),
'dt': datetime.datetime(
2006, 1, 5, 14, 34, tzinfo=pytz.utc),
'amount': int(100),
# we ordered 100 shares, but default volume slippage only allows
# for 2.5% of the volume. 2.5% * 2000 = 50 shares
'amount': int(50),
'sid': int(133),
'order_id': open_orders[0].id
}
@@ -160,67 +292,77 @@ class SlippageTestCase(TestCase):
self.assertEquals(value, txn[key])
# short, does not trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.5})
]
orders_txns = list(slippage_model.simulate(
events[0],
open_orders
))
bar_data = BarData(self.data_portal,
lambda: self.minutes[0],
self.sim_params.data_frequency)
expected_txn = {}
orders_txns = list(slippage_model.simulate(
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# short, does not trade - impacted price worse than limit price
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.5})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[0],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[1],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# short, does trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'limit': 3.4})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[1],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[1],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 1)
_, txn = orders_txns[0]
expected_txn = {
'price': float(3.499125),
'price': float(3.49978125),
'dt': datetime.datetime(
2006, 1, 5, 14, 32, tzinfo=pytz.utc),
'amount': int(-100),
'amount': int(-50),
'sid': int(133)
}
@@ -275,9 +417,9 @@ class SlippageTestCase(TestCase):
},
'expected': {
'transaction': {
'price': 4.001,
'price': 4.00025,
'dt': pd.Timestamp('2006-01-05 14:31', tz='UTC'),
'amount': 100,
'amount': 50,
'sid': 133,
}
}
@@ -346,9 +488,9 @@ class SlippageTestCase(TestCase):
},
'expected': {
'transaction': {
'price': 2.99925,
'price': 2.9998125,
'dt': pd.Timestamp('2006-01-05 14:31', tz='UTC'),
'amount': -100,
'amount': -50,
'sid': 133,
}
}
@@ -360,113 +502,181 @@ class SlippageTestCase(TestCase):
for name, case in STOP_ORDER_CASES.items()
])
def test_orders_stop(self, name, order_data, event_data, expected):
order = Order(**order_data)
event = Event(initial_values=event_data)
slippage_model = VolumeShareSlippage()
tempdir = TempDirectory()
try:
_, txn = next(slippage_model.simulate(event, [order]))
except StopIteration:
txn = None
data = order_data
data['sid'] = self.ASSET133
if expected['transaction'] is None:
self.assertIsNone(txn)
else:
self.assertIsNotNone(txn)
order = Order(**data)
for key, value in expected['transaction'].items():
self.assertEquals(value, txn[key])
assets = {
133: pd.DataFrame({
"open": [event_data["open"]],
"high": [event_data["high"]],
"low": [event_data["low"]],
"close": [event_data["close"]],
"volume": [event_data["volume"]],
"dt": [pd.Timestamp('2006-01-05 14:31', tz='UTC')]
}).set_index("dt")
}
write_bcolz_minute_data(
self.env,
pd.date_range(
start=normalize_date(self.minutes[0]),
end=normalize_date(self.minutes[-1])
),
tempdir.path,
assets
)
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
data_portal = DataPortal(
self.env,
equity_minute_reader=equity_minute_reader,
)
slippage_model = VolumeShareSlippage()
try:
dt = pd.Timestamp('2006-01-05 14:31', tz='UTC')
bar_data = BarData(data_portal,
lambda: dt,
'minute')
_, txn = next(slippage_model.simulate(
bar_data,
self.ASSET133,
[order],
))
except StopIteration:
txn = None
if expected['transaction'] is None:
self.assertIsNone(txn)
else:
self.assertIsNotNone(txn)
for key, value in expected['transaction'].items():
self.assertEquals(value, txn[key])
finally:
tempdir.cleanup()
def test_orders_stop_limit(self):
events = self.gen_trades()
slippage_model = VolumeShareSlippage()
slippage_model.data_portal = self.data_portal
# long, does not trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 4.0,
'limit': 3.0})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[2],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[2],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# long, does not trade - impacted price worse than limit price
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 4.0,
'limit': 3.5})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[2],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[2],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# long, does trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': 100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 4.0,
'limit': 3.6})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[2],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[2],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[3],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[3],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 1)
_, txn = orders_txns[0]
expected_txn = {
'price': float(3.500875),
'price': float(3.50021875),
'dt': datetime.datetime(
2006, 1, 5, 14, 34, tzinfo=pytz.utc),
'amount': int(100),
'amount': int(50),
'sid': int(133)
}
@@ -480,166 +690,113 @@ class SlippageTestCase(TestCase):
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 3.0,
'limit': 4.0})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[0],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[0],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[1],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[1],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# short, does not trade - impacted price worse than limit price
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 3.0,
'limit': 3.5})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[0],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[0],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[1],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[1],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
# short, does trade
open_orders = [
Order(**{
'dt': datetime.datetime(2006, 1, 5, 14, 30, tzinfo=pytz.utc),
'amount': -100,
'filled': 0,
'sid': 133,
'sid': self.ASSET133,
'stop': 3.0,
'limit': 3.4})
]
bar_data = BarData(self.data_portal,
lambda: self.minutes[0],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[0],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 0)
bar_data = BarData(self.data_portal,
lambda: self.minutes[1],
self.sim_params.data_frequency)
orders_txns = list(slippage_model.simulate(
events[1],
open_orders
bar_data,
self.ASSET133,
open_orders,
))
self.assertEquals(len(orders_txns), 1)
_, txn = orders_txns[0]
expected_txn = {
'price': float(3.499125),
'price': float(3.49978125),
'dt': datetime.datetime(
2006, 1, 5, 14, 32, tzinfo=pytz.utc),
'amount': int(-100),
'amount': int(-50),
'sid': int(133)
}
for key, value in expected_txn.items():
self.assertEquals(value, txn[key])
def gen_trades(self):
# create a sequence of trades
events = [
Event({
'volume': 2000,
'type': 4,
'price': 3.0,
'datetime': datetime.datetime(
2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 3.0,
'dt':
datetime.datetime(2006, 1, 5, 14, 31, tzinfo=pytz.utc),
'open': 3.0
}),
Event({
'volume': 2000,
'type': 4,
'price': 3.5,
'datetime': datetime.datetime(
2006, 1, 5, 14, 32, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 3.5,
'dt':
datetime.datetime(2006, 1, 5, 14, 32, tzinfo=pytz.utc),
'open': 3.0
}),
Event({
'volume': 2000,
'type': 4,
'price': 4.0,
'datetime': datetime.datetime(
2006, 1, 5, 14, 33, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 4.0,
'dt':
datetime.datetime(2006, 1, 5, 14, 33, tzinfo=pytz.utc),
'open': 3.5
}),
Event({
'volume': 2000,
'type': 4,
'price': 3.5,
'datetime': datetime.datetime(
2006, 1, 5, 14, 34, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 3.5,
'dt':
datetime.datetime(2006, 1, 5, 14, 34, tzinfo=pytz.utc),
'open': 4.0
}),
Event({
'volume': 2000,
'type': 4,
'price': 3.0,
'datetime': datetime.datetime(
2006, 1, 5, 14, 35, tzinfo=pytz.utc),
'high': 3.15,
'low': 2.85,
'sid': 133,
'source_id': 'test_source',
'close': 3.0,
'dt':
datetime.datetime(2006, 1, 5, 14, 35, tzinfo=pytz.utc),
'open': 3.5
})
]
return events
+166
View File
@@ -0,0 +1,166 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import pandas as pd
from zipline.finance.trading import TradingEnvironment
from zipline.data.us_equity_minutes import BcolzMinuteBarWriter
def generate_daily_test_data(first_day,
last_day,
starting_open,
starting_volume,
multipliers_list,
path):
days = TradingEnvironment.instance().days_in_range(first_day, last_day)
days_count = len(days)
o = np.zeros(days_count, dtype=np.uint32)
h = np.zeros(days_count, dtype=np.uint32)
l = np.zeros(days_count, dtype=np.uint32)
c = np.zeros(days_count, dtype=np.uint32)
v = np.zeros(days_count, dtype=np.uint32)
last_open = starting_open * 1000
last_volume = starting_volume
for idx in range(days_count):
new_open = last_open + round((random.random() * 5), 2)
o[idx] = new_open
h[idx] = new_open + round((random.random() * 10000), 2)
l[idx] = new_open - round((random.random() * 10000), 2)
c[idx] = (h[idx] + l[idx]) / 2
v[idx] = int(last_volume + (random.randrange(-10, 10) * 1e4))
last_open = o[idx]
last_volume = v[idx]
# now deal with multipliers
if len(multipliers_list) > 0:
range_start = 0
for multiplier_info in multipliers_list:
range_end = days.searchsorted(multiplier_info[0])
# dividing by the multiplier because we're going backwards
# and generating the original data that will then be adjusted.
o[range_start:range_end] /= multiplier_info[1]
h[range_start:range_end] /= multiplier_info[1]
l[range_start:range_end] /= multiplier_info[1]
c[range_start:range_end] /= multiplier_info[1]
v[range_start:range_end] *= multiplier_info[1]
range_start = range_end
df = pd.DataFrame({
"open": o,
"high": h,
"low": l,
"close": c,
"volume": v
}, columns=[
"open",
"high",
"low",
"close",
"volume"
], index=days)
df.to_csv(path, index_label="day")
def generate_minute_test_data(first_day,
last_day,
starting_open,
starting_volume,
multipliers_list,
path):
"""
Utility method to generate fake minute-level CSV data.
:param first_day: first trading day
:param last_day: last trading day
:param starting_open: first open value, raw value.
:param starting_volume: first volume value, raw value.
:param multipliers_list: ordered list of pd.Timestamp -> float, one per day
in the range
:param path: path to save the CSV
:return: None
"""
full_minutes = BcolzMinuteBarWriter.full_minutes_for_days(
first_day, last_day)
minutes_count = len(full_minutes)
minutes = TradingEnvironment.instance().minutes_for_days_in_range(
first_day, last_day)
o = np.zeros(minutes_count, dtype=np.uint32)
h = np.zeros(minutes_count, dtype=np.uint32)
l = np.zeros(minutes_count, dtype=np.uint32)
c = np.zeros(minutes_count, dtype=np.uint32)
v = np.zeros(minutes_count, dtype=np.uint32)
last_open = starting_open * 1000
last_volume = starting_volume
for minute in minutes:
# ugly, but works
idx = full_minutes.searchsorted(minute)
new_open = last_open + round((random.random() * 5), 2)
o[idx] = new_open
h[idx] = new_open + round((random.random() * 10000), 2)
l[idx] = new_open - round((random.random() * 10000), 2)
c[idx] = (h[idx] + l[idx]) / 2
v[idx] = int(last_volume + (random.randrange(-10, 10) * 1e4))
last_open = o[idx]
last_volume = v[idx]
# now deal with multipliers
if len(multipliers_list) > 0:
for idx, multiplier_info in enumerate(multipliers_list):
start_idx = idx * 390
end_idx = start_idx + 390
# dividing by the multipler because we're going backwards
# and generating the original data that will then be adjusted.
o[start_idx:end_idx] /= multiplier_info[1]
h[start_idx:end_idx] /= multiplier_info[1]
l[start_idx:end_idx] /= multiplier_info[1]
c[start_idx:end_idx] /= multiplier_info[1]
v[start_idx:end_idx] *= multiplier_info[1]
df = pd.DataFrame({
"open": o,
"high": h,
"low": l,
"close": c,
"volume": v
}, columns=[
"open",
"high",
"low",
"close",
"volume"
], index=minutes)
df.to_csv(path, index_label="minute")
-647
View File
@@ -1,647 +0,0 @@
"""
Test case definitions for history tests.
"""
import pandas as pd
import numpy as np
from zipline.finance.trading import TradingEnvironment, noop_load
from zipline.history.history import HistorySpec
from zipline.protocol import BarData
from zipline.testing import to_utc
_cases_env = TradingEnvironment(load=noop_load)
def mixed_frequency_expected_index(count, frequency):
"""
Helper for enumerating expected indices for test_mixed_frequency.
"""
minute = MIXED_FREQUENCY_MINUTES[count]
if frequency == '1d':
return [_cases_env.previous_open_and_close(minute)[1], minute]
elif frequency == '1m':
return [_cases_env.previous_market_minute(minute), minute]
def mixed_frequency_expected_data(count, frequency):
"""
Helper for enumerating expected data test_mixed_frequency.
"""
if frequency == '1d':
# First day of this test is July 3rd, which is a half day.
if count < 210:
return [np.nan, count]
else:
return [209, count]
elif frequency == '1m':
if count == 0:
return [np.nan, count]
else:
return [count - 1, count]
MIXED_FREQUENCY_MINUTES = _cases_env.market_minute_window(
to_utc('2013-07-03 9:31AM'), 600,
)
ONE_MINUTE_PRICE_ONLY_SPECS = [
HistorySpec(1, '1m', 'price', True, _cases_env, data_frequency='minute'),
]
DAILY_OPEN_CLOSE_SPECS = [
HistorySpec(3, '1d', 'open_price', False, _cases_env,
data_frequency='minute'),
HistorySpec(3, '1d', 'close_price', False, _cases_env,
data_frequency='minute'),
]
ILLIQUID_PRICES_SPECS = [
HistorySpec(3, '1m', 'price', False, _cases_env, data_frequency='minute'),
HistorySpec(5, '1m', 'price', True, _cases_env, data_frequency='minute'),
]
MIXED_FREQUENCY_SPECS = [
HistorySpec(1, '1m', 'price', False, _cases_env, data_frequency='minute'),
HistorySpec(2, '1m', 'price', False, _cases_env, data_frequency='minute'),
HistorySpec(2, '1d', 'price', False, _cases_env, data_frequency='minute'),
]
MIXED_FIELDS_SPECS = [
HistorySpec(3, '1m', 'price', True, _cases_env, data_frequency='minute'),
HistorySpec(3, '1m', 'open_price', True, _cases_env,
data_frequency='minute'),
HistorySpec(3, '1m', 'close_price', True, _cases_env,
data_frequency='minute'),
HistorySpec(3, '1m', 'high', True, _cases_env, data_frequency='minute'),
HistorySpec(3, '1m', 'low', True, _cases_env, data_frequency='minute'),
HistorySpec(3, '1m', 'volume', True, _cases_env, data_frequency='minute'),
]
HISTORY_CONTAINER_TEST_CASES = {
# June 2013
# Su Mo Tu We Th Fr Sa
# 1
# 2 3 4 5 6 7 8
# 9 10 11 12 13 14 15
# 16 17 18 19 20 21 22
# 23 24 25 26 27 28 29
# 30
'test one minute price only': {
# A list of HistorySpec objects.
'specs': ONE_MINUTE_PRICE_ONLY_SPECS,
# Sids for the test.
'sids': [1],
# Start date for test.
'dt': to_utc('2013-06-21 9:31AM'),
# Sequency of updates to the container
'updates': [
BarData(
{
1: {
'price': 5,
'dt': to_utc('2013-06-21 9:31AM'),
},
},
),
BarData(
{
1: {
'price': 6,
'dt': to_utc('2013-06-21 9:32AM'),
},
},
),
],
# Expected results
'expected': {
ONE_MINUTE_PRICE_ONLY_SPECS[0].key_str: [
pd.DataFrame(
data={
1: [5],
},
index=[
to_utc('2013-06-21 9:31AM'),
],
),
pd.DataFrame(
data={
1: [6],
},
index=[
to_utc('2013-06-21 9:32AM'),
],
),
],
},
},
'test daily open close': {
# A list of HistorySpec objects.
'specs': DAILY_OPEN_CLOSE_SPECS,
# Sids for the test.
'sids': [1],
# Start date for test.
'dt': to_utc('2013-06-21 9:31AM'),
# Sequence of updates to the container
'updates': [
BarData(
{
1: {
'open_price': 10,
'close_price': 11,
'dt': to_utc('2013-06-21 10:00AM'),
},
},
),
BarData(
{
1: {
'open_price': 12,
'close_price': 13,
'dt': to_utc('2013-06-21 3:30PM'),
},
},
),
BarData(
{
1: {
'open_price': 14,
'close_price': 15,
# Wait a full market day before the next bar.
# We should end up with nans for Monday the 24th.
'dt': to_utc('2013-06-25 9:31AM'),
},
},
),
],
# Dictionary mapping spec_key -> list of expected outputs
'expected': {
# open
DAILY_OPEN_CLOSE_SPECS[0].key_str: [
pd.DataFrame(
data={
1: [np.nan, np.nan, 10]
},
index=[
to_utc('2013-06-19 4:00PM'),
to_utc('2013-06-20 4:00PM'),
to_utc('2013-06-21 10:00AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, np.nan, 10]
},
index=[
to_utc('2013-06-19 4:00PM'),
to_utc('2013-06-20 4:00PM'),
to_utc('2013-06-21 3:30PM'),
],
),
pd.DataFrame(
data={
1: [10, np.nan, 14]
},
index=[
to_utc('2013-06-21 4:00PM'),
to_utc('2013-06-24 4:00PM'),
to_utc('2013-06-25 9:31AM'),
],
),
],
# close
DAILY_OPEN_CLOSE_SPECS[1].key_str: [
pd.DataFrame(
data={
1: [np.nan, np.nan, 11]
},
index=[
to_utc('2013-06-19 4:00PM'),
to_utc('2013-06-20 4:00PM'),
to_utc('2013-06-21 10:00AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, np.nan, 13]
},
index=[
to_utc('2013-06-19 4:00PM'),
to_utc('2013-06-20 4:00PM'),
to_utc('2013-06-21 3:30PM'),
],
),
pd.DataFrame(
data={
1: [13, np.nan, 15]
},
index=[
to_utc('2013-06-21 4:00PM'),
to_utc('2013-06-24 4:00PM'),
to_utc('2013-06-25 9:31AM'),
],
),
],
},
},
'test illiquid prices': {
# A list of HistorySpec objects.
'specs': ILLIQUID_PRICES_SPECS,
# Sids for the test.
'sids': [1],
# Start date for test.
'dt': to_utc('2013-06-28 9:31AM'),
# Sequence of updates to the container
'updates': [
BarData(
{
1: {
'price': 10,
'dt': to_utc('2013-06-28 9:31AM'),
},
},
),
BarData(
{
1: {
'price': 11,
'dt': to_utc('2013-06-28 9:32AM'),
},
},
),
BarData(
{
1: {
'price': 12,
'dt': to_utc('2013-06-28 9:33AM'),
},
},
),
BarData(
{
1: {
'price': 13,
# Note: Skipping 9:34 to simulate illiquid bar/missing
# data.
'dt': to_utc('2013-06-28 9:35AM'),
},
},
),
],
# Dictionary mapping spec_key -> list of expected outputs
'expected': {
ILLIQUID_PRICES_SPECS[0].key_str: [
pd.DataFrame(
data={
1: [np.nan, np.nan, 10],
},
index=[
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, 10, 11],
},
index=[
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
],
),
pd.DataFrame(
data={
1: [10, 11, 12],
},
index=[
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
],
),
# Since there's no update for 9:34, this is called at 9:35.
pd.DataFrame(
data={
1: [12, np.nan, 13],
},
index=[
to_utc('2013-06-28 9:33AM'),
to_utc('2013-06-28 9:34AM'),
to_utc('2013-06-28 9:35AM'),
],
),
],
ILLIQUID_PRICES_SPECS[1].key_str: [
pd.DataFrame(
data={
1: [np.nan, np.nan, np.nan, np.nan, 10],
},
index=[
to_utc('2013-06-27 3:57PM'),
to_utc('2013-06-27 3:58PM'),
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, np.nan, np.nan, 10, 11],
},
index=[
to_utc('2013-06-27 3:58PM'),
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, np.nan, 10, 11, 12],
},
index=[
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
],
),
# Since there's no update for 9:34, this is called at 9:35.
# The 12 value from 9:33 should be forward-filled.
pd.DataFrame(
data={
1: [10, 11, 12, 12, 13],
},
index=[
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
to_utc('2013-06-28 9:34AM'),
to_utc('2013-06-28 9:35AM'),
],
),
],
},
},
'test mixed frequencies': {
# A list of HistorySpec objects.
'specs': MIXED_FREQUENCY_SPECS,
# Sids for the test.
'sids': [1],
# Start date for test.
# July 2013
# Su Mo Tu We Th Fr Sa
# 1 2 3 4 5 6
# 7 8 9 10 11 12 13
# 14 15 16 17 18 19 20
# 21 22 23 24 25 26 27
# 28 29 30 31
'dt': to_utc('2013-07-03 9:31AM'),
# Sequence of updates to the container
'updates': [
BarData(
{
1: {
'price': count,
'dt': dt,
}
}
)
for count, dt in enumerate(MIXED_FREQUENCY_MINUTES)
],
# Dictionary mapping spec_key -> list of expected outputs.
'expected': {
MIXED_FREQUENCY_SPECS[0].key_str: [
pd.DataFrame(
data={
1: [count],
},
index=[minute],
)
for count, minute in enumerate(MIXED_FREQUENCY_MINUTES)
],
MIXED_FREQUENCY_SPECS[1].key_str: [
pd.DataFrame(
data={
1: mixed_frequency_expected_data(count, '1m'),
},
index=mixed_frequency_expected_index(count, '1m'),
)
for count in range(len(MIXED_FREQUENCY_MINUTES))
],
MIXED_FREQUENCY_SPECS[2].key_str: [
pd.DataFrame(
data={
1: mixed_frequency_expected_data(count, '1d'),
},
index=mixed_frequency_expected_index(count, '1d'),
)
for count in range(len(MIXED_FREQUENCY_MINUTES))
]
},
},
'test multiple fields and sids': {
# A list of HistorySpec objects.
'specs': MIXED_FIELDS_SPECS,
# Sids for the test.
'sids': [1, 10],
# Start date for test.
'dt': to_utc('2013-06-28 9:31AM'),
# Sequence of updates to the container
'updates': [
BarData(
{
1: {
'dt': dt,
'price': count,
'open_price': count,
'close_price': count,
'high': count,
'low': count,
'volume': count,
},
10: {
'dt': dt,
'price': count * 10,
'open_price': count * 10,
'close_price': count * 10,
'high': count * 10,
'low': count * 10,
'volume': count * 10,
},
},
)
for count, dt in enumerate([
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
# NOTE: No update for 9:34
to_utc('2013-06-28 9:35AM'),
])
],
# Dictionary mapping spec_key -> list of expected outputs
'expected': dict(
# Build a dict from a list of tuples. Doing it this way because
# there are two distinct cases we want to test: forward-fillable
# fields and non-forward-fillable fields.
[
(
# Non forward-fill fields
key,
[
pd.DataFrame(
data={
1: [np.nan, np.nan, 0],
10: [np.nan, np.nan, 0],
},
index=[
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, 0, 1],
10: [np.nan, 0, 10],
},
index=[
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
],
),
pd.DataFrame(
data={
1: [0, 1, 2],
10: [0, 10, 20],
},
index=[
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
],
),
pd.DataFrame(
data={
1: [2, np.nan, 3],
10: [20, np.nan, 30],
},
index=[
to_utc('2013-06-28 9:33AM'),
to_utc('2013-06-28 9:34AM'),
to_utc('2013-06-28 9:35AM'),
],
# For volume, when we are missing data, we replace
# it with 0s to show that no trades occured.
).fillna(0 if 'volume' in key else np.nan),
],
)
for key in [spec.key_str for spec in MIXED_FIELDS_SPECS
if spec.field not in HistorySpec.FORWARD_FILLABLE]
] +
# Concatenate the expected results for non-ffillable with
# expected result for ffillable.
[
(
# Forward-fillable fields
key,
[
pd.DataFrame(
data={
1: [np.nan, np.nan, 0],
10: [np.nan, np.nan, 0],
},
index=[
to_utc('2013-06-27 3:59PM'),
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
],
),
pd.DataFrame(
data={
1: [np.nan, 0, 1],
10: [np.nan, 0, 10],
},
index=[
to_utc('2013-06-27 4:00PM'),
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
],
),
pd.DataFrame(
data={
1: [0, 1, 2],
10: [0, 10, 20],
},
index=[
to_utc('2013-06-28 9:31AM'),
to_utc('2013-06-28 9:32AM'),
to_utc('2013-06-28 9:33AM'),
],
),
pd.DataFrame(
data={
1: [2, 2, 3],
10: [20, 20, 30],
},
index=[
to_utc('2013-06-28 9:33AM'),
to_utc('2013-06-28 9:34AM'),
to_utc('2013-06-28 9:35AM'),
],
),
],
)
for key in [spec.key_str for spec in MIXED_FIELDS_SPECS
if spec.field in HistorySpec.FORWARD_FILLABLE]
]
),
},
}
+48 -25
View File
@@ -1,6 +1,7 @@
"""
Tests for Algorithms using the Pipeline API.
"""
import os
from unittest import TestCase
from os.path import (
dirname,
@@ -22,8 +23,6 @@ from pandas import (
concat,
DataFrame,
date_range,
DatetimeIndex,
Panel,
read_csv,
Series,
Timestamp,
@@ -37,6 +36,7 @@ from zipline.api import (
pipeline_output,
get_datetime,
)
from zipline.data.data_portal import DataPortal
from zipline.errors import (
AttachPipelineAfterInitialize,
PipelineOutputDuringInitialize,
@@ -59,8 +59,10 @@ from zipline.pipeline.loaders.equity_pricing_loader import (
)
from zipline.testing import (
make_simple_equity_info,
str_to_seconds,
str_to_seconds
)
from zipline.testing.core import DailyBarWriterFromDataFrames, \
create_empty_splits_mergers_frame, FakeDataPortal
from zipline.utils.tradingcalendar import (
trading_day,
trading_days,
@@ -89,6 +91,14 @@ def rolling_vwap(df, length):
class ClosesOnly(TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = TempDirectory()
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
def setUp(self):
self.env = env = trading.TradingEnvironment()
self.dates = date_range(
@@ -132,6 +142,33 @@ class ClosesOnly(TestCase):
dtype=float,
)
# Create a data portal holding the data in self.closes
data = {}
for sid in sids:
data[sid] = DataFrame({
"open": self.closes[sid].values,
"high": self.closes[sid].values,
"low": self.closes[sid].values,
"close": self.closes[sid].values,
"volume": self.closes[sid].values,
"day": [day.value for day in self.dates]
})
path = os.path.join(self.tempdir.path, "testdaily.bcolz")
DailyBarWriterFromDataFrames(data).write(
path,
self.dates,
data
)
daily_bar_reader = BcolzDailyBarReader(path)
self.data_portal = DataPortal(
self.env,
equity_daily_reader=daily_bar_reader,
)
# Add a split for 'A' on its second date.
self.split_asset = self.assets[0]
self.split_date = self.split_asset.start_date + trading_day
@@ -189,7 +226,7 @@ class ClosesOnly(TestCase):
)
with self.assertRaises(AttachPipelineAfterInitialize):
algo.run(source=self.closes)
algo.run(self.data_portal)
def barf(context, data):
raise AssertionError("Shouldn't make it past before_trading_start")
@@ -206,7 +243,7 @@ class ClosesOnly(TestCase):
)
with self.assertRaises(AttachPipelineAfterInitialize):
algo.run(source=self.closes)
algo.run(self.data_portal)
def test_pipeline_output_after_initialize(self):
"""
@@ -235,7 +272,7 @@ class ClosesOnly(TestCase):
)
with self.assertRaises(PipelineOutputDuringInitialize):
algo.run(source=self.closes)
algo.run(self.data_portal)
def test_get_output_nonexistent_pipeline(self):
"""
@@ -263,7 +300,7 @@ class ClosesOnly(TestCase):
)
with self.assertRaises(NoSuchPipeline):
algo.run(source=self.closes)
algo.run(self.data_portal)
@parameterized.expand([('default', None),
('day', 1),
@@ -313,8 +350,7 @@ class ClosesOnly(TestCase):
)
# Run for a week in the middle of our data.
algo.run(source=self.closes.loc[self.first_asset_start:
self.last_asset_end])
algo.run(self.data_portal)
class MockDailyBarSpotReader(object):
@@ -395,16 +431,7 @@ class PipelineAlgorithmTestCase(TestCase):
'sid': cls.AAPL,
}
])
mergers = DataFrame(
{
# Hackery to make the dtypes correct on an empty frame.
'effective_date': array([], dtype=int),
'ratio': array([], dtype=float),
'sid': array([], dtype=int),
},
index=DatetimeIndex([]),
columns=['effective_date', 'ratio', 'sid'],
)
mergers = create_empty_splits_mergers_frame()
dividends = DataFrame({
'sid': array([], dtype=uint32),
'amount': array([], dtype=float64),
@@ -416,9 +443,6 @@ class PipelineAlgorithmTestCase(TestCase):
writer.write(splits, mergers, dividends)
return SQLiteAdjustmentReader(dbpath)
def make_source(self):
return Panel(self.raw_data).tz_localize('UTC', axis=1)
def compute_expected_vwaps(self, window_lengths):
AAPL, MSFT, BRK_A = self.AAPL, self.MSFT, self.BRK_A
@@ -532,7 +556,6 @@ class PipelineAlgorithmTestCase(TestCase):
BRK_A: True,
}
for asset in assets:
should_pass_filter = expect_over_300[asset]
if set_screen and not should_pass_filter:
self.assertNotIn(asset, results.index)
@@ -562,7 +585,7 @@ class PipelineAlgorithmTestCase(TestCase):
)
algo.run(
source=self.make_source(),
FakeDataPortal(),
# Yes, I really do want to use the start and end dates I passed to
# TradingAlgorithm.
overwrite_sim_params=False,
@@ -602,7 +625,7 @@ class PipelineAlgorithmTestCase(TestCase):
)
algo.run(
source=self.make_source(),
FakeDataPortal(),
overwrite_sim_params=False,
)
View File
@@ -0,0 +1,715 @@
MULTI_SIGNAL_CSV_DATA = """
symbol,date,signal
ibm,1/1/06,1
ibm,2/1/06,0
ibm,3/1/06,0
ibm,4/1/06,0
ibm,5/1/06,1
ibm,6/1/06,1
ibm,7/1/06,1
ibm,8/1/06,1
ibm,9/1/06,0
ibm,10/1/06,1
ibm,11/1/06,1
ibm,12/1/06,5
ibm,1/1/07,1
ibm,2/1/07,0
ibm,3/1/07,1
ibm,4/1/07,0
ibm,5/1/07,1
dell,1/1/06,1
dell,2/1/06,0
dell,3/1/06,0
dell,4/1/06,0
dell,5/1/06,1
dell,6/1/06,1
dell,7/1/06,1
dell,8/1/06,1
dell,9/1/06,0
dell,10/1/06,1
dell,11/1/06,1
dell,12/1/06,5
dell,1/1/07,1
dell,2/1/07,0
dell,3/1/07,1
dell,4/1/07,0
dell,5/1/07,1
""".strip()
AAPL_CSV_DATA = """
symbol,date,signal
aapl,1/1/06,1
aapl,2/1/06,0
aapl,3/1/06,0
aapl,4/1/06,0
aapl,5/1/06,1
aapl,6/1/06,1
aapl,7/1/06,1
aapl,8/1/06,1
aapl,9/1/06,0
aapl,10/1/06,1
aapl,11/1/06,1
aapl,12/1/06,5
aapl,1/1/07,1
aapl,2/1/07,0
aapl,3/1/07,1
aapl,4/1/07,0
aapl,5/1/07,1
""".strip()
# times are expected in UTC
AAPL_MINUTE_CSV_DATA = """
symbol,date,signal
aapl,1/4/06 5:31AM, 1
aapl,1/4/06 11:30AM, 2
aapl,1/5/06 5:31AM, 1
aapl,1/5/06 11:30AM, 3
aapl,1/9/06 5:31AM, 1
aapl,1/9/06 11:30AM, 4
""".strip()
IBM_CSV_DATA = """
symbol,date,signal
ibm,1/1/06,1
ibm,2/1/06,0
ibm,3/1/06,0
ibm,4/1/06,0
ibm,5/1/06,1
ibm,6/1/06,1
ibm,7/1/06,1
ibm,8/1/06,1
ibm,9/1/06,0
ibm,10/1/06,1
ibm,11/1/06,1
ibm,12/1/06,5
ibm,1/1/07,1
ibm,2/1/07,0
ibm,3/1/07,1
ibm,4/1/07,0
ibm,5/1/07,1
""".strip()
ANNUAL_AAPL_CSV_DATA = """
symbol,date,signal
aapl,1/2/06,1
aapl,1/15/06,1
aapl,3/1/06,1
aapl,3/15/06,1
aapl,6/1/06,1
aapl,6/15/06,1
aapl,9/1/06,1
aapl,9/15/06,1
aapl,12/1/06,1
aapl,12/15/06,1
""".strip()
AAPL_IBM_CSV_DATA = """
symbol,date,signal
aapl,1/1/06,1
aapl,2/1/06,0
aapl,3/1/06,0
aapl,4/1/06,0
aapl,5/1/06,1
aapl,6/1/06,1
aapl,7/1/06,1
aapl,8/1/06,1
aapl,9/1/06,0
aapl,10/1/06,1
aapl,11/1/06,1
aapl,12/1/06,5
aapl,1/1/07,1
aapl,2/1/07,0
aapl,3/1/07,1
aapl,4/1/07,0
aapl,5/1/07,1
ibm,1/1/06,1
ibm,2/1/06,0
ibm,3/1/06,0
ibm,4/1/06,0
ibm,5/1/06,1
ibm,6/1/06,1
ibm,7/1/06,1
ibm,8/1/06,1
ibm,9/1/06,0
ibm,10/1/06,1
ibm,11/1/06,1
ibm,12/1/06,5
ibm,1/1/07,1
ibm,2/1/07,0
ibm,3/1/07,1
ibm,4/1/07,0
ibm,5/1/07,1
""".strip()
CPIAUCSL_DATA = """
Date,Value
2007-12-01,211.445
2007-11-01,210.834
2007-10-01,209.19
2007-09-01,208.547
2007-08-01,207.667
2007-07-01,207.603
2007-06-01,207.234
2007-05-01,206.755
2007-04-01,205.904
2007-03-01,205.288
2007-02-01,204.226
2007-01-01,203.437
2006-12-01,203.1
2006-11-01,202.0
2006-10-01,201.9
2006-09-01,202.8
2006-08-01,203.8
2006-07-01,202.9
2006-06-01,201.8
2006-05-01,201.3
2006-04-01,200.7
2006-03-01,199.7
2006-02-01,199.4
2006-01-01,199.3
""".strip()
PALLADIUM_DATA = """
Date,Hong Kong 8:30,Hong Kong 14:00,London 08:00,New York 9:30,New York 15:00
2007-12-31,367.0,367.0,368.0,368.0,368.0
2007-12-28,366.0,366.0,365.0,368.0,368.0
2007-12-27,367.0,367.0,366.0,363.0,367.0
2007-12-26,,,,365.0,365.0
2007-12-24,351.0,357.0,357.0,357.0,365.0
2007-12-21,356.0,356.0,354.0,357.0,357.0
2007-12-20,357.0,356.0,354.0,356.0,356.0
2007-12-19,359.0,359.0,359.0,356.0,358.0
2007-12-18,357.0,356.0,356.0,359.0,359.0
2007-12-17,353.0,353.0,351.0,354.0,360.0
2007-12-14,347.0,347.0,347.0,347.0,355.0
2007-12-13,349.0,349.0,349.0,349.0,347.0
2007-12-12,348.0,349.0,349.0,351.0,349.0
2007-12-11,346.0,346.0,346.0,348.0,350.0
2007-12-10,346.0,346.0,346.0,348.0,348.0
2007-12-07,348.0,348.0,348.0,346.0,346.0
2007-12-06,350.0,350.0,352.0,348.0,348.0
2007-12-05,350.0,350.0,352.0,351.0,351.0
2007-12-04,349.0,349.0,352.0,351.0,351.0
2007-12-03,350.0,350.0,354.0,350.0,350.0
2007-11-30,345.0,345.0,347.0,353.0,350.0
2007-11-29,348.0,348.0,348.0,347.0,345.0
2007-11-28,350.0,347.0,347.0,348.0,348.0
2007-11-27,356.0,356.0,358.0,354.0,350.0
2007-11-26,357.0,357.0,360.0,360.0,360.0
2007-11-23,353.0,354.0,357.0,355.0,
2007-11-22,359.0,359.0,359.0,358.0,
2007-11-21,364.0,364.0,366.0,365.0,359.0
2007-11-20,360.0,359.0,362.0,364.0,364.0
2007-11-19,366.0,365.0,365.0,365.0,361.0
2007-11-16,368.0,366.0,368.0,369.0,366.0
2007-11-15,373.0,372.0,372.0,368.0,368.0
2007-11-14,372.0,372.0,372.0,373.0,373.0
2007-11-13,365.0,365.0,368.0,372.0,372.0
2007-11-12,373.0,370.0,370.0,366.0,366.0
2007-11-09,376.0,375.0,373.0,373.0,373.0
2007-11-08,376.0,376.0,373.0,376.0,376.0
2007-11-07,379.0,379.0,383.0,378.0,378.0
2007-11-06,374.0,374.0,374.0,379.0,379.0
2007-11-05,376.0,376.0,376.0,376.0,374.0
2007-11-02,372.0,371.0,371.0,371.0,376.0
2007-11-01,374.0,374.0,374.0,374.0,374.0
2007-10-31,369.0,369.0,371.0,372.0,372.0
2007-10-30,373.0,372.0,373.0,371.0,371.0
2007-10-29,373.0,375.0,375.0,376.0,373.0
2007-10-26,364.0,368.0,370.0,373.0,373.0
2007-10-25,360.0,360.0,360.0,364.0,368.0
2007-10-24,364.0,364.0,364.0,360.0,360.0
2007-10-23,361.0,361.0,364.0,366.0,366.0
2007-10-22,367.0,362.0,361.0,361.0,361.0
2007-10-19,,,374.0,372.0,370.0
2007-10-18,373.0,373.0,374.0,373.0,373.0
2007-10-17,372.0,372.0,370.0,373.0,373.0
2007-10-16,375.0,375.0,375.0,372.0,372.0
2007-10-15,379.0,379.0,380.0,382.0,375.0
2007-10-12,378.0,378.0,378.0,379.0,379.0
2007-10-11,375.0,375.0,376.0,381.0,384.0
2007-10-10,365.0,365.0,367.0,377.0,377.0
2007-10-09,365.0,363.0,362.0,362.0,365.0
2007-10-08,369.0,369.0,367.0,366.0,365.0
2007-10-05,369.0,369.0,371.0,369.0,369.0
2007-10-04,359.0,359.0,360.0,362.0,369.0
2007-10-03,352.0,350.0,352.0,352.0,359.0
2007-10-02,358.0,357.0,356.0,352.0,352.0
2007-10-01,,,349.0,355.0,360.0
2007-09-28,345.0,345.0,345.0,346.0,348.0
2007-09-27,342.0,342.0,342.0,343.0,345.0
2007-09-26,,,341.0,340.0,343.0
2007-09-25,342.0,341.0,343.0,341.0,341.0
2007-09-24,340.0,340.0,342.0,342.0,342.0
2007-09-21,341.0,341.0,342.0,342.0,340.0
2007-09-20,335.0,335.0,335.0,338.0,341.0
2007-09-19,333.0,333.0,335.0,335.0,335.0
2007-09-18,333.0,333.0,334.0,333.0,333.0
2007-09-17,331.0,331.0,331.0,333.0,333.0
2007-09-14,334.0,333.0,333.0,333.0,331.0
2007-09-13,336.0,336.0,336.0,334.0,334.0
2007-09-12,336.0,336.0,336.0,336.0,336.0
2007-09-11,333.0,335.0,335.0,336.0,336.0
2007-09-10,337.0,337.0,337.0,336.0,333.0
2007-09-07,336.0,336.0,338.0,337.0,337.0
2007-09-06,333.0,333.0,336.0,336.0,336.0
2007-09-05,334.0,334.0,334.0,336.0,333.0
2007-09-04,333.0,333.0,334.0,334.0,334.0
2007-09-03,334.0,334.0,335.0,334.0,
2007-08-31,331.0,333.0,334.0,333.0,333.0
2007-08-30,331.0,331.0,332.0,331.0,331.0
2007-08-29,329.0,327.0,329.0,329.0,331.0
2007-08-28,331.0,331.0,334.0,331.0,331.0
2007-08-27,330.0,331.0,331.0,331.0,331.0
2007-08-24,326.0,326.0,327.0,325.0,330.0
2007-08-23,322.0,322.0,326.0,330.0,326.0
2007-08-22,321.0,319.0,319.0,322.0,322.0
2007-08-21,331.0,331.0,329.0,328.0,325.0
2007-08-20,331.0,331.0,331.0,331.0,331.0
2007-08-17,334.0,334.0,334.0,335.0,331.0
2007-08-16,348.0,346.0,345.0,338.0,329.0
2007-08-15,354.0,354.0,352.0,348.0,348.0
2007-08-14,357.0,357.0,356.0,351.0,354.0
2007-08-13,355.0,355.0,354.0,356.0,358.0
2007-08-10,361.0,357.0,357.0,350.0,358.0
2007-08-09,364.0,364.0,364.0,361.0,361.0
2007-08-08,362.0,362.0,362.0,364.0,364.0
2007-08-07,365.0,365.0,363.0,360.0,363.0
2007-08-06,365.0,365.0,365.0,365.0,365.0
2007-08-03,366.0,366.0,365.0,365.0,367.0
2007-08-02,365.0,365.0,365.0,368.0,366.0
2007-08-01,367.0,366.0,366.0,365.0,367.0
2007-07-31,367.0,367.0,365.0,367.0,367.0
2007-07-30,363.0,362.0,361.0,365.0,367.0
2007-07-27,365.0,365.0,364.0,363.0,363.0
2007-07-26,366.0,366.0,365.0,365.0,365.0
2007-07-25,368.0,368.0,368.0,366.0,366.0
2007-07-24,372.0,372.0,372.0,370.0,368.0
2007-07-23,372.0,372.0,372.0,372.0,372.0
2007-07-20,372.0,372.0,372.0,372.0,372.0
2007-07-19,370.0,369.0,369.0,370.0,372.0
2007-07-18,368.0,368.0,367.0,367.0,370.0
2007-07-17,368.0,368.0,368.0,368.0,365.0
2007-07-16,369.0,369.0,368.0,368.0,368.0
2007-07-13,370.0,370.0,370.0,369.0,369.0
2007-07-12,369.0,369.0,368.0,370.0,370.0
2007-07-11,369.0,369.0,369.0,369.0,369.0
2007-07-10,369.0,369.0,369.0,369.0,367.0
2007-07-09,367.0,367.0,366.0,370.0,369.0
2007-07-06,366.0,366.0,365.0,365.0,367.0
2007-07-05,366.0,366.0,366.0,367.0,366.0
2007-07-04,366.0,368.0,368.0,366.0,
2007-07-03,368.0,370.0,370.0,368.0,366.0
2007-07-02,,,369.0,368.0,368.0
2007-06-29,368.0,368.0,368.0,368.0,368.0
2007-06-28,367.0,367.0,368.0,368.0,368.0
2007-06-27,366.0,366.0,366.0,368.0,364.0
2007-06-26,372.0,372.0,370.0,368.0,366.0
2007-06-25,377.0,377.0,376.0,373.0,372.0
2007-06-22,376.0,376.0,375.0,377.0,377.0
2007-06-21,375.0,375.0,374.0,376.0,376.0
2007-06-20,373.0,373.0,371.0,375.0,377.0
2007-06-19,,,372.0,371.0,371.0
2007-06-18,370.0,371.0,373.0,373.0,373.0
2007-06-15,370.0,369.0,369.0,369.0,372.0
2007-06-14,367.0,367.0,369.0,369.0,369.0
2007-06-13,369.0,369.0,367.0,365.0,369.0
2007-06-12,368.0,368.0,371.0,369.0,369.0
2007-06-11,367.0,367.0,367.0,368.0,368.0
2007-06-08,369.0,368.0,368.0,371.0,369.0
2007-06-07,370.0,370.0,370.0,369.0,371.0
2007-06-06,370.0,370.0,370.0,368.0,368.0
2007-06-05,372.0,372.0,372.0,372.0,368.0
2007-06-04,376.0,374.0,374.0,372.0,372.0
2007-06-01,370.0,370.0,370.0,373.0,373.0
2007-05-31,368.0,368.0,368.0,370.0,370.0
2007-05-30,370.0,369.0,369.0,367.0,367.0
2007-05-29,370.0,369.0,369.0,371.0,368.0
2007-05-28,368.0,368.0,368.0,,
2007-05-25,368.0,368.0,368.0,367.0,367.0
2007-05-24,,,376.0,376.0,368.0
2007-05-23,375.0,375.0,378.0,376.0,376.0
2007-05-22,374.0,374.0,374.0,378.0,378.0
2007-05-21,364.0,364.0,365.0,368.0,374.0
2007-05-18,362.0,361.0,361.0,364.0,364.0
2007-05-17,359.0,359.0,359.0,359.0,362.0
2007-05-16,363.0,363.0,362.0,362.0,359.0
2007-05-15,362.0,362.0,362.0,358.0,362.0
2007-05-14,368.0,368.0,368.0,364.0,362.0
2007-05-11,361.0,363.0,362.0,364.0,367.0
2007-05-10,370.0,370.0,366.0,363.0,363.0
2007-05-09,376.0,376.0,373.0,372.0,370.0
2007-05-08,378.0,378.0,378.0,376.0,376.0
2007-05-07,378.0,378.0,381.0,381.0,381.0
2007-05-04,376.0,374.0,374.0,376.0,376.0
2007-05-03,373.0,373.0,373.0,376.0,376.0
2007-05-02,373.0,373.0,373.0,372.0,375.0
2007-05-01,,,371.0,369.0,374.0
2007-04-30,373.0,373.0,373.0,373.0,373.0
2007-04-27,373.0,372.0,372.0,374.0,374.0
2007-04-26,380.0,380.0,380.0,376.0,373.0
2007-04-25,377.0,377.0,377.0,380.0,380.0
2007-04-24,384.0,384.0,384.0,383.0,379.0
2007-04-23,386.0,386.0,386.0,382.0,386.0
2007-04-20,378.0,378.0,378.0,385.0,387.0
2007-04-19,383.0,382.0,377.0,377.0,377.0
2007-04-18,377.0,377.0,378.0,377.0,382.0
2007-04-17,376.0,376.0,376.0,376.0,379.0
2007-04-16,380.0,381.0,381.0,376.0,376.0
2007-04-13,371.0,371.0,371.0,374.0,380.0
2007-04-12,367.0,367.0,369.0,371.0,371.0
2007-04-11,360.0,360.0,363.0,366.0,369.0
2007-04-10,358.0,358.0,360.0,360.0,360.0
2007-04-09,,,,355.0,355.0
2007-04-05,,,355.0,353.0,355.0
2007-04-04,354.0,354.0,353.0,355.0,355.0
2007-04-03,353.0,353.0,354.0,354.0,354.0
2007-04-02,355.0,355.0,355.0,353.0,355.0
2007-03-30,354.0,354.0,356.0,355.0,355.0
2007-03-29,355.0,356.0,356.0,355.0,355.0
2007-03-28,355.0,356.0,356.0,356.0,356.0
2007-03-27,355.0,355.0,357.0,355.0,355.0
2007-03-26,354.0,354.0,355.0,355.0,357.0
2007-03-23,355.0,355.0,355.0,355.0,358.0
2007-03-22,354.0,354.0,353.0,356.0,356.0
2007-03-21,352.0,352.0,352.0,352.0,350.0
2007-03-20,352.0,352.0,352.0,352.0,352.0
2007-03-19,352.0,352.0,352.0,352.0,352.0
2007-03-16,352.0,352.0,352.0,352.0,352.0
2007-03-15,349.0,349.0,349.0,352.0,352.0
2007-03-14,351.0,349.0,348.0,349.0,349.0
2007-03-13,352.0,352.0,352.0,351.0,351.0
2007-03-12,353.0,353.0,353.0,352.0,352.0
2007-03-09,353.0,351.0,353.0,353.0,353.0
2007-03-08,349.0,349.0,349.0,353.0,355.0
2007-03-07,349.0,348.0,348.0,348.0,348.0
2007-03-06,342.0,343.0,345.0,345.0,350.0
2007-03-05,344.0,342.0,340.0,340.0,345.0
2007-03-02,351.0,351.0,351.0,349.0,349.0
2007-03-01,351.0,354.0,352.0,355.0,351.0
2007-02-28,347.0,348.0,348.0,350.0,350.0
2007-02-27,357.0,356.0,356.0,351.0,356.0
2007-02-26,358.0,359.0,359.0,357.0,357.0
2007-02-23,347.0,348.0,348.0,355.0,360.0
2007-02-22,346.0,346.0,346.0,350.0,350.0
2007-02-21,339.0,339.0,340.0,339.0,346.0
2007-02-20,,,342.0,337.0,337.0
2007-02-19,,,343.0,342.0,342.0
2007-02-16,344.0,343.0,343.0,340.0,343.0
2007-02-15,345.0,343.0,343.0,344.0,344.0
2007-02-14,343.0,343.0,343.0,345.0,347.0
2007-02-13,340.0,339.0,339.0,339.0,343.0
2007-02-12,338.0,338.0,340.0,338.0,340.0
2007-02-09,343.0,343.0,343.0,338.0,342.0
2007-02-08,344.0,344.0,344.0,339.0,342.0
2007-02-07,344.0,346.0,345.0,346.0,346.0
2007-02-06,340.0,340.0,342.0,344.0,344.0
2007-02-05,337.0,336.0,336.0,340.0,343.0
2007-02-02,344.0,344.0,343.0,341.0,341.0
2007-02-01,341.0,341.0,341.0,344.0,344.0
2007-01-31,341.0,340.0,340.0,334.0,341.0
2007-01-30,343.0,341.0,343.0,336.0,342.0
2007-01-29,349.0,349.0,350.0,342.0,346.0
2007-01-26,353.0,352.0,351.0,351.0,351.0
2007-01-25,350.0,350.0,350.0,353.0,353.0
2007-01-24,351.0,350.0,350.0,348.0,348.0
2007-01-23,345.0,345.0,347.0,350.0,350.0
2007-01-22,343.0,343.0,343.0,344.0,347.0
2007-01-19,340.0,340.0,341.0,341.0,344.0
2007-01-18,340.0,342.0,342.0,342.0,342.0
2007-01-17,335.0,335.0,333.0,334.0,343.0
2007-01-16,332.0,332.0,332.0,334.0,337.0
2007-01-15,334.0,336.0,335.0,332.0,332.0
2007-01-12,331.0,331.0,331.0,331.0,335.0
2007-01-11,331.0,331.0,331.0,333.0,333.0
2007-01-10,333.0,333.0,334.0,331.0,331.0
2007-01-09,333.0,333.0,336.0,329.0,329.0
2007-01-08,335.0,335.0,335.0,333.0,333.0
2007-01-05,340.0,340.0,340.0,342.0,336.0
2007-01-04,337.0,337.0,337.0,340.0,343.0
2007-01-03,338.0,336.0,336.0,342.0,342.0
2007-01-02,337.0,337.0,334.0,336.0,336.0
2006-12-29,327.0,327.0,327.0,327.0,337.0
2006-12-28,326.0,326.0,328.0,327.0,326.0
2006-12-27,326.0,328.0,328.0,328.0,326.0
2006-12-26,,,,327.0,327.0
2006-12-22,325.0,325.0,327.0,327.0,327.0
2006-12-21,326.0,326.0,327.0,325.0,325.0
2006-12-20,328.0,328.0,328.0,326.0,326.0
2006-12-19,324.0,324.0,325.0,322.0,326.0
2006-12-18,325.0,325.0,326.0,324.0,324.0
2006-12-15,330.0,329.0,329.0,327.0,325.0
2006-12-14,328.0,328.0,328.0,330.0,330.0
2006-12-13,329.0,329.0,330.0,328.0,328.0
2006-12-12,332.0,332.0,332.0,329.0,329.0
2006-12-11,329.0,329.0,329.0,329.0,329.0
2006-12-08,330.0,329.0,329.0,332.0,336.0
2006-12-07,328.0,326.0,326.0,328.0,328.0
2006-12-06,333.0,331.0,331.0,328.0,328.0
2006-12-05,330.0,330.0,329.0,333.0,333.0
2006-12-04,330.0,330.0,330.0,330.0,330.0
2006-12-01,330.0,330.0,330.0,328.0,328.0
2006-11-30,324.0,323.0,323.0,330.0,330.0
2006-11-29,326.0,326.0,328.0,321.0,321.0
2006-11-28,329.0,328.0,328.0,326.0,326.0
2006-11-27,330.0,329.0,329.0,329.0,329.0
2006-11-24,326.0,326.0,326.0,330.0,
2006-11-23,328.0,328.0,327.0,326.0,
2006-11-22,330.0,330.0,328.0,328.0,328.0
2006-11-21,323.0,327.0,327.0,330.0,330.0
2006-11-20,320.0,320.0,322.0,323.0,323.0
2006-11-17,321.0,321.0,321.0,318.0,320.0
2006-11-16,320.0,320.0,322.0,323.0,323.0
2006-11-15,321.0,321.0,321.0,317.0,320.0
2006-11-14,326.0,325.0,324.0,324.0,321.0
2006-11-13,333.0,333.0,333.0,326.0,326.0
2006-11-10,338.0,338.0,338.0,335.0,333.0
2006-11-09,329.0,329.0,328.0,331.0,338.0
2006-11-08,333.0,333.0,334.0,327.0,327.0
2006-11-07,334.0,332.0,332.0,335.0,335.0
2006-11-06,340.0,340.0,340.0,330.0,335.0
2006-11-03,326.0,326.0,325.0,330.0,333.0
2006-11-02,327.0,326.0,326.0,324.0,326.0
2006-11-01,323.0,323.0,324.0,326.0,326.0
2006-10-31,325.0,325.0,325.0,318.0,323.0
2006-10-30,,,325.0,325.0,325.0
2006-10-27,324.0,324.0,324.0,321.0,323.0
2006-10-26,325.0,324.0,324.0,323.0,326.0
2006-10-25,322.0,322.0,322.0,319.0,319.0
2006-10-24,319.0,318.0,318.0,320.0,323.0
2006-10-23,326.0,326.0,326.0,319.0,319.0
2006-10-20,337.0,337.0,334.0,329.0,329.0
2006-10-19,331.0,331.0,331.0,330.0,337.0
2006-10-18,320.0,320.0,320.0,326.0,334.0
2006-10-17,324.0,326.0,326.0,321.0,321.0
2006-10-16,318.0,321.0,320.0,324.0,324.0
2006-10-13,309.0,309.0,309.0,316.0,316.0
2006-10-12,305.0,308.0,308.0,310.0,310.0
2006-10-11,299.0,299.0,301.0,305.0,309.0
2006-10-10,304.0,308.0,308.0,299.0,299.0
2006-10-09,302.0,302.0,304.0,304.0,304.0
2006-10-06,301.0,301.0,301.0,297.0,297.0
2006-10-05,297.0,299.0,299.0,301.0,301.0
2006-10-04,300.0,298.0,298.0,302.0,297.0
2006-10-03,315.0,315.0,314.0,305.0,305.0
2006-10-02,,,322.0,315.0,315.0
2006-09-29,321.0,323.0,323.0,318.0,318.0
2006-09-28,320.0,323.0,323.0,323.0,323.0
2006-09-27,318.0,318.0,320.0,317.0,320.0
2006-09-26,318.0,318.0,319.0,318.0,318.0
2006-09-25,319.0,318.0,319.0,316.0,316.0
2006-09-22,310.0,310.0,313.0,325.0,322.0
2006-09-21,308.0,308.0,308.0,309.0,309.0
2006-09-20,307.0,307.0,308.0,311.0,311.0
2006-09-19,317.0,316.0,316.0,319.0,310.0
2006-09-18,313.0,313.0,313.0,306.0,312.0
2006-09-15,311.0,311.0,314.0,315.0,315.0
2006-09-14,317.0,317.0,317.0,332.0,326.0
2006-09-13,310.0,310.0,310.0,321.0,318.0
2006-09-12,311.0,323.0,322.0,320.0,314.0
2006-09-11,330.0,322.0,321.0,317.0,317.0
2006-09-08,347.0,345.0,345.0,323.0,330.0
2006-09-07,350.0,350.0,353.0,348.0,348.0
2006-09-06,351.0,351.0,351.0,351.0,356.0
2006-09-05,347.0,347.0,347.0,351.0,351.0
2006-09-04,346.0,346.0,347.0,346.0,
2006-09-01,348.0,345.0,346.0,346.0,346.0
2006-08-31,340.0,340.0,342.0,343.0,343.0
2006-08-30,339.0,341.0,340.0,339.0,340.0
2006-08-29,341.0,343.0,342.0,338.0,340.0
2006-08-28,345.0,345.0,345.0,345.0,345.0
2006-08-25,345.0,345.0,345.0,346.0,346.0
2006-08-24,345.0,345.0,347.0,348.0,348.0
2006-08-23,340.0,340.0,340.0,345.0,345.0
2006-08-22,347.0,347.0,346.0,340.0,340.0
2006-08-21,335.0,338.0,338.0,341.0,347.0
2006-08-18,332.0,334.0,333.0,335.0,335.0
2006-08-17,333.0,337.0,338.0,341.0,337.0
2006-08-16,326.0,325.0,324.0,334.0,337.0
2006-08-15,317.0,320.0,319.0,322.0,327.0
2006-08-14,320.0,320.0,320.0,314.0,319.0
2006-08-11,320.0,320.0,322.0,324.0,324.0
2006-08-10,326.0,326.0,327.0,326.0,324.0
2006-08-09,320.0,320.0,320.0,324.0,327.0
2006-08-08,327.0,325.0,324.0,320.0,320.0
2006-08-07,327.0,327.0,328.0,324.0,324.0
2006-08-04,324.0,324.0,324.0,327.0,327.0
2006-08-03,330.0,326.0,327.0,324.0,324.0
2006-08-02,319.0,319.0,322.0,325.0,330.0
2006-08-01,316.0,316.0,316.0,319.0,319.0
2006-07-31,315.0,315.0,317.0,313.0,316.0
2006-07-28,320.0,318.0,318.0,315.0,315.0
2006-07-27,315.0,315.0,318.0,320.0,320.0
2006-07-26,315.0,315.0,315.0,315.0,315.0
2006-07-25,314.0,314.0,315.0,314.0,317.0
2006-07-24,309.0,309.0,309.0,309.0,314.0
2006-07-21,308.0,311.0,310.0,310.0,310.0
2006-07-20,317.0,315.0,316.0,315.0,315.0
2006-07-19,308.0,308.0,311.0,311.0,318.0
2006-07-18,320.0,320.0,319.0,318.0,316.0
2006-07-17,333.0,333.0,333.0,321.0,321.0
2006-07-14,331.0,331.0,331.0,331.0,331.0
2006-07-13,330.0,328.0,328.0,331.0,331.0
2006-07-12,330.0,330.0,330.0,330.0,330.0
2006-07-11,318.0,320.0,323.0,326.0,330.0
2006-07-10,325.0,323.0,323.0,320.0,320.0
2006-07-07,329.0,329.0,329.0,327.0,327.0
2006-07-06,328.0,324.0,326.0,323.0,329.0
2006-07-05,328.0,328.0,330.0,328.0,328.0
2006-07-04,325.0,328.0,327.0,326.0,
2006-07-03,322.0,326.0,326.0,329.0,
2006-06-30,320.0,320.0,320.0,316.0,322.0
2006-06-29,309.0,309.0,307.0,314.0,314.0
2006-06-28,310.0,310.0,313.0,314.0,314.0
2006-06-27,318.0,320.0,320.0,318.0,318.0
2006-06-26,308.0,305.0,309.0,320.0,320.0
2006-06-23,310.0,304.0,305.0,306.0,310.0
2006-06-22,315.0,318.0,320.0,320.0,316.0
2006-06-21,303.0,306.0,308.0,311.0,315.0
2006-06-20,292.0,297.0,296.0,301.0,305.0
2006-06-19,307.0,304.0,303.0,302.0,297.0
2006-06-16,300.0,306.0,305.0,310.0,307.0
2006-06-15,290.0,290.0,292.0,300.0,300.0
2006-06-14,277.0,274.0,275.0,288.0,293.0
2006-06-13,313.0,308.0,307.0,286.0,277.0
2006-06-12,320.0,320.0,316.0,321.0,316.0
2006-06-09,317.0,313.0,313.0,327.0,327.0
2006-06-08,342.0,336.0,333.0,331.0,320.0
2006-06-07,348.0,346.0,346.0,335.0,343.0
2006-06-06,359.0,359.0,359.0,350.0,350.0
2006-06-05,356.0,356.0,358.0,363.0,363.0
2006-06-02,340.0,343.0,342.0,351.0,356.0
2006-06-01,347.0,345.0,345.0,340.0,340.0
2006-05-31,,,358.0,358.0,345.0
2006-05-30,352.0,350.0,355.0,359.0,358.0
2006-05-29,357.0,352.0,350.0,,
2006-05-26,355.0,353.0,354.0,354.0,354.0
2006-05-25,348.0,348.0,350.0,350.0,350.0
2006-05-24,358.0,362.0,365.0,352.0,352.0
2006-05-23,343.0,342.0,343.0,355.0,362.0
2006-05-22,350.0,345.0,345.0,340.0,340.0
2006-05-19,366.0,369.0,373.0,347.0,352.0
2006-05-18,372.0,375.0,376.0,380.0,375.0
2006-05-17,379.0,379.0,382.0,390.0,380.0
2006-05-16,368.0,370.0,366.0,379.0,379.0
2006-05-15,395.0,395.0,397.0,370.0,375.0
2006-05-12,400.0,396.0,398.0,407.0,399.0
2006-05-11,390.0,397.0,395.0,400.0,400.0
2006-05-10,394.0,397.0,398.0,390.0,390.0
2006-05-09,375.0,375.0,378.0,384.0,394.0
2006-05-08,380.0,380.0,381.0,377.0,375.0
2006-05-05,,,383.0,382.0,382.0
2006-05-04,379.0,379.0,378.0,379.0,379.0
2006-05-03,386.0,386.0,388.0,384.0,379.0
2006-05-02,377.0,377.0,380.0,380.0,384.0
2006-05-01,,,,380.0,380.0
2006-04-28,360.0,363.0,363.0,364.0,377.0
2006-04-27,368.0,365.0,367.0,364.0,364.0
2006-04-26,366.0,366.0,367.0,361.0,368.0
2006-04-25,356.0,355.0,355.0,362.0,362.0
2006-04-24,359.0,359.0,363.0,360.0,360.0
2006-04-21,344.0,348.0,347.0,352.0,359.0
2006-04-20,368.0,372.0,374.0,365.0,349.0
2006-04-19,366.0,364.0,364.0,371.0,374.0
2006-04-18,364.0,360.0,360.0,361.0,361.0
2006-04-17,,,,358.0,358.0
2006-04-13,347.0,342.0,341.0,346.0,349.0
2006-04-12,340.0,344.0,343.0,347.0,347.0
2006-04-11,359.0,359.0,360.0,359.0,345.0
2006-04-10,351.0,354.0,355.0,359.0,359.0
2006-04-07,352.0,352.0,354.0,351.0,351.0
2006-04-06,341.0,341.0,344.0,352.0,352.0
2006-04-05,,,336.0,341.0,341.0
2006-04-04,342.0,339.0,337.0,338.0,342.0
2006-04-03,332.0,337.0,338.0,341.0,345.0
2006-03-31,349.0,349.0,348.0,332.0,332.0
2006-03-30,338.0,341.0,343.0,349.0,349.0
2006-03-29,340.0,337.0,337.0,333.0,338.0
2006-03-28,340.0,344.0,345.0,340.0,340.0
2006-03-27,333.0,333.0,334.0,341.0,341.0
2006-03-24,321.0,321.0,320.0,326.0,333.0
2006-03-23,323.0,321.0,321.0,317.0,322.0
2006-03-22,317.0,318.0,322.0,320.0,324.0
2006-03-21,320.0,318.0,316.0,315.0,318.0
2006-03-20,318.0,318.0,319.0,317.0,317.0
2006-03-17,316.0,316.0,315.0,318.0,318.0
2006-03-16,315.0,314.0,314.0,316.0,316.0
2006-03-15,305.0,305.0,307.0,318.0,318.0
2006-03-14,300.0,300.0,300.0,302.0,306.0
2006-03-13,288.0,291.0,290.0,292.0,300.0
2006-03-10,289.0,289.0,289.0,288.0,288.0
2006-03-09,280.0,282.0,282.0,285.0,285.0
2006-03-08,291.0,289.0,289.0,285.0,282.0
2006-03-07,296.0,296.0,296.0,299.0,292.0
2006-03-06,307.0,304.0,302.0,302.0,297.0
2006-03-03,300.0,300.0,300.0,305.0,305.0
2006-03-02,297.0,297.0,296.0,294.0,300.0
2006-03-01,291.0,291.0,289.0,290.0,297.0
2006-02-28,284.0,284.0,285.0,288.0,291.0
2006-02-27,286.0,290.0,290.0,285.0,284.0
2006-02-24,283.0,285.0,286.0,286.0,286.0
2006-02-23,289.0,286.0,287.0,288.0,286.0
2006-02-22,293.0,293.0,293.0,292.0,289.0
2006-02-21,292.0,290.0,291.0,291.0,293.0
2006-02-20,292.0,292.0,292.0,292.0,292.0
2006-02-17,279.0,279.0,280.0,285.0,290.0
2006-02-16,276.0,276.0,278.0,275.0,279.0
2006-02-15,282.0,285.0,287.0,285.0,279.0
2006-02-14,273.0,270.0,274.0,278.0,282.0
2006-02-13,283.0,278.0,277.0,282.0,276.0
2006-02-10,304.0,298.0,297.0,296.0,285.0
2006-02-09,293.0,297.0,295.0,300.0,300.0
2006-02-08,288.0,288.0,287.0,290.0,290.0
2006-02-07,309.0,309.0,309.0,297.0,290.0
2006-02-06,317.0,317.0,320.0,305.0,312.0
2006-02-03,309.0,310.0,310.0,317.0,317.0
2006-02-02,294.0,296.0,295.0,300.0,305.0
2006-02-01,294.0,293.0,293.0,294.0,294.0
2006-01-31,,,282.0,293.0,295.0
2006-01-30,,,277.0,278.0,278.0
2006-01-27,275.0,275.0,276.0,275.0,275.0
2006-01-26,279.0,279.0,280.0,275.0,275.0
2006-01-25,275.0,275.0,275.0,279.0,279.0
2006-01-24,278.0,278.0,278.0,276.0,276.0
2006-01-23,276.0,278.0,277.0,278.0,278.0
2006-01-20,279.0,278.0,277.0,280.0,277.0
2006-01-19,273.0,275.0,275.0,273.0,277.0
2006-01-18,282.0,276.0,275.0,273.0,273.0
2006-01-17,289.0,286.0,286.0,281.0,283.0
2006-01-16,283.0,285.0,285.0,289.0,289.0
2006-01-13,273.0,273.0,273.0,275.0,281.0
2006-01-12,274.0,274.0,274.0,273.0,273.0
2006-01-11,274.0,274.0,274.0,271.0,274.0
2006-01-10,279.0,278.0,278.0,277.0,274.0
2006-01-09,272.0,272.0,274.0,275.0,278.0
2006-01-06,264.0,265.0,262.0,269.0,272.0
2006-01-05,274.0,274.0,272.0,263.0,263.0
2006-01-04,272.0,272.0,272.0,272.0,274.0
2006-01-03,260.0,262.0,262.0,267.0,267.0
""".strip()
FETCHER_UNIVERSE_DATA = """
date,symbol
1/9/2006,aapl
1/9/2006,ibm
1/9/2006,msft
1/11/2006,aapl
1/11/2006,ibm
1/11/2006,msft
1/11/2006,yhoo
""".strip()
NON_ASSET_FETCHER_UNIVERSE_DATA = """
date,symbol
1/9/2006,foobarbaz
1/9/2006,bazfoobar
1/9/2006,barbazfoo
1/11/2006,foobarbaz
1/11/2006,bazfoobar
1/11/2006,barbazfoo
1/11/2006,foobarbaz
""".strip()
FETCHER_ALTERNATE_COLUMN_HEADER = "ARGLEBARGLE"
FETCHER_UNIVERSE_DATA_TICKER_COLUMN = FETCHER_UNIVERSE_DATA.replace(
"symbol", FETCHER_ALTERNATE_COLUMN_HEADER)
@@ -1,21 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_stateversion_'
p3
I1
sS'new_orders'
p4
(lp5
sS'orders'
p6
(dp7
sS'open_orders'
p8
(dp9
ssS'initargs'
p10
NsS'newargs'
p11
Ns.
@@ -1,60 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'direction'
p3
F1.0
sS'_stateversion_'
p4
I1
sS'_status'
p5
I0
sS'created'
p6
cdatetime
datetime
p7
(S'\x07\xdd\x06\x13\x00\x00\x00\x00\x00\x00'
p8
tp9
Rp10
sS'limit_reached'
p11
I00
sS'stop'
p12
NsS'reason'
p13
NsS'stop_reached'
p14
I00
sS'commission'
p15
NsS'amount'
p16
I100
sS'limit'
p17
NsS'sid'
p18
I8554
sS'dt'
p19
g10
sS'type'
p20
I6
sS'id'
p21
S'e837d6193375414eb1594c8adb068a34'
p22
sS'filled'
p23
I0
ssS'initargs'
p24
NsS'newargs'
p25
Ns.
@@ -1,15 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'cost'
p3
F0.0015
sS'_stateversion_'
p4
I1
ssS'initargs'
p5
NsS'newargs'
p6
Ns.
@@ -1,17 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'min_trade_cost'
p3
NsS'cost'
p4
F0.03
sS'_stateversion_'
p5
I1
ssS'initargs'
p6
NsS'newargs'
p7
Ns.
@@ -1,15 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'cost'
p3
F5.0
sS'_stateversion_'
p4
I1
ssS'initargs'
p5
NsS'newargs'
p6
Ns.
@@ -1,194 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_account_store'
p3
ccopy_reg
_reconstructor
p4
(czipline.protocol
Account
p5
c__builtin__
object
p6
Ntp7
Rp8
(dp9
S'regt_margin'
p10
Finf
sS'maintenance_margin_requirement'
p11
F0.0
sS'day_trades_remaining'
p12
Finf
sS'buying_power'
p13
Finf
sS'net_leverage'
p14
F0.0
sS'settled_cash'
p15
F0.0
sS'cushion'
p16
F0.0
sS'_stateversion_'
p17
I1
sS'leverage'
p18
F0.0
sS'regt_equity'
p19
F0.0
sS'excess_liquidity'
p20
F0.0
sS'available_funds'
p21
F0.0
sS'equity_with_loan'
p22
F0.0
sS'initial_margin_requirement'
p23
F0.0
sS'net_liquidation'
p24
F0.0
sS'total_positions_value'
p25
F0.0
sS'accrued_interest'
p26
F0.0
sbsS'orders_by_modified'
p27
(dp28
sS'keep_transactions'
p29
I01
sS'ending_cash'
p30
F10000.0
sS'_positions_store'
p31
(dp32
sS'positions'
p33
(dp34
sS'processed_transactions'
p35
(dp36
sS'ending_value'
p37
cnumpy.core.multiarray
scalar
p38
(cnumpy
dtype
p39
(S'f8'
p40
I0
I1
tp41
Rp42
(I3
S'<'
p43
NNNI-1
I-1
I0
tp44
bS'\x00\x00\x00\x00\x00\x00\x00\x00'
p45
tp46
Rp47
sS'loc_map'
p48
(dp49
sS'starting_cash'
p50
I10000
sS'returns'
p51
g38
(g42
S'\x00\x00\x00\x00\x00\x00\x00\x00'
p52
tp53
Rp54
sg17
I1
sS'pnl'
p55
g38
(g42
S'\x00\x00\x00\x00\x00\x00\x00\x00'
p56
tp57
Rp58
sS'period_cash_flow'
p59
F0.0
sS'serialize_positions'
p60
I01
sS'keep_orders'
p61
I00
sS'_portfolio_store'
p62
g4
(czipline.protocol
Portfolio
p63
g6
Ntp64
Rp65
(dp66
g17
I1
sS'portfolio_value'
p67
F0.0
sS'cash'
p68
F0.0
sg50
F0.0
sg51
F0.0
sS'capital_used'
p69
F0.0
sg55
F0.0
sg33
(dp70
sS'positions_value'
p71
F0.0
sS'start_date'
p72
NsbsS'starting_value'
p73
F0.0
sS'period_open'
p74
NsS'period_close'
p75
NsS'orders_by_id'
p76
(dp77
ssS'initargs'
p78
NsS'newargs'
p79
Ns.
@@ -1,152 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_account_store'
p3
ccopy_reg
_reconstructor
p4
(czipline.protocol
Account
p5
c__builtin__
object
p6
Ntp7
Rp8
(dp9
S'regt_margin'
p10
Finf
sS'maintenance_margin_requirement'
p11
F0.0
sS'day_trades_remaining'
p12
Finf
sS'buying_power'
p13
Finf
sS'net_leverage'
p14
F0.0
sS'settled_cash'
p15
F0.0
sS'cushion'
p16
F0.0
sS'_stateversion_'
p17
I1
sS'leverage'
p18
F0.0
sS'regt_equity'
p19
F0.0
sS'excess_liquidity'
p20
F0.0
sS'available_funds'
p21
F0.0
sS'equity_with_loan'
p22
F0.0
sS'initial_margin_requirement'
p23
F0.0
sS'net_liquidation'
p24
F0.0
sS'total_positions_value'
p25
F0.0
sS'accrued_interest'
p26
F0.0
sbsS'orders_by_modified'
p27
(dp28
sS'keep_transactions'
p29
I01
sS'ending_cash'
p30
I10000
sS'processed_transactions'
p31
(dp32
sS'ending_value'
p33
F0.0
sS'starting_cash'
p34
I10000
sg17
I2
sS'pnl'
p35
F0.0
sS'period_cash_flow'
p36
F0.0
sS'serialize_positions'
p37
I01
sS'keep_orders'
p38
I00
sS'_portfolio_store'
p39
g4
(czipline.protocol
Portfolio
p40
g6
Ntp41
Rp42
(dp43
g17
I1
sS'portfolio_value'
p44
F0.0
sS'cash'
p45
F0.0
sg34
F0.0
sS'returns'
p46
F0.0
sS'capital_used'
p47
F0.0
sg35
F0.0
sS'positions'
p48
(dp49
sS'positions_value'
p50
F0.0
sS'start_date'
p51
NsbsS'starting_value'
p52
F0.0
sS'period_open'
p53
NsS'period_close'
p54
NsS'orders_by_id'
p55
(dp56
ssS'initargs'
p57
NsS'newargs'
p58
Ns.
@@ -1,26 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_stateversion_'
p3
I1
sS'cost_basis'
p4
F0.0
sS'amount'
p5
I0
sS'last_sale_price'
p6
F0.0
sS'sid'
p7
I8554
sS'last_sale_date'
p8
NssS'initargs'
p9
NsS'newargs'
p10
Ns.
@@ -1,136 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'positions'
p3
(dp4
sS'unpaid_dividends'
p5
ccopy_reg
_reconstructor
p6
(cpandas.core.frame
DataFrame
p7
c__builtin__
object
p8
Ntp9
Rp10
g6
(cpandas.core.internals
BlockManager
p11
g8
Ntp12
Rp13
((lp14
cnumpy.core.multiarray
_reconstruct
p15
(cpandas.core.index
Index
p16
(I0
tp17
S'b'
p18
tp19
Rp20
((I1
(I4
tp21
cnumpy
dtype
p22
(S'O8'
p23
I0
I1
tp24
Rp25
(I3
S'|'
p26
NNNI-1
I-1
I63
tp27
bI00
(lp28
S'id'
p29
aS'payment_sid'
p30
aS'cash_amount'
p31
aS'share_count'
p32
atp33
(Ntp34
tp35
bag15
(g16
(I0
tp36
g18
tp37
Rp38
((I1
(I0
tp39
g25
I00
(lp40
tp41
(Ntp42
tp43
ba(lp44
g15
(cnumpy
ndarray
p45
(I0
tp46
g18
tp47
Rp48
(I1
(I4
I0
tp49
g25
I00
(lp50
tp51
ba(lp52
g15
(g16
(I0
tp53
g18
tp54
Rp55
((I1
(I4
tp56
g25
I00
(lp57
g29
ag30
ag31
ag32
atp58
(Ntp59
tp60
batp61
bbsS'_stateversion_'
p62
I1
ssS'initargs'
p63
NsS'newargs'
p64
Ns.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,15 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'spread'
p3
F0.0
sS'_stateversion_'
p4
I1
ssS'initargs'
p5
NsS'newargs'
p6
Ns.
@@ -1,39 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'commission'
p3
NsS'amount'
p4
I10
sS'_stateversion_'
p5
I1
sS'sid'
p6
I8554
sS'order_id'
p7
S'0000'
p8
sS'price'
p9
I100
sS'type'
p10
I5
sS'dt'
p11
cdatetime
datetime
p12
(S'\x07\xdd\x06\x13\x00\x00\x00\x00\x00\x00'
p13
tp14
Rp15
ssS'initargs'
p16
NsS'newargs'
p17
Ns.
@@ -1,18 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'price_impact'
p3
F0.1
sS'volume_limit'
p4
F0.25
sS'_stateversion_'
p5
I1
ssS'initargs'
p6
NsS'newargs'
p7
Ns.
@@ -1,60 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'regt_margin'
p3
Finf
sS'maintenance_margin_requirement'
p4
F0.0
sS'day_trades_remaining'
p5
Finf
sS'buying_power'
p6
Finf
sS'net_leverage'
p7
F0.0
sS'settled_cash'
p8
F0.0
sS'cushion'
p9
F0.0
sS'_stateversion_'
p10
I1
sS'leverage'
p11
F0.0
sS'regt_equity'
p12
F0.0
sS'excess_liquidity'
p13
F0.0
sS'available_funds'
p14
F0.0
sS'equity_with_loan'
p15
F0.0
sS'initial_margin_requirement'
p16
F0.0
sS'net_liquidation'
p17
F0.0
sS'total_positions_value'
p18
F0.0
sS'accrued_interest'
p19
F0.0
ssS'initargs'
p20
NsS'newargs'
p21
Ns.
@@ -1,38 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_stateversion_'
p3
I1
sS'portfolio_value'
p4
F0.0
sS'cash'
p5
F0.0
sS'starting_cash'
p6
F0.0
sS'returns'
p7
F0.0
sS'capital_used'
p8
F0.0
sS'pnl'
p9
F0.0
sS'positions'
p10
(dp11
sS'positions_value'
p12
F0.0
sS'start_date'
p13
NssS'initargs'
p14
NsS'newargs'
p15
Ns.
@@ -1,24 +0,0 @@
(dp0
S'obj_state'
p1
(dp2
S'_stateversion_'
p3
I1
sS'amount'
p4
I0
sS'last_sale_price'
p5
F0.0
sS'cost_basis'
p6
F0.0
sS'sid'
p7
I8554
ssS'initargs'
p8
NsS'newargs'
p9
Ns.
-122
View File
@@ -1,122 +0,0 @@
import datetime
import pytz
import nose.tools as nt
import pandas.util.testing as tm
import pandas as pd
from zipline.finance.blotter import Blotter, Order
from zipline.finance.commission import PerShare, PerTrade, PerDollar
from zipline.finance.performance.period import PerformancePeriod
from zipline.finance.performance.position import Position
from zipline.finance.performance.tracker import PerformanceTracker
from zipline.finance.performance.position_tracker import PositionTracker
from zipline.finance.risk.cumulative import RiskMetricsCumulative
from zipline.finance.risk.period import RiskMetricsPeriod
from zipline.finance.risk.report import RiskReport
from zipline.finance.slippage import (
FixedSlippage,
VolumeShareSlippage
)
from zipline.finance.transaction import Transaction
from zipline.protocol import Account
from zipline.protocol import Portfolio
from zipline.protocol import Position as ProtocolPosition
from zipline.finance.trading import SimulationParameters, TradingEnvironment
from zipline.utils import factory
def stringify_cases(cases, func=None):
# get better test case names
results = []
if func is None:
def func(case):
return case[0].__name__
for case in cases:
new_case = list(case)
key = func(case)
new_case.insert(0, key)
results.append(new_case)
return results
cases_env = TradingEnvironment()
sim_params_daily = SimulationParameters(
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
10000,
emission_rate='daily',
env=cases_env)
sim_params_minute = SimulationParameters(
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
datetime.datetime(2013, 6, 19, tzinfo=pytz.UTC),
10000,
emission_rate='minute',
env=cases_env)
returns = factory.create_returns_from_list(
[1.0], sim_params_daily)
def object_serialization_cases(skip_daily=False):
# Wrapped in a function to recreate DI objects.
cases = [
(Blotter, (), {}, 'repr'),
(Order, (datetime.datetime(2013, 6, 19), 8554, 100), {}, 'dict'),
(PerShare, (), {}, 'dict'),
(PerTrade, (), {}, 'dict'),
(PerDollar, (), {}, 'dict'),
(PerformancePeriod,
(10000, cases_env.asset_finder),
{'position_tracker': PositionTracker(cases_env.asset_finder)},
'to_dict'),
(Position, (8554,), {}, 'dict'),
(PositionTracker, (cases_env.asset_finder,), {}, 'dict'),
(PerformanceTracker, (sim_params_minute, cases_env), {}, 'to_dict'),
(RiskMetricsCumulative, (sim_params_minute, cases_env), {}, 'to_dict'),
(RiskMetricsPeriod,
(returns.index[0], returns.index[0], returns, cases_env),
{}, 'to_dict'),
(RiskReport, (returns, sim_params_minute, cases_env), {}, 'to_dict'),
(FixedSlippage, (), {}, 'dict'),
(Transaction,
(8554, 10, datetime.datetime(2013, 6, 19), 100, "0000"), {},
'dict'),
(VolumeShareSlippage, (), {}, 'dict'),
(Account, (), {}, 'dict'),
(Portfolio, (), {}, 'dict'),
(ProtocolPosition, (8554,), {}, 'dict')
]
if not skip_daily:
cases.extend([
(PerformanceTracker,
(sim_params_daily, cases_env), {}, 'to_dict'),
(RiskMetricsCumulative,
(sim_params_daily, cases_env), {}, 'to_dict'),
(RiskReport,
(returns, sim_params_daily, cases_env), {}, 'to_dict'),
])
return stringify_cases(cases)
def assert_dict_equal(d1, d2):
# check keys
nt.assert_is_instance(d1, dict)
nt.assert_is_instance(d2, dict)
nt.assert_set_equal(set(d1.keys()), set(d2.keys()))
for k in d1:
v1 = d1[k]
v2 = d2[k]
asserter = nt.assert_equal
if isinstance(v1, pd.DataFrame):
asserter = tm.assert_frame_equal
if isinstance(v1, pd.Series):
asserter = tm.assert_series_equal
try:
asserter(v1, v2)
except AssertionError:
raise AssertionError('{k} is not equal'.format(k=k))
+1553 -772
View File
File diff suppressed because it is too large Load Diff
-247
View File
@@ -1,247 +0,0 @@
#
# Copyright 2014 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from nose.tools import (
timed,
nottest
)
from datetime import datetime
import pandas as pd
import pytz
from zipline.finance import trading
from zipline.algorithm import TradingAlgorithm
from zipline.finance import slippage
from zipline.protocol import (
Event,
DATASOURCE_TYPE,
)
from zipline.testing import (
setup_logger,
teardown_logger
)
from zipline.utils import factory
from zipline.utils.factory import create_simulation_parameters
DEFAULT_TIMEOUT = 15 # seconds
EXTENDED_TIMEOUT = 90
class RecordDateSlippage(slippage.FixedSlippage):
def __init__(self, spread):
super(RecordDateSlippage, self).__init__(spread=spread)
self.latest_date = None
def simulate(self, event, open_orders):
self.latest_date = event.dt
result = super(RecordDateSlippage, self).simulate(event, open_orders)
return result
class TestAlgo(TradingAlgorithm):
def __init__(self, asserter, *args, **kwargs):
super(TestAlgo, self).__init__(*args, **kwargs)
self.asserter = asserter
def initialize(self, window_length=100):
self.latest_date = None
self.set_slippage(RecordDateSlippage(spread=0.05))
self.stocks = [self.sid(8229)]
self.ordered = False
self.num_bars = 0
def handle_data(self, data):
self.num_bars += 1
self.latest_date = self.get_datetime()
if not self.ordered:
for stock in self.stocks:
self.order(stock, 100)
self.ordered = True
else:
self.asserter.assertGreaterEqual(
self.latest_date,
self.slippage.latest_date
)
class AlgorithmGeneratorTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.env = trading.TradingEnvironment()
cls.env.write_data(equities_identifiers=[8229])
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self):
setup_logger(self)
def tearDown(self):
teardown_logger(self)
@nottest
def test_lse_algorithm(self):
lse = trading.TradingEnvironment(
bm_symbol='^FTSE',
exchange_tz='Europe/London'
)
with lse:
sim_params = factory.create_simulation_parameters(
start=datetime(2012, 5, 1, tzinfo=pytz.utc),
end=datetime(2012, 6, 30, tzinfo=pytz.utc)
)
algo = TestAlgo(self, identifiers=[8229], sim_params=sim_params)
# This call appears inconsistent with
# the signature of create_daily_trade_source
trade_source = factory.create_daily_trade_source(
[8229],
200,
sim_params
)
algo.set_sources([trade_source])
gen = algo.get_generator()
results = list(gen)
self.assertEqual(len(results), 42)
# May 7, 2012 was an LSE holiday, confirm the 4th trading
# day was May 8.
self.assertEqual(results[4]['daily_perf']['period_open'],
datetime(2012, 5, 8, 8, 31, tzinfo=pytz.utc))
@timed(DEFAULT_TIMEOUT)
def test_generator_dates(self):
"""
Ensure the pipeline of generators are in sync, at least as far as
their current dates.
"""
sim_params = factory.create_simulation_parameters(
start=datetime(2011, 7, 30, tzinfo=pytz.utc),
end=datetime(2012, 7, 30, tzinfo=pytz.utc),
env=self.env,
)
algo = TestAlgo(self, sim_params=sim_params, env=self.env)
trade_source = factory.create_daily_trade_source(
[8229],
sim_params,
env=self.env,
)
algo.set_sources([trade_source])
gen = algo.get_generator()
self.assertTrue(list(gen))
self.assertTrue(algo.slippage.latest_date)
self.assertTrue(algo.latest_date)
@timed(DEFAULT_TIMEOUT)
def test_handle_data_on_market(self):
"""
Ensure that handle_data is only called on market minutes.
i.e. events that come in at midnight should be processed at market
open.
"""
from zipline.finance.trading import SimulationParameters
sim_params = SimulationParameters(
period_start=datetime(2012, 7, 30, tzinfo=pytz.utc),
period_end=datetime(2012, 7, 30, tzinfo=pytz.utc),
data_frequency='minute',
env=self.env,
)
algo = TestAlgo(self, sim_params=sim_params, env=self.env)
midnight_custom_source = [Event({
'custom_field': 42.0,
'sid': 'custom_data',
'source_id': 'TestMidnightSource',
'dt': pd.Timestamp('2012-07-30', tz='UTC'),
'type': DATASOURCE_TYPE.CUSTOM
})]
minute_event_source = [Event({
'volume': 100,
'price': 200.0,
'high': 210.0,
'open_price': 190.0,
'low': 180.0,
'sid': 8229,
'source_id': 'TestMinuteEventSource',
'dt': pd.Timestamp('2012-07-30 9:31 AM', tz='US/Eastern').
tz_convert('UTC'),
'type': DATASOURCE_TYPE.TRADE
})]
algo.set_sources([midnight_custom_source, minute_event_source])
gen = algo.get_generator()
# Consume the generator
list(gen)
# Though the events had different time stamps, handle data should
# have only been called once, at the market open.
self.assertEqual(algo.num_bars, 1)
@timed(DEFAULT_TIMEOUT)
def test_progress(self):
"""
Ensure the pipeline of generators are in sync, at least as far as
their current dates.
"""
sim_params = factory.create_simulation_parameters(
start=datetime(2008, 1, 1, tzinfo=pytz.utc),
end=datetime(2008, 1, 5, tzinfo=pytz.utc),
env=self.env,
)
algo = TestAlgo(self, sim_params=sim_params, env=self.env)
trade_source = factory.create_daily_trade_source(
[8229],
sim_params,
env=self.env,
)
algo.set_sources([trade_source])
gen = algo.get_generator()
results = list(gen)
self.assertEqual(results[-2]['progress'], 1.0)
def test_benchmark_times_match_market_close_for_minutely_data(self):
"""
Benchmark dates should be adjusted so that benchmark events are
emitted at the end of each trading day when working with minutely
data.
Verification relies on the fact that there are no trades so
algo.datetime should be equal to the last benchmark time.
See https://github.com/quantopian/zipline/issues/241
"""
sim_params = create_simulation_parameters(num_days=1,
data_frequency='minute',
env=self.env)
algo = TestAlgo(self, sim_params=sim_params, env=self.env)
algo.run(source=[], overwrite_sim_params=False)
self.assertEqual(algo.datetime, sim_params.last_close)
+579
View File
@@ -0,0 +1,579 @@
import warnings
from unittest import TestCase
from mock import patch
import pandas as pd
import numpy as np
from testfixtures import TempDirectory
from zipline import TradingAlgorithm
from zipline.data.data_portal import DataPortal
from zipline.data.minute_bars import BcolzMinuteBarWriter, \
US_EQUITIES_MINUTES_PER_DAY, BcolzMinuteBarReader
from zipline.data.us_equity_pricing import BcolzDailyBarReader, \
SQLiteAdjustmentReader, SQLiteAdjustmentWriter
from zipline.finance.trading import TradingEnvironment, SimulationParameters
from zipline.protocol import BarData
from zipline.testing.core import write_minute_data_for_asset, \
create_daily_df_for_asset, DailyBarWriterFromDataFrames, MockDailyBarReader
from zipline.testing import str_to_seconds
from zipline.zipline_warnings import ZiplineDeprecationWarning
simple_algo = """
from zipline.api import sid, order
def initialize(context):
pass
def handle_data(context, data):
assert sid(1) in data
assert sid(2) in data
assert len(data) == 3
for asset in data:
pass
"""
history_algo = """
from zipline.api import sid, history
def initialize(context):
context.sid1 = sid(1)
def handle_data(context, data):
context.history_window = history(5, "1m", "volume")
"""
history_bts_algo = """
from zipline.api import sid, history, record
def initialize(context):
context.sid3 = sid(3)
context.num_bts = 0
def before_trading_start(context, data):
context.num_bts += 1
# Get history at the second BTS (beginning of second day)
if context.num_bts == 2:
record(history=history(5, "1m", "volume"))
def handle_data(context, data):
pass
"""
simple_transforms_algo = """
from zipline.api import sid
def initialize(context):
context.count = 0
def handle_data(context, data):
if context.count == 2:
context.mavg = data[sid(1)].mavg(5)
context.vwap = data[sid(1)].vwap(5)
context.stddev = data[sid(1)].stddev(5)
context.returns = data[sid(1)].returns()
context.count += 1
"""
manipulation_algo = """
def initialize(context):
context.asset1 = sid(1)
context.asset2 = sid(2)
def handle_data(context, data):
assert len(data) == 2
assert len(data.keys()) == 2
assert context.asset1 in data.keys()
assert context.asset2 in data.keys()
"""
sid_accessor_algo = """
from zipline.api import sid
def initialize(context):
context.asset1 = sid(1)
def handle_data(context,data):
assert data[sid(1)].sid == context.asset1
assert data[sid(1)]["sid"] == context.asset1
"""
data_items_algo = """
from zipline.api import sid
def initialize(context):
context.asset1 = sid(1)
context.asset2 = sid(2)
def handle_data(context, data):
iter_list = list(data.iteritems())
items_list = data.items()
assert iter_list == items_list
"""
class TestAPIShim(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.tempdir = TempDirectory()
cls.trading_days = cls.env.days_in_range(
start=pd.Timestamp("2016-01-05", tz='UTC'),
end=pd.Timestamp("2016-01-28", tz='UTC')
)
equities_data = {}
for sid in [1, 2, 3]:
equities_data[sid] = {
"start_date": cls.trading_days[0],
"end_date": cls.env.next_trading_day(cls.trading_days[-1]),
"symbol": "ASSET{0}".format(sid),
}
cls.env.write_data(equities_data=equities_data)
cls.asset1 = cls.env.asset_finder.retrieve_asset(1)
cls.asset2 = cls.env.asset_finder.retrieve_asset(2)
cls.asset3 = cls.env.asset_finder.retrieve_asset(3)
market_opens = cls.env.open_and_closes.market_open.loc[
cls.trading_days]
market_closes = cls.env.open_and_closes.market_close.loc[
cls.trading_days]
minute_writer = BcolzMinuteBarWriter(
cls.trading_days[0],
cls.tempdir.path,
market_opens,
market_closes,
US_EQUITIES_MINUTES_PER_DAY
)
for sid in [1, 2, 3]:
write_minute_data_for_asset(
cls.env, minute_writer, cls.trading_days[0],
cls.trading_days[-1], sid
)
cls.adj_reader = cls.create_adjustments_reader()
cls.sim_params = SimulationParameters(
period_start=cls.trading_days[0],
period_end=cls.trading_days[-1],
data_frequency="minute",
env=cls.env
)
@classmethod
def build_daily_data(cls):
path = cls.tempdir.getpath("testdaily.bcolz")
dfs = {
1: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1]),
2: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1]),
3: create_daily_df_for_asset(cls.env, cls.trading_days[0],
cls.trading_days[-1])
}
daily_writer = DailyBarWriterFromDataFrames(dfs)
daily_writer.write(path, cls.trading_days, dfs)
return BcolzDailyBarReader(path)
@classmethod
def create_adjustments_reader(cls):
path = cls.tempdir.getpath("test_adjustments.db")
adj_writer = SQLiteAdjustmentWriter(
path,
cls.env.trading_days,
MockDailyBarReader()
)
splits = pd.DataFrame([
{
'effective_date': str_to_seconds("2016-01-06"),
'ratio': 0.5,
'sid': cls.asset3.sid
}
])
# Mergers and Dividends are not tested, but we need to have these
# anyway
mergers = pd.DataFrame({}, columns=['effective_date', 'ratio', 'sid'])
mergers.effective_date = mergers.effective_date.astype(int)
mergers.ratio = mergers.ratio.astype(float)
mergers.sid = mergers.sid.astype(int)
dividends = pd.DataFrame({}, columns=['ex_date', 'record_date',
'declared_date', 'pay_date',
'amount', 'sid'])
dividends.amount = dividends.amount.astype(float)
dividends.sid = dividends.sid.astype(int)
adj_writer.write(splits, mergers, dividends)
return SQLiteAdjustmentReader(path)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
def setUp(self):
self.data_portal = DataPortal(
self.env,
equity_minute_reader=BcolzMinuteBarReader(self.tempdir.path),
equity_daily_reader=self.build_daily_data(),
adjustment_reader=self.adj_reader
)
@classmethod
def create_algo(cls, code, filename=None, sim_params=None):
if sim_params is None:
sim_params = cls.sim_params
return TradingAlgorithm(
script=code,
sim_params=sim_params,
env=cls.env,
algo_filename=filename
)
def test_old_new_data_api_paths(self):
"""
Test that the new and old data APIs hit the same code paths.
We want to ensure that the old data API(data[sid(N)].field and
similar) and the new data API(data.current(sid(N), field) and
similar) hit the same code paths on the DataPortal.
"""
test_start_minute = self.env.market_minutes_for_day(
self.trading_days[0]
)[1]
test_end_minute = self.env.market_minutes_for_day(
self.trading_days[0]
)[-1]
bar_data = BarData(
self.data_portal,
lambda: test_end_minute, "minute"
)
ohlcvp_fields = [
"open",
"high",
"low"
"close",
"volume",
"price",
]
spot_value_meth = 'zipline.data.data_portal.DataPortal.get_spot_value'
def assert_get_spot_value_called(fun, field):
"""
Assert that get_spot_value was called during the execution of fun.
Takes in a function fun and a string field.
"""
with patch(spot_value_meth) as gsv:
fun()
gsv.assert_called_with(
self.asset1,
field,
test_end_minute,
'minute'
)
# Ensure that data.current(sid(n), field) has the same behaviour as
# data[sid(n)].field.
for field in ohlcvp_fields:
assert_get_spot_value_called(
lambda: getattr(bar_data[self.asset1], field),
field,
)
assert_get_spot_value_called(
lambda: bar_data.current(self.asset1, field),
field,
)
history_meth = 'zipline.data.data_portal.DataPortal.get_history_window'
def assert_get_history_window_called(fun, is_legacy):
"""
Assert that get_history_window was called during fun().
Takes in a function fun and a boolean is_legacy.
"""
with patch(history_meth) as ghw:
fun()
# Slightly hacky, but done to get around the fact that
# history( explicitly passes an ffill param as the last arg,
# while data.history doesn't.
if is_legacy:
ghw.assert_called_with(
[self.asset1, self.asset2, self.asset3],
test_end_minute,
5,
"1m",
"volume",
True
)
else:
ghw.assert_called_with(
[self.asset1, self.asset2, self.asset3],
test_end_minute,
5,
"1m",
"volume",
)
test_sim_params = SimulationParameters(
period_start=test_start_minute,
period_end=test_end_minute,
data_frequency="minute",
env=self.env
)
history_algorithm = self.create_algo(
history_algo,
sim_params=test_sim_params
)
assert_get_history_window_called(
lambda: history_algorithm.run(self.data_portal),
is_legacy=True
)
assert_get_history_window_called(
lambda: bar_data.history(
[self.asset1, self.asset2, self.asset3],
"volume",
5,
"1m"
),
is_legacy=False
)
def test_sid_accessor(self):
"""
Test that we maintain backwards compat for sid access on a data object.
We want to support both data[sid(24)].sid, as well as
data[sid(24)]["sid"]. Since these are deprecated and will eventually
cease to be supported, we also want to assert that we're seeing a
deprecation warning.
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(sid_accessor_algo)
algo.run(self.data_portal)
# Since we're already raising a warning on doing data[sid(x)],
# we don't want to raise an extra warning on data[sid(x)].sid.
self.assertEqual(2, len(w))
# Check that both the warnings raised were in fact
# ZiplineDeprecationWarnings
for warning in w:
self.assertEqual(
ZiplineDeprecationWarning,
warning.category
)
self.assertEqual(
"`data[sid(N)]` is deprecated. Use `data.current`.",
str(warning.message)
)
def test_data_items(self):
"""
Test that we maintain backwards compat for data.[items | iteritems].
We also want to assert that we warn that iterating over the assets
in `data` is deprecated.
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(data_items_algo)
algo.run(self.data_portal)
self.assertEqual(4, len(w))
for idx, warning in enumerate(w):
self.assertEqual(
ZiplineDeprecationWarning,
warning.category
)
if idx % 2 == 0:
self.assertEqual(
"Iterating over the assets in `data` is deprecated.",
str(warning.message)
)
else:
self.assertEqual(
"`data[sid(N)]` is deprecated. Use `data.current`.",
str(warning.message)
)
def test_iterate_data(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(simple_algo)
algo.run(self.data_portal)
self.assertEqual(4, len(w))
line_nos = [warning.lineno for warning in w]
self.assertEqual(4, len(set(line_nos)))
for idx, warning in enumerate(w):
self.assertEqual(ZiplineDeprecationWarning,
warning.category)
self.assertEqual("<string>", warning.filename)
self.assertEqual(line_nos[idx], warning.lineno)
if idx < 2:
self.assertEqual(
"Checking whether an asset is in data is deprecated.",
str(warning.message)
)
else:
self.assertEqual(
"Iterating over the assets in `data` is deprecated.",
str(warning.message)
)
def test_history(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
sim_params = SimulationParameters(
period_start=self.trading_days[1],
period_end=self.sim_params.period_end,
capital_base=self.sim_params.capital_base,
data_frequency=self.sim_params.data_frequency,
emission_rate=self.sim_params.emission_rate,
env=self.env,
)
algo = self.create_algo(history_algo,
sim_params=sim_params)
algo.run(self.data_portal)
self.assertEqual(1, len(w))
self.assertEqual(ZiplineDeprecationWarning, w[0].category)
self.assertEqual("<string>", w[0].filename)
self.assertEqual(8, w[0].lineno)
self.assertEqual("The `history` method is deprecated. Use "
"`data.history` instead.", str(w[0].message))
def test_old_new_history_bts_paths(self):
"""
Tests that calling history in before_trading_start gets us the correct
values, which involves 1) calling data_portal.get_history_window as of
the previous market minute, 2) getting adjustments between the previous
market minute and the current time, and 3) applying those adjustments
"""
algo = self.create_algo(history_bts_algo)
algo.run(self.data_portal)
expected_vol_without_split = np.arange(386, 391) * 100
expected_vol_with_split = np.arange(386, 391) * 200
window = algo.recorded_vars['history']
np.testing.assert_array_equal(window[self.asset1].values,
expected_vol_without_split)
np.testing.assert_array_equal(window[self.asset2].values,
expected_vol_without_split)
np.testing.assert_array_equal(window[self.asset3].values,
expected_vol_with_split)
def test_simple_transforms(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
sim_params = SimulationParameters(
period_start=self.trading_days[8],
period_end=self.trading_days[-1],
data_frequency="minute",
env=self.env
)
algo = self.create_algo(simple_transforms_algo,
sim_params=sim_params)
algo.run(self.data_portal)
self.assertEqual(8, len(w))
transforms = ["mavg", "vwap", "stddev", "returns"]
for idx, line_no in enumerate(range(8, 12)):
warning1 = w[idx * 2]
warning2 = w[(idx * 2) + 1]
self.assertEqual("<string>", warning1.filename)
self.assertEqual("<string>", warning2.filename)
self.assertEqual(line_no, warning1.lineno)
self.assertEqual(line_no, warning2.lineno)
self.assertEqual("`data[sid(N)]` is deprecated. Use "
"`data.current`.",
str(warning1.message))
self.assertEqual("The `{0}` method is "
"deprecated.".format(transforms[idx]),
str(warning2.message))
# now verify the transform values
# minute price
# 2016-01-11 14:31:00+00:00 1561
# ...
# 2016-01-14 20:59:00+00:00 3119
# 2016-01-14 21:00:00+00:00 3120
# 2016-01-15 14:31:00+00:00 3121
# 2016-01-15 14:32:00+00:00 3122
# 2016-01-15 14:33:00+00:00 3123
# volume
# 2016-01-11 14:31:00+00:00 156100
# ...
# 2016-01-14 20:59:00+00:00 311900
# 2016-01-14 21:00:00+00:00 312000
# 2016-01-15 14:31:00+00:00 312100
# 2016-01-15 14:32:00+00:00 312200
# 2016-01-15 14:33:00+00:00 312300
# daily price (last day built with minute data)
# 2016-01-14 00:00:00+00:00 9
# 2016-01-15 00:00:00+00:00 3123
# mavg = average of all the prices = (1561 + 3123) / 2 = 2342
# vwap = sum(price * volume) / sum(volumes)
# = 889119531400.0 / 366054600.0
# = 2428.9259891830343
# stddev = stddev(price, ddof=1) = 451.3435498597493
# returns = (todayprice - yesterdayprice) / yesterdayprice
# = (3123 - 9) / 9 = 346
self.assertEqual(2342, algo.mavg)
self.assertAlmostEqual(2428.92599, algo.vwap, places=5)
self.assertAlmostEqual(451.34355, algo.stddev, places=5)
self.assertAlmostEqual(346, algo.returns)
def test_manipulation(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(simple_algo)
algo.run(self.data_portal)
self.assertEqual(4, len(w))
for idx, warning in enumerate(w):
self.assertEqual("<string>", warning.filename)
self.assertEqual(7 + idx, warning.lineno)
if idx < 2:
self.assertEqual("Checking whether an asset is in data is "
"deprecated.",
str(warning.message))
else:
self.assertEqual("Iterating over the assets in `data` is "
"deprecated.",
str(warning.message))
+844
View File
@@ -0,0 +1,844 @@
#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from testfixtures import TempDirectory
import pandas as pd
import numpy as np
from nose_parameterized import parameterized
from zipline._protocol import handle_non_market_minutes
from zipline.data.data_portal import DataPortal
from zipline.data.minute_bars import BcolzMinuteBarWriter, \
US_EQUITIES_MINUTES_PER_DAY, BcolzMinuteBarReader
from zipline.data.us_equity_pricing import BcolzDailyBarReader, \
SQLiteAdjustmentReader, SQLiteAdjustmentWriter
from zipline.finance.trading import TradingEnvironment
from zipline.protocol import BarData
from zipline.testing.core import write_minute_data_for_asset, \
create_daily_df_for_asset, DailyBarWriterFromDataFrames, \
create_mock_adjustments, str_to_seconds, MockDailyBarReader
OHLC = ["open", "high", "low", "close"]
OHLCP = OHLC + ["price"]
ALL_FIELDS = OHLCP + ["volume", "last_traded"]
# offsets used in test data
field_info = {
"open": 1,
"high": 2,
"low": -1,
"close": 0
}
class TestBarDataBase(TestCase):
def assert_same(self, val1, val2):
try:
self.assertEqual(val1, val2)
except AssertionError:
if val1 is pd.NaT:
self.assertTrue(val2 is pd.NaT)
elif np.isnan(val1):
self.assertTrue(np.isnan(val2))
else:
raise
def check_internal_consistency(self, bar_data):
df = bar_data.current([self.ASSET1, self.ASSET2], ALL_FIELDS)
asset1_multi_field = bar_data.current(self.ASSET1, ALL_FIELDS)
asset2_multi_field = bar_data.current(self.ASSET2, ALL_FIELDS)
for field in ALL_FIELDS:
asset1_value = bar_data.current(self.ASSET1, field)
asset2_value = bar_data.current(self.ASSET2, field)
multi_asset_series = bar_data.current(
[self.ASSET1, self.ASSET2], field
)
# make sure all the different query forms are internally
# consistent
self.assert_same(multi_asset_series.loc[self.ASSET1], asset1_value)
self.assert_same(multi_asset_series.loc[self.ASSET2], asset2_value)
self.assert_same(df.loc[self.ASSET1][field], asset1_value)
self.assert_same(df.loc[self.ASSET2][field], asset2_value)
self.assert_same(asset1_multi_field[field], asset1_value)
self.assert_same(asset2_multi_field[field], asset2_value)
# also verify that bar_data doesn't expose anything bad
for field in ["data_portal", "simulation_dt_func", "data_frequency",
"_views", "_universe_func", "_last_calculated_universe",
"_universe_last_updatedat"]:
with self.assertRaises(AttributeError):
getattr(bar_data, field)
class TestMinuteBarData(TestBarDataBase):
@classmethod
def setUpClass(cls):
cls.tempdir = TempDirectory()
# asset1 has trades every minute
# asset2 has trades every 10 minutes
# split_asset trades every minute
# illiquid_split_asset trades every 10 minutes
cls.env = TradingEnvironment()
cls.days = cls.env.days_in_range(
start=pd.Timestamp("2016-01-05", tz='UTC'),
end=pd.Timestamp("2016-01-07", tz='UTC')
)
cls.env.write_data(equities_data={
sid: {
'start_date': cls.days[0],
'end_date': cls.days[-1],
'symbol': "ASSET{0}".format(sid)
} for sid in [1, 2, 3, 4, 5]
})
cls.ASSET1 = cls.env.asset_finder.retrieve_asset(1)
cls.ASSET2 = cls.env.asset_finder.retrieve_asset(2)
cls.SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(3)
cls.ILLIQUID_SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(4)
cls.HILARIOUSLY_ILLIQUID_ASSET = cls.env.asset_finder.retrieve_asset(5)
cls.ASSETS = [cls.ASSET1, cls.ASSET2]
cls.adjustments_reader = cls.create_adjustments_reader()
cls.data_portal = DataPortal(
cls.env,
equity_minute_reader=cls.build_minute_data(),
adjustment_reader=cls.adjustments_reader
)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
@classmethod
def create_adjustments_reader(cls):
path = create_mock_adjustments(
cls.tempdir,
cls.days,
splits=[{
'effective_date': str_to_seconds("2016-01-06"),
'ratio': 0.5,
'sid': cls.SPLIT_ASSET.sid
}, {
'effective_date': str_to_seconds("2016-01-06"),
'ratio': 0.5,
'sid': cls.ILLIQUID_SPLIT_ASSET.sid
}]
)
return SQLiteAdjustmentReader(path)
@classmethod
def build_minute_data(cls):
market_opens = cls.env.open_and_closes.market_open.loc[cls.days]
market_closes = cls.env.open_and_closes.market_close.loc[cls.days]
writer = BcolzMinuteBarWriter(
cls.days[0],
cls.tempdir.path,
market_opens,
market_closes,
US_EQUITIES_MINUTES_PER_DAY
)
for sid in [cls.ASSET1.sid, cls.SPLIT_ASSET.sid]:
write_minute_data_for_asset(
cls.env,
writer,
cls.days[0],
cls.days[-1],
sid
)
for sid in [cls.ASSET2.sid, cls.ILLIQUID_SPLIT_ASSET.sid]:
write_minute_data_for_asset(
cls.env,
writer,
cls.days[0],
cls.days[-1],
sid,
10
)
write_minute_data_for_asset(
cls.env,
writer,
cls.days[0],
cls.days[-1],
cls.HILARIOUSLY_ILLIQUID_ASSET.sid,
50
)
return BcolzMinuteBarReader(cls.tempdir.path)
def test_minute_before_assets_trading(self):
# grab minutes that include the day before the asset start
minutes = self.env.market_minutes_for_day(
self.env.previous_trading_day(self.days[0])
)
# this entire day is before either asset has started trading
for idx, minute in enumerate(minutes):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.check_internal_consistency(bar_data)
self.assertFalse(bar_data.can_trade(self.ASSET1))
self.assertFalse(bar_data.can_trade(self.ASSET2))
self.assertFalse(bar_data.is_stale(self.ASSET1))
self.assertFalse(bar_data.is_stale(self.ASSET2))
for field in ALL_FIELDS:
for asset in self.ASSETS:
asset_value = bar_data.current(asset, field)
if field in OHLCP:
self.assertTrue(np.isnan(asset_value))
elif field == "volume":
self.assertEqual(0, asset_value)
elif field == "last_traded":
self.assertTrue(asset_value is pd.NaT)
def test_regular_minute(self):
minutes = self.env.market_minutes_for_day(self.days[0])
for idx, minute in enumerate(minutes):
# day2 has prices
# (every minute for asset1, every 10 minutes for asset2)
# asset1:
# opens: 2-391
# high: 3-392
# low: 0-389
# close: 1-390
# volume: 100-3900 (by 100)
# asset2 is the same thing, but with only every 10th minute
# populated.
# this test covers the "IPO morning" case, because asset2 only
# has data starting on the 10th minute.
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.check_internal_consistency(bar_data)
asset2_has_data = (((idx + 1) % 10) == 0)
self.assertTrue(bar_data.can_trade(self.ASSET1))
self.assertFalse(bar_data.is_stale(self.ASSET1))
if idx < 9:
self.assertFalse(bar_data.can_trade(self.ASSET2))
self.assertFalse(bar_data.is_stale(self.ASSET2))
else:
self.assertTrue(bar_data.can_trade(self.ASSET2))
if asset2_has_data:
self.assertFalse(bar_data.is_stale(self.ASSET2))
else:
self.assertTrue(bar_data.is_stale(self.ASSET2))
for field in ALL_FIELDS:
asset1_value = bar_data.current(self.ASSET1, field)
asset2_value = bar_data.current(self.ASSET2, field)
# now check the actual values
if idx == 0 and field == "low":
# first low value is 0, which is interpreted as NaN
self.assertTrue(np.isnan(asset1_value))
else:
if field in OHLC:
self.assertEqual(
idx + 1 + field_info[field],
asset1_value
)
if asset2_has_data:
self.assertEqual(
idx + 1 + field_info[field],
asset2_value
)
else:
self.assertTrue(np.isnan(asset2_value))
elif field == "volume":
self.assertEqual((idx + 1) * 100, asset1_value)
if asset2_has_data:
self.assertEqual((idx + 1) * 100, asset2_value)
else:
self.assertEqual(0, asset2_value)
elif field == "price":
self.assertEqual(idx + 1, asset1_value)
if asset2_has_data:
self.assertEqual(idx + 1, asset2_value)
elif idx < 9:
# no price to forward fill from
self.assertTrue(np.isnan(asset2_value))
else:
# forward-filled price
self.assertEqual((idx // 10) * 10, asset2_value)
elif field == "last_traded":
self.assertEqual(minute, asset1_value)
if idx < 9:
self.assertTrue(asset2_value is pd.NaT)
elif asset2_has_data:
self.assertEqual(minute, asset2_value)
else:
last_traded_minute = minutes[(idx // 10) * 10]
self.assertEqual(last_traded_minute - 1,
asset2_value)
def test_minute_of_last_day(self):
minutes = self.env.market_minutes_for_day(self.days[-1])
# this is the last day the assets exist
for idx, minute in enumerate(minutes):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.assertTrue(bar_data.can_trade(self.ASSET1))
self.assertTrue(bar_data.can_trade(self.ASSET2))
def test_minute_after_assets_stopped(self):
minutes = self.env.market_minutes_for_day(
self.env.next_trading_day(self.days[-1])
)
last_trading_minute = \
self.env.market_minutes_for_day(self.days[-1])[-1]
# this entire day is after both assets have stopped trading
for idx, minute in enumerate(minutes):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.assertFalse(bar_data.can_trade(self.ASSET1))
self.assertFalse(bar_data.can_trade(self.ASSET2))
self.assertFalse(bar_data.is_stale(self.ASSET1))
self.assertFalse(bar_data.is_stale(self.ASSET2))
self.check_internal_consistency(bar_data)
for field in ALL_FIELDS:
for asset in self.ASSETS:
asset_value = bar_data.current(asset, field)
if field in OHLCP:
self.assertTrue(np.isnan(asset_value))
elif field == "volume":
self.assertEqual(0, asset_value)
elif field == "last_traded":
self.assertEqual(last_trading_minute, asset_value)
def test_spot_price_is_unadjusted(self):
# verify there is a split for SPLIT_ASSET
splits = self.adjustments_reader.get_adjustments_for_sid(
"splits",
self.SPLIT_ASSET.sid
)
self.assertEqual(1, len(splits))
split = splits[0]
self.assertEqual(
split[0],
pd.Timestamp("2016-01-06", tz='UTC')
)
# ... but that's it's not applied when using spot value
minutes = self.env.minutes_for_days_in_range(
start=self.days[0], end=self.days[1]
)
for idx, minute in enumerate(minutes):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.assertEqual(
idx + 1,
bar_data.current(self.SPLIT_ASSET, "price")
)
def test_spot_price_is_adjusted_if_needed(self):
# on cls.days[1], the first 9 minutes of ILLIQUID_SPLIT_ASSET are
# missing. let's get them.
day0_minutes = self.env.market_minutes_for_day(self.days[0])
day1_minutes = self.env.market_minutes_for_day(self.days[1])
for idx, minute in enumerate(day0_minutes[-10:-1]):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
self.assertEqual(
380,
bar_data.current(self.ILLIQUID_SPLIT_ASSET, "price")
)
bar_data = BarData(
self.data_portal, lambda: day0_minutes[-1], "minute"
)
self.assertEqual(
390,
bar_data.current(self.ILLIQUID_SPLIT_ASSET, "price")
)
for idx, minute in enumerate(day1_minutes[0:9]):
bar_data = BarData(self.data_portal, lambda: minute, "minute")
# should be half of 390, due to the split
self.assertEqual(
195,
bar_data.current(self.ILLIQUID_SPLIT_ASSET, "price")
)
def test_spot_price_at_midnight(self):
# make sure that if we try to get a minute price at a non-market
# minute, we use the previous market close's timestamp
day = self.days[1]
eight_fortyfive_am_eastern = \
pd.Timestamp("{0}-{1}-{2} 8:45".format(
day.year, day.month, day.day),
tz='US/Eastern'
)
bar_data = BarData(self.data_portal, lambda: day, "minute")
bar_data2 = BarData(self.data_portal,
lambda: eight_fortyfive_am_eastern,
"minute")
with handle_non_market_minutes(bar_data), \
handle_non_market_minutes(bar_data2):
for bd in [bar_data, bar_data2]:
for field in ["close", "price"]:
self.assertEqual(
390,
bd.current(self.ASSET1, field)
)
# make sure that if the asset didn't trade at the previous
# close, we properly ffill (or not ffill)
self.assertEqual(
350,
bd.current(self.HILARIOUSLY_ILLIQUID_ASSET, "price")
)
self.assertTrue(
np.isnan(bd.current(self.HILARIOUSLY_ILLIQUID_ASSET,
"high"))
)
self.assertEqual(
0,
bd.current(self.HILARIOUSLY_ILLIQUID_ASSET, "volume")
)
def test_can_trade_at_midnight(self):
# make sure that if we use `can_trade` at midnight, we don't pretend
# we're in the previous day's last minute
the_day_after = self.env.next_trading_day(self.days[-1])
bar_data = BarData(self.data_portal, lambda: the_day_after, "minute")
for asset in [self.ASSET1, self.HILARIOUSLY_ILLIQUID_ASSET]:
self.assertFalse(bar_data.can_trade(asset))
with handle_non_market_minutes(bar_data):
self.assertFalse(bar_data.can_trade(asset))
# but make sure it works when the assets are alive
bar_data2 = BarData(self.data_portal, lambda: self.days[1], "minute")
for asset in [self.ASSET1, self.HILARIOUSLY_ILLIQUID_ASSET]:
self.assertTrue(bar_data2.can_trade(asset))
with handle_non_market_minutes(bar_data2):
self.assertTrue(bar_data2.can_trade(asset))
def test_is_stale_at_midnight(self):
bar_data = BarData(self.data_portal, lambda: self.days[1], "minute")
with handle_non_market_minutes(bar_data):
self.assertTrue(bar_data.is_stale(self.HILARIOUSLY_ILLIQUID_ASSET))
def test_overnight_adjustments(self):
# verify there is a split for SPLIT_ASSET
splits = self.adjustments_reader.get_adjustments_for_sid(
"splits",
self.SPLIT_ASSET.sid
)
self.assertEqual(1, len(splits))
split = splits[0]
self.assertEqual(
split[0],
pd.Timestamp("2016-01-06", tz='UTC')
)
# Current day is 1/06/16
day = self.days[1]
eight_fortyfive_am_eastern = \
pd.Timestamp("{0}-{1}-{2} 8:45".format(
day.year, day.month, day.day),
tz='US/Eastern'
)
bar_data = BarData(self.data_portal,
lambda: eight_fortyfive_am_eastern,
"minute")
expected = {
'open': 391 / 2.0,
'high': 392 / 2.0,
'low': 389 / 2.0,
'close': 390 / 2.0,
'volume': 39000 * 2.0,
'price': 390 / 2.0,
}
with handle_non_market_minutes(bar_data):
for field in OHLCP + ['volume']:
value = bar_data.current(self.SPLIT_ASSET, field)
# Assert the price is adjusted for the overnight split
self.assertEqual(value, expected[field])
class TestDailyBarData(TestBarDataBase):
@classmethod
def setUpClass(cls):
cls.tempdir = TempDirectory()
# asset1 has a daily data for each day (1/5, 1/6, 1/7)
# asset2 only has daily data for day2 (1/6)
cls.env = TradingEnvironment()
cls.days = cls.env.days_in_range(
start=pd.Timestamp("2016-01-05", tz='UTC'),
end=pd.Timestamp("2016-01-08", tz='UTC')
)
cls.env.write_data(equities_data={
sid: {
'start_date': cls.days[0],
'end_date': cls.days[-1],
'symbol': "ASSET{0}".format(sid)
} for sid in [1, 2, 3, 4, 5, 6, 7, 8]
})
cls.ASSET1 = cls.env.asset_finder.retrieve_asset(1)
cls.ASSET2 = cls.env.asset_finder.retrieve_asset(2)
cls.SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(3)
cls.ILLIQUID_SPLIT_ASSET = cls.env.asset_finder.retrieve_asset(4)
cls.MERGER_ASSET = cls.env.asset_finder.retrieve_asset(5)
cls.ILLIQUID_MERGER_ASSET = cls.env.asset_finder.retrieve_asset(6)
cls.DIVIDEND_ASSET = cls.env.asset_finder.retrieve_asset(7)
cls.ILLIQUID_DIVIDEND_ASSET = cls.env.asset_finder.retrieve_asset(8)
cls.ASSETS = [cls.ASSET1, cls.ASSET2]
cls.adjustments_reader = cls.create_adjustments_reader()
cls.data_portal = DataPortal(
cls.env,
equity_daily_reader=cls.build_daily_data(),
adjustment_reader=cls.adjustments_reader
)
@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()
@classmethod
def create_adjustments_reader(cls):
path = cls.tempdir.getpath("test_adjustments.db")
adj_writer = SQLiteAdjustmentWriter(
path,
cls.env.trading_days,
MockDailyBarReader()
)
splits = pd.DataFrame([
{
'effective_date': str_to_seconds("2016-01-06"),
'ratio': 0.5,
'sid': cls.SPLIT_ASSET.sid
},
{
'effective_date': str_to_seconds("2016-01-07"),
'ratio': 0.5,
'sid': cls.ILLIQUID_SPLIT_ASSET.sid
}
])
mergers = pd.DataFrame([
{
'effective_date': str_to_seconds("2016-01-06"),
'ratio': 0.5,
'sid': cls.MERGER_ASSET.sid
},
{
'effective_date': str_to_seconds("2016-01-07"),
'ratio': 0.6,
'sid': cls.ILLIQUID_MERGER_ASSET.sid
}
])
# we're using a fake daily reader in the adjustments writer which
# returns every daily price as 100, so dividend amounts of 2.0 and 4.0
# correspond to 2% and 4% dividends, respectively.
dividends = pd.DataFrame([
{
# only care about ex date, the other dates don't matter here
'ex_date':
pd.Timestamp("2016-01-06", tz='UTC').to_datetime64(),
'record_date':
pd.Timestamp("2016-01-06", tz='UTC').to_datetime64(),
'declared_date':
pd.Timestamp("2016-01-06", tz='UTC').to_datetime64(),
'pay_date':
pd.Timestamp("2016-01-06", tz='UTC').to_datetime64(),
'amount': 2.0,
'sid': cls.DIVIDEND_ASSET.sid
},
{
'ex_date':
pd.Timestamp("2016-01-07", tz='UTC').to_datetime64(),
'record_date':
pd.Timestamp("2016-01-07", tz='UTC').to_datetime64(),
'declared_date':
pd.Timestamp("2016-01-07", tz='UTC').to_datetime64(),
'pay_date':
pd.Timestamp("2016-01-07", tz='UTC').to_datetime64(),
'amount': 4.0,
'sid': cls.ILLIQUID_DIVIDEND_ASSET.sid
}],
columns=['ex_date',
'record_date',
'declared_date',
'pay_date',
'amount',
'sid']
)
adj_writer.write(splits, mergers, dividends)
return SQLiteAdjustmentReader(path)
@classmethod
def build_daily_data(cls):
path = cls.tempdir.getpath("testdaily.bcolz")
dfs = {
1: create_daily_df_for_asset(cls.env, cls.days[0], cls.days[-1]),
2: create_daily_df_for_asset(
cls.env, cls.days[0], cls.days[-1], interval=2
),
3: create_daily_df_for_asset(cls.env, cls.days[0], cls.days[-1]),
4: create_daily_df_for_asset(
cls.env, cls.days[0], cls.days[-1], interval=2
),
5: create_daily_df_for_asset(cls.env, cls.days[0], cls.days[-1]),
6: create_daily_df_for_asset(
cls.env, cls.days[0], cls.days[-1], interval=2
),
7: create_daily_df_for_asset(cls.env, cls.days[0], cls.days[-1]),
8: create_daily_df_for_asset(
cls.env, cls.days[0], cls.days[-1], interval=2
),
}
daily_writer = DailyBarWriterFromDataFrames(dfs)
daily_writer.write(path, cls.days, dfs)
return BcolzDailyBarReader(path)
def test_day_before_assets_trading(self):
# use the day before self.days[0]
day = self.env.previous_trading_day(self.days[0])
bar_data = BarData(self.data_portal, lambda: day, "daily")
self.check_internal_consistency(bar_data)
self.assertFalse(bar_data.can_trade(self.ASSET1))
self.assertFalse(bar_data.can_trade(self.ASSET2))
self.assertFalse(bar_data.is_stale(self.ASSET1))
self.assertFalse(bar_data.is_stale(self.ASSET2))
for field in ALL_FIELDS:
for asset in self.ASSETS:
asset_value = bar_data.current(asset, field)
if field in OHLCP:
self.assertTrue(np.isnan(asset_value))
elif field == "volume":
self.assertEqual(0, asset_value)
elif field == "last_traded":
self.assertTrue(asset_value is pd.NaT)
def test_semi_active_day(self):
# on self.days[0], only asset1 has data
bar_data = BarData(self.data_portal, lambda: self.days[0], "daily")
self.check_internal_consistency(bar_data)
self.assertTrue(bar_data.can_trade(self.ASSET1))
self.assertFalse(bar_data.can_trade(self.ASSET2))
# because there is real data
self.assertFalse(bar_data.is_stale(self.ASSET1))
# because there has never been a trade bar yet
self.assertFalse(bar_data.is_stale(self.ASSET2))
self.assertEqual(3, bar_data.current(self.ASSET1, "open"))
self.assertEqual(4, bar_data.current(self.ASSET1, "high"))
self.assertEqual(1, bar_data.current(self.ASSET1, "low"))
self.assertEqual(2, bar_data.current(self.ASSET1, "close"))
self.assertEqual(200, bar_data.current(self.ASSET1, "volume"))
self.assertEqual(2, bar_data.current(self.ASSET1, "price"))
self.assertEqual(self.days[0],
bar_data.current(self.ASSET1, "last_traded"))
for field in OHLCP:
self.assertTrue(np.isnan(bar_data.current(self.ASSET2, field)),
field)
self.assertEqual(0, bar_data.current(self.ASSET2, "volume"))
self.assertTrue(
bar_data.current(self.ASSET2, "last_traded") is pd.NaT
)
def test_fully_active_day(self):
bar_data = BarData(self.data_portal, lambda: self.days[1], "daily")
self.check_internal_consistency(bar_data)
# on self.days[1], both assets have data
for asset in self.ASSETS:
self.assertTrue(bar_data.can_trade(asset))
self.assertFalse(bar_data.is_stale(asset))
self.assertEqual(4, bar_data.current(asset, "open"))
self.assertEqual(5, bar_data.current(asset, "high"))
self.assertEqual(2, bar_data.current(asset, "low"))
self.assertEqual(3, bar_data.current(asset, "close"))
self.assertEqual(300, bar_data.current(asset, "volume"))
self.assertEqual(3, bar_data.current(asset, "price"))
self.assertEqual(
self.days[1],
bar_data.current(asset, "last_traded")
)
def test_last_active_day(self):
bar_data = BarData(self.data_portal, lambda: self.days[-1], "daily")
self.check_internal_consistency(bar_data)
for asset in self.ASSETS:
self.assertTrue(bar_data.can_trade(asset))
self.assertFalse(bar_data.is_stale(asset))
self.assertEqual(6, bar_data.current(asset, "open"))
self.assertEqual(7, bar_data.current(asset, "high"))
self.assertEqual(4, bar_data.current(asset, "low"))
self.assertEqual(5, bar_data.current(asset, "close"))
self.assertEqual(500, bar_data.current(asset, "volume"))
self.assertEqual(5, bar_data.current(asset, "price"))
def test_after_assets_dead(self):
# both assets end on self.day[-1], so let's try the next day
next_day = self.env.next_trading_day(self.days[-1])
bar_data = BarData(self.data_portal, lambda: next_day, "daily")
self.check_internal_consistency(bar_data)
for asset in self.ASSETS:
self.assertFalse(bar_data.can_trade(asset))
self.assertFalse(bar_data.is_stale(asset))
for field in OHLCP:
self.assertTrue(np.isnan(bar_data.current(asset, field)))
self.assertEqual(0, bar_data.current(asset, "volume"))
last_traded_dt = bar_data.current(asset, "last_traded")
if asset == self.ASSET1:
self.assertEqual(self.days[-2], last_traded_dt)
else:
self.assertEqual(self.days[1], last_traded_dt)
@parameterized.expand([
("split", 2, 3, 3, 1.5),
("merger", 2, 3, 3, 1.8),
("dividend", 2, 3, 3, 2.88)
])
def test_spot_price_adjustments(self,
adjustment_type,
liquid_day_0_price,
liquid_day_1_price,
illiquid_day_0_price,
illiquid_day_1_price_adjusted):
"""Test the behaviour of spot prices during adjustments."""
table_name = adjustment_type + 's'
liquid_asset = getattr(self, (adjustment_type.upper() + "_ASSET"))
illiquid_asset = getattr(
self,
("ILLIQUID_" + adjustment_type.upper() + "_ASSET")
)
# verify there is an adjustment for liquid_asset
adjustments = self.adjustments_reader.get_adjustments_for_sid(
table_name,
liquid_asset.sid
)
self.assertEqual(1, len(adjustments))
adjustment = adjustments[0]
self.assertEqual(
adjustment[0],
pd.Timestamp("2016-01-06", tz='UTC')
)
# ... but that's it's not applied when using spot value
bar_data = BarData(self.data_portal, lambda: self.days[0], "daily")
self.assertEqual(
liquid_day_0_price,
bar_data.current(liquid_asset, "price")
)
bar_data = BarData(self.data_portal, lambda: self.days[1], "daily")
self.assertEqual(
liquid_day_1_price,
bar_data.current(liquid_asset, "price")
)
# ... except when we have to forward fill across a day boundary
# ILLIQUID_ASSET has no data on days 0 and 2, and a split on day 2
bar_data = BarData(self.data_portal, lambda: self.days[1], "daily")
self.assertEqual(
illiquid_day_0_price, bar_data.current(illiquid_asset, "price")
)
bar_data = BarData(self.data_portal, lambda: self.days[2], "daily")
# 3 (price from previous day) * 0.5 (split ratio)
self.assertAlmostEqual(
illiquid_day_1_price_adjusted,
bar_data.current(illiquid_asset, "price")
)
-297
View File
@@ -1,297 +0,0 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from copy import deepcopy
from datetime import datetime
from unittest import TestCase
import pytz
import numpy as np
import pandas as pd
from zipline.algorithm import TradingAlgorithm
from zipline.finance.trading import TradingEnvironment
from zipline.sources.data_source import DataSource
from zipline.test_algorithms import (
BatchTransformAlgorithm,
BatchTransformAlgorithmMinute,
)
from zipline.testing import setup_logger, teardown_logger
from zipline.transforms import batch_transform
import zipline.utils.factory as factory
from zipline.utils.tradingcalendar import trading_days
@batch_transform
def return_price(data):
return data.price
class BatchTransformAlgorithmSetSid(TradingAlgorithm):
def initialize(self, sids=None):
self.history = []
self.batch_transform = return_price(
refresh_period=1,
window_length=10,
clean_nans=False,
sids=sids,
compute_only_full=False
)
def handle_data(self, data):
self.history.append(
deepcopy(self.batch_transform.handle_data(data)))
class DifferentSidSource(DataSource):
def __init__(self):
self.dates = pd.date_range('1990-01-01', periods=180, tz='utc')
self.start = self.dates[0]
self.end = self.dates[-1]
self._raw_data = None
self.sids = range(90)
self.sid = 0
self.trading_days = []
@property
def instance_hash(self):
return '1234'
@property
def raw_data(self):
if not self._raw_data:
self._raw_data = self.raw_data_gen()
return self._raw_data
@property
def mapping(self):
return {
'dt': (lambda x: x, 'dt'),
'sid': (lambda x: x, 'sid'),
'price': (float, 'price'),
'volume': (int, 'volume'),
}
def raw_data_gen(self):
# Create differente sid for each event
for date in self.dates:
if date not in trading_days:
continue
event = {'dt': date,
'sid': self.sid,
'price': self.sid,
'volume': self.sid}
self.sid += 1
self.trading_days.append(date)
yield event
class TestChangeOfSids(TestCase):
def setUp(self):
self.sids = range(90)
self.env = TradingEnvironment()
self.env.write_data(equities_identifiers=self.sids)
self.sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 1, 8, tzinfo=pytz.utc),
env=self.env,
)
def test_all_sids_passed(self):
algo = BatchTransformAlgorithmSetSid(
sim_params=self.sim_params,
env=self.env,
)
source = DifferentSidSource()
algo.run(source)
for i, (df, date) in enumerate(zip(algo.history, source.trading_days)):
self.assertEqual(df.index[-1], date, "Newest event doesn't \
match.")
for sid in self.sids[:i]:
self.assertIn(sid, df.columns)
self.assertEqual(df.iloc[-1].iloc[-1], i)
class TestBatchTransformMinutely(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[0])
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self):
setup_logger(self)
start = pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
self.sim_params = factory.create_simulation_parameters(
start=start, end=end, env=self.env,
)
self.sim_params.emission_rate = 'daily'
self.sim_params.data_frequency = 'minute'
self.source, self.df = \
factory.create_test_df_source(sim_params=self.sim_params,
env=self.env,
bars='minute')
def tearDown(self):
teardown_logger(self)
def test_core(self):
algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params,
env=self.env)
algo.run(self.source)
wl = int(algo.window_length * 6.5 * 60)
for bt in algo.history[wl:]:
self.assertEqual(len(bt), wl)
def test_window_length(self):
algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params,
env=self.env,
window_length=1,
refresh_period=0)
algo.run(self.source)
wl = int(algo.window_length * 6.5 * 60)
np.testing.assert_array_equal(algo.history[:(wl - 1)],
[None] * (wl - 1))
for bt in algo.history[wl:]:
self.assertEqual(len(bt), wl)
class TestBatchTransform(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[0])
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self):
setup_logger(self)
self.sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 1, 8, tzinfo=pytz.utc),
env=self.env
)
self.source, self.df = \
factory.create_test_df_source(self.sim_params, self.env)
def tearDown(self):
teardown_logger(self)
def test_core_functionality(self):
algo = BatchTransformAlgorithm(sim_params=self.sim_params,
env=self.env)
algo.run(self.source)
wl = algo.window_length
# The following assertion depend on window length of 3
self.assertEqual(wl, 3)
# If window_length is 3, there should be 2 None events, as the
# window fills up on the 3rd day.
n_none_events = 2
self.assertEqual(algo.history_return_price_class[:n_none_events],
[None] * n_none_events,
"First two iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_class,))
self.assertEqual(algo.history_return_price_decorator[:n_none_events],
[None] * n_none_events,
"First two iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_decorator,))
# After three Nones, the next value should be a data frame
self.assertTrue(isinstance(
algo.history_return_price_class[wl],
pd.DataFrame)
)
# Test whether arbitrary fields can be added to datapanel
field = algo.history_return_arbitrary_fields[-1]
self.assertTrue(
'arbitrary' in field.items,
'datapanel should contain column arbitrary'
)
self.assertTrue(all(
field['arbitrary'].values.flatten() ==
[123] * algo.window_length),
'arbitrary dataframe should contain only "test"'
)
for data in algo.history_return_sid_filter[wl:]:
self.assertIn(0, data.columns)
self.assertNotIn(1, data.columns)
for data in algo.history_return_field_filter[wl:]:
self.assertIn('price', data.items)
self.assertNotIn('ignore', data.items)
for data in algo.history_return_field_no_filter[wl:]:
self.assertIn('price', data.items)
self.assertIn('ignore', data.items)
for data in algo.history_return_ticks[wl:]:
self.assertTrue(isinstance(data, deque))
for data in algo.history_return_not_full:
self.assertIsNot(data, None)
# test overloaded class
for test_history in [algo.history_return_price_class,
algo.history_return_price_decorator]:
# starting at window length, the window should contain
# consecutive (of window length) numbers up till the end.
for i in range(algo.window_length, len(test_history)):
np.testing.assert_array_equal(
range(i - algo.window_length + 2, i + 2),
test_history[i].values.flatten()
)
def test_passing_of_args(self):
algo = BatchTransformAlgorithm(1, kwarg='str',
sim_params=self.sim_params,
env=self.env)
algo.run(self.source)
self.assertEqual(algo.args, (1,))
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
expected_item = ((1, ), {'kwarg': 'str'})
self.assertEqual(
algo.history_return_args,
[
# 1990-01-01 - market holiday, no event
# 1990-01-02 - window not full
None,
# 1990-01-03 - window not full
None,
# 1990-01-04 - window now full, 3rd event
expected_item,
# 1990-01-05 - window now full
expected_item,
# 1990-01-08 - window now full
expected_item
])
+207
View File
@@ -0,0 +1,207 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import TestCase
from datetime import timedelta
import numpy as np
import pandas as pd
from testfixtures import TempDirectory
from zipline.data.us_equity_pricing import SQLiteAdjustmentWriter, \
SQLiteAdjustmentReader
from zipline.errors import (
BenchmarkAssetNotAvailableTooEarly,
BenchmarkAssetNotAvailableTooLate,
InvalidBenchmarkAsset)
from zipline.finance.trading import TradingEnvironment
from zipline.sources.benchmark_source import BenchmarkSource
from zipline.utils import factory
from zipline.testing.core import create_data_portal, write_minute_data, \
create_empty_splits_mergers_frame
from .test_perf_tracking import MockDailyBarSpotReader
class TestBenchmark(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.tempdir = TempDirectory()
cls.sim_params = factory.create_simulation_parameters()
cls.env.write_data(equities_data={
1: {
"start_date": cls.sim_params.trading_days[0],
"end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
},
2: {
"start_date": cls.sim_params.trading_days[0],
"end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
},
3: {
"start_date": cls.sim_params.trading_days[100],
"end_date": cls.sim_params.trading_days[-100]
},
4: {
"start_date": cls.sim_params.trading_days[0],
"end_date": cls.sim_params.trading_days[-1] + timedelta(days=1)
}
})
dbpath = os.path.join(cls.tempdir.path, "adjustments.db")
writer = SQLiteAdjustmentWriter(dbpath, cls.env.trading_days,
MockDailyBarSpotReader())
splits = mergers = create_empty_splits_mergers_frame()
dividends = pd.DataFrame({
'sid': np.array([], dtype=np.uint32),
'amount': np.array([], dtype=np.float64),
'declared_date': np.array([], dtype='datetime64[ns]'),
'ex_date': np.array([], dtype='datetime64[ns]'),
'pay_date': np.array([], dtype='datetime64[ns]'),
'record_date': np.array([], dtype='datetime64[ns]'),
})
declared_date = cls.sim_params.trading_days[45]
ex_date = cls.sim_params.trading_days[50]
record_date = pay_date = cls.sim_params.trading_days[55]
stock_dividends = pd.DataFrame({
'sid': np.array([4], dtype=np.uint32),
'payment_sid': np.array([5], dtype=np.uint32),
'ratio': np.array([2], dtype=np.float64),
'declared_date': np.array([declared_date], dtype='datetime64[ns]'),
'ex_date': np.array([ex_date], dtype='datetime64[ns]'),
'record_date': np.array([record_date], dtype='datetime64[ns]'),
'pay_date': np.array([pay_date], dtype='datetime64[ns]'),
})
writer.write(splits, mergers, dividends,
stock_dividends=stock_dividends)
cls.data_portal = create_data_portal(
cls.env,
cls.tempdir,
cls.sim_params,
[1, 2, 3, 4],
adjustment_reader=SQLiteAdjustmentReader(dbpath)
)
@classmethod
def tearDownClass(cls):
del cls.env
cls.tempdir.cleanup()
def test_normal(self):
days_to_use = self.sim_params.trading_days[1:]
source = BenchmarkSource(
1, self.env, days_to_use, self.data_portal
)
# should be the equivalent of getting the price history, then doing
# a pct_change on it
manually_calculated = self.data_portal.get_history_window(
[1], days_to_use[-1], len(days_to_use), "1d", "close"
)[1].pct_change()
# compare all the fields except the first one, for which we don't have
# data in manually_calculated
for idx, day in enumerate(days_to_use[1:]):
self.assertEqual(
source.get_value(day),
manually_calculated[idx + 1]
)
def test_asset_not_trading(self):
with self.assertRaises(BenchmarkAssetNotAvailableTooEarly) as exc:
BenchmarkSource(
3,
self.env,
self.sim_params.trading_days[1:],
self.data_portal
)
self.assertEqual(
'3 does not exist on 2006-01-04 00:00:00+00:00. '
'It started trading on 2006-05-26 00:00:00+00:00.',
exc.exception.message
)
with self.assertRaises(BenchmarkAssetNotAvailableTooLate) as exc2:
BenchmarkSource(
3,
self.env,
self.sim_params.trading_days[120:],
self.data_portal
)
self.assertEqual(
'3 does not exist on 2006-06-26 00:00:00+00:00. '
'It stopped trading on 2006-08-09 00:00:00+00:00.',
exc2.exception.message
)
def test_asset_IPOed_same_day(self):
# gotta get some minute data up in here.
# add sid 4 for a couple of days
minutes = self.env.minutes_for_days_in_range(
self.sim_params.trading_days[0],
self.sim_params.trading_days[5]
)
path = write_minute_data(
self.env,
self.tempdir,
minutes,
[2]
)
self.data_portal._minutes_equities_path = path
source = BenchmarkSource(
2,
self.env,
self.sim_params.trading_days,
self.data_portal
)
days_to_use = self.sim_params.trading_days
# first value should be 0.0, coming from daily data
self.assertAlmostEquals(0.0, source.get_value(days_to_use[0]))
manually_calculated = self.data_portal.get_history_window(
[2], days_to_use[-1], len(days_to_use), "1d", "close"
)[2].pct_change()
for idx, day in enumerate(days_to_use[1:]):
self.assertEqual(
source.get_value(day),
manually_calculated[idx + 1]
)
def test_no_stock_dividends_allowed(self):
# try to use sid(4) as benchmark, should blow up due to the presence
# of a stock dividend
with self.assertRaises(InvalidBenchmarkAsset) as exc:
BenchmarkSource(
4, self.env, self.sim_params.trading_days, self.data_portal
)
self.assertEqual("4 cannot be used as the benchmark because it has a "
"stock dividend on 2006-03-16 00:00:00. Choose "
"another asset to use as the benchmark.",
exc.exception.message)
+190 -55
View File
@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from nose_parameterized import parameterized
from unittest import TestCase
from testfixtures import TempDirectory
import pandas as pd
import zipline.utils.factory as factory
from zipline.finance import trading
from zipline.finance.blotter import Blotter
@@ -26,30 +30,89 @@ from zipline.finance.execution import (
StopLimitOrder,
StopOrder,
)
from zipline.sources.test_source import create_trade
from zipline.testing import(
setup_logger,
teardown_logger,
)
from zipline.gens.sim_engine import DAY_END, BAR
from zipline.finance.cancel_policy import EODCancel, NeverCancel
from zipline.finance.slippage import DEFAULT_VOLUME_SLIPPAGE_BAR_LIMIT, \
FixedSlippage
from .utils.daily_bar_writer import DailyBarWriterFromDataFrames
from zipline.data.us_equity_pricing import BcolzDailyBarReader
from zipline.data.data_portal import DataPortal
from zipline.protocol import BarData
class BlotterTestCase(TestCase):
@classmethod
def setUpClass(cls):
setup_logger(cls)
cls.env = trading.TradingEnvironment()
cls.env.write_data(equities_identifiers=[24])
cls.sim_params = factory.create_simulation_parameters(
start=pd.Timestamp("2006-01-05", tz='UTC'),
end=pd.Timestamp("2006-01-06", tz='UTC')
)
cls.env.write_data(equities_data={
24: {
'start_date': cls.sim_params.trading_days[0],
'end_date': cls.env.next_trading_day(
cls.sim_params.trading_days[-1]
)
},
25: {
'start_date': cls.sim_params.trading_days[0],
'end_date': cls.env.next_trading_day(
cls.sim_params.trading_days[-1]
)
}
})
cls.tempdir = TempDirectory()
assets = {
24: pd.DataFrame({
"open": [50, 50],
"high": [50, 50],
"low": [50, 50],
"close": [50, 50],
"volume": [100, 400],
"day": [day.value for day in cls.sim_params.trading_days]
}),
25: pd.DataFrame({
"open": [50, 50],
"high": [50, 50],
"low": [50, 50],
"close": [50, 50],
"volume": [100, 400],
"day": [day.value for day in cls.sim_params.trading_days]
})
}
path = os.path.join(cls.tempdir.path, "tempdata.bcolz")
DailyBarWriterFromDataFrames(assets).write(
path,
cls.sim_params.trading_days,
assets
)
equity_daily_reader = BcolzDailyBarReader(path)
cls.data_portal = DataPortal(
cls.env,
equity_daily_reader=equity_daily_reader,
)
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self, env=None):
setup_logger(self)
def tearDown(self):
teardown_logger(self)
cls.tempdir.cleanup()
teardown_logger(cls)
@parameterized.expand([(MarketOrder(), None, None),
(LimitOrder(10), 10, None),
@@ -57,60 +120,108 @@ class BlotterTestCase(TestCase):
(StopLimitOrder(10, 20), 10, 20)])
def test_blotter_order_types(self, style_obj, expected_lmt, expected_stp):
blotter = Blotter()
blotter = Blotter('daily', self.env.asset_finder)
blotter.order(24, 100, style_obj)
result = blotter.open_orders[24][0]
asset_24 = blotter.asset_finder.retrieve_asset(24)
blotter.order(asset_24, 100, style_obj)
result = blotter.open_orders[asset_24][0]
self.assertEqual(result.limit, expected_lmt)
self.assertEqual(result.stop, expected_stp)
def test_cancel(self):
blotter = Blotter()
blotter = Blotter('daily', self.env.asset_finder)
oid_1 = blotter.order(24, 100, MarketOrder())
oid_2 = blotter.order(24, 200, MarketOrder())
oid_3 = blotter.order(24, 300, MarketOrder())
asset_24 = blotter.asset_finder.retrieve_asset(24)
asset_25 = blotter.asset_finder.retrieve_asset(25)
oid_1 = blotter.order(asset_24, 100, MarketOrder())
oid_2 = blotter.order(asset_24, 200, MarketOrder())
oid_3 = blotter.order(asset_24, 300, MarketOrder())
# Create an order for another asset to verify that we don't remove it
# when we do cancel_all on 24.
blotter.order(25, 150, MarketOrder())
blotter.order(asset_25, 150, MarketOrder())
self.assertEqual(len(blotter.open_orders), 2)
self.assertEqual(len(blotter.open_orders[24]), 3)
self.assertEqual(len(blotter.open_orders[asset_24]), 3)
self.assertEqual(
[o.amount for o in blotter.open_orders[24]],
[o.amount for o in blotter.open_orders[asset_24]],
[100, 200, 300],
)
blotter.cancel(oid_2)
self.assertEqual(len(blotter.open_orders), 2)
self.assertEqual(len(blotter.open_orders[24]), 2)
self.assertEqual(len(blotter.open_orders[asset_24]), 2)
self.assertEqual(
[o.amount for o in blotter.open_orders[24]],
[o.amount for o in blotter.open_orders[asset_24]],
[100, 300],
)
self.assertEqual(
[o.id for o in blotter.open_orders[24]],
[o.id for o in blotter.open_orders[asset_24]],
[oid_1, oid_3],
)
blotter.cancel_all(24)
blotter.cancel_all_orders_for_asset(asset_24)
self.assertEqual(len(blotter.open_orders), 1)
self.assertEqual(list(blotter.open_orders), [25])
self.assertEqual(list(blotter.open_orders), [asset_25])
def test_blotter_eod_cancellation(self):
blotter = Blotter('minute', self.env.asset_finder,
cancel_policy=EODCancel())
asset_24 = blotter.asset_finder.retrieve_asset(24)
# Make two orders for the same sid, so we can test that we are not
# mutating the orders list as we are cancelling orders
blotter.order(asset_24, 100, MarketOrder())
blotter.order(asset_24, -100, MarketOrder())
self.assertEqual(len(blotter.new_orders), 2)
order_ids = [order.id for order in blotter.open_orders[asset_24]]
self.assertEqual(blotter.new_orders[0].status, ORDER_STATUS.OPEN)
self.assertEqual(blotter.new_orders[1].status, ORDER_STATUS.OPEN)
blotter.execute_cancel_policy(BAR)
self.assertEqual(blotter.new_orders[0].status, ORDER_STATUS.OPEN)
self.assertEqual(blotter.new_orders[1].status, ORDER_STATUS.OPEN)
blotter.execute_cancel_policy(DAY_END)
for order_id in order_ids:
order = blotter.orders[order_id]
self.assertEqual(order.status, ORDER_STATUS.CANCELLED)
def test_blotter_never_cancel(self):
blotter = Blotter('minute', self.env.asset_finder,
cancel_policy=NeverCancel())
blotter.order(blotter.asset_finder.retrieve_asset(24), 100,
MarketOrder())
self.assertEqual(len(blotter.new_orders), 1)
self.assertEqual(blotter.new_orders[0].status, ORDER_STATUS.OPEN)
blotter.execute_cancel_policy(BAR)
self.assertEqual(blotter.new_orders[0].status, ORDER_STATUS.OPEN)
blotter.execute_cancel_policy(DAY_END)
self.assertEqual(blotter.new_orders[0].status, ORDER_STATUS.OPEN)
def test_order_rejection(self):
blotter = Blotter()
blotter = Blotter(self.sim_params.data_frequency,
self.env.asset_finder)
asset_24 = blotter.asset_finder.retrieve_asset(24)
# Reject a nonexistent order -> no order appears in new_order,
# no exceptions raised out
blotter.reject(56)
self.assertEqual(blotter.new_orders, [])
# Basic tests of open order behavior
open_order_id = blotter.order(24, 100, MarketOrder())
second_order_id = blotter.order(24, 50, MarketOrder())
self.assertEqual(len(blotter.open_orders[24]), 2)
open_order = blotter.open_orders[24][0]
open_order_id = blotter.order(asset_24, 100, MarketOrder())
second_order_id = blotter.order(asset_24, 50, MarketOrder())
self.assertEqual(len(blotter.open_orders[asset_24]), 2)
open_order = blotter.open_orders[asset_24][0]
self.assertEqual(open_order.status, ORDER_STATUS.OPEN)
self.assertEqual(open_order.id, open_order_id)
self.assertIn(open_order, blotter.new_orders)
@@ -118,7 +229,7 @@ class BlotterTestCase(TestCase):
# Reject that order immediately (same bar, i.e. still in new_orders)
blotter.reject(open_order_id)
self.assertEqual(len(blotter.new_orders), 2)
self.assertEqual(len(blotter.open_orders[24]), 1)
self.assertEqual(len(blotter.open_orders[asset_24]), 1)
still_open_order = blotter.new_orders[0]
self.assertEqual(still_open_order.id, second_order_id)
self.assertEqual(still_open_order.status, ORDER_STATUS.OPEN)
@@ -128,9 +239,10 @@ class BlotterTestCase(TestCase):
# Do it again, but reject it at a later time (after tradesimulation
# pulls it from new_orders)
blotter = Blotter()
new_open_id = blotter.order(24, 10, MarketOrder())
new_open_order = blotter.open_orders[24][0]
blotter = Blotter(self.sim_params.data_frequency,
self.env.asset_finder)
new_open_id = blotter.order(asset_24, 10, MarketOrder())
new_open_order = blotter.open_orders[asset_24][0]
self.assertEqual(new_open_id, new_open_order.id)
# Pretend that the trade simulation did this.
blotter.new_orders = []
@@ -143,18 +255,26 @@ class BlotterTestCase(TestCase):
self.assertEqual(rejected_order.reason, rejection_reason)
# You can't reject a filled order.
blotter = Blotter() # Reset for paranoia
blotter.current_dt = datetime.datetime.now()
filled_id = blotter.order(24, 100, MarketOrder())
aapl_trade = create_trade(24, 50.0, 400, datetime.datetime.now())
# Reset for paranoia
blotter = Blotter(self.sim_params.data_frequency,
self.env.asset_finder)
blotter.slippage_func = FixedSlippage()
filled_id = blotter.order(asset_24, 100, MarketOrder())
filled_order = None
for txn, updated_order in blotter.process_trade(aapl_trade):
filled_order = updated_order
blotter.current_dt = self.sim_params.trading_days[-1]
bar_data = BarData(
self.data_portal,
lambda: self.sim_params.trading_days[-1],
self.sim_params.data_frequency,
)
txns, _ = blotter.get_transactions(bar_data)
for txn in txns:
filled_order = blotter.orders[txn.order_id]
self.assertEqual(filled_order.id, filled_id)
self.assertIn(filled_order, blotter.new_orders)
self.assertEqual(filled_order.status, ORDER_STATUS.FILLED)
self.assertNotIn(filled_order, blotter.open_orders[24])
self.assertNotIn(filled_order, blotter.open_orders[asset_24])
blotter.reject(filled_id)
updated_order = blotter.orders[filled_id]
@@ -166,50 +286,65 @@ class BlotterTestCase(TestCase):
status indication. When a fill happens, the order should switch
status to OPEN/FILLED as necessary
"""
blotter = Blotter()
blotter = Blotter(self.sim_params.data_frequency,
self.env.asset_finder)
# Nothing happens on held of a non-existent order
blotter.hold(56)
self.assertEqual(blotter.new_orders, [])
open_id = blotter.order(24, 100, MarketOrder())
open_order = blotter.open_orders[24][0]
asset_24 = blotter.asset_finder.retrieve_asset(24)
open_id = blotter.order(asset_24, 100, MarketOrder())
open_order = blotter.open_orders[asset_24][0]
self.assertEqual(open_order.id, open_id)
blotter.hold(open_id)
self.assertEqual(len(blotter.new_orders), 1)
self.assertEqual(len(blotter.open_orders[24]), 1)
self.assertEqual(len(blotter.open_orders[asset_24]), 1)
held_order = blotter.new_orders[0]
self.assertEqual(held_order.status, ORDER_STATUS.HELD)
self.assertEqual(held_order.reason, '')
blotter.cancel(held_order.id)
self.assertEqual(len(blotter.new_orders), 1)
self.assertEqual(len(blotter.open_orders[24]), 0)
self.assertEqual(len(blotter.open_orders[asset_24]), 0)
cancelled_order = blotter.new_orders[0]
self.assertEqual(cancelled_order.id, held_order.id)
self.assertEqual(cancelled_order.status, ORDER_STATUS.CANCELLED)
for trade_amt in (100, 400):
for data in ([100, self.sim_params.trading_days[0]],
[400, self.sim_params.trading_days[1]]):
# Verify that incoming fills will change the order status.
trade_amt = data[0]
dt = data[1]
order_size = 100
expected_filled = trade_amt * 0.25
expected_filled = int(trade_amt *
DEFAULT_VOLUME_SLIPPAGE_BAR_LIMIT)
expected_open = order_size - expected_filled
expected_status = ORDER_STATUS.OPEN if expected_open else \
ORDER_STATUS.FILLED
blotter = Blotter()
blotter.current_dt = datetime.datetime.now()
open_id = blotter.order(24, order_size, MarketOrder())
open_order = blotter.open_orders[24][0]
blotter = Blotter(self.sim_params.data_frequency,
self.env.asset_finder)
open_id = blotter.order(blotter.asset_finder.retrieve_asset(24),
order_size, MarketOrder())
open_order = blotter.open_orders[asset_24][0]
self.assertEqual(open_id, open_order.id)
blotter.hold(open_id)
held_order = blotter.new_orders[0]
aapl_trade = create_trade(24, 50.0, trade_amt,
datetime.datetime.now())
filled_order = None
for txn, updated_order in blotter.process_trade(aapl_trade):
filled_order = updated_order
blotter.current_dt = dt
bar_data = BarData(
self.data_portal,
lambda: dt,
self.sim_params.data_frequency,
)
txns, _ = blotter.get_transactions(bar_data)
for txn in txns:
filled_order = blotter.orders[txn.order_id]
self.assertEqual(filled_order.id, held_order.id)
self.assertEqual(filled_order.status, expected_status)
self.assertEqual(filled_order.filled, expected_filled)
+73
View File
@@ -0,0 +1,73 @@
#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pandas.tslib import Timedelta
from zipline.data.data_portal import DataPortal
from zipline.testing.fixtures import WithTradingEnvironment, ZiplineTestCase
import pandas as pd
# Note: most of dataportal functionality is tested in various other places,
# such as test_history.
class TestDataPortal(WithTradingEnvironment, ZiplineTestCase):
def init_instance_fixtures(self):
super(TestDataPortal, self).init_instance_fixtures()
self.data_portal = DataPortal(self.env)
def test_bar_count_for_simple_transforms(self):
# July 2015
# Su Mo Tu We Th Fr Sa
# 1 2 3 4
# 5 6 7 8 9 10 11
# 12 13 14 15 16 17 18
# 19 20 21 22 23 24 25
# 26 27 28 29 30 31
# half an hour into july 9, getting a 4-"day" window should get us
# all the minutes of 7/6, 7/7, 7/8, and 31 minutes of 7/9
july_9_dt = self.env.get_open_and_close(
pd.Timestamp("2015-07-09")
)[0] + Timedelta("30 minutes")
self.assertEqual(
(3 * 390) + 31,
self.data_portal._get_minute_count_for_transform(july_9_dt, 4)
)
# November 2015
# Su Mo Tu We Th Fr Sa
# 1 2 3 4 5 6 7
# 8 9 10 11 12 13 14
# 15 16 17 18 19 20 21
# 22 23 24 25 26 27 28
# 29 30
# nov 26th closed
# nov 27th was an early close
# half an hour into nov 30, getting a 4-"day" window should get us
# all the minutes of 11/24, 11/25, 11/27 (half day!), and 31 minutes
# of 11/30
nov_30_dt = self.env.get_open_and_close(
pd.Timestamp("2015-11-30")
)[0] + Timedelta("30 minutes")
self.assertEqual(
390 + 390 + 210 + 31,
self.data_portal._get_minute_count_for_transform(nov_30_dt, 4)
)
-349
View File
@@ -1,349 +0,0 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import datetime
import pandas as pd
import pytz
import numpy as np
from zipline.finance.trading import SimulationParameters, TradingEnvironment
from zipline.algorithm import TradingAlgorithm
from zipline.protocol import (
Event,
DATASOURCE_TYPE
)
class BuyAndHoldAlgorithm(TradingAlgorithm):
SID_TO_BUY_AND_HOLD = 1
def initialize(self):
self.holding = False
def handle_data(self, data):
if not self.holding:
self.order(self.sid(self.SID_TO_BUY_AND_HOLD), 100)
self.holding = True
class TestEventsThroughRisk(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[1])
@classmethod
def tearDownClass(cls):
del cls.env
def test_daily_buy_and_hold(self):
start_date = datetime.datetime(
year=2006,
month=1,
day=3,
hour=0,
minute=0,
tzinfo=pytz.utc)
end_date = datetime.datetime(
year=2006,
month=1,
day=5,
hour=0,
minute=0,
tzinfo=pytz.utc)
sim_params = SimulationParameters(
period_start=start_date,
period_end=end_date,
data_frequency='daily',
emission_rate='daily'
)
algo = BuyAndHoldAlgorithm(sim_params=sim_params, env=self.env)
first_date = pd.Timestamp('2006-01-03', tz='UTC')
second_date = pd.Timestamp('2006-01-04', tz='UTC')
third_date = pd.Timestamp('2006-01-05', tz='UTC')
trade_bar_data = [
Event({
'open_price': 10,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': first_date,
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 15,
'close_price': 20,
'price': 20,
'volume': 2000,
'sid': 1,
'dt': second_date,
'source_id': 'test_list',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 20,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': third_date,
'source_id': 'test_list',
'type': DATASOURCE_TYPE.TRADE
}),
]
benchmark_data = [
Event({
'returns': 0.1,
'dt': first_date,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
Event({
'returns': 0.2,
'dt': second_date,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
Event({
'returns': 0.4,
'dt': third_date,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
]
algo.benchmark_return_source = benchmark_data
algo.set_sources(list([trade_bar_data]))
gen = algo._create_generator(sim_params)
# TODO: Hand derive these results.
# Currently, the output from the time of this writing to
# at least be an early warning against changes.
expected_algorithm_returns = {
first_date: 0.0,
second_date: -0.000350,
third_date: -0.050018
}
# TODO: Hand derive these results.
# Currently, the output from the time of this writing to
# at least be an early warning against changes.
expected_sharpe = {
first_date: np.nan,
second_date: -22.322677,
third_date: -9.353741
}
for bar in gen:
current_dt = algo.datetime
crm = algo.perf_tracker.cumulative_risk_metrics
dt_loc = crm.cont_index.get_loc(current_dt)
np.testing.assert_almost_equal(
crm.algorithm_returns[dt_loc],
expected_algorithm_returns[current_dt],
decimal=6)
np.testing.assert_almost_equal(
crm.sharpe[dt_loc],
expected_sharpe[current_dt],
decimal=6,
err_msg="Mismatch at %s" % (current_dt,))
def test_minute_buy_and_hold(self):
start_date = datetime.datetime(
year=2006,
month=1,
day=3,
hour=0,
minute=0,
tzinfo=pytz.utc)
end_date = datetime.datetime(
year=2006,
month=1,
day=5,
hour=0,
minute=0,
tzinfo=pytz.utc)
sim_params = SimulationParameters(
period_start=start_date,
period_end=end_date,
emission_rate='daily',
data_frequency='minute',
env=self.env)
algo = BuyAndHoldAlgorithm(
sim_params=sim_params,
env=self.env)
first_date = datetime.datetime(2006, 1, 3, tzinfo=pytz.utc)
first_open, first_close = self.env.get_open_and_close(first_date)
second_date = datetime.datetime(2006, 1, 4, tzinfo=pytz.utc)
second_open, second_close = self.env.get_open_and_close(second_date)
third_date = datetime.datetime(2006, 1, 5, tzinfo=pytz.utc)
third_open, third_close = self.env.get_open_and_close(third_date)
benchmark_data = [
Event({
'returns': 0.1,
'dt': first_close,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
Event({
'returns': 0.2,
'dt': second_close,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
Event({
'returns': 0.4,
'dt': third_close,
'source_id': 'test-benchmark-source',
'type': DATASOURCE_TYPE.BENCHMARK
}),
]
trade_bar_data = [
Event({
'open_price': 10,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': first_open,
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 10,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': first_open + datetime.timedelta(minutes=10),
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 15,
'close_price': 20,
'price': 20,
'volume': 2000,
'sid': 1,
'dt': second_open,
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 15,
'close_price': 20,
'price': 20,
'volume': 2000,
'sid': 1,
'dt': second_open + datetime.timedelta(minutes=10),
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 20,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': third_open,
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
Event({
'open_price': 20,
'close_price': 15,
'price': 15,
'volume': 1000,
'sid': 1,
'dt': third_open + datetime.timedelta(minutes=10),
'source_id': 'test-trade-source',
'type': DATASOURCE_TYPE.TRADE
}),
]
algo.benchmark_return_source = benchmark_data
algo.set_sources(list([trade_bar_data]))
gen = algo._create_generator(sim_params)
crm = algo.perf_tracker.cumulative_risk_metrics
dt_loc = crm.cont_index.get_loc(algo.datetime)
first_msg = next(gen)
self.assertIsNotNone(first_msg,
"There should be a message emitted.")
# Protects against bug where the positions appeared to be
# a day late, because benchmarks were triggering
# calculations before the events for the day were
# processed.
self.assertEqual(1, len(algo.portfolio.positions), "There should "
"be one position after the first day.")
self.assertEquals(
0,
crm.algorithm_volatility[dt_loc],
"On the first day algorithm volatility does not exist.")
second_msg = next(gen)
self.assertIsNotNone(second_msg, "There should be a message "
"emitted.")
self.assertEqual(1, len(algo.portfolio.positions),
"Number of positions should stay the same.")
# TODO: Hand derive. Current value is just a canary to
# detect changes.
np.testing.assert_almost_equal(
0.050022510129558301,
crm.algorithm_returns[-1],
decimal=6)
third_msg = next(gen)
self.assertEqual(1, len(algo.portfolio.positions),
"Number of positions should stay the same.")
self.assertIsNotNone(third_msg, "There should be a message "
"emitted.")
# TODO: Hand derive. Current value is just a canary to
# detect changes.
np.testing.assert_almost_equal(
-0.047639464532418657,
crm.algorithm_returns[-1],
decimal=6)
+1 -1
View File
@@ -45,7 +45,7 @@ class ExamplesTests(TestCase):
runpy.run_path(example, run_name='__main__')
# Test algorithm as if scripts/run_algo.py is being used.
def test_example_run_pipline(self):
def test_example_run_pipeline(self):
example = os.path.join(example_dir(), 'buyapple.py')
confs = ['-f', example, '--start', '2011-1-1', '--end', '2012-1-1']
parsed_args = parse_args(confs)
+36 -67
View File
@@ -14,8 +14,8 @@
# limitations under the License.
from unittest import TestCase
from testfixtures import TempDirectory
from zipline.finance.slippage import FixedSlippage
from zipline.finance.trading import TradingEnvironment
from zipline.test_algorithms import (
ExceptionAlgorithm,
@@ -23,13 +23,11 @@ from zipline.test_algorithms import (
SetPortfolioAlgorithm,
)
from zipline.testing import (
drain_zipline,
setup_logger,
teardown_logger,
ExceptionSource,
teardown_logger
)
import zipline.utils.simfactory as simfactory
import zipline.utils.factory as factory
from zipline.testing.core import create_data_portal
DEFAULT_TIMEOUT = 15 # seconds
EXTENDED_TIMEOUT = 90
@@ -39,87 +37,58 @@ class ExceptionTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.sid = 133
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[133])
cls.env.write_data(equities_identifiers=[cls.sid])
cls.tempdir = TempDirectory()
cls.sim_params = factory.create_simulation_parameters(
num_days=4,
env=cls.env
)
cls.data_portal = create_data_portal(
env=cls.env,
tempdir=cls.tempdir,
sim_params=cls.sim_params,
sids=[cls.sid]
)
setup_logger(cls)
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self):
self.zipline_test_config = {
'sid': 133,
'slippage': FixedSlippage()
}
setup_logger(self)
def tearDown(self):
teardown_logger(self)
def test_datasource_exception(self):
self.zipline_test_config['trade_source'] = ExceptionSource()
zipline = simfactory.create_test_zipline(
**self.zipline_test_config
)
with self.assertRaises(ZeroDivisionError):
output, _ = drain_zipline(self, zipline)
cls.tempdir.cleanup()
teardown_logger(cls)
def test_exception_in_handle_data(self):
# Simulation
# ----------
self.zipline_test_config['algorithm'] = \
ExceptionAlgorithm(
'handle_data',
self.zipline_test_config['sid'],
sim_params=factory.create_simulation_parameters(),
env=self.env
)
zipline = simfactory.create_test_zipline(
**self.zipline_test_config
)
algo = ExceptionAlgorithm('handle_data',
self.sid,
sim_params=self.sim_params,
env=self.env)
with self.assertRaises(Exception) as ctx:
output, _ = drain_zipline(self, zipline)
algo.run(self.data_portal)
self.assertEqual(str(ctx.exception), 'Algo exception in handle_data')
def test_zerodivision_exception_in_handle_data(self):
# Simulation
# ----------
self.zipline_test_config['algorithm'] = \
DivByZeroAlgorithm(
self.zipline_test_config['sid'],
sim_params=factory.create_simulation_parameters(),
env=self.env
)
zipline = simfactory.create_test_zipline(
**self.zipline_test_config
)
algo = DivByZeroAlgorithm(self.sid,
sim_params=self.sim_params,
env=self.env)
with self.assertRaises(ZeroDivisionError):
output, _ = drain_zipline(self, zipline)
algo.run(self.data_portal)
def test_set_portfolio(self):
"""
Are we protected against overwriting an algo's portfolio?
"""
# Simulation
# ----------
self.zipline_test_config['algorithm'] = \
SetPortfolioAlgorithm(
self.zipline_test_config['sid'],
sim_params=factory.create_simulation_parameters(),
env=self.env
)
zipline = simfactory.create_test_zipline(
**self.zipline_test_config
)
algo = SetPortfolioAlgorithm(self.sid,
sim_params=self.sim_params,
env=self.env)
with self.assertRaises(AttributeError):
output, _ = drain_zipline(self, zipline)
algo.run(self.data_portal)
+528
View File
@@ -0,0 +1,528 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from nose_parameterized import parameterized
import pandas as pd
import numpy as np
import responses
from mock import patch
from zipline import TradingAlgorithm
from zipline.errors import UnsupportedOrderParameters
from zipline.finance.trading import TradingEnvironment
from zipline.sources.requests_csv import mask_requests_args
from zipline.utils import factory
from zipline.testing.core import FetcherDataPortal
from .resources.fetcher_inputs.fetcher_test_data import (
MULTI_SIGNAL_CSV_DATA,
AAPL_CSV_DATA,
AAPL_MINUTE_CSV_DATA,
IBM_CSV_DATA,
ANNUAL_AAPL_CSV_DATA,
AAPL_IBM_CSV_DATA,
CPIAUCSL_DATA,
PALLADIUM_DATA,
FETCHER_UNIVERSE_DATA,
NON_ASSET_FETCHER_UNIVERSE_DATA,
FETCHER_UNIVERSE_DATA_TICKER_COLUMN, FETCHER_ALTERNATE_COLUMN_HEADER)
class FetcherTestCase(TestCase):
@classmethod
def setUpClass(cls):
responses.start()
responses.add(responses.GET,
'https://fake.urls.com/aapl_minute_csv_data.csv',
body=AAPL_MINUTE_CSV_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/aapl_csv_data.csv',
body=AAPL_CSV_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/multi_signal_csv_data.csv',
body=MULTI_SIGNAL_CSV_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/cpiaucsl_data.csv',
body=CPIAUCSL_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/ibm_csv_data.csv',
body=IBM_CSV_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/aapl_ibm_csv_data.csv',
body=AAPL_IBM_CSV_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/palladium_data.csv',
body=PALLADIUM_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/fetcher_universe_data.csv',
body=FETCHER_UNIVERSE_DATA, content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/bad_fetcher_universe_data.csv',
body=NON_ASSET_FETCHER_UNIVERSE_DATA,
content_type='text/csv')
responses.add(responses.GET,
'https://fake.urls.com/annual_aapl_csv_data.csv',
body=ANNUAL_AAPL_CSV_DATA, content_type='text/csv')
cls.sim_params = factory.create_simulation_parameters()
cls.env = TradingEnvironment()
cls.env.write_data(
equities_data={
24: {
"start_date": pd.Timestamp("2006-01-01", tz='UTC'),
"end_date": pd.Timestamp("2007-01-01", tz='UTC'),
'symbol': "AAPL",
"asset_type": "equity",
"exchange": "nasdaq"
},
3766: {
"start_date": pd.Timestamp("2006-01-01", tz='UTC'),
"end_date": pd.Timestamp("2007-01-01", tz='UTC'),
'symbol': "IBM",
"asset_type": "equity",
"exchange": "nasdaq"
},
5061: {
"start_date": pd.Timestamp("2006-01-01", tz='UTC'),
"end_date": pd.Timestamp("2007-01-01", tz='UTC'),
'symbol': "MSFT",
"asset_type": "equity",
"exchange": "nasdaq"
},
14848: {
"start_date": pd.Timestamp("2006-01-01", tz='UTC'),
"end_date": pd.Timestamp("2007-01-01", tz='UTC'),
'symbol': "YHOO",
"asset_type": "equity",
"exchange": "nasdaq"
},
25317: {
"start_date": pd.Timestamp("2006-01-01", tz='UTC'),
"end_date": pd.Timestamp("2007-01-01", tz='UTC'),
'symbol': "DELL",
"asset_type": "equity",
"exchange": "nasdaq"
}
}
)
@classmethod
def tearDownClass(cls):
responses.stop()
responses.reset()
def run_algo(self, code, sim_params=None, data_frequency="daily"):
if sim_params is None:
sim_params = self.sim_params
test_algo = TradingAlgorithm(
script=code,
sim_params=sim_params,
env=self.env,
data_frequency=data_frequency
)
results = test_algo.run(FetcherDataPortal(self.env))
return results
def test_fetch_minute_granularity(self):
sim_params = factory.create_simulation_parameters(
start=pd.Timestamp("2006-01-03", tz='UTC'),
end=pd.Timestamp("2006-01-10", tz='UTC'),
emission_rate="minute",
data_frequency="minute"
)
test_algo = TradingAlgorithm(
script="""
from zipline.api import fetch_csv, record, sid
def initialize(context):
fetch_csv('https://fake.urls.com/aapl_minute_csv_data.csv')
def handle_data(context, data):
record(aapl_signal=data.current(sid(24), "signal"))
""", sim_params=sim_params, data_frequency="minute", env=self.env)
# manually setting data portal and getting generator because we need
# the minutely emission packets here. TradingAlgorithm.run() only
# returns daily packets.
test_algo.data_portal = FetcherDataPortal(self.env)
gen = test_algo.get_generator()
perf_packets = list(gen)
signal = [result["minute_perf"]["recorded_vars"]["aapl_signal"] for
result in perf_packets if "minute_perf" in result]
self.assertEqual(6 * 390, len(signal))
# csv data is:
# symbol,date,signal
# aapl,1/4/06 5:31AM, 1
# aapl,1/4/06 11:30AM, 2
# aapl,1/5/06 5:31AM, 1
# aapl,1/5/06 11:30AM, 3
# aapl,1/9/06 5:31AM, 1
# aapl,1/9/06 11:30AM, 4 for dates 1/3 to 1/10
# 2 signals per day, only last signal is taken. So we expect
# 390 bars of signal NaN on 1/3
# 390 bars of signal 2 on 1/4
# 390 bars of signal 3 on 1/5
# 390 bars of signal 3 on 1/6 (forward filled)
# 390 bars of signal 4 on 1/9
# 390 bars of signal 4 on 1/9 (forward filled)
np.testing.assert_array_equal([np.NaN] * 390, signal[0:390])
np.testing.assert_array_equal([2] * 390, signal[390:780])
np.testing.assert_array_equal([3] * 780, signal[780:1560])
np.testing.assert_array_equal([4] * 780, signal[1560:])
def test_fetch_csv_with_multi_symbols(self):
results = self.run_algo(
"""
from zipline.api import fetch_csv, record, sid
def initialize(context):
fetch_csv('https://fake.urls.com/multi_signal_csv_data.csv')
context.stocks = [sid(3766), sid(25317)]
def handle_data(context, data):
record(ibm_signal=data.current(sid(3766), "signal"))
record(dell_signal=data.current(sid(25317), "signal"))
""")
self.assertEqual(5, results["ibm_signal"].iloc[-1])
self.assertEqual(5, results["dell_signal"].iloc[-1])
def test_fetch_csv_with_pure_signal_file(self):
results = self.run_algo(
"""
from zipline.api import fetch_csv, sid, record
def clean(df):
return df.rename(columns={'Value':'cpi', 'Date':'date'})
def initialize(context):
fetch_csv(
'https://fake.urls.com/cpiaucsl_data.csv',
symbol='urban',
pre_func=clean,
date_format='%Y-%m-%d'
)
context.stocks = [sid(3766), sid(25317)]
def handle_data(context, data):
cur_cpi = data.current("urban", "cpi")
record(cpi=cur_cpi)
""")
self.assertEqual(results["cpi"][-1], 203.1)
def test_algo_fetch_csv(self):
results = self.run_algo(
"""
from zipline.api import fetch_csv, record, sid
def normalize(df):
df['scaled'] = df['signal'] * 10
return df
def initialize(context):
fetch_csv('https://fake.urls.com/aapl_csv_data.csv',
post_func=normalize)
context.checked_name = False
def handle_data(context, data):
record(
signal=data.current(sid(24), "signal"),
scaled=data.current(sid(24), "scaled"),
price=data.current(sid(24), "price"))
""")
self.assertEqual(5, results["signal"][-1])
self.assertEqual(50, results["scaled"][-1])
self.assertEqual(24, results["price"][-1]) # fake value
def test_algo_fetch_csv_with_extra_symbols(self):
results = self.run_algo(
"""
from zipline.api import fetch_csv, record, sid
def normalize(df):
df['scaled'] = df['signal'] * 10
return df
def initialize(context):
fetch_csv('https://fake.urls.com/aapl_ibm_csv_data.csv',
post_func=normalize,
mask=True)
def handle_data(context, data):
record(
signal=data.current(sid(24),"signal"),
scaled=data.current(sid(24), "scaled"),
price=data.current(sid(24), "price"))
"""
)
self.assertEqual(5, results["signal"][-1])
self.assertEqual(50, results["scaled"][-1])
self.assertEqual(24, results["price"][-1]) # fake value
@parameterized.expand([("unspecified", ""),
("none", "usecols=None"),
("empty", "usecols=[]"),
("without date", "usecols=['Value']"),
("with date", "usecols=('Value', 'Date')")])
def test_usecols(self, testname, usecols):
code = """
from zipline.api import fetch_csv, sid, record
def clean(df):
return df.rename(columns={{'Value':'cpi'}})
def initialize(context):
fetch_csv(
'https://fake.urls.com/cpiaucsl_data.csv',
symbol='urban',
pre_func=clean,
date_column='Date',
date_format='%Y-%m-%d',{usecols}
)
context.stocks = [sid(3766), sid(25317)]
def handle_data(context, data):
if {should_have_data}:
try:
data.current("urban", "cpi")
except (KeyError, ValueError):
assert False
else:
try:
data.current("urban", "cpi")
except (KeyError, ValueError):
assert True
"""
results = self.run_algo(
code.format(
usecols=usecols,
should_have_data=testname in [
'none',
'unspecified',
'without date',
'with date',
],
)
)
# 251 trading days in 2006
self.assertEqual(len(results), 251)
def test_sources_merge_custom_ticker(self):
requests_kwargs = {}
def capture_kwargs(zelf, url, **kwargs):
requests_kwargs.update(
mask_requests_args(url, kwargs).requests_kwargs
)
return PALLADIUM_DATA
# Patching fetch_url instead of using responses in this test so that we
# can intercept the requests keyword arguments and confirm that they're
# correct.
with patch('zipline.sources.requests_csv.PandasRequestsCSV.fetch_url',
new=capture_kwargs):
results = self.run_algo(
"""
from zipline.api import fetch_csv, record, sid
def rename_col(df):
df = df.rename(columns={'New York 15:00': 'price'})
df = df.fillna(method='ffill')
return df[['price', 'sid']]
def initialize(context):
fetch_csv('https://dl.dropbox.com/u/16705795/PALL.csv',
date_column='Date',
symbol='palladium',
post_func=rename_col,
date_format='%Y-%m-%d'
)
context.stock = sid(24)
def handle_data(context, data):
record(palladium=data.current("palladium", "price"))
record(aapl=data.current(context.stock, "price"))
""")
np.testing.assert_array_equal([24] * 251, results["aapl"])
self.assertEqual(337, results["palladium"].iloc[-1])
expected = {
'allow_redirects': False,
'stream': True,
'timeout': 30.0,
}
self.assertEqual(expected, requests_kwargs)
@parameterized.expand([("symbol", FETCHER_UNIVERSE_DATA, None),
("arglebargle", FETCHER_UNIVERSE_DATA_TICKER_COLUMN,
FETCHER_ALTERNATE_COLUMN_HEADER)])
def test_fetcher_universe(self, name, data, column_name):
# Patching fetch_url here rather than using responses because (a) it's
# easier given the paramaterization, and (b) there are enough tests
# using responses that the fetch_url code is getting a good workout so
# we don't have to use it in every test.
with patch('zipline.sources.requests_csv.PandasRequestsCSV.fetch_url',
new=lambda *a, **k: data):
sim_params = factory.create_simulation_parameters(
start=pd.Timestamp("2006-01-09", tz='UTC'),
end=pd.Timestamp("2006-01-11", tz='UTC')
)
algocode = """
from pandas import Timestamp
from zipline.api import fetch_csv, record, sid, get_datetime
def initialize(context):
fetch_csv(
'https://dl.dropbox.com/u/16705795/dtoc_history.csv',
date_format='%m/%d/%Y'{token}
)
context.expected_sids = {{
Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
}}
context.bar_count = 0
def handle_data(context, data):
expected = context.expected_sids[get_datetime()]
actual = data.fetcher_assets
for stk in expected:
if stk not in actual:
raise Exception(
"{{stk}} is missing on dt={{dt}}".format(
stk=stk, dt=get_datetime()))
record(sid_count=len(actual))
record(bar_count=context.bar_count)
context.bar_count += 1
"""
replacement = ""
if column_name:
replacement = ",symbol_column='%s'\n" % column_name
real_algocode = algocode.format(token=replacement)
results = self.run_algo(real_algocode, sim_params=sim_params)
self.assertEqual(len(results), 3)
self.assertEqual(3, results["sid_count"].iloc[0])
self.assertEqual(3, results["sid_count"].iloc[1])
self.assertEqual(4, results["sid_count"].iloc[2])
def test_fetcher_universe_non_security_return(self):
sim_params = factory.create_simulation_parameters(
start=pd.Timestamp("2006-01-09", tz='UTC'),
end=pd.Timestamp("2006-01-10", tz='UTC')
)
self.run_algo(
"""
from zipline.api import fetch_csv
def initialize(context):
fetch_csv(
'https://fake.urls.com/bad_fetcher_universe_data.csv',
date_format='%m/%d/%Y'
)
def handle_data(context, data):
if len(data.fetcher_assets) > 0:
raise Exception("Shouldn't be any assets in fetcher_assets!")
""",
sim_params=sim_params,
)
def test_order_against_data(self):
with self.assertRaises(UnsupportedOrderParameters):
self.run_algo("""
from zipline.api import fetch_csv, order, sid
def rename_col(df):
return df.rename(columns={'New York 15:00': 'price'})
def initialize(context):
fetch_csv('https://fake.urls.com/palladium_data.csv',
date_column='Date',
symbol='palladium',
post_func=rename_col,
date_format='%Y-%m-%d'
)
context.stock = sid(24)
def handle_data(context, data):
order('palladium', 100)
""")
def test_fetcher_universe_minute(self):
sim_params = factory.create_simulation_parameters(
start=pd.Timestamp("2006-01-09", tz='UTC'),
end=pd.Timestamp("2006-01-11", tz='UTC'),
data_frequency="minute"
)
results = self.run_algo(
"""
from pandas import Timestamp
from zipline.api import fetch_csv, record, get_datetime
def initialize(context):
fetch_csv(
'https://fake.urls.com/fetcher_universe_data.csv',
date_format='%m/%d/%Y'
)
context.expected_sids = {
Timestamp('2006-01-09 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
Timestamp('2006-01-10 00:00:00+0000', tz='UTC'):[24, 3766, 5061],
Timestamp('2006-01-11 00:00:00+0000', tz='UTC'):[24, 3766, 5061, 14848]
}
context.bar_count = 0
def handle_data(context, data):
expected = context.expected_sids[get_datetime().replace(hour=0, minute=0)]
actual = data.fetcher_assets
for stk in expected:
if stk not in actual:
raise Exception("{stk} is missing".format(stk=stk))
record(sid_count=len(actual))
record(bar_count=context.bar_count)
context.bar_count += 1
""", sim_params=sim_params, data_frequency="minute"
)
self.assertEqual(3, len(results))
self.assertEqual(3, results["sid_count"].iloc[0])
self.assertEqual(3, results["sid_count"].iloc[1])
self.assertEqual(4, results["sid_count"].iloc[2])
+192 -154
View File
@@ -17,8 +17,7 @@
Tests for the zipline.finance package
"""
from datetime import datetime, timedelta
import itertools
import operator
import os
from unittest import TestCase
@@ -27,22 +26,27 @@ import numpy as np
import pandas as pd
import pytz
from six.moves import range
from testfixtures import TempDirectory
from zipline.finance.blotter import Blotter
from zipline.finance.execution import MarketOrder, LimitOrder
from zipline.finance.trading import TradingEnvironment
from zipline.finance.performance import PerformanceTracker
from zipline.finance.trading import SimulationParameters
from zipline.gens.composites import date_sorted_sources
import zipline.protocol
from zipline.protocol import Event, DATASOURCE_TYPE
from zipline.testing import(
from zipline.testing import (
setup_logger,
teardown_logger,
assert_single_position
teardown_logger
)
from zipline.data.us_equity_pricing import BcolzDailyBarReader
from zipline.data.minute_bars import BcolzMinuteBarReader
from zipline.data.data_portal import DataPortal
from zipline.finance.slippage import FixedSlippage
from zipline.protocol import BarData
from zipline.testing.core import write_bcolz_minute_data
from .utils.daily_bar_writer import DailyBarWriterFromDataFrames
import zipline.utils.factory as factory
import zipline.utils.simfactory as simfactory
DEFAULT_TIMEOUT = 15 # seconds
EXTENDED_TIMEOUT = 90
@@ -55,7 +59,7 @@ class FinanceTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[1, 133])
cls.env.write_data(equities_identifiers=[1, 2, 133])
@classmethod
def tearDownClass(cls):
@@ -71,34 +75,6 @@ class FinanceTestCase(TestCase):
def tearDown(self):
teardown_logger(self)
@timed(DEFAULT_TIMEOUT)
def test_factory_daily(self):
sim_params = factory.create_simulation_parameters()
trade_source = factory.create_daily_trade_source(
[133],
sim_params,
env=self.env,
)
prev = None
for trade in trade_source:
if prev:
self.assertTrue(trade.dt > prev.dt)
prev = trade
@timed(EXTENDED_TIMEOUT)
def test_full_zipline(self):
# provide enough trades to ensure all orders are filled.
self.zipline_test_config['order_count'] = 100
# making a small order amount, so that each order is filled
# in a single transaction, and txn_count == order_count.
self.zipline_test_config['order_amount'] = 25
# No transactions can be filled on the first trade, so
# we have one extra trade to ensure all orders are filled.
self.zipline_test_config['trade_count'] = 101
full_zipline = simfactory.create_test_zipline(
**self.zipline_test_config)
assert_single_position(self, full_zipline)
# TODO: write tests for short sales
# TODO: write a test to do massive buying or shorting.
@@ -109,16 +85,17 @@ class FinanceTestCase(TestCase):
# so that orders must be spread out over several trades.
params = {
'trade_count': 360,
'trade_amount': 100,
'trade_interval': timedelta(minutes=1),
'order_count': 2,
'order_amount': 100,
'order_interval': timedelta(minutes=1),
# because we placed an order for 100 shares, and the volume
# of each trade is 100, the simulator should spread the order
# into 4 trades of 25 shares per order.
'expected_txn_count': 8,
'expected_txn_volume': 2 * 100
# because we placed two orders for 100 shares each, and the volume
# of each trade is 100, and by default you can take up 2.5% of the
# bar's volume, the simulator should spread the order into 100
# trades of 2 shares per order.
'expected_txn_count': 100,
'expected_txn_volume': 2 * 100,
'default_slippage': True
}
self.transaction_sim(**params)
@@ -126,13 +103,13 @@ class FinanceTestCase(TestCase):
# same scenario, but with short sales
params2 = {
'trade_count': 360,
'trade_amount': 100,
'trade_interval': timedelta(minutes=1),
'order_count': 2,
'order_amount': -100,
'order_interval': timedelta(minutes=1),
'expected_txn_count': 8,
'expected_txn_volume': 2 * -100
'expected_txn_count': 100,
'expected_txn_volume': 2 * -100,
'default_slippage': True
}
self.transaction_sim(**params2)
@@ -144,7 +121,6 @@ class FinanceTestCase(TestCase):
# but are represented by multiple transactions.
params1 = {
'trade_count': 6,
'trade_amount': 100,
'trade_interval': timedelta(hours=1),
'order_count': 24,
'order_amount': 1,
@@ -159,7 +135,6 @@ class FinanceTestCase(TestCase):
# second verse, same as the first. except short!
params2 = {
'trade_count': 6,
'trade_amount': 100,
'trade_interval': timedelta(hours=1),
'order_count': 24,
'order_amount': -1,
@@ -173,7 +148,6 @@ class FinanceTestCase(TestCase):
# Ensuring that our delay works for daily intervals as well.
params3 = {
'trade_count': 6,
'trade_amount': 100,
'trade_interval': timedelta(days=1),
'order_count': 24,
'order_amount': 1,
@@ -188,7 +162,6 @@ class FinanceTestCase(TestCase):
# create a scenario where we alternate buys and sells
params1 = {
'trade_count': int(6.5 * 60 * 4),
'trade_amount': 100,
'trade_interval': timedelta(minutes=1),
'order_count': 4,
'order_amount': 10,
@@ -204,136 +177,201 @@ class FinanceTestCase(TestCase):
""" This is a utility method that asserts expected
results for conversion of orders to transactions given a
trade history"""
tempdir = TempDirectory()
try:
trade_count = params['trade_count']
trade_interval = params['trade_interval']
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
expected_txn_volume = params['expected_txn_volume']
trade_count = params['trade_count']
trade_interval = params['trade_interval']
order_count = params['order_count']
order_amount = params['order_amount']
order_interval = params['order_interval']
expected_txn_count = params['expected_txn_count']
expected_txn_volume = params['expected_txn_volume']
# optional parameters
# ---------------------
# if present, alternate between long and short sales
alternate = params.get('alternate')
# if present, expect transaction amounts to match orders exactly.
complete_fill = params.get('complete_fill')
# optional parameters
# ---------------------
# if present, alternate between long and short sales
alternate = params.get('alternate')
sid = 1
sim_params = factory.create_simulation_parameters()
blotter = Blotter()
price = [10.1] * trade_count
volume = [100] * trade_count
start_date = sim_params.first_open
# if present, expect transaction amounts to match orders exactly.
complete_fill = params.get('complete_fill')
generated_trades = factory.create_trade_history(
sid,
price,
volume,
trade_interval,
sim_params,
env=self.env,
)
env = TradingEnvironment()
if alternate:
alternator = -1
else:
alternator = 1
sid = 1
order_date = start_date
for i in range(order_count):
if trade_interval < timedelta(days=1):
sim_params = factory.create_simulation_parameters(
data_frequency="minute"
)
blotter.set_date(order_date)
blotter.order(sid, order_amount * alternator ** i, MarketOrder())
minutes = env.market_minute_window(
sim_params.first_open,
int((trade_interval.total_seconds() / 60) * trade_count)
+ 100)
order_date = order_date + order_interval
# move after market orders to just after market next
# market open.
if order_date.hour >= 21:
if order_date.minute >= 00:
order_date = order_date + timedelta(days=1)
order_date = order_date.replace(hour=14, minute=30)
price_data = np.array([10.1] * len(minutes))
assets = {
sid: pd.DataFrame({
"open": price_data,
"high": price_data,
"low": price_data,
"close": price_data,
"volume": np.array([100] * len(minutes)),
"dt": minutes
}).set_index("dt")
}
# there should now be one open order list stored under the sid
oo = blotter.open_orders
self.assertEqual(len(oo), 1)
self.assertTrue(sid in oo)
order_list = oo[sid][:] # make copy
self.assertEqual(order_count, len(order_list))
write_bcolz_minute_data(
env,
env.days_in_range(minutes[0], minutes[-1]),
tempdir.path,
assets
)
for i in range(order_count):
order = order_list[i]
self.assertEqual(order.sid, sid)
self.assertEqual(order.amount, order_amount * alternator ** i)
equity_minute_reader = BcolzMinuteBarReader(tempdir.path)
tracker = PerformanceTracker(sim_params, env=self.env)
data_portal = DataPortal(
env,
equity_minute_reader=equity_minute_reader,
)
else:
sim_params = factory.create_simulation_parameters(
data_frequency="daily"
)
benchmark_returns = [
Event({'dt': dt,
'returns': ret,
'type':
zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
'source_id': 'benchmarks'})
for dt, ret in self.env.benchmark_returns.iteritems()
if dt.date() >= sim_params.period_start.date() and
dt.date() <= sim_params.period_end.date()
]
days = sim_params.trading_days
generated_events = date_sorted_sources(generated_trades,
benchmark_returns)
assets = {
1: pd.DataFrame({
"open": [10.1] * len(days),
"high": [10.1] * len(days),
"low": [10.1] * len(days),
"close": [10.1] * len(days),
"volume": [100] * len(days),
"day": [day.value for day in days]
}, index=days)
}
# this approximates the loop inside TradingSimulationClient
transactions = []
for dt, events in itertools.groupby(generated_events,
operator.attrgetter('dt')):
for event in events:
if event.type == DATASOURCE_TYPE.TRADE:
path = os.path.join(tempdir.path, "testdata.bcolz")
DailyBarWriterFromDataFrames(assets).write(
path, days, assets)
for txn, order in blotter.process_trade(event):
transactions.append(txn)
equity_daily_reader = BcolzDailyBarReader(path)
data_portal = DataPortal(
env,
equity_daily_reader=equity_daily_reader,
)
if "default_slippage" not in params or \
not params["default_slippage"]:
slippage_func = FixedSlippage()
else:
slippage_func = None
blotter = Blotter(sim_params.data_frequency, self.env.asset_finder,
slippage_func)
env.write_data(equities_data={
sid: {
"start_date": sim_params.trading_days[0],
"end_date": sim_params.trading_days[-1]
}
})
start_date = sim_params.first_open
if alternate:
alternator = -1
else:
alternator = 1
tracker = PerformanceTracker(sim_params, self.env, data_portal)
# replicate what tradesim does by going through every minute or day
# of the simulation and processing open orders each time
if sim_params.data_frequency == "minute":
ticks = minutes
else:
ticks = days
transactions = []
order_list = []
order_date = start_date
for tick in ticks:
blotter.current_dt = tick
if tick >= order_date and len(order_list) < order_count:
# place an order
direction = alternator ** len(order_list)
order_id = blotter.order(
blotter.asset_finder.retrieve_asset(sid),
order_amount * direction,
MarketOrder())
order_list.append(blotter.orders[order_id])
order_date = order_date + order_interval
# move after market orders to just after market next
# market open.
if order_date.hour >= 21:
if order_date.minute >= 00:
order_date = order_date + timedelta(days=1)
order_date = order_date.replace(hour=14, minute=30)
else:
bar_data = BarData(
data_portal,
lambda: tick,
sim_params.data_frequency
)
txns, _ = blotter.get_transactions(bar_data)
for txn in txns:
tracker.process_transaction(txn)
elif event.type == DATASOURCE_TYPE.BENCHMARK:
tracker.process_benchmark(event)
elif event.type == DATASOURCE_TYPE.TRADE:
tracker.process_trade(event)
transactions.append(txn)
if complete_fill:
self.assertEqual(len(transactions), len(order_list))
total_volume = 0
for i in range(len(transactions)):
txn = transactions[i]
total_volume += txn.amount
if complete_fill:
for i in range(order_count):
order = order_list[i]
self.assertEqual(order.amount, txn.amount)
self.assertEqual(order.sid, sid)
self.assertEqual(order.amount, order_amount * alternator ** i)
self.assertEqual(total_volume, expected_txn_volume)
self.assertEqual(len(transactions), expected_txn_count)
if complete_fill:
self.assertEqual(len(transactions), len(order_list))
cumulative_pos = tracker.cumulative_performance.positions[sid]
self.assertEqual(total_volume, cumulative_pos.amount)
total_volume = 0
for i in range(len(transactions)):
txn = transactions[i]
total_volume += txn.amount
if complete_fill:
order = order_list[i]
self.assertEqual(order.amount, txn.amount)
# the open orders should not contain sid.
oo = blotter.open_orders
self.assertNotIn(sid, oo, "Entry is removed when no open orders")
self.assertEqual(total_volume, expected_txn_volume)
self.assertEqual(len(transactions), expected_txn_count)
cumulative_pos = tracker.position_tracker.positions[sid]
if total_volume == 0:
self.assertIsNone(cumulative_pos)
else:
self.assertEqual(total_volume, cumulative_pos.amount)
# the open orders should not contain sid.
oo = blotter.open_orders
self.assertNotIn(sid, oo, "Entry is removed when no open orders")
finally:
tempdir.cleanup()
def test_blotter_processes_splits(self):
sim_params = factory.create_simulation_parameters()
blotter = Blotter()
blotter.set_date(sim_params.period_start)
blotter = Blotter('daily', self.env.asset_finder,
slippage_func=FixedSlippage())
# set up two open limit orders with very low limit prices,
# one for sid 1 and one for sid 2
blotter.order(1, 100, LimitOrder(10))
blotter.order(2, 100, LimitOrder(10))
blotter.order(
blotter.asset_finder.retrieve_asset(1), 100, LimitOrder(10))
blotter.order(
blotter.asset_finder.retrieve_asset(2), 100, LimitOrder(10))
# send in a split for sid 2
split_event = factory.create_split(2, 0.33333,
sim_params.period_start +
timedelta(days=1))
blotter.process_split(split_event)
blotter.process_splits([(2, 0.3333)])
for sid in [1, 2]:
order_lists = blotter.open_orders[sid]
+1694 -1430
View File
File diff suppressed because it is too large Load Diff
+879 -1060
View File
File diff suppressed because it is too large Load Diff
-54
View File
@@ -1,54 +0,0 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from zipline.utils.serialization_utils import (
loads_with_persistent_ids, dumps_with_persistent_ids
)
from nose_parameterized import parameterized
from unittest import TestCase
from .serialization_cases import (
object_serialization_cases,
assert_dict_equal,
cases_env,
)
class PickleSerializationTestCase(TestCase):
@parameterized.expand(object_serialization_cases())
def test_object_serialization(self,
_,
cls,
initargs,
di_vars,
comparison_method='dict'):
obj = cls(*initargs)
for k, v in di_vars.items():
setattr(obj, k, v)
state = dumps_with_persistent_ids(obj)
obj2 = loads_with_persistent_ids(state, env=cases_env)
for k, v in di_vars.items():
setattr(obj2, k, v)
if comparison_method == 'repr':
self.assertEqual(obj.__repr__(), obj2.__repr__())
elif comparison_method == 'to_dict':
assert_dict_equal(obj.to_dict(), obj2.to_dict())
else:
assert_dict_equal(obj.__dict__, obj2.__dict__)
-230
View File
@@ -1,230 +0,0 @@
#
# Copyright 2014 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from collections import deque
import numpy as np
import pandas as pd
import pandas.util.testing as tm
from zipline.utils.data import MutableIndexRollingPanel, RollingPanel
from zipline.finance.trading import TradingEnvironment
class TestRollingPanel(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
@classmethod
def tearDownClass(cls):
del cls.env
def test_alignment(self):
items = ('a', 'b')
sids = (1, 2)
dts = self.env.market_minute_window(
self.env.open_and_closes.market_open[0], 4,
).values
rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1])
frame = pd.DataFrame(
data=np.arange(4).reshape((2, 2)),
columns=sids,
index=items,
)
nan_arr = np.empty((2, 6))
nan_arr.fill(np.nan)
rp.add_frame(dts[-1], frame)
cur = rp.get_current()
data = np.array((((np.nan, np.nan),
(0, 1)),
((np.nan, np.nan),
(2, 3))),
float)
expected = pd.Panel(
data,
major_axis=dts[2:],
minor_axis=sids,
items=items,
)
expected.major_axis = expected.major_axis.tz_localize('utc')
tm.assert_panel_equal(
cur,
expected,
)
rp.extend_back(dts[:-2])
cur = rp.get_current()
data = np.array((((np.nan, np.nan),
(np.nan, np.nan),
(np.nan, np.nan),
(0, 1)),
((np.nan, np.nan),
(np.nan, np.nan),
(np.nan, np.nan),
(2, 3))),
float)
expected = pd.Panel(
data,
major_axis=dts,
minor_axis=sids,
items=items,
)
expected.major_axis = expected.major_axis.tz_localize('utc')
tm.assert_panel_equal(
cur,
expected,
)
def test_get_current_multiple_call_same_tick(self):
"""
In old get_current, each call the get_current would copy the data. Thus
changing that object would have no side effects.
To keep the same api, make sure that the raw option returns a copy too.
"""
def data_id(values):
return values.__array_interface__['data']
items = ('a', 'b')
sids = (1, 2)
dts = self.env.market_minute_window(
self.env.open_and_closes.market_open[0], 4,
).values
rp = RollingPanel(2, items, sids, initial_dates=dts[1:-1])
frame = pd.DataFrame(
data=np.arange(4).reshape((2, 2)),
columns=sids,
index=items,
)
nan_arr = np.empty((2, 6))
nan_arr.fill(np.nan)
rp.add_frame(dts[-1], frame)
# each get_current call makea a copy
cur = rp.get_current()
cur2 = rp.get_current()
assert data_id(cur.values) != data_id(cur2.values)
# make sure raw follow same logic
raw = rp.get_current(raw=True)
raw2 = rp.get_current(raw=True)
assert data_id(raw) != data_id(raw2)
class TestMutableIndexRollingPanel(unittest.TestCase):
def test_basics(self, window=10):
items = ['bar', 'baz', 'foo']
minor = ['A', 'B', 'C', 'D']
rp = MutableIndexRollingPanel(window, items, minor, cap_multiple=2)
dates = pd.date_range('2000-01-01', periods=30, tz='utc')
major_deque = deque(maxlen=window)
frames = {}
for i, date in enumerate(dates):
frame = pd.DataFrame(np.random.randn(3, 4), index=items,
columns=minor)
rp.add_frame(date, frame)
frames[date] = frame
major_deque.append(date)
result = rp.get_current()
expected = pd.Panel(frames, items=list(major_deque),
major_axis=items, minor_axis=minor)
tm.assert_panel_equal(result, expected.swapaxes(0, 1))
def test_adding_and_dropping_items(self, n_items=5, n_minor=10, window=10,
periods=30):
np.random.seed(123)
items = deque(range(n_items))
minor = deque(range(n_minor))
expected_items = deque(range(n_items))
expected_minor = deque(range(n_minor))
first_non_existant = max(n_items, n_minor) + 1
# We want to add new columns with random order
add_items = np.arange(first_non_existant, first_non_existant + periods)
np.random.shuffle(add_items)
rp = MutableIndexRollingPanel(window, items, minor, cap_multiple=2)
dates = pd.date_range('2000-01-01', periods=periods, tz='utc')
frames = {}
expected_frames = deque(maxlen=window)
expected_dates = deque()
for i, (date, add_item) in enumerate(zip(dates, add_items)):
frame = pd.DataFrame(np.random.randn(n_items, n_minor),
index=items, columns=minor)
if i >= window:
# Old labels and dates should start to get dropped at every
# call
del frames[expected_dates.popleft()]
expected_minor.popleft()
expected_items.popleft()
expected_frames.append(frame)
expected_dates.append(date)
rp.add_frame(date, frame)
frames[date] = frame
result = rp.get_current()
np.testing.assert_array_equal(sorted(result.minor_axis.values),
sorted(expected_minor))
np.testing.assert_array_equal(sorted(result.items.values),
sorted(expected_items))
tm.assert_frame_equal(frame.T,
result.ix[frame.index, -1, frame.columns])
expected_result = pd.Panel(frames).swapaxes(0, 1)
tm.assert_panel_equal(expected_result,
result)
# Insert new items
minor.popleft()
minor.append(add_item)
items.popleft()
items.append(add_item)
expected_minor.append(add_item)
expected_items.append(add_item)
+142 -160
View File
@@ -1,18 +1,19 @@
import pytz
from datetime import datetime, timedelta
from unittest import TestCase
import pandas as pd
from datetime import timedelta
from unittest import TestCase
from testfixtures import TempDirectory
from zipline.algorithm import TradingAlgorithm
from zipline.errors import TradingControlViolation
from zipline.finance.trading import TradingEnvironment
from zipline.sources import SpecificEquityTrades
from zipline.testing import (
add_security_data,
security_list_copy,
setup_logger,
teardown_logger,
)
from zipline.testing.core import create_data_portal
from zipline.utils import factory
from zipline.utils.security_list import (
SecurityListSet,
@@ -24,10 +25,10 @@ LEVERAGED_ETFS = load_from_directory('leveraged_etf_list')
class RestrictedAlgoWithCheck(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_do_not_order_list(self.rl.leveraged_etf_list)
self.order_count = 0
self.sid = self.symbol(symbol)
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_do_not_order_list(self.rl.leveraged_etf_list)
self.order_count = 0
self.sid = self.symbol(symbol)
def handle_data(self, data):
if not self.order_count:
@@ -51,11 +52,11 @@ class RestrictedAlgoWithoutCheck(TradingAlgorithm):
class IterateRLAlgo(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_do_not_order_list(self.rl.leveraged_etf_list)
self.order_count = 0
self.sid = self.symbol(symbol)
self.found = False
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_do_not_order_list(self.rl.leveraged_etf_list)
self.order_count = 0
self.sid = self.symbol(symbol)
self.found = False
def handle_data(self, data):
for stock in self.rl.leveraged_etf_list:
@@ -67,45 +68,86 @@ class SecurityListTestCase(TestCase):
@classmethod
def setUpClass(cls):
# this is ugly, but we need to create two different
# TradingEnvironment/DataPortal pairs
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=['AAPL', 'GOOG', 'BZQ',
'URTY', 'JFT'])
cls.env2 = TradingEnvironment()
cls.extra_knowledge_date = pd.Timestamp("2015-01-27", tz='UTC')
cls.trading_day_before_first_kd = pd.Timestamp("2015-01-23", tz='UTC')
symbols = ['AAPL', 'GOOG', 'BZQ', 'URTY', 'JFT']
days = cls.env.days_in_range(
list(LEVERAGED_ETFS.keys())[0],
pd.Timestamp("2015-02-17", tz='UTC')
)
cls.sim_params = factory.create_simulation_parameters(
start=list(LEVERAGED_ETFS.keys())[0],
num_days=4,
env=cls.env
)
cls.sim_params2 = factory.create_simulation_parameters(
start=cls.trading_day_before_first_kd, num_days=4
)
equities_metadata = {}
for i, symbol in enumerate(symbols):
equities_metadata[i] = {
'start_date': days[0],
'end_date': days[-1],
'symbol': symbol
}
equities_metadata2 = {}
for i, symbol in enumerate(symbols):
equities_metadata2[i] = {
'start_date': cls.sim_params2.period_start,
'end_date': cls.sim_params2.period_end,
'symbol': symbol
}
cls.env.write_data(equities_data=equities_metadata)
cls.env2.write_data(equities_data=equities_metadata2)
cls.tempdir = TempDirectory()
cls.tempdir2 = TempDirectory()
cls.data_portal = create_data_portal(
env=cls.env,
tempdir=cls.tempdir,
sim_params=cls.sim_params,
sids=range(0, 5),
)
cls.data_portal2 = create_data_portal(
env=cls.env2,
tempdir=cls.tempdir2,
sim_params=cls.sim_params2,
sids=range(0, 5)
)
setup_logger(cls)
@classmethod
def tearDownClass(cls):
del cls.env
cls.tempdir.cleanup()
cls.tempdir2.cleanup()
teardown_logger(cls)
def setUp(self, env=None):
def test_iterate_over_restricted_list(self):
algo = IterateRLAlgo(symbol='BZQ', sim_params=self.sim_params,
env=self.env)
self.extra_knowledge_date = \
datetime(2015, 1, 27, 0, 0, tzinfo=pytz.utc)
self.trading_day_before_first_kd = datetime(
2015, 1, 23, 0, 0, tzinfo=pytz.utc)
setup_logger(self)
def tearDown(self):
teardown_logger(self)
def test_iterate_over_rl(self):
sim_params = factory.create_simulation_parameters(
start=list(LEVERAGED_ETFS.keys())[0], num_days=4, env=self.env)
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = IterateRLAlgo(symbol='BZQ', sim_params=sim_params, env=self.env)
algo.run(self.source)
algo.run(self.data_portal)
self.assertTrue(algo.found)
def test_security_list(self):
# set the knowledge date to the first day of the
# leveraged etf knowledge date.
def get_datetime():
@@ -136,7 +178,7 @@ class SecurityListTestCase(TestCase):
def test_security_add(self):
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
return pd.Timestamp("2015-01-27", tz='UTC')
with security_list_copy():
add_security_data(['AAPL', 'GOOG'], [])
rl = SecurityListSet(get_datetime, self.env.asset_finder)
@@ -153,90 +195,38 @@ class SecurityListTestCase(TestCase):
def test_security_add_delete(self):
with security_list_copy():
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
return pd.Timestamp("2015-01-27", tz='UTC')
rl = SecurityListSet(get_datetime, self.env.asset_finder)
self.assertNotIn("BZQ", rl.leveraged_etf_list)
self.assertNotIn("URTY", rl.leveraged_etf_list)
def test_algo_without_rl_violation_via_check(self):
sim_params = factory.create_simulation_parameters(
start=list(LEVERAGED_ETFS.keys())[0], num_days=4,
env=self.env)
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithCheck(symbol='BZQ',
sim_params=sim_params,
sim_params=self.sim_params,
env=self.env)
algo.run(self.source)
algo.run(self.data_portal)
def test_algo_without_rl_violation(self):
sim_params = factory.create_simulation_parameters(
start=list(LEVERAGED_ETFS.keys())[0], num_days=4,
env=self.env)
trade_history = factory.create_trade_history(
'AAPL',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(symbol='AAPL',
sim_params=sim_params,
sim_params=self.sim_params,
env=self.env)
algo.run(self.source)
algo.run(self.data_portal)
def test_algo_with_rl_violation(self):
sim_params = factory.create_simulation_parameters(
start=list(LEVERAGED_ETFS.keys())[0], num_days=4,
env=self.env)
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
sim_params=sim_params,
sim_params=self.sim_params,
env=self.env)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run(self.source)
algo.run(self.data_portal)
self.check_algo_exception(algo, ctx, 0)
# repeat with a symbol from a different lookup date
trade_history = factory.create_trade_history(
'JFT',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(symbol='JFT',
sim_params=sim_params,
sim_params=self.sim_params,
env=self.env)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run(self.source)
algo.run(self.data_portal)
self.check_algo_exception(algo, ctx, 0)
@@ -245,21 +235,19 @@ class SecurityListTestCase(TestCase):
start=list(
LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=5,
env=self.env)
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
data_portal = create_data_portal(
self.env,
self.tempdir,
sim_params=sim_params,
sids=range(0, 5)
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(symbol='BZQ',
sim_params=sim_params,
env=self.env)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run(self.source)
algo.run(data_portal)
self.check_algo_exception(algo, ctx, 0)
@@ -275,65 +263,59 @@ class SecurityListTestCase(TestCase):
with security_list_copy():
add_security_data(['AAPL'], [])
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env,
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(
symbol='BZQ', sim_params=sim_params, env=self.env)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run(self.source)
algo.run(self.data_portal)
self.check_algo_exception(algo, ctx, 0)
def test_algo_without_rl_violation_after_delete(self):
with security_list_copy():
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
sim_params = factory.create_simulation_parameters(
start=self.extra_knowledge_date, num_days=3)
new_tempdir = TempDirectory()
try:
with security_list_copy():
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
trade_history = factory.create_trade_history(
'BZQ',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env,
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(
symbol='BZQ', sim_params=sim_params, env=self.env
)
algo.run(self.source)
# now fast-forward to self.extra_knowledge_date. requires
# a new env, simparams, and dataportal
env = TradingEnvironment()
sim_params = factory.create_simulation_parameters(
start=self.extra_knowledge_date, num_days=4, env=env)
env.write_data(equities_data={
"0": {
'symbol': 'BZQ',
'start_date': sim_params.period_start,
'end_date': sim_params.period_end,
}
})
data_portal = create_data_portal(
env,
new_tempdir,
sim_params,
range(0, 5)
)
algo = RestrictedAlgoWithoutCheck(
symbol='BZQ', sim_params=sim_params, env=env
)
algo.run(data_portal)
finally:
new_tempdir.cleanup()
def test_algo_with_rl_violation_after_add(self):
with security_list_copy():
add_security_data(['AAPL'], [])
sim_params = factory.create_simulation_parameters(
start=self.trading_day_before_first_kd, num_days=4)
trade_history = factory.create_trade_history(
'AAPL',
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
sim_params,
env=self.env
)
self.source = SpecificEquityTrades(event_list=trade_history,
env=self.env)
algo = RestrictedAlgoWithoutCheck(
symbol='AAPL', sim_params=sim_params, env=self.env)
algo = RestrictedAlgoWithoutCheck(symbol='AAPL',
sim_params=self.sim_params2,
env=self.env2)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run(self.source)
algo.run(self.data_portal2)
self.check_algo_exception(algo, ctx, 2)
-94
View File
@@ -1,94 +0,0 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nose_parameterized import parameterized
from unittest import TestCase
from zipline.finance.trading import TradingEnvironment
from .serialization_cases import (
object_serialization_cases,
assert_dict_equal
)
from six import iteritems
def gather_bad_dicts(state):
bad = []
for k, v in iteritems(state):
if not isinstance(v, dict):
continue
if type(v) != dict:
bad.append((k, v))
bad.extend(gather_bad_dicts(v))
return bad
class SerializationTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
@classmethod
def tearDownClass(cls):
del cls.env
@parameterized.expand(object_serialization_cases())
def test_object_serialization(self,
_,
cls,
initargs,
di_vars,
comparison_method='dict'):
obj = cls(*initargs)
for k, v in di_vars.items():
setattr(obj, k, v)
state = obj.__getstate__()
bad_dicts = gather_bad_dicts(state)
bad_template = "type({0}) == {1}".format
bad_msgs = [bad_template(k, type(v)) for k, v in bad_dicts]
msg = "Only support bare dicts. " + ', '.join(bad_msgs)
self.assertEqual(len(bad_dicts), 0, msg)
# no state should have a dict subclass. Only regular PyDict
if hasattr(obj, '__getinitargs__'):
initargs = obj.__getinitargs__()
else:
initargs = None
if hasattr(obj, '__getnewargs__'):
newargs = obj.__getnewargs__()
else:
newargs = None
if newargs is not None:
obj2 = cls.__new__(cls, *newargs)
else:
obj2 = cls.__new__(cls)
if initargs is not None:
obj2.__init__(*initargs)
obj2.__setstate__(state)
for k, v in di_vars.items():
setattr(obj2, k, v)
if comparison_method == 'repr':
self.assertEqual(obj.__repr__(), obj2.__repr__())
elif comparison_method == 'to_dict':
assert_dict_equal(obj.to_dict(), obj2.to_dict())
else:
assert_dict_equal(obj.__dict__, obj2.__dict__)
-174
View File
@@ -1,174 +0,0 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
import pytz
from six import integer_types
from unittest import TestCase
import zipline.utils.factory as factory
from zipline.sources import (DataFrameSource,
DataPanelSource,
RandomWalkSource)
from zipline.utils import tradingcalendar as calendar_nyse
from zipline.assets import AssetFinder
from zipline.finance.trading import TradingEnvironment
class TestDataFrameSource(TestCase):
def test_df_source(self):
source, df = factory.create_test_df_source(env=None)
assert isinstance(source.start, pd.lib.Timestamp)
assert isinstance(source.end, pd.lib.Timestamp)
for expected_dt, expected_price in df.iterrows():
sid0 = next(source)
assert expected_dt == sid0.dt
assert expected_price[0] == sid0.price
def test_df_sid_filtering(self):
_, df = factory.create_test_df_source(env=None)
source = DataFrameSource(df)
assert 1 not in [event.sid for event in source], \
"DataFrameSource should only stream selected sid 0, not sid 1."
def test_panel_source(self):
source, panel = factory.create_test_panel_source(source_type=5)
assert isinstance(source.start, pd.lib.Timestamp)
assert isinstance(source.end, pd.lib.Timestamp)
for event in source:
self.assertTrue('sid' in event)
self.assertTrue('arbitrary' in event)
self.assertTrue('type' in event)
self.assertTrue(hasattr(event, 'volume'))
self.assertTrue(hasattr(event, 'price'))
self.assertEquals(event['type'], 5)
self.assertEquals(event['arbitrary'], 1.)
self.assertEquals(event['sid'], 0)
self.assertTrue(isinstance(event['volume'], int))
self.assertTrue(isinstance(event['arbitrary'], float))
def test_yahoo_bars_to_panel_source(self):
env = TradingEnvironment()
finder = AssetFinder(env.engine)
stocks = ['AAPL', 'GE']
env.write_data(equities_identifiers=stocks)
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
data = factory.load_bars_from_yahoo(stocks=stocks,
indexes={},
start=start,
end=end)
check_fields = ['sid', 'open', 'high', 'low', 'close',
'volume', 'price']
copy_panel = data.copy()
sids = finder.map_identifier_index_to_sids(
data.items, data.major_axis[0]
)
copy_panel.items = sids
source = DataPanelSource(copy_panel)
for event in source:
for check_field in check_fields:
self.assertIn(check_field, event)
self.assertTrue(isinstance(event['volume'], (integer_types)))
self.assertTrue(event['sid'] in sids)
def test_nan_filter_dataframe(self):
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
df = pd.DataFrame(np.random.randn(2, 2),
index=dates,
columns=[4, 5])
# should be filtered
df.loc[dates[0], 4] = np.nan
# should not be filtered, should have been ffilled
df.loc[dates[1], 5] = np.nan
source = DataFrameSource(df)
event = next(source)
self.assertEqual(5, event.sid)
event = next(source)
self.assertEqual(4, event.sid)
event = next(source)
self.assertEqual(5, event.sid)
self.assertFalse(np.isnan(event.price))
def test_nan_filter_panel(self):
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
df = pd.Panel(np.random.randn(2, 2, 2),
major_axis=dates,
items=[4, 5],
minor_axis=['price', 'volume'])
# should be filtered
df.loc[4, dates[0], 'price'] = np.nan
# should not be filtered, should have been ffilled
df.loc[5, dates[1], 'price'] = np.nan
source = DataPanelSource(df)
event = next(source)
self.assertEqual(5, event.sid)
event = next(source)
self.assertEqual(4, event.sid)
self.assertRaises(StopIteration, next, source)
class TestRandomWalkSource(TestCase):
def test_minute(self):
np.random.seed(123)
start_prices = {0: 100,
1: 500}
start = pd.Timestamp('1990-01-01', tz='UTC')
end = pd.Timestamp('1991-01-01', tz='UTC')
source = RandomWalkSource(start_prices=start_prices,
calendar=calendar_nyse, start=start,
end=end)
self.assertIsInstance(source.start, pd.lib.Timestamp)
self.assertIsInstance(source.end, pd.lib.Timestamp)
for event in source:
self.assertIn(event.sid, start_prices.keys())
self.assertIn(event.dt.replace(minute=0, hour=0),
calendar_nyse.trading_days)
self.assertGreater(event.dt, start)
self.assertLess(event.dt, end)
self.assertGreater(event.price, 0,
"price should never go negative.")
self.assertTrue(13 <= event.dt.hour <= 21,
"event.dt.hour == %i, not during market \
hours." % event.dt.hour)
def test_day(self):
np.random.seed(123)
start_prices = {0: 100,
1: 500}
start = pd.Timestamp('1990-01-01', tz='UTC')
end = pd.Timestamp('1992-01-01', tz='UTC')
source = RandomWalkSource(start_prices=start_prices,
calendar=calendar_nyse, start=start,
end=end, freq='daily')
self.assertIsInstance(source.start, pd.lib.Timestamp)
self.assertIsInstance(source.end, pd.lib.Timestamp)
for event in source:
self.assertIn(event.sid, start_prices.keys())
self.assertIn(event.dt.replace(minute=0, hour=0),
calendar_nyse.trading_days)
self.assertGreater(event.dt, start)
self.assertLess(event.dt, end)
self.assertGreater(event.price, 0,
"price should never go negative.")
self.assertEqual(event.dt.hour, 0)
+27 -10
View File
@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
from mock import patch
from nose_parameterized import parameterized
from six.moves import range
from unittest import TestCase
from zipline import TradingAlgorithm
from zipline.sources.benchmark_source import BenchmarkSource
from zipline.test_algorithms import NoopAlgorithm
from zipline.utils import factory
from zipline.testing.core import FakeDataPortal
class BeforeTradingAlgorithm(TradingAlgorithm):
@@ -30,19 +33,27 @@ class BeforeTradingAlgorithm(TradingAlgorithm):
def before_trading_start(self, data):
self.before_trading_at.append(self.datetime)
def handle_data(self, data):
pass
FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute
class TestTradeSimulation(TestCase):
def fake_minutely_benchmark(self, dt):
return 0.01
def test_minutely_emissions_generate_performance_stats_for_last_day(self):
params = factory.create_simulation_parameters(num_days=1,
data_frequency='minute',
emission_rate='minute')
algo = NoopAlgorithm(sim_params=params)
algo.run(source=[], overwrite_sim_params=False)
self.assertEqual(algo.perf_tracker.day_count, 1.0)
with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = NoopAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
self.assertEqual(algo.perf_tracker.day_count, 1.0)
@parameterized.expand([('%s_%s_%s' % (num_days, freq, emission_rate),
num_days, freq, emission_rate)
@@ -56,11 +67,17 @@ class TestTradeSimulation(TestCase):
num_days=num_days, data_frequency=freq,
emission_rate=emission_rate)
algo = BeforeTradingAlgorithm(sim_params=params)
algo.run(source=[], overwrite_sim_params=False)
def fake_benchmark(self, dt):
return 0.01
self.assertEqual(algo.perf_tracker.day_count, num_days)
self.assertTrue(params.trading_days.equals(
pd.DatetimeIndex(algo.before_trading_at)),
"Expected %s but was %s."
% (params.trading_days, algo.before_trading_at))
with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = BeforeTradingAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
self.assertEqual(algo.perf_tracker.day_count, num_days)
self.assertTrue(params.trading_days.equals(
pd.DatetimeIndex(algo.before_trading_at)),
"Expected %s but was %s."
% (params.trading_days, algo.before_trading_at))
-220
View File
@@ -1,220 +0,0 @@
#
# Copyright 2014 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from functools import wraps
from itertools import product
from nose_parameterized import parameterized
import operator
import random
from six import itervalues
from six.moves import map
from unittest import TestCase
import numpy as np
from numpy.testing import assert_allclose
from zipline.finance.trading import TradingEnvironment
from zipline.algorithm import TradingAlgorithm
import zipline.utils.factory as factory
from zipline.api import add_transform, get_datetime
def handle_data_wrapper(f):
@wraps(f)
def wrapper(context, data):
dt = get_datetime()
if dt.date() != context.current_date:
context.warmup -= 1
context.mins_for_days.append(1)
context.current_date = dt.date()
else:
context.mins_for_days[-1] += 1
hist = context.history(2, '1d', 'close_price')
for n in (1, 2, 3):
if n in data:
if data[n].dt == dt:
context.vol_bars[n].append(data[n].volume)
else:
context.vol_bars[n].append(0)
context.price_bars[n].append(data[n].price)
else:
context.price_bars[n].append(np.nan)
context.vol_bars[n].append(0)
context.last_close_prices[n] = hist[n][0]
if context.warmup < 0:
return f(context, data)
return wrapper
def initialize_with(test_case, tfm_name, days):
def initalize(context):
context.test_case = test_case
context.days = days
context.mins_for_days = []
context.price_bars = (None, [np.nan], [np.nan], [np.nan])
context.vol_bars = (None, [np.nan], [np.nan], [np.nan])
if context.days:
context.warmup = days + 1
else:
context.warmup = 2
context.current_date = None
context.last_close_prices = [np.nan, np.nan, np.nan, np.nan]
add_transform(tfm_name, days)
return initalize
def windows_with_frequencies(*args):
args = args or (None,)
return product(('daily', 'minute'), args)
def with_algo(f):
name = f.__name__
if not name.startswith('test_'):
raise ValueError('This must decorate a test case')
tfm_name = name[len('test_'):]
@wraps(f)
def wrapper(self, data_frequency, days=None):
sim_params, source = self.sim_and_source[data_frequency]
algo = TradingAlgorithm(
initialize=initialize_with(self, tfm_name, days),
handle_data=handle_data_wrapper(f),
sim_params=sim_params,
env=self.env,
)
algo.run(source)
return wrapper
class TransformTestCase(TestCase):
"""
Tests the simple transforms by running them through a zipline.
"""
@classmethod
def setUpClass(cls):
random.seed(0)
cls.sids = (1, 2, 3)
minute_sim_ps = factory.create_simulation_parameters(
num_days=3,
data_frequency='minute',
emission_rate='minute',
)
daily_sim_ps = factory.create_simulation_parameters(
num_days=30,
data_frequency='daily',
emission_rate='daily',
)
cls.env = TradingEnvironment()
cls.env.write_data(equities_identifiers=[1, 2, 3])
cls.sim_and_source = {
'minute': (minute_sim_ps, factory.create_minutely_trade_source(
cls.sids,
sim_params=minute_sim_ps,
env=cls.env,
)),
'daily': (daily_sim_ps, factory.create_trade_source(
cls.sids,
trade_time_increment=timedelta(days=1),
sim_params=daily_sim_ps,
env=cls.env,
)),
}
@classmethod
def tearDownClass(cls):
del cls.env
def tearDown(self):
"""
Each test consumes a source, we need to rewind it.
"""
for _, source in itervalues(self.sim_and_source):
source.rewind()
@parameterized.expand(windows_with_frequencies(1, 2, 3, 4))
@with_algo
def test_mavg(context, data):
"""
Tests the mavg transform by manually keeping track of the prices
in a naiive way and asserting that our mean is the same.
"""
mins = sum(context.mins_for_days[-context.days:])
for sid in data:
assert_allclose(
data[sid].mavg(context.days),
np.mean(context.price_bars[sid][-mins:]),
)
@parameterized.expand(windows_with_frequencies(2, 3, 4))
@with_algo
def test_stddev(context, data):
"""
Tests the stddev transform by manually keeping track of the prices
in a naiive way and asserting that our stddev is the same.
This accounts for the corrected ddof.
"""
mins = sum(context.mins_for_days[-context.days:])
for sid in data:
assert_allclose(
data[sid].stddev(context.days),
np.std(context.price_bars[sid][-mins:], ddof=1),
)
@parameterized.expand(windows_with_frequencies(2, 3, 4))
@with_algo
def test_vwap(context, data):
"""
Tests the vwap transform by manually keeping track of the prices
and volumes in a naiive way and asserting that our hand-rolled vwap is
the same
"""
mins = sum(context.mins_for_days[-context.days:])
for sid in data:
prices = context.price_bars[sid][-mins:]
vols = context.vol_bars[sid][-mins:]
manual_vwap = sum(
map(operator.mul, np.nan_to_num(np.array(prices)), vols),
) / sum(vols)
assert_allclose(
data[sid].vwap(context.days),
manual_vwap,
)
@parameterized.expand(windows_with_frequencies())
@with_algo
def test_returns(context, data):
for sid in data:
last_close = context.last_close_prices[sid]
returns = (data[sid].price - last_close) / last_close
assert_allclose(
data[sid].returns(),
returns,
)
-155
View File
@@ -1,155 +0,0 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta, datetime
from unittest import TestCase, skip
import numpy as np
import pandas as pd
import pytz
import talib
from zipline.finance.trading import TradingEnvironment
from zipline.test_algorithms import TALIBAlgorithm
from zipline.testing import setup_logger, teardown_logger
import zipline.transforms.ta as ta
import zipline.utils.factory as factory
class TestTALIB(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
@classmethod
def tearDownClass(cls):
del cls.env
def setUp(self):
setup_logger(self)
sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 3, 30, tzinfo=pytz.utc))
self.source, self.panel = \
factory.create_test_panel_ohlc_source(sim_params, self.env)
def tearDown(self):
teardown_logger(self)
@skip
def test_talib_with_default_params(self):
BLACKLIST = ['make_transform', 'BatchTransform',
# TODO: Figure out why MAVP generates a KeyError
'MAVP']
names = [name for name in dir(ta)
if name[0].isupper() and name not in BLACKLIST]
for name in names:
print(name)
zipline_transform = getattr(ta, name)(sid=0)
talib_fn = getattr(talib.abstract, name)
start = datetime(1990, 1, 1, tzinfo=pytz.utc)
end = start + timedelta(days=zipline_transform.lookback + 10)
sim_params = factory.create_simulation_parameters(
start=start, end=end)
source, panel = \
factory.create_test_panel_ohlc_source(sim_params, self.env)
algo = TALIBAlgorithm(talib=zipline_transform)
algo.run(source)
zipline_result = np.array(
algo.talib_results[zipline_transform][-1])
talib_data = dict()
data = zipline_transform.window
# TODO: Figure out if we are clobbering the tests by this
# protection against empty windows
if not data:
continue
for key in ['open', 'high', 'low', 'volume']:
if key in data:
talib_data[key] = data[key][0].values
talib_data['close'] = data['price'][0].values
expected_result = talib_fn(talib_data)
if isinstance(expected_result, list):
expected_result = np.array([e[-1] for e in expected_result])
else:
expected_result = np.array(expected_result[-1])
if not (np.all(np.isnan(zipline_result)) and
np.all(np.isnan(expected_result))):
self.assertTrue(np.allclose(zipline_result, expected_result))
else:
print('--- NAN')
# reset generator so next iteration has data
# self.source, self.panel = \
# factory.create_test_panel_ohlc_source(self.sim_params)
def test_multiple_talib_with_args(self):
zipline_transforms = [ta.MA(timeperiod=10),
ta.MA(timeperiod=25)]
talib_fn = talib.abstract.MA
algo = TALIBAlgorithm(talib=zipline_transforms, identifiers=[0])
algo.run(self.source)
# Test if computed values match those computed by pandas rolling mean.
sid = 0
talib_values = np.array([x[sid] for x in
algo.talib_results[zipline_transforms[0]]])
np.testing.assert_array_equal(talib_values,
pd.rolling_mean(self.panel[0]['price'],
10).values)
talib_values = np.array([x[sid] for x in
algo.talib_results[zipline_transforms[1]]])
np.testing.assert_array_equal(talib_values,
pd.rolling_mean(self.panel[0]['price'],
25).values)
for t in zipline_transforms:
talib_result = np.array(algo.talib_results[t][-1])
talib_data = dict()
data = t.window
# TODO: Figure out if we are clobbering the tests by this
# protection against empty windows
if not data:
continue
for key in ['open', 'high', 'low', 'volume']:
if key in data:
talib_data[key] = data[key][0].values
talib_data['close'] = data['price'][0].values
expected_result = talib_fn(talib_data, **t.call_kwargs)[-1]
np.testing.assert_allclose(talib_result, expected_result)
def test_talib_with_minute_data(self):
ma_one_day_minutes = ta.MA(timeperiod=10, bars='minute')
# Assert that the BatchTransform window length is enough to cover
# the amount of minutes in the timeperiod.
# Here, 10 minutes only needs a window length of 1.
self.assertEquals(1, ma_one_day_minutes.window_length)
# With minutes greater than the 390, i.e. one trading day, we should
# have a window_length of two days.
ma_two_day_minutes = ta.MA(timeperiod=490, bars='minute')
self.assertEquals(2, ma_two_day_minutes.window_length)
# TODO: Ensure that the lookback into the datapanel is returning
# expected results.
# Requires supplying minute instead of day data to the unit test.
# When adding test data, should add more minute events than the
# timeperiod to ensure that lookback is behaving properly.
-99
View File
@@ -1,99 +0,0 @@
#
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from nose_parameterized import parameterized
from unittest import TestCase, skip
from zipline.finance.blotter import Order
from .serialization_cases import (
object_serialization_cases,
assert_dict_equal
)
base_state_dir = 'tests/resources/saved_state_archive'
BASE_STATE_DIR = os.path.join(
os.path.dirname(__file__),
'resources',
'saved_state_archive')
class VersioningTestCase(TestCase):
def load_state_from_disk(self, cls):
state_dir = cls.__module__ + '.' + cls.__name__
full_dir = BASE_STATE_DIR + '/' + state_dir
state_files = \
[f for f in os.listdir(full_dir) if 'State_Version_' in f]
for f_name in state_files:
f = open(full_dir + '/' + f_name, 'r')
yield pickle.load(f)
# Only test versioning in minutely mode right now
@parameterized.expand(object_serialization_cases(skip_daily=True))
@skip
def test_object_serialization(self,
_,
cls,
initargs,
di_vars,
comparison_method='dict'):
# Make reference object
obj = cls(*initargs)
for k, v in di_vars.items():
setattr(obj, k, v)
# Fetch state
state_versions = self.load_state_from_disk(cls)
for version in state_versions:
# For each version inflate a new object and ensure that it
# matches the original.
newargs = version['newargs']
initargs = version['initargs']
state = version['obj_state']
if newargs is not None:
obj2 = cls.__new__(cls, *newargs)
else:
obj2 = cls.__new__(cls)
if initargs is not None:
obj2.__init__(*initargs)
obj2.__setstate__(state)
for k, v in di_vars.items():
setattr(obj2, k, v)
# The ObjectId generated on instantiation of Order will
# not be the same as the one loaded from saved state.
if cls == Order:
obj.__dict__['id'] = obj2.__dict__['id']
if comparison_method == 'repr':
self.assertEqual(obj.__repr__(), obj2.__repr__())
elif comparison_method == 'to_dict':
assert_dict_equal(obj.to_dict(), obj2.to_dict())
else:
assert_dict_equal(obj.__dict__, obj2.__dict__)
+48
View File
@@ -0,0 +1,48 @@
from numpy import (
float64,
uint32
)
from bcolz import ctable
from zipline.data.us_equity_pricing import (
BcolzDailyBarWriter,
OHLC,
UINT32_MAX
)
class DailyBarWriterFromDataFrames(BcolzDailyBarWriter):
_csv_dtypes = {
'open': float64,
'high': float64,
'low': float64,
'close': float64,
'volume': float64,
}
def __init__(self, asset_map):
self._asset_map = asset_map
def gen_tables(self, assets):
for asset in assets:
yield asset, ctable.fromdataframe(assets[asset])
def to_uint32(self, array, colname):
arrmax = array.max()
if colname in OHLC:
self.check_uint_safe(arrmax * 1000, colname)
return (array * 1000).astype(uint32)
elif colname == 'volume':
self.check_uint_safe(arrmax, colname)
return array.astype(uint32)
elif colname == 'day':
nanos_per_second = (1000 * 1000 * 1000)
self.check_uint_safe(arrmax.view(int) / nanos_per_second, colname)
return (array.view(int) / nanos_per_second).astype(uint32)
@staticmethod
def check_uint_safe(value, colname):
if value >= UINT32_MAX:
raise ValueError(
"Value %s from column '%s' is too large" % (value, colname)
)
+48 -9
View File
@@ -15,6 +15,7 @@
from collections import namedtuple
import datetime
from functools import partial
from inspect import isabstract
from itertools import islice
import random
from unittest import TestCase
@@ -190,7 +191,7 @@ class TestEventRule(TestCase):
super(Always, Always()).should_trigger('a', env=None)
def minutes_for_days():
def minutes_for_days(ordered_days=False):
"""
500 randomly selected days.
This is used to make sure our test coverage is unbaised towards any rules.
@@ -203,19 +204,44 @@ def minutes_for_days():
true.
This returns a generator of tuples each wrapping a single generator.
Iterating over this yeilds a single day, iterating over the day yields
Iterating over this yields a single day, iterating over the day yields
the minutes for that day.
"""
env = TradingEnvironment()
random.seed('deterministic')
return ((env.market_minutes_for_day(random.choice(env.trading_days)),)
for _ in range(500))
if ordered_days:
# Get a list of 500 trading days, in order. As a performance
# optimization in AfterOpen and BeforeClose, we rely on the fact that
# the clock only ever moves forward in a simulation. For those cases,
# we guarantee that the list of trading days we test is ordered.
ordered_day_list = random.sample(list(env.trading_days), 500)
ordered_day_list.sort()
def day_picker(day):
return ordered_day_list[day]
else:
# Other than AfterOpen and BeforeClose, we don't rely on the the nature
# of the clock, so we don't care.
def day_picker(day):
return random.choice(env.trading_days[:-1])
return ((env.market_minutes_for_day(day_picker(cnt)),)
for cnt in range(500))
class RuleTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.env = TradingEnvironment()
# On the AfterOpen and BeforeClose tests, we want ensure that the
# functions are pure, and that running them with the same input will
# provide the same output, regardless of whether the function is run 1
# or N times. (For performance reasons, we cache some internal state
# in AfterOpen and BeforeClose, but we don't want it to affect
# purity). Hence, we use the same before_close and after_open across
# subtests.
cls.before_close = BeforeClose(hours=1, minutes=5)
cls.after_open = AfterOpen(hours=1, minutes=5)
cls.class_ = None # Mark that this is the base class.
@classmethod
@@ -233,7 +259,8 @@ class RuleTestCase(TestCase):
k for k, v in iteritems(vars(zipline.utils.events))
if isinstance(v, type) and
issubclass(v, self.class_) and
v is not self.class_
v is not self.class_ and
not isabstract(v)
}
ds = {
k[5:] for k in dir(self)
@@ -273,10 +300,10 @@ class TestStatelessRules(RuleTestCase):
should_trigger = partial(Never().should_trigger, env=self.env)
self.assertFalse(any(map(should_trigger, ms)))
@subtest(minutes_for_days(), 'ms')
@subtest(minutes_for_days(ordered_days=True), 'ms')
def test_AfterOpen(self, ms):
should_trigger = partial(
AfterOpen(minutes=5, hours=1).should_trigger,
self.after_open.should_trigger,
env=self.env,
)
for m in islice(ms, 64):
@@ -285,15 +312,16 @@ class TestStatelessRules(RuleTestCase):
# at 13:30 UTC, meaning the first minute of data has an
# offset of 1.
self.assertFalse(should_trigger(m))
for m in islice(ms, 64, None):
# Check the rest of the day.
self.assertTrue(should_trigger(m))
@subtest(minutes_for_days(), 'ms')
@subtest(minutes_for_days(ordered_days=True), 'ms')
def test_BeforeClose(self, ms):
ms = list(ms)
should_trigger = partial(
BeforeClose(hours=1, minutes=5).should_trigger,
self.before_close.should_trigger,
env=self.env
)
for m in ms[0:-66]:
@@ -307,6 +335,17 @@ class TestStatelessRules(RuleTestCase):
self.assertTrue(should_trigger(FULL_DAY))
self.assertFalse(should_trigger(HALF_DAY))
def test_NthTradingDayOfWeek_day_zero(self):
"""
Test that we don't blow up when trying to call week_start's
should_trigger on the first day of a trading environment.
"""
self.assertTrue(
NthTradingDayOfWeek(0).should_trigger(
self.env.trading_days[0], self.env
)
)
@subtest(param_range(MAX_WEEK_RANGE), 'n')
def test_NthTradingDayOfWeek(self, n):
should_trigger = partial(NthTradingDayOfWeek(n).should_trigger,
-2
View File
@@ -20,7 +20,6 @@ from . import data
from . import finance
from . import gens
from . import utils
from . import transforms
from ._version import get_versions
# These need to happen after the other imports.
from . algorithm import TradingAlgorithm
@@ -42,7 +41,6 @@ __all__ = [
'finance',
'gens',
'utils',
'transforms',
'api',
'TradingAlgorithm',
]
+762
View File
@@ -0,0 +1,762 @@
#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from contextlib import contextmanager
from pandas.tslib import normalize_date
import pandas as pd
import numpy as np
from six import iteritems
from cpython cimport bool
from zipline.assets import Asset
from zipline.zipline_warnings import ZiplineDeprecationWarning
class assert_keywords(object):
"""
Asserts that the keywords passed into the wrapped function are included
in those passed into this decorator. If not, raise a TypeError with a
meaningful message, unlike the one cython returns by default.
"""
def __init__(self, *args):
self.names = args
def __call__(self, func):
def assert_keywords_and_call(*args, **kwargs):
for field in kwargs:
if field not in self.names:
raise TypeError("%s() got an unexpected keyword argument"
" '%s'" % (func.__name__, field))
return func(*args, **kwargs)
return assert_keywords_and_call
@contextmanager
def handle_non_market_minutes(bar_data):
try:
bar_data._handle_non_market_minutes = True
yield
finally:
bar_data._handle_non_market_minutes = False
cdef class BarData:
cdef object data_portal
cdef object simulation_dt_func
cdef object data_frequency
cdef dict _views
cdef object _universe_func
cdef object _last_calculated_universe
cdef object _universe_last_updated_at
cdef bool _adjust_minutes
"""
Provides methods to access spot value or history windows of price data.
Also provides some utility methods to determine if an asset is alive,
has recent trade data, etc.
This is what is passed as `data` to the `handle_data` function.
"""
def __init__(self, data_portal, simulation_dt_func, data_frequency,
universe_func=None):
"""
Parameters
---------
data_portal : DataPortal
Provider for bar pricing data.
simulation_dt_func: function
Function which returns the current simulation time.
This is usually bound to a method of TradingSimulation.
data_frequency: string
The frequency of the bar data; i.e. whether the data is
'daily' or 'minute' bars
universe_func: function
Function which returns the current 'universe'. This is for
backwards compatibility with older API concepts.
"""
self.data_portal = data_portal
self.simulation_dt_func = simulation_dt_func
self.data_frequency = data_frequency
self._views = {}
self._universe_func = universe_func
self._last_calculated_universe = None
self._universe_last_updated_at = None
self._adjust_minutes = False
cdef _get_equity_price_view(self, asset):
"""
Returns a DataPortalSidView for the given asset. Used to support the
data[sid(N)] public API. Not needed if DataPortal is used standalone.
Parameters
----------
asset : Asset
Asset that is being queried.
Returns
-------
SidView: Accessor into the given asset's data.
"""
try:
self._warn_deprecated("`data[sid(N)]` is deprecated. Use "
"`data.current`.")
view = self._views[asset]
except KeyError:
try:
asset = self.data_portal.env.asset_finder.retrieve_asset(asset)
except ValueError:
# assume fetcher
pass
view = self._views[asset] = self._create_sid_view(asset)
return view
cdef _create_sid_view(self, asset):
return SidView(
asset,
self.data_portal,
self.simulation_dt_func,
self.data_frequency
)
cdef _get_current_minute(self):
dt = self.simulation_dt_func()
if self._adjust_minutes:
dt = self.data_portal.env.previous_market_minute(dt)
return dt
@assert_keywords('assets', 'fields')
def current(self, assets, fields):
"""
Returns the current value of the given assets for the given fields
at the current simulation time. Current values are the as-traded price
and are usually not adjusted for events like splits or dividends (see
notes for more information).
Parameters
----------
assets : Asset or iterable of Assets
fields : string or iterable of strings. Valid values are: "price",
"last_traded", "open", "high", "low", "close", "volume", or column
names in files read by fetch_csv.
Returns
-------
Scalar, pandas Series, or pandas DataFrame. See notes below.
Notes
-----
If a single asset and a single field are passed in, a scalar float
value is returned.
If a single asset and a list of fields are passed in, a pandas Series
is returned whose indices are the fields, and whose values are scalar
values for this asset for each field.
If a list of assets and a single field are passed in, a pandas Series
is returned whose indices are the assets, and whose values are scalar
values for each asset for the given field.
If a list of assets and a list of fields are passed in, a pandas
DataFrame is returned, indexed by asset. The columns are the requested
fields, filled with the scalar values for each asset for each field.
If the current simulation time is not a valid market time, we use the
last market close instead.
"price" returns the last known close price of the asset. If there is
no last known value (either because the asset has never traded, or
because it has delisted) NaN is returned. If a value is found, and we
had to cross an adjustment boundary (split, dividend, etc) to get it,
the value is adjusted before being returned.
"last_traded" returns the date of the last trade event of the asset,
even if the asset has stopped trading. If there is no last known value,
pd.NaT is returned.
"volume" returns the trade volume for the current simulation time. If
there is no trade this minute, 0 is returned.
"open", "high", "low", and "close" return the relevant information for
the current trade bar. If there is no current trade bar, NaN is
returned.
"""
multiple_assets = self._is_iterable(assets)
multiple_fields = self._is_iterable(fields)
# There's some overly verbose code in here, particularly around
# 'do something if self._adjust_minutes is False, otherwise do
# something else'. This could be less verbose, but the 99% case is that
# `self._adjust_minutes` is False, so it's important to keep that code
# path as fast as possible.
# There's probably a way to make this method (and `history`) less
# verbose, but this is OK for now.
if not multiple_assets:
asset = assets
if not multiple_fields:
field = fields
# return scalar value
if not self._adjust_minutes:
return self.data_portal.get_spot_value(
asset,
field,
self._get_current_minute(),
self.data_frequency
)
else:
return self.data_portal.get_adjusted_value(
asset,
field,
self._get_current_minute(),
self.simulation_dt_func(),
self.data_frequency
)
else:
# assume fields is iterable
# return a Series indexed by field
if not self._adjust_minutes:
return pd.Series(data={
field: self.data_portal.get_spot_value(
asset,
field,
self._get_current_minute(),
self.data_frequency
)
for field in fields
}, index=fields, name=assets.symbol)
else:
return pd.Series(data={
field: self.data_portal.get_adjusted_value(
asset,
field,
self._get_current_minute(),
self.simulation_dt_func(),
self.data_frequency
)
for field in fields
}, index=fields, name=assets.symbol)
else:
if not multiple_fields:
field = fields
# assume assets is iterable
# return a Series indexed by asset
if not self._adjust_minutes:
return pd.Series(data={
asset: self.data_portal.get_spot_value(
asset,
field,
self._get_current_minute(),
self.data_frequency
)
for asset in assets
}, index=assets, name=fields)
else:
return pd.Series(data={
asset: self.data_portal.get_adjusted_value(
asset,
field,
self._get_current_minute(),
self.simulation_dt_func(),
self.data_frequency
)
for asset in assets
}, index=assets, name=fields)
else:
# both assets and fields are iterable
data = {}
if not self._adjust_minutes:
for field in fields:
series = pd.Series(data={
asset: self.data_portal.get_spot_value(
asset,
field,
self._get_current_minute(),
self.data_frequency
)
for asset in assets
}, index=assets, name=field)
data[field] = series
else:
for field in fields:
series = pd.Series(data={
asset: self.data_portal.get_adjusted_value(
asset,
field,
self._get_current_minute(),
self.simulation_dt_func(),
self.data_frequency
)
for asset in assets
}, index=assets, name=field)
data[field] = series
return pd.DataFrame(data)
cdef bool _is_iterable(self, obj):
return hasattr(obj, '__iter__') and not isinstance(obj, str)
def can_trade(self, assets):
"""
For the given asset or iterable of assets, returns true if the asset
is alive at the current simulation time and there is a known last
price.
Parameters
----------
assets: Asset or iterable of assets
Returns
-------
boolean or Series of booleans, indexed by asset.
"""
dt = self.simulation_dt_func()
if self._adjust_minutes:
adjusted_dt = self._get_current_minute()
else:
adjusted_dt = dt
data_portal = self.data_portal
if isinstance(assets, Asset):
return self._can_trade_for_asset(
assets, dt, adjusted_dt, data_portal
)
else:
return pd.Series(data={
asset: self._can_trade_for_asset(
asset, dt, adjusted_dt, data_portal
)
for asset in assets
})
cdef bool _can_trade_for_asset(self, asset, dt, adjusted_dt, data_portal):
if asset._is_alive(dt, False):
# is there a last price?
return not np.isnan(
data_portal.get_spot_value(
asset, "price", adjusted_dt, self.data_frequency
)
)
return False
def is_stale(self, assets):
"""
For the given asset or iterable of assets, returns true if the asset
is alive and there is no trade data for the current simulation time.
If the asset has never traded, returns False.
If the current simulation time is not a valid market time, we use the
current time to check if the asset is alive, but we use the last
market minute/day for the trade data check.
Parameters
----------
assets: Asset or iterable of assets
Returns
-------
boolean or Series of booleans, indexed by asset.
"""
dt = self.simulation_dt_func()
if self._adjust_minutes:
adjusted_dt = self._get_current_minute()
else:
adjusted_dt = dt
data_portal = self.data_portal
if isinstance(assets, Asset):
return self._is_stale_for_asset(
assets, dt, adjusted_dt, data_portal
)
else:
return pd.Series(data={
asset: self._is_stale_for_asset(
asset, dt, adjusted_dt, data_portal
)
for asset in assets
})
cdef bool _is_stale_for_asset(self, asset, dt, adjusted_dt, data_portal):
if not asset._is_alive(dt, False):
return False
current_volume = data_portal.get_spot_value(
asset, "volume", adjusted_dt, self.data_frequency
)
if current_volume > 0:
# found a current value, so we know this asset is not stale.
return False
else:
# we need to distinguish between if this asset has ever traded
# (stale = True) or has never traded (stale = False)
last_traded_dt = \
data_portal.get_spot_value(asset, "last_traded", adjusted_dt,
self.data_frequency)
return not (last_traded_dt is pd.NaT)
@assert_keywords('assets', 'fields', 'bar_count', 'frequency')
def history(self, assets, fields, bar_count, frequency):
"""
Returns a window of data for the given assets and fields.
This data is adjusted for splits, dividends, and mergers as of the
current algorithm time.
The semantics of missing data are identical to the ones described in
the notes for `get_spot_value`.
Parameters
----------
assets: Asset or iterable of Asset
fields: string or iterable of string. Valid values are "open", "high",
"low", "close", "volume", "price", and "last_traded".
bar_count: integer number of bars of trade data
frequency: string. "1m" for minutely data or "1d" for daily date
Returns
-------
Series or DataFrame or Panel, depending on the dimensionality of
the 'assets' and 'fields' parameters.
If single asset and field are passed in, the returned Series is
indexed by dt.
If multiple assets and single field are passed in, the returned
DataFrame is indexed by dt, and has assets as columns.
If a single asset and multiple fields are passed in, the returned
DataFrame is indexed by dt, and has fields as columns.
If multiple assets and multiple fields are passed in, the returned
Panel is indexed by field, has dt as the major axis, and assets
as the minor axis.
Notes
-----
If the current simulation time is not a valid market time, we use the
last market close instead.
"""
if isinstance(fields, str):
single_asset = isinstance(assets, Asset)
if single_asset:
asset_list = [assets]
else:
asset_list = assets
df = self.data_portal.get_history_window(
asset_list,
self._get_current_minute(),
bar_count,
frequency,
fields
)
if self._adjust_minutes:
adjs = self.data_portal.get_adjustments(
assets,
fields,
self._get_current_minute(),
self.simulation_dt_func()
)
df = df * adjs
if single_asset:
# single asset, single field, return a series.
return df[assets]
else:
# multiple assets, single field, return a dataframe whose
# columns are the assets, indexed by dt.
return df
else:
if isinstance(assets, Asset):
# one asset, multiple fields. for now, just make multiple
# history calls, one per field, then stitch together the
# results. this can definitely be optimized!
df_dict = {
field: self.data_portal.get_history_window(
[assets],
self._get_current_minute(),
bar_count,
frequency,
field
)[assets] for field in fields
}
if self._adjust_minutes:
adjs = {
field: self.data_portal.get_adjustments(
assets,
field,
self._get_current_minute(),
self.simulation_dt_func()
)[0] for field in fields
}
df_dict = {field: df * adjs[field]
for field, df in iteritems(df_dict)}
# returned dataframe whose columns are the fields, indexed by
# dt.
return pd.DataFrame(df_dict)
else:
df_dict = {
field: self.data_portal.get_history_window(
assets,
self._get_current_minute(),
bar_count,
frequency,
field
) for field in fields
}
if self._adjust_minutes:
adjs = {
field: self.data_portal.get_adjustments(
assets,
field,
self._get_current_minute(),
self.simulation_dt_func()
) for field in fields
}
df_dict = {field: df * adjs[field]
for field, df in iteritems(df_dict)}
# returned panel has:
# items: fields
# major axis: dt
# minor axis: assets
return pd.Panel(df_dict)
property current_dt:
def __get__(self):
return self.simulation_dt_func()
@property
def fetcher_assets(self):
return self.data_portal.get_fetcher_assets(self.simulation_dt_func())
property _handle_non_market_minutes:
def __set__(self, val):
self._adjust_minutes = val
#################
# OLD API SUPPORT
#################
cdef _calculate_universe(self):
if self._universe_func is None:
return []
simulation_dt = self.simulation_dt_func()
if self._last_calculated_universe is None or \
self._universe_last_updated_at != simulation_dt:
self._last_calculated_universe = self._universe_func()
self._universe_last_updated_at = simulation_dt
return self._last_calculated_universe
def __iter__(self):
self._warn_deprecated("Iterating over the assets in `data` is "
"deprecated.")
for asset in self._calculate_universe():
yield asset
def __contains__(self, asset):
self._warn_deprecated("Checking whether an asset is in data is "
"deprecated.")
universe = self._calculate_universe()
return asset in universe
def items(self):
self._warn_deprecated("Iterating over the assets in `data` is "
"deprecated.")
return [(asset, self[asset]) for asset in self._calculate_universe()]
def iteritems(self):
self._warn_deprecated("Iterating over the assets in `data` is "
"deprecated.")
for asset in self._calculate_universe():
yield asset, self[asset]
def __len__(self):
self._warn_deprecated("Iterating over the assets in `data` is "
"deprecated.")
return len(self._calculate_universe())
def keys(self):
self._warn_deprecated("Iterating over the assets in `data` is "
"deprecated.")
return list(self._calculate_universe())
def iterkeys(self):
return iter(self.keys())
def __getitem__(self, name):
return self._get_equity_price_view(name)
cdef _warn_deprecated(self, msg):
warnings.warn(
msg,
category=ZiplineDeprecationWarning,
stacklevel=1
)
cdef class SidView:
cdef object asset
cdef object data_portal
cdef object simulation_dt_func
cdef object data_frequency
"""
This class exists to temporarily support the deprecated data[sid(N)] API.
"""
def __init__(self, asset, data_portal, simulation_dt_func, data_frequency):
"""
Parameters
---------
asset : Asset
The asset for which the instance retrieves data.
data_portal : DataPortal
Provider for bar pricing data.
simulation_dt_func: function
Function which returns the current simulation time.
This is usually bound to a method of TradingSimulation.
data_frequency: string
The frequency of the bar data; i.e. whether the data is
'daily' or 'minute' bars
"""
self.asset = asset
self.data_portal = data_portal
self.simulation_dt_func = simulation_dt_func
self.data_frequency = data_frequency
def __getattr__(self, column):
# backwards compatibility code for Q1 API
if column == "close_price":
column = "close"
elif column == "open_price":
column = "open"
elif column == "dt":
return self.dt
elif column == "datetime":
return self.datetime
elif column == "sid":
return self.sid
return self.data_portal.get_spot_value(
self.asset,
column,
self.simulation_dt_func(),
self.data_frequency
)
def __contains__(self, column):
return self.data_portal.contains(self.asset, column)
def __getitem__(self, column):
return self.__getattr__(column)
property sid:
def __get__(self):
return self.asset
property dt:
def __get__(self):
return self.datetime
property datetime:
def __get__(self):
return self.data_portal.get_last_traded_dt(
self.asset,
self.simulation_dt_func(),
self.data_frequency)
property current_dt:
def __get__(self):
return self.simulation_dt_func()
def mavg(self, num_minutes):
self._warn_deprecated("The `mavg` method is deprecated.")
return self.data_portal.get_simple_transform(
self.asset, "mavg", self.simulation_dt_func(),
self.data_frequency, bars=num_minutes
)
def stddev(self, num_minutes):
self._warn_deprecated("The `stddev` method is deprecated.")
return self.data_portal.get_simple_transform(
self.asset, "stddev", self.simulation_dt_func(),
self.data_frequency, bars=num_minutes
)
def vwap(self, num_minutes):
self._warn_deprecated("The `vwap` method is deprecated.")
return self.data_portal.get_simple_transform(
self.asset, "vwap", self.simulation_dt_func(),
self.data_frequency, bars=num_minutes
)
def returns(self):
self._warn_deprecated("The `returns` method is deprecated.")
return self.data_portal.get_simple_transform(
self.asset, "returns", self.simulation_dt_func(),
self.data_frequency
)
cdef _warn_deprecated(self, msg):
warnings.warn(
msg,
category=ZiplineDeprecationWarning,
stacklevel=1
)
+406 -343
View File
File diff suppressed because it is too large Load Diff
+9 -6
View File
@@ -17,8 +17,7 @@
# methods (e.g. order). These are added to this namespace via the
# decorator `api_methods` inside of algorithm.py.
import zipline
from .finance import (commission, slippage)
from .finance import (commission, slippage, cancel_policy)
from .utils import math_utils, events
from zipline.finance.slippage import (
@@ -26,20 +25,24 @@ from zipline.finance.slippage import (
VolumeShareSlippage,
)
from zipline.finance.cancel_policy import (
NeverCancel,
EODCancel
)
from zipline.utils.events import (
date_rules,
time_rules
)
batch_transform = zipline.transforms.BatchTransform
__all__ = [
'slippage',
'commission',
'cancel_policy',
'NeverCancel',
'EODCancel',
'events',
'math_utils',
'batch_transform',
'FixedSlippage',
'VolumeShareSlippage',
'date_rules',
+37 -2
View File
@@ -27,16 +27,19 @@ from cpython.object cimport (
Py_GT,
Py_LT,
)
from numbers import Integral
from cpython cimport bool
import numpy as np
from numpy cimport int64_t
import warnings
cimport numpy as np
# IMPORTANT NOTE: You must change this template if you change
# Asset.__reduce__, or else we'll attempt to unpickle an old version of this
# class
from pandas.tslib import normalize_date
CACHE_FILE_TEMPLATE = '/tmp/.%s-%s.v6.cache'
cdef class Asset:
@@ -175,6 +178,38 @@ cdef class Asset:
"""
return cls(**dict_)
def _is_alive(self, dt, bool normalized):
"""
Returns whether the asset is alive at the given dt.
Parameters
----------
dt: pd.Timestamp
The desired timestamp.
normalized: boolean
Whether the date has already been normalized. If not, we need
to first normalize the date before doing the alive check. If the
date is already normalized, this method runs up to 80% faster.
Returns
-------
boolean: whether the asset is alive at the given dt.
"""
cdef int64_t dt_value
cdef int64_t ref_start
cdef int64_t ref_end
if not normalized:
dt_value = normalize_date(dt).value
else:
dt_value = dt.value
ref_start = self.start_date.value
ref_end = self.end_date.value
return ref_start <= dt_value <= ref_end
cdef class Equity(Asset):
+29 -11
View File
@@ -21,7 +21,7 @@ import numpy as np
import pandas as pd
from pandas import isnull
from six import with_metaclass, string_types, viewkeys
from six.moves import map as imap, range
from six.moves import map as imap
import sqlalchemy as sa
from zipline.errors import (
@@ -40,12 +40,12 @@ from zipline.assets.asset_writer import (
check_version_info,
split_delimited_symbol,
asset_db_table_names,
SQLITE_MAX_VARIABLE_NUMBER,
)
from zipline.assets.asset_db_schema import (
ASSET_DB_VERSION
)
from zipline.utils.control_flow import invert
from zipline.utils.sqlite_utils import group_into_chunks
log = Logger('assets.py')
@@ -163,7 +163,7 @@ class AssetFinder(object):
router_cols = self.asset_router.c
for assets in self._group_into_chunks(missing):
for assets in group_into_chunks(missing):
query = sa.select((router_cols.sid, router_cols.asset_type)).where(
self.asset_router.c.sid.in_(map(int, assets))
)
@@ -176,12 +176,6 @@ class AssetFinder(object):
return found
@staticmethod
def _group_into_chunks(items, chunk_size=SQLITE_MAX_VARIABLE_NUMBER):
items = list(items)
return [items[x:x+chunk_size]
for x in range(0, len(items), chunk_size)]
def group_by_type(self, sids):
"""
Group a list of sids by asset type.
@@ -210,7 +204,7 @@ class AssetFinder(object):
Parameters
----------
sids : interable of int
sids : iterable of int
Assets to retrieve.
default_none : bool
If True, return None for failed lookups.
@@ -358,7 +352,7 @@ class AssetFinder(object):
cache = self._asset_cache
hits = {}
for assets in self._group_into_chunks(sids):
for assets in group_into_chunks(sids):
# Load misses from the db.
query = self._select_assets_by_sid(asset_tbl, assets)
@@ -666,6 +660,30 @@ class AssetFinder(object):
contracts = self.retrieve_futures_contracts(sids)
return [contracts[sid] for sid in sids]
def lookup_expired_futures(self, start, end):
if not isinstance(start, pd.Timestamp):
start = pd.Timestamp(start)
start = start.value
if not isinstance(end, pd.Timestamp):
end = pd.Timestamp(end)
end = end.value
fc_cols = self.futures_contracts.c
nd = sa.func.nullif(fc_cols.notice_date, pd.tslib.iNaT)
ed = sa.func.nullif(fc_cols.expiration_date, pd.tslib.iNaT)
date = sa.func.coalesce(sa.func.min(nd, ed), ed, nd)
sids = list(map(
itemgetter('sid'),
sa.select((fc_cols.sid,)).where(
(date >= start) & (date < end)).order_by(
sa.func.coalesce(ed, nd).asc()
).execute().fetchall()
))
return sids
@property
def sids(self):
return tuple(map(
+162
View File
@@ -0,0 +1,162 @@
from numpy cimport ndarray, long_t
from numpy import searchsorted
from cpython cimport bool
cimport cython
cdef inline int int_min(int a, int b): return a if a <= b else b
@cython.cdivision(True)
def minute_value(ndarray[long_t, ndim=1] market_opens,
Py_ssize_t pos,
short minutes_per_day):
"""
Finds the value of the minute represented by `pos` in the given array of
market opens.
Parameters
----------
market_opens: numpy array of ints
Market opens, in minute epoch values.
pos: int
The index of the desired minute.
minutes_per_day: int
The number of minutes per day (e.g. 390 for NYSE).
Returns
-------
int: The minute epoch value of the desired minute.
"""
cdef short q, r
q = cython.cdiv(pos, minutes_per_day)
r = cython.cmod(pos, minutes_per_day)
return market_opens[q] + r
def find_position_of_minute(ndarray[long_t, ndim=1] market_opens,
ndarray[long_t, ndim=1] market_closes,
long_t minute_val,
short minutes_per_day,
bool adjust_half_day_minutes):
"""
Finds the position of a given minute in the given array of market opens.
If not a market minute, adjusts to the last market minute.
Parameters
----------
market_opens: numpy array of ints
Market opens, in minute epoch values.
market_closes: numpy array of ints
Market closes, in minute epoch values.
minute_val: int
The desired minute, as a minute epoch.
minutes_per_day: int
The number of minutes per day (e.g. 390 for NYSE).
adjust_half_day_minutes: boolean
Whether or not we want to adjust non trading minutes to early close on
half days as opposed to normal close.
Further explanation of the use adjust_half_day_minutes:
adjust_half_day_minutes=True:
We are using this method for the purpose finding a value for a
minute, and therefore, all non market minutes must be adjusted to
the last available (e.g. 9 pm EST -> 4 pm EST, 2 pm EST -> 1 pm EST
on a half day)
adjust_half_day_minutes=False:
We are using this method for the purpose of finding the positions
of minutes we want to ignore (1 pm to 4 pm EST on half days).
The minute bar reader tape has 390 bars per day, with 0's filled in
for the extra bars on half days. If we index a minute between
1:01 pm and 4 pm on a half day, we want a position for that
unadjusted time, not adjusted to 1 pm as in the above case
(e.g. for all days: 9 pm EST -> 4 pm EST, 2 pm EST -> 2 pm EST)
Returns
-------
int: The position of the given minute in the market opens array.
"""
cdef Py_ssize_t market_open_loc, market_open, delta
market_open_loc = \
searchsorted(market_opens, minute_val, side='right') - 1
market_open = market_opens[market_open_loc]
market_close = market_closes[market_open_loc]
if adjust_half_day_minutes:
# The min of the distance to market open from minute_val and number
# of trading minutes for that day
delta = int_min(minute_val - market_open, market_close - market_open)
else:
# The min of the distance to market open from minute_val and number
# of trading minutes for a normal day (390)
delta = int_min(minute_val - market_open, minutes_per_day)
return (market_open_loc * minutes_per_day) + delta
def find_last_traded_position_internal(
ndarray[long_t, ndim=1] market_opens,
ndarray[long_t, ndim=1] market_closes,
long_t end_minute,
long_t start_minute,
volumes,
short minutes_per_day):
"""
Finds the position of the last traded minute for the given volumes array.
Parameters
----------
market_opens: numpy array of ints
Market opens, in minute epoch values.
market_closes: numpy array of ints
Market closes, in minute epoch values.
end_minute: int
The minute from which to start looking backwards, as a minute epoch.
start_minute: int
The asset's start date, as a minute epoch. Acts as the bottom limit of
how far we can look backwards.
volumes: bcolz carray
The volume history for the given asset.
minutes_per_day: int
The number of minutes per day (e.g. 390 for NYSE).
Returns
-------
int: The position of the last traded minute, starting from `minute_val`
"""
cdef Py_ssize_t minute_pos, current_minute
minute_pos = int_min(
find_position_of_minute(market_opens, market_closes, end_minute,
minutes_per_day, True),
len(volumes) - 1
)
while minute_pos >= 0:
current_minute = minute_value(
market_opens, minute_pos, minutes_per_day
)
if current_minute < start_minute:
return -1
if volumes[minute_pos] != 0:
return minute_pos
minute_pos -= 1
# we've gone to the beginning of this asset's range, and still haven't
# found a trade event
return -1
File diff suppressed because it is too large Load Diff
+177 -37
View File
@@ -15,12 +15,22 @@ from textwrap import dedent
import bcolz
from bcolz import ctable
import numpy as np
from numpy import nan_to_num
from intervaltree import IntervalTree
from numpy import nan_to_num, timedelta64
from os.path import join
import json
import os
import numpy as np
import pandas as pd
from zipline.gens.sim_engine import NANOS_IN_MINUTE
from zipline.data._minute_bar_internal import (
minute_value,
find_position_of_minute,
find_last_traded_position_internal
)
from zipline.utils.memoize import remember_last, lazyval
US_EQUITIES_MINUTES_PER_DAY = 390
@@ -90,20 +100,22 @@ class BcolzMinuteBarMetadata(object):
path = cls.metadata_path(rootdir)
with open(path) as fp:
raw_data = json.load(fp)
first_trading_day = pd.Timestamp(
raw_data['first_trading_day'], tz='UTC')
minute_index = pd.to_datetime(raw_data['minute_index'],
market_opens = pd.to_datetime(raw_data['market_opens'],
unit='m',
utc=True)
market_closes = pd.to_datetime(raw_data['market_closes'],
unit='m',
utc=True)
ohlc_ratio = raw_data['ohlc_ratio']
return cls(first_trading_day,
minute_index,
None, # currently only writing market_opens
None, # currently only writing market_closes
market_opens,
market_closes,
ohlc_ratio)
def __init__(self,
first_trading_day,
minute_index,
def __init__(self, first_trading_day,
market_opens,
market_closes,
ohlc_ratio):
@@ -124,7 +136,6 @@ class BcolzMinuteBarMetadata(object):
float data can be stored as an integer.
"""
self.first_trading_day = first_trading_day
self.minute_index = minute_index
self.market_opens = market_opens
self.market_closes = market_closes
self.ohlc_ratio = ohlc_ratio
@@ -146,7 +157,6 @@ class BcolzMinuteBarMetadata(object):
"""
metadata = {
'first_trading_day': str(self.first_trading_day.date()),
'minute_index': self.minute_index.asi8.tolist(),
'market_opens': self.market_opens.values.
astype('datetime64[m]').
astype(int).tolist(),
@@ -281,7 +291,6 @@ class BcolzMinuteBarWriter(object):
metadata = BcolzMinuteBarMetadata(
self._first_trading_day,
self._minute_index,
self._market_opens,
self._market_closes,
self._ohlc_ratio,
@@ -430,14 +439,10 @@ class BcolzMinuteBarWriter(object):
def write(self, sid, df):
"""
Write the OHLCV data for the given sid.
If there is no bcolz ctable yet created for the sid, create it.
If the length of the bcolz ctable is not exactly to the date before
the first day provided, fill the ctable with 0s up to that date.
Writes in blocks of the size of the days times minutes per day.
Parameters:
-----------
sid : int
@@ -465,18 +470,14 @@ class BcolzMinuteBarWriter(object):
def write_cols(self, sid, dts, cols):
"""
Write the OHLCV data for the given sid.
If there is no bcolz ctable yet created for the sid, create it.
If the length of the bcolz ctable is not exactly to the date before
the first day provided, fill the ctable with 0s up to that date.
Writes in blocks of the size of the days times minutes per day.
Parameters:
-----------
sid : int
The asset identifer for the data being written.
The asset identifier for the data being written.
dts : datetime64 array
The dts corresponding to values in cols.
cols : dict of str -> np.array
@@ -564,7 +565,14 @@ class BcolzMinuteBarReader(object):
metadata = self._get_metadata()
self._first_trading_day = metadata.first_trading_day
self._minute_index = metadata.minute_index
self._market_opens = metadata.market_opens
self._market_open_values = metadata.market_opens.values.\
astype('datetime64[m]').astype(int)
self._market_closes = metadata.market_closes
self._market_close_values = metadata.market_closes.values.\
astype('datetime64[m]').astype(int)
self._ohlc_inverse = 1.0 / metadata.ohlc_ratio
self._carrays = {
@@ -578,6 +586,84 @@ class BcolzMinuteBarReader(object):
def _get_metadata(self):
return BcolzMinuteBarMetadata.read(self._rootdir)
@lazyval
def last_available_dt(self):
return self._market_closes[-1]
@property
def first_trading_day(self):
return self._first_trading_day
def _minutes_to_exclude(self):
"""
Calculate the minutes which should be excluded when a window
occurs on days which had an early close, i.e. days where the close
based on the regular period of minutes per day and the market close
do not match.
Returns:
--------
List of DatetimeIndex representing the minutes to exclude because
of early closes.
"""
market_opens = self._market_opens.values.astype('datetime64[m]')
market_closes = self._market_closes.values.astype('datetime64[m]')
minutes_per_day = (market_closes - market_opens).astype(int)
early_indices = np.where(
minutes_per_day != US_EQUITIES_MINUTES_PER_DAY - 1)[0]
regular_closes = market_opens[early_indices] + timedelta64(
US_EQUITIES_MINUTES_PER_DAY - 1, 'm')
early_closes = market_closes[early_indices]
minutes = [pd.date_range(early, regular, freq='min')
for early, regular
in zip(early_closes + 1, regular_closes)]
return minutes
@lazyval
def _minute_exclusion_tree(self):
"""
Build an interval tree keyed by the start and end of each range
of positions should be dropped from windows. (These are the minutes
between an early close and the minute which would be the close based
on the regular period if there were no early close.)
The value of each node is the same start and end position stored as
a tuple.
The data is stored as such in support of a fast answer to the question,
does a given start and end position overlap any of the exclusion spans?
Returns
-------
IntervalTree containing nodes which represent the minutes to exclude
because of early closes.
"""
itree = IntervalTree()
for minute_range in self._minutes_to_exclude():
# setting adjust_half_day_minutes to False because we want to find
# the positions of minutes 211 to 390 on a 390-bar day
start_pos = self._find_position_of_minute(minute_range[0], False)
end_pos = self._find_position_of_minute(minute_range[-1], False)
data = (start_pos, end_pos)
itree[start_pos:end_pos + 1] = data
return itree
def _exclusion_indices_for_range(self, start_idx, end_idx):
"""
Returns
-------
List of tuples of (start, stop) which represent the ranges of minutes
which should be excluded when a market minute window is requested.
"""
itree = self._minute_exclusion_tree
if itree.overlaps(start_idx, end_idx):
ranges = []
intervals = itree[start_idx:end_idx]
for interval in intervals:
ranges.append(interval.data)
return ranges
else:
return None
def _get_carray_path(self, sid, field):
sid_subdir = _sid_subdir_path(sid)
# carrays are subdirectories of the sid's rootdir
@@ -623,22 +709,55 @@ class BcolzMinuteBarReader(object):
Returns the integer value of the volume.
(A volume of 0 signifies no trades for the given dt.)
"""
minute_pos = self._find_position_of_minute(dt)
minute_pos = self._find_position_of_minute(dt, True)
value = self._open_minute_file(field, sid)[minute_pos]
if value == 0:
if field != 'volume':
return np.nan
else:
if field == 'volume':
return 0
else:
return np.nan
if field != 'volume':
value *= self._ohlc_inverse
return value
def _find_position_of_minute(self, minute_dt):
def get_last_traded_dt(self, asset, dt):
minute_pos = self._find_last_traded_position(asset, dt)
if minute_pos == -1:
return pd.NaT
return self._pos_to_minute(minute_pos)
def _find_last_traded_position(self, asset, dt):
volumes = self._open_minute_file('volume', asset)
start_date_minutes = asset.start_date.value / NANOS_IN_MINUTE
dt_minutes = dt.value / NANOS_IN_MINUTE
if dt_minutes < start_date_minutes:
return -1
return find_last_traded_position_internal(
self._market_open_values,
self._market_close_values,
dt_minutes,
start_date_minutes,
volumes,
US_EQUITIES_MINUTES_PER_DAY
)
def _pos_to_minute(self, pos):
minute_epoch = minute_value(
self._market_open_values,
pos,
US_EQUITIES_MINUTES_PER_DAY
)
return pd.Timestamp(minute_epoch, tz='UTC', unit="m")
@remember_last
def _find_position_of_minute(self, minute_dt, adjust_half_day_minutes):
"""
Internal method that returns the position of the given minute in the
list of every trading minute since market open of the first trading
day.
day. Adjusts non market minutes to the last close.
ex. this method would return 1 for 2002-01-02 9:32 AM Eastern, if
2002-01-02 is the first trading day of the dataset.
@@ -648,14 +767,22 @@ class BcolzMinuteBarReader(object):
minute_dt: pd.Timestamp
The minute whose position should be calculated.
adjust_half_day_minutes: boolean
Whether or not we want to adjust minutes to early close on half
days.
Returns
-------
out : int
The position of the given minute in the list of all trading minutes
since market open on the first trading day.
int: The position of the given minute in the list of all trading
minutes since market open on the first trading day.
"""
return self._minute_index.get_loc(minute_dt)
return find_position_of_minute(
self._market_open_values,
self._market_close_values,
minute_dt.value / NANOS_IN_MINUTE,
US_EQUITIES_MINUTES_PER_DAY,
adjust_half_day_minutes
)
def unadjusted_window(self, fields, start_dt, end_dt, sids):
"""
@@ -677,13 +804,21 @@ class BcolzMinuteBarReader(object):
(sids, minutes in range) with a dtype of float64, containing the
values for the respective field over start and end dt range.
"""
# TODO: Handle early closes.
start_idx = self._find_position_of_minute(start_dt)
end_idx = self._find_position_of_minute(end_dt)
start_idx = self._find_position_of_minute(start_dt, True)
end_idx = self._find_position_of_minute(end_dt, True)
num_minutes = (end_idx - start_idx + 1)
results = []
shape = (len(sids), (end_idx - start_idx + 1))
indices_to_exclude = self._exclusion_indices_for_range(
start_idx, end_idx)
if indices_to_exclude is not None:
for excl_start, excl_stop in indices_to_exclude:
length = excl_stop - excl_start + 1
num_minutes -= length
shape = (len(sids), num_minutes)
for field in fields:
if field != 'volume':
@@ -694,6 +829,11 @@ class BcolzMinuteBarReader(object):
for i, sid in enumerate(sids):
carray = self._open_minute_file(field, sid)
values = carray[start_idx:end_idx + 1]
if indices_to_exclude is not None:
for excl_start, excl_stop in indices_to_exclude[::-1]:
excl_slice = np.s_[
excl_start - start_idx:excl_stop - start_idx + 1]
values = np.delete(values, excl_slice)
where = values != 0
out[i, where] = values[where]
if field != 'volume':
+316
View File
@@ -0,0 +1,316 @@
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import (
ABCMeta,
abstractmethod,
abstractproperty,
)
from numpy import dtype, around
from pandas.tslib import normalize_date
from six import iteritems, with_metaclass
from zipline.pipeline.data.equity_pricing import USEquityPricing
from zipline.lib._float64window import AdjustedArrayWindow as Float64Window
from zipline.lib.adjustment import Float64Multiply
from zipline.utils.cache import CachedObject, Expired
from zipline.utils.memoize import lazyval
class SlidingWindow(object):
"""
Wrapper around an AdjustedArrayWindow which supports monotonically
increasing (by datetime) requests for a sized window of data.
Parameters
----------
window : AdjustedArrayWindow
Window of pricing data with prefetched values beyond the current
simulation dt.
cal_start : int
Index in the overall calendar at which the window starts.
"""
def __init__(self, window, size, cal_start, offset):
self.window = window
self.cal_start = cal_start
self.current = around(next(window), 3)
self.offset = offset
self.most_recent_ix = self.cal_start + size
def get(self, end_ix):
"""
Returns
-------
out : A np.ndarray of the equity pricing up to end_ix after adjustments
and rounding have been applied.
"""
if self.most_recent_ix == end_ix:
return self.current
target = end_ix - self.cal_start - self.offset + 1
self.current = around(self.window.seek(target), 3)
self.most_recent_ix = end_ix
return self.current
class USEquityHistoryLoader(with_metaclass(ABCMeta)):
"""
Loader for sliding history windows of adjusted US Equity Pricing data.
Parameters
----------
reader : DailyBarReader, MinuteBarReader
Reader for pricing bars.
adjustment_reader : SQLiteAdjustmentReader
Reader for adjustment data.
"""
def __init__(self, env, reader, adjustment_reader):
self.env = env
self._reader = reader
self._adjustments_reader = adjustment_reader
self._window_blocks = {}
@abstractproperty
def _prefetch_length(self):
pass
@abstractproperty
def _calendar(self):
pass
@abstractmethod
def _array(self, start, end, assets, field):
pass
def _get_adjustments_in_range(self, assets, dts, field):
"""
Get the Float64Multiply objects to pass to an AdjustedArrayWindow.
For the use of AdjustedArrayWindow in the loader, which looks back
from current simulation time back to a window of data the dictionary is
structured with:
- the key into the dictionary for adjustments is the location of the
day from which the window is being viewed.
- the start of all multiply objects is always 0 (in each window all
adjustments are overlapping)
- the end of the multiply object is the location before the calendar
location of the adjustment action, making all days before the event
adjusted.
Parameters
----------
assets : iterable of Asset
The assets for which to get adjustments.
days : iterable of datetime64-like
The days for which adjustment data is needed.
field : str
OHLCV field for which to get the adjustments.
Returns
-------
out : The adjustments as a dict of loc -> Float64Multiply
"""
sids = {int(asset): i for i, asset in enumerate(assets)}
start = normalize_date(dts[0])
end = normalize_date(dts[-1])
adjs = {}
for sid, i in iteritems(sids):
if field != 'volume':
mergers = self._adjustments_reader.get_adjustments_for_sid(
'mergers', sid)
for m in mergers:
dt = m[0]
if start < dt <= end:
end_loc = dts.searchsorted(dt)
mult = Float64Multiply(0,
end_loc - 1,
i,
i,
m[1])
try:
adjs[end_loc].append(mult)
except KeyError:
adjs[end_loc] = [mult]
divs = self._adjustments_reader.get_adjustments_for_sid(
'dividends', sid)
for d in divs:
dt = d[0]
if start < dt <= end:
end_loc = dts.searchsorted(dt)
mult = Float64Multiply(0,
end_loc - 1,
i,
i,
d[1])
try:
adjs[end_loc].append(mult)
except KeyError:
adjs[end_loc] = [mult]
splits = self._adjustments_reader.get_adjustments_for_sid(
'splits', sid)
for s in splits:
dt = s[0]
if field == 'volume':
ratio = 1.0 / s[1]
else:
ratio = s[1]
if start < dt <= end:
end_loc = dts.searchsorted(dt)
mult = Float64Multiply(0,
end_loc - 1,
i,
i,
ratio)
try:
adjs[end_loc].append(mult)
except KeyError:
adjs[end_loc] = [mult]
return adjs
def _ensure_sliding_window(
self, assets, dts, field):
"""
Ensure that there is a Float64Multiply window that can provide data
for the given parameters.
If the corresponding window for the (assets, len(dts), field) does not
exist, then create a new one.
If a corresponding window does exist for (assets, len(dts), field), but
can not provide data for the current dts range, then create a new
one and replace the expired window.
WARNING: A simulation with a high variance of assets, may cause
unbounded growth of floating windows stored in `_window_blocks`.
There should be some regular clean up of the cache, if stale windows
prevent simulations from completing because of memory constraints.
Parameters
----------
assets : iterable of Assets
The assets in the window
dts : iterable of datetime64-like
The datetimes for which to fetch data.
Makes an assumption that all dts are present and contiguous,
in the calendar.
field : str
The OHLCV field for which to retrieve data.
Returns
-------
out : Float64Window with sufficient data so that the window can
provide `get` for the index corresponding with the last value in `dts`
"""
end = dts[-1]
size = len(dts)
assets_key = frozenset(assets)
try:
block_cache = self._window_blocks[(assets_key, field, size)]
try:
return block_cache.unwrap(end)
except Expired:
pass
except KeyError:
pass
start = dts[0]
offset = 0
start_ix = self._calendar.get_loc(start)
end_ix = self._calendar.get_loc(end)
cal = self._calendar
prefetch_end_ix = min(end_ix + self._prefetch_length, len(cal) - 1)
prefetch_end = cal[prefetch_end_ix]
prefetch_dts = cal[start_ix:prefetch_end_ix + 1]
array = self._array(prefetch_dts, assets, field)
if self._adjustments_reader:
adjs = self._get_adjustments_in_range(assets, prefetch_dts, field)
else:
adjs = {}
if field == 'volume':
array = array.astype('float64')
dtype_ = dtype('float64')
window = Float64Window(
array,
dtype_,
adjs,
offset,
size
)
block = SlidingWindow(window, size, start_ix, offset)
self._window_blocks[(assets_key, field, size)] = CachedObject(
block, prefetch_end)
return block
def history(self, assets, dts, field):
"""
A window of pricing data with adjustments applied assuming that the
end of the window is the day before the current simulation time.
Parameters
----------
assets : iterable of Assets
The assets in the window.
dts : iterable of datetime64-like
The datetimes for which to fetch data.
Makes an assumption that all dts are present and contiguous,
in the calendar.
field : str
The OHLCV field for which to retrieve data.
Returns
-------
out : np.ndarray with shape(len(days between start, end), len(assets))
"""
block = self._ensure_sliding_window(assets, dts, field)
end_ix = self._calendar.get_loc(dts[-1])
return block.get(end_ix)
class USEquityDailyHistoryLoader(USEquityHistoryLoader):
@property
def _prefetch_length(self):
return 40
@property
def _calendar(self):
return self._reader._calendar
def _array(self, dts, assets, field):
col = getattr(USEquityPricing, field)
return self._reader.load_raw_arrays(
[col], dts[0], dts[-1], assets)[0]
class USEquityMinuteHistoryLoader(USEquityHistoryLoader):
@property
def _prefetch_length(self):
return 1560
@lazyval
def _calendar(self):
mm = self.env.market_minutes
return mm[mm.slice_indexer(start=self._reader.first_trading_day,
end=self._reader.last_available_dt)]
def _array(self, dts, assets, field):
return self._reader.unadjusted_window(
[field], dts[0], dts[-1], assets)[0].T
+246 -4
View File
@@ -14,6 +14,7 @@
from abc import (
ABCMeta,
abstractmethod,
abstractproperty,
)
from errno import ENOENT
from os import remove
@@ -25,6 +26,7 @@ from bcolz import (
ctable,
open as open_ctable,
)
from collections import namedtuple
from click import progressbar
from numpy import (
array,
@@ -43,6 +45,8 @@ from pandas import (
DatetimeIndex,
read_csv,
Timestamp,
NaT,
isnull,
)
from six import (
iteritems,
@@ -50,6 +54,7 @@ from six import (
)
from zipline.utils.input_validation import coerce_string, preprocess
from zipline.utils.sqlite_utils import group_into_chunks
from ._equities import _compute_row_slices, _read_bcolz_data
from ._adjustments import load_adjustments_from_sqlite
@@ -123,7 +128,6 @@ class BcolzDailyBarWriter(with_metaclass(ABCMeta)):
--------
BcolzDailyBarReader : Consumer of the data written by this class.
"""
@abstractmethod
def gen_tables(self, assets):
"""
@@ -200,6 +204,8 @@ class BcolzDailyBarWriter(with_metaclass(ABCMeta)):
for k in US_EQUITY_PRICING_BCOLZ_COLUMNS
}
earliest_date = None
for asset_id, table in iterator:
nrows = len(table)
for column_name in columns:
@@ -212,6 +218,11 @@ class BcolzDailyBarWriter(with_metaclass(ABCMeta)):
self.to_uint32(table[column_name][:], column_name)
)
if earliest_date is None:
earliest_date = table["day"][0]
else:
earliest_date = min(earliest_date, table["day"][0])
# Bcolz doesn't support ints as keys in `attrs`, so convert
# assets to strings for use as attr keys.
asset_key = str(asset_id)
@@ -245,6 +256,9 @@ class BcolzDailyBarWriter(with_metaclass(ABCMeta)):
rootdir=filename,
mode='w',
)
full_table.attrs['first_trading_day'] = \
int(earliest_date / 1e6)
full_table.attrs['first_row'] = first_row
full_table.attrs['last_row'] = last_row
full_table.attrs['calendar_offset'] = calendar_offset
@@ -314,7 +328,24 @@ class DailyBarWriterFromCSVs(BcolzDailyBarWriter):
)
class BcolzDailyBarReader(object):
class DailyBarReader(with_metaclass(ABCMeta)):
"""
Reader for OHCLV pricing data at a daily frequency.
"""
@abstractmethod
def load_raw_arrays(self, columns, start_date, end_date, assets):
pass
@abstractmethod
def spot_price(self, sid, day, colname):
pass
@abstractproperty
def last_available_dt(self):
pass
class BcolzDailyBarReader(DailyBarReader):
"""
Reader for raw pricing data written by BcolzDailyOHLCVWriter.
@@ -383,12 +414,24 @@ class BcolzDailyBarReader(object):
int(id_): offset
for id_, offset in iteritems(table.attrs['calendar_offset'])
}
try:
self._first_trading_day = Timestamp(
table.attrs['first_trading_day'],
unit='ms',
tz='UTC'
)
except KeyError:
self._first_trading_day = None
# Cache of fully read np.array for the carrays in the daily bar table.
# raw_array does not use the same cache, but it could.
# Need to test keeping the entire array in memory for the course of a
# process first.
self._spot_cols = {}
self.PRICE_ADJUSTMENT_FACTOR = 0.001
def _compute_slices(self, start_idx, end_idx, assets):
"""
Compute the raw row indices to load for each asset on a query for the
@@ -449,6 +492,14 @@ class BcolzDailyBarReader(object):
offsets,
)
@property
def first_trading_day(self):
return self._first_trading_day
@property
def last_available_dt(self):
return self._calendar[-1]
def _spot_col(self, colname):
"""
Get the colname from daily_bar_table and read all of it into memory,
@@ -468,9 +519,33 @@ class BcolzDailyBarReader(object):
try:
col = self._spot_cols[colname]
except KeyError:
col = self._spot_cols[colname] = self._table[colname][:]
col = self._spot_cols[colname] = self._table[colname]
return col
def get_last_traded_dt(self, asset, day):
volumes = self._spot_col('volume')
if day >= asset.end_date:
# go back to one day before the asset ended
search_day = self._calendar[
self._calendar.searchsorted(asset.end_date) - 1
]
else:
search_day = day
while True:
try:
ix = self.sid_day_index(asset, search_day)
except NoDataOnDate:
return None
if volumes[ix] != 0:
return search_day
prev_day_ix = self._calendar.get_loc(search_day) - 1
if prev_day_ix > -1:
search_day = self._calendar[prev_day_ix]
else:
return None
def sid_day_index(self, sid, day):
"""
Parameters
@@ -487,7 +562,11 @@ class BcolzDailyBarReader(object):
Raises a NoDataOnDate exception if the given day and sid is before
or after the date range of the equity.
"""
day_loc = self._calendar.get_loc(day)
try:
day_loc = self._calendar.get_loc(day)
except:
raise NoDataOnDate("day={0} is outside of calendar={1}".format(
day, self._calendar))
offset = day_loc - self._calendar_offsets[sid]
if offset < 0:
raise NoDataOnDate(
@@ -530,6 +609,93 @@ class BcolzDailyBarReader(object):
return price
class PanelDailyBarReader(DailyBarReader):
"""
Reader for data passed as Panel.
DataPanel Structure
-------
items : Int64Index, asset identifiers
major_axis : DatetimeIndex, days provided by the Panel.
minor_axis : ['open', 'high', 'low', 'close', 'volume']
Attributes
----------
The table with which this loader interacts contains the following
attributes:
panel : pd.Panel
The panel from which to read OHLCV data.
first_trading_day : pd.Timestamp
The first trading day in the dataset.
"""
def __init__(self, calendar, panel):
panel = panel.copy()
if 'volume' not in panel.items:
# Fake volume if it does not exist.
panel.loc[:, :, 'volume'] = int(1e9)
self.first_trading_day = panel.major_axis[0]
self._calendar = calendar
self.panel = panel
@property
def last_available_dt(self):
return self._calendar[-1]
def load_raw_arrays(self, columns, start_date, end_date, assets):
col_names = [col.name for col in columns]
cal = self._calendar
index = cal[cal.slice_indexer(start_date, end_date)]
result = self.panel.loc[assets, start_date:end_date, col_names]
return result.reindex_axis(index, 1).values
def spot_price(self, sid, day, colname):
"""
Parameters
----------
sid : int
The asset identifier.
day : datetime64-like
Midnight of the day for which data is requested.
colname : string
The price field. e.g. ('open', 'high', 'low', 'close', 'volume')
Returns
-------
float
The spot price for colname of the given sid on the given day.
Raises a NoDataOnDate exception if the given day and sid is before
or after the date range of the equity.
Returns -1 if the day is within the date range, but the price is
0.
"""
return self.panel[sid, day, colname]
def get_last_traded_dt(self, sid, dt):
"""
Parameters
----------
sid : int
The asset identifier.
dt : datetime64-like
Midnight of the day for which data is requested.
Returns
-------
pd.Timestamp : The last know dt for the asset and dt;
NaT if no trade is found before the given dt.
"""
while dt in self.panel.major_axis:
freq = self.panel.major_axis.freq
if not isnull(self.panel.loc[sid, dt, 'close']):
return dt
dt -= freq
else:
return NaT
class SQLiteAdjustmentWriter(object):
"""
Writer for data to be read by SQLiteAdjustmentReader
@@ -900,6 +1066,23 @@ class SQLiteAdjustmentWriter(object):
self.conn.close()
UNPAID_QUERY_TEMPLATE = """
SELECT sid, amount, pay_date from dividend_payouts
WHERE ex_date=? AND sid IN ({0})
"""
Dividend = namedtuple('Dividend', ['asset', 'amount', 'pay_date'])
UNPAID_STOCK_DIVIDEND_QUERY_TEMPLATE = """
SELECT sid, payment_sid, ratio, pay_date from stock_dividend_payouts
WHERE ex_date=? AND sid IN ({0})
"""
StockDividend = namedtuple(
'StockDividend',
['asset', 'payment_asset', 'ratio', 'pay_date'])
class SQLiteAdjustmentReader(object):
"""
Loads adjustments based on corporate actions from a SQLite database.
@@ -923,3 +1106,62 @@ class SQLiteAdjustmentReader(object):
dates,
assets,
)
def get_adjustments_for_sid(self, table_name, sid):
t = (sid,)
c = self.conn.cursor()
adjustments_for_sid = c.execute(
"SELECT effective_date, ratio FROM %s WHERE sid = ?" %
table_name, t).fetchall()
c.close()
return [[Timestamp(adjustment[0], unit='s', tz='UTC'), adjustment[1]]
for adjustment in
adjustments_for_sid]
def get_dividends_with_ex_date(self, assets, date, asset_finder):
seconds = date.value / int(1e9)
c = self.conn.cursor()
divs = []
for chunk in group_into_chunks(assets):
query = UNPAID_QUERY_TEMPLATE.format(
",".join(['?' for _ in chunk]))
t = (seconds,) + tuple(map(lambda x: int(x), chunk))
c.execute(query, t)
rows = c.fetchall()
for row in rows:
div = Dividend(
asset_finder.retrieve_asset(row[0]),
row[1], Timestamp(row[2], unit='s', tz='UTC'))
divs.append(div)
c.close()
return divs
def get_stock_dividends_with_ex_date(self, assets, date, asset_finder):
seconds = date.value / int(1e9)
c = self.conn.cursor()
stock_divs = []
for chunk in group_into_chunks(assets):
query = UNPAID_STOCK_DIVIDEND_QUERY_TEMPLATE.format(
",".join(['?' for _ in chunk]))
t = (seconds,) + tuple(map(lambda x: int(x), chunk))
c.execute(query, t)
rows = c.fetchall()
for row in rows:
stock_div = StockDividend(
asset_finder.retrieve_asset(row[0]), # asset
asset_finder.retrieve_asset(row[1]), # payment_asset
row[2],
Timestamp(row[3], unit='s', tz='UTC'))
stock_divs.append(stock_div)
c.close()
return stock_divs
+73 -2
View File
@@ -1,5 +1,5 @@
#
# Copyright 2013 Quantopian, Inc.
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,10 @@ class ZiplineError(Exception):
def __init__(self, **kwargs):
self.kwargs = kwargs
self.message = str(self)
@lazyval
def message(self):
return str(self)
def __str__(self):
msg = self.msg.format(**self.kwargs)
@@ -31,6 +34,33 @@ class ZiplineError(Exception):
__repr__ = __str__
class NoTradeDataAvailable(ZiplineError):
pass
class NoTradeDataAvailableTooEarly(NoTradeDataAvailable):
msg = "{sid} does not exist on {dt}. It started trading on {start_dt}."
class NoTradeDataAvailableTooLate(NoTradeDataAvailable):
msg = "{sid} does not exist on {dt}. It stopped trading on {end_dt}."
class BenchmarkAssetNotAvailableTooEarly(NoTradeDataAvailableTooEarly):
pass
class BenchmarkAssetNotAvailableTooLate(NoTradeDataAvailableTooLate):
pass
class InvalidBenchmarkAsset(ZiplineError):
msg = """
{sid} cannot be used as the benchmark because it has a stock \
dividend on {dt}. Choose another asset to use as the benchmark.
""".strip()
class WrongDataForTransform(ZiplineError):
"""
Raised whenever a rolling transform is called on an event that
@@ -60,6 +90,15 @@ You may only call 'set_slippage' in your initialize method.
""".strip()
class SetCancelPolicyPostInit(ZiplineError):
# Raised if a users script calls set_cancel_policy
# after the initialize method has returned.
msg = """
You attempted to set the cancel policy outside of `initialize`. \
You may only call 'set_cancel_policy' in your initialize method.
""".strip()
class RegisterTradingControlPostInit(ZiplineError):
# Raised if a user's script register's a trading control after initialize
# has been run.
@@ -90,6 +129,17 @@ Please use PerShare or PerTrade.
""".strip()
class UnsupportedCancelPolicy(ZiplineError):
"""
Raised if a user script calls set_cancel_policy with an object that isn't
a CancelPolicy.
"""
msg = """
You attempted to set the cancel policy with an unsupported class. Please use
an instance of CancelPolicy.
""".strip()
class SetCommissionPostInit(ZiplineError):
"""
Raised if a users script calls set_commission magic
@@ -147,6 +197,13 @@ class UnsupportedOrderParameters(ZiplineError):
msg = "{msg}"
class CannotOrderDelistedAsset(ZiplineError):
"""
Raised if an order is for a delisted asset.
"""
msg = "{msg}"
class BadOrderParameters(ZiplineError):
"""
Raised if any impossible parameters (nan, negative limit/stop)
@@ -162,6 +219,13 @@ class OrderDuringInitialize(ZiplineError):
msg = "{msg}"
class SetBenchmarkOutsideInitialize(ZiplineError):
"""
Raised if set_benchmark is called outside initialize()
"""
msg = "'set_benchmark' can only be called within initialize function."
class AccountControlViolation(ZiplineError):
"""
Raised if the account violates a constraint set by a AccountControl.
@@ -529,3 +593,10 @@ class AssetDBImpossibleDowngrade(ZiplineError):
"The existing Asset database is version: {db_version} which is lower "
"than the desired downgrade version: {desired_version}."
)
class HistoryWindowStartsBeforeData(ZiplineError):
msg = (
"History window extends before {first_trading_day}. To use this "
"history window, start the backtest on or after {suggested_start_day}."
)
+21 -15
View File
@@ -15,31 +15,37 @@
# limitations under the License.
import pandas as pd
from zipline import TradingAlgorithm
from zipline.api import order, sid
from zipline.api import order, symbol
from zipline.data.loader import load_bars_from_yahoo
# creating time interval
start = pd.Timestamp('2008-01-01', tz='UTC')
end = pd.Timestamp('2013-01-01', tz='UTC')
# loading the data
input_data = load_bars_from_yahoo(
stocks=['AAPL', 'MSFT'],
start=start,
end=end,
)
stocks = ['AAPL', 'MSFT']
def initialize(context):
context.has_ordered = False
context.stocks = stocks
def handle_data(context, data):
if not context.has_ordered:
for stock in data:
order(sid(stock), 100)
for stock in context.stocks:
order(symbol(stock), 100)
context.has_ordered = True
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data)
results = algo.run(input_data)
if __name__ == '__main__':
# creating time interval
start = pd.Timestamp('2008-01-01', tz='UTC')
end = pd.Timestamp('2013-01-01', tz='UTC')
# loading the data
input_data = load_bars_from_yahoo(
stocks=stocks,
start=start,
end=end,
)
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
identifiers=stocks)
results = algo.run(input_data)
+1 -1
View File
@@ -23,7 +23,7 @@ def initialize(context):
def handle_data(context, data):
order(symbol('AAPL'), 10)
record(AAPL=data[symbol('AAPL')].price)
record(AAPL=data.current(symbol('AAPL'), "price"))
# Note: this function can be removed if running
+10 -13
View File
@@ -26,41 +26,38 @@ momentum).
from zipline.api import order, record, symbol
# Import exponential moving average from talib wrapper
from zipline.transforms.ta import EMA
from talib import EMA
def initialize(context):
context.asset = symbol('AAPL')
# Add 2 mavg transforms, one with a long window, one with a short window.
context.short_ema_trans = EMA(timeperiod=20)
context.long_ema_trans = EMA(timeperiod=40)
# To keep track of whether we invested in the stock or not
context.invested = False
def handle_data(context, data):
short_ema = context.short_ema_trans.handle_data(data)
long_ema = context.long_ema_trans.handle_data(data)
if short_ema is None or long_ema is None:
trailing_window = data.history(context.asset, 'price', 40, '1d')
if trailing_window.isnull().values.any():
return
short_ema = EMA(trailing_window.values, timeperiod=20)
long_ema = EMA(trailing_window.values, timeperiod=40)
buy = False
sell = False
if (short_ema > long_ema).all() and not context.invested:
if (short_ema[-1] > long_ema[-1]) and not context.invested:
order(context.asset, 100)
context.invested = True
buy = True
elif (short_ema < long_ema).all() and context.invested:
elif (short_ema[-1] < long_ema[-1]) and context.invested:
order(context.asset, -100)
context.invested = False
sell = True
record(AAPL=data[context.asset].price,
short_ema=short_ema[context.asset],
long_ema=long_ema[context.asset],
record(AAPL=data.current(context.asset, "price"),
short_ema=short_ema[-1],
long_ema=long_ema[-1],
buy=buy,
sell=sell)
+8 -13
View File
@@ -22,15 +22,10 @@ its shares once the averages cross again (indicating downwards
momentum).
"""
from zipline.api import order_target, record, symbol, history, add_history
from zipline.api import order_target, record, symbol
def initialize(context):
# Register 2 histories that track daily prices,
# one with a 100 window and one with a 300 day window
add_history(100, '1d', 'price')
add_history(300, '1d', 'price')
context.sym = symbol('AAPL')
context.i = 0
@@ -45,21 +40,21 @@ def handle_data(context, data):
# Compute averages
# history() has to be called with the same params
# from above and returns a pandas dataframe.
short_mavg = history(100, '1d', 'price').mean()
long_mavg = history(300, '1d', 'price').mean()
short_mavg = data.history(context.sym, 'price', 100, '1d').mean()
long_mavg = data.history(context.sym, 'price', 300, '1d').mean()
# Trading logic
if short_mavg[context.sym] > long_mavg[context.sym]:
if short_mavg > long_mavg:
# order_target orders as many shares as needed to
# achieve the desired number of shares.
order_target(context.sym, 100)
elif short_mavg[context.sym] < long_mavg[context.sym]:
elif short_mavg < long_mavg:
order_target(context.sym, 0)
# Save values for later inspection
record(AAPL=data[context.sym].price,
short_mavg=short_mavg[context.sym],
long_mavg=long_mavg[context.sym])
record(AAPL=data.current(context.sym, "price"),
short_mavg=short_mavg,
long_mavg=long_mavg)
# Note: this function can be removed if running
+4 -4
View File
@@ -33,7 +33,6 @@ def initialize(algo, eps=1, window_length=5):
algo.init = True
algo.days = 0
algo.window_length = window_length
algo.add_transform('mavg', 5)
algo.set_commission(commission.PerShare(cost=0))
@@ -54,10 +53,11 @@ def handle_data(algo, data):
b = np.zeros(m)
# find relative moving average price for each asset
mavgs = data.history(algo.sids, 'price', algo.window_length, '1d').mean()
for i, sid in enumerate(algo.sids):
price = data[sid].price
price = data.current(sid, "price")
# Relative mean deviation
x_tilde[i] = data[sid].mavg(algo.window_length) / price
x_tilde[i] = mavgs[sid] / price
###########################
# Inside of OLMAR (algo 2)
@@ -101,7 +101,7 @@ def rebalance_portfolio(algo, data, desired_port):
for i, sid in enumerate(algo.sids):
current_amount[i] = algo.portfolio.positions[sid].amount
prices[i] = data[sid].price
prices[i] = data.current(sid, "price")
desired_amount = np.round(desired_port * positions_value / prices)
-158
View File
@@ -1,158 +0,0 @@
#!/usr/bin/env python
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logbook
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
from datetime import datetime
import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import batch_transform
from zipline.utils.factory import load_from_yahoo
from zipline.api import symbol
@batch_transform
def ols_transform(data, sid1, sid2):
"""Computes regression coefficient (slope and intercept)
via Ordinary Least Squares between two SIDs.
"""
p0 = data.price[sid1].values
p1 = sm.add_constant(data.price[sid2].values, prepend=True)
slope, intercept = sm.OLS(p0, p1).fit().params
return slope, intercept
class Pairtrade(TradingAlgorithm):
"""Pairtrading relies on cointegration of two stocks.
The expectation is that once the two stocks drifted apart
(i.e. there is spread), they will eventually revert again. Thus,
if we short the upward drifting stock and long the downward
drifting stock (in short, we buy the spread) once the spread
widened we can sell the spread with profit once they converged
again. A nice property of this algorithm is that we enter the
market in a neutral position.
This specific algorithm tries to exploit the cointegration of
Pepsi and Coca Cola by estimating the correlation between the
two. Divergence of the spread is evaluated by z-scoring.
"""
def initialize(self, window_length=100):
self.spreads = []
self.invested = 0
self.window_length = window_length
self.ols_transform = ols_transform(refresh_period=self.window_length,
window_length=self.window_length)
self.PEP = self.symbol('PEP')
self.KO = self.symbol('KO')
def handle_data(self, data):
######################################################
# 1. Compute regression coefficients between PEP and KO
params = self.ols_transform.handle_data(data, self.PEP, self.KO)
if params is None:
return
intercept, slope = params
######################################################
# 2. Compute spread and zscore
zscore = self.compute_zscore(data, slope, intercept)
self.record(zscores=zscore,
PEP=data[symbol('PEP')].price,
KO=data[symbol('KO')].price)
######################################################
# 3. Place orders
self.place_orders(data, zscore)
def compute_zscore(self, data, slope, intercept):
"""1. Compute the spread given slope and intercept.
2. zscore the spread.
"""
spread = (data[self.PEP].price -
(slope * data[self.KO].price + intercept))
self.spreads.append(spread)
spread_wind = self.spreads[-self.window_length:]
zscore = (spread - np.mean(spread_wind)) / np.std(spread_wind)
return zscore
def place_orders(self, data, zscore):
"""Buy spread if zscore is > 2, sell if zscore < .5.
"""
if zscore >= 2.0 and not self.invested:
self.order(self.PEP, int(100 / data[self.PEP].price))
self.order(self.KO, -int(100 / data[self.KO].price))
self.invested = True
elif zscore <= -2.0 and not self.invested:
self.order(self.PEP, -int(100 / data[self.PEP].price))
self.order(self.KO, int(100 / data[self.KO].price))
self.invested = True
elif abs(zscore) < .5 and self.invested:
self.sell_spread()
self.invested = False
def sell_spread(self):
"""
decrease exposure, regardless of position long/short.
buy for a short position, sell for a long.
"""
ko_amount = self.portfolio.positions[self.KO].amount
self.order(self.KO, -1 * ko_amount)
pep_amount = self.portfolio.positions[self.PEP].amount
self.order(self.PEP, -1 * pep_amount)
# Note: this function can be removed if running
# this algorithm on quantopian.com
def analyze(context=None, results=None):
ax1 = plt.subplot(211)
plt.title('PepsiCo & Coca-Cola Co. share prices')
results[['PEP', 'KO']].plot(ax=ax1)
plt.ylabel('Price (USD)')
plt.setp(ax1.get_xticklabels(), visible=False)
ax2 = plt.subplot(212, sharex=ax1)
results.zscores.plot(ax=ax2, color='r')
plt.ylabel('Z-scored spread')
plt.gcf().set_size_inches(18, 8)
plt.show()
# Note: this if-block should be removed if running
# this algorithm on quantopian.com
if __name__ == '__main__':
logbook.StderrHandler().push_application()
# Set the simulation start and end dates.
start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
# Load price data from yahoo.
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
start=start, end=end)
# Create and run the algorithm.
pairtrade = Pairtrade()
results = pairtrade.run(data)
# Plot the portfolio data.
analyze(results=results)
+162 -118
View File
@@ -16,50 +16,60 @@ import math
from logbook import Logger
from collections import defaultdict
from copy import copy
from six.moves import filter
import pandas as pd
from six import iteritems
import zipline.errors
import zipline.protocol as zp
from zipline.finance.slippage import (
VolumeShareSlippage,
transact_partial,
)
from zipline.finance.commission import PerShare
from zipline.finance.order import Order
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
from zipline.finance.slippage import VolumeShareSlippage
from zipline.finance.commission import PerShare
from zipline.finance.cancel_policy import NeverCancel
log = Logger('Blotter')
warning_logger = Logger('AlgoWarning')
class Blotter(object):
def __init__(self):
self.transact = transact_partial(VolumeShareSlippage(), PerShare())
def __init__(self, data_frequency, asset_finder, slippage_func=None,
commission=None, cancel_policy=None):
# these orders are aggregated by sid
self.open_orders = defaultdict(list)
# keep a dict of orders by their own id
self.orders = {}
# holding orders that have come in since the last
# event.
# all our legacy order management code works with integer sids.
# this lets us convert those to assets when needed. ideally, we'd just
# revamp all the legacy code to work with assets.
self.asset_finder = asset_finder
# holding orders that have come in since the last event.
self.new_orders = []
self.current_dt = None
self.max_shares = int(1e+11)
self.slippage_func = slippage_func or VolumeShareSlippage()
self.commission = commission or PerShare()
self.data_frequency = data_frequency
self.cancel_policy = cancel_policy if cancel_policy else NeverCancel()
def __repr__(self):
return """
{class_name}(
transact_partial={transact_partial},
slippage={slippage_func},
commission={commission},
open_orders={open_orders},
orders={orders},
new_orders={new_orders},
current_dt={current_dt})
""".strip().format(class_name=self.__class__.__name__,
transact_partial=self.transact.args,
slippage_func=self.slippage_func,
commission=self.commission,
open_orders=self.open_orders,
orders=self.orders,
new_orders=self.new_orders,
@@ -108,7 +118,7 @@ class Blotter(object):
return order.id
def cancel(self, order_id):
def cancel(self, order_id, relay_status=True):
if order_id not in self.orders:
return
@@ -123,26 +133,62 @@ class Blotter(object):
self.new_orders.remove(cur_order)
cur_order.cancel()
cur_order.dt = self.current_dt
# we want this order's new status to be relayed out
# along with newly placed orders.
self.new_orders.append(cur_order)
def cancel_all(self, sid):
if relay_status:
# we want this order's new status to be relayed out
# along with newly placed orders.
self.new_orders.append(cur_order)
def cancel_all_orders_for_asset(self, asset, warn=False,
relay_status=True):
"""
Cancel all open orders for a given sid.
Cancel all open orders for a given asset.
"""
# (sadly) open_orders is a defaultdict, so this will always succeed.
orders = self.open_orders[sid]
orders = self.open_orders[asset]
# We're making a copy here because `cancel` mutates the list of open
# orders in place. The right thing to do here would be to make
# self.open_orders no longer a defaultdict. If we do that, then we
# should just remove the orders once here and be done with the matter.
for order in orders[:]:
self.cancel(order.id)
self.cancel(order.id, relay_status)
if warn:
# Message appropriately depending on whether there's
# been a partial fill or not.
if order.filled > 0:
warning_logger.warn(
'Your order for {order_amt} shares of '
'{order_sym} has been partially filled. '
'{order_filled} shares were successfully '
'purchased. {order_failed} shares were not '
'filled by the end of day and '
'were canceled.'.format(
order_amt=order.amount,
order_sym=order.sid.symbol,
order_filled=order.filled,
order_failed=order.amount - order.filled,
)
)
else:
warning_logger.warn(
'Your order for {order_amt} shares of '
'{order_sym} failed to fill by the end of day '
'and was canceled.'.format(
order_amt=order.amount,
order_sym=order.sid.symbol,
)
)
assert not orders
del self.open_orders[sid]
del self.open_orders[asset]
def execute_cancel_policy(self, event):
if self.cancel_policy.should_cancel(event):
warn = self.cancel_policy.warn_on_cancel
for asset in copy(self.open_orders):
self.cancel_all_orders_for_asset(asset, warn,
relay_status=False)
def reject(self, order_id, reason=''):
"""
@@ -187,104 +233,102 @@ class Blotter(object):
# along with newly placed orders.
self.new_orders.append(cur_order)
def process_split(self, split_event):
if split_event.sid not in self.open_orders:
return
def process_splits(self, splits):
"""
Processes a list of splits by modifying any open orders as needed.
orders_to_modify = self.open_orders[split_event.sid]
for order in orders_to_modify:
order.handle_split(split_event)
Parameters
----------
splits: list
A list of splits. Each split is a tuple of (sid, ratio).
def process_benchmark(self, benchmark_event):
return
yield
Returns
-------
None
"""
for split in splits:
sid = split[0]
if sid not in self.open_orders:
return
def process_trade(self, trade_event):
orders_to_modify = self.open_orders[sid]
for order in orders_to_modify:
order.handle_split(split[1])
if trade_event.sid not in self.open_orders:
return
def get_transactions(self, bar_data):
"""
Creates a list of transactions based on the current open orders,
slippage model, and commission model.
if trade_event.volume < 1:
# there are zero volume trade_events bc some stocks trade
# less frequently than once per minute.
return
Parameters
----------
bar_data: zipline._protocol.BarData
orders = self.open_orders[trade_event.sid]
orders.sort(key=lambda o: o.dt)
# Only use orders for the current day or before
current_orders = filter(
lambda o: o.dt <= trade_event.dt,
orders)
Notes
-----
This method book-keeps the blotter's open_orders dictionary, so that
it is accurate by the time we're done processing open orders.
processed_orders = []
for txn, order in self.process_transactions(trade_event,
current_orders):
processed_orders.append(order)
yield txn, order
Returns
-------
transactions_list: List
transactions_list: list of transactions resulting from the current
open orders. If there were no open orders, an empty list is
returned.
# remove closed orders. we should only have to check
# processed orders
def not_open(order):
return not order.open
closed_orders = filter(not_open, processed_orders)
commissions_list: List
commissions_list: list of commissions resulting from filling the
open orders. A commission is an object with "sid" and "cost"
parameters. If there are no commission events (because, for
example, Zipline models the commission cost into the fill price
of the transaction), then this is None.
"""
closed_orders = []
transactions = []
if self.open_orders:
assets = self.asset_finder.retrieve_all(self.open_orders)
asset_dict = {asset.sid: asset for asset in assets}
for sid, asset_orders in iteritems(self.open_orders):
asset = asset_dict[sid]
for order, txn in \
self.slippage_func(bar_data, asset, asset_orders):
direction = math.copysign(1, txn.amount)
per_share, total_commission = \
self.commission.calculate(txn)
txn.price += per_share * direction
txn.commission = total_commission
order.filled += txn.amount
if txn.commission is not None:
order.commission = (order.commission or 0.0) + \
txn.commission
txn.dt = pd.Timestamp(txn.dt, tz='UTC')
order.dt = txn.dt
transactions.append(txn)
if not order.open:
closed_orders.append(order)
# remove all closed orders from our open_orders dict
for order in closed_orders:
orders.remove(order)
sid = order.sid
try:
sid_orders = self.open_orders[sid]
sid_orders.remove(order)
except KeyError:
continue
if len(orders) == 0:
del self.open_orders[trade_event.sid]
# now clear out the sids from our open_orders dict that have
# zero open orders
for sid in list(self.open_orders.keys()):
if len(self.open_orders[sid]) == 0:
del self.open_orders[sid]
def process_transactions(self, trade_event, current_orders):
for order, txn in self.transact(trade_event, current_orders):
if txn.type == zp.DATASOURCE_TYPE.COMMISSION:
order.commission = (order.commission or 0.0) + txn.cost
else:
if txn.amount == 0:
raise zipline.errors.TransactionWithNoAmount(txn=txn)
if math.copysign(1, txn.amount) != order.direction:
raise zipline.errors.TransactionWithWrongDirection(
txn=txn, order=order)
if abs(txn.amount) > abs(self.orders[txn.order_id].amount):
raise zipline.errors.TransactionVolumeExceedsOrder(
txn=txn, order=order)
order.filled += txn.amount
if txn.commission is not None:
order.commission = ((order.commission or 0.0) +
txn.commission)
# mark the date of the order to match the transaction
# that is filling it.
order.dt = txn.dt
yield txn, order
def __getstate__(self):
state_to_save = ['new_orders', 'orders', '_status']
state_dict = {k: self.__dict__[k] for k in state_to_save
if k in self.__dict__}
# Have to handle defaultdicts specially
state_dict['open_orders'] = dict(self.open_orders)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
self.__init__()
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Blotter saved is state too old.")
open_orders = defaultdict(list)
open_orders.update(state.pop('open_orders'))
self.open_orders = open_orders
self.__dict__.update(state)
# FIXME this API doesn't feel right (returning two things here)
return transactions, None
+50
View File
@@ -0,0 +1,50 @@
#
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from abc import abstractmethod
from six import with_metaclass
from zipline.gens.sim_engine import DAY_END
class CancelPolicy(with_metaclass(abc.ABCMeta)):
@abstractmethod
def should_cancel(self, event):
pass
class EODCancel(CancelPolicy):
"""
This policy cancels open orders at the end of the day. For now, Zipline
will only apply this policy to minutely simulations.
"""
def __init__(self, warn_on_cancel=True):
self.warn_on_cancel = warn_on_cancel
def should_cancel(self, event):
return event == DAY_END
class NeverCancel(CancelPolicy):
"""
Orders are never automatically canceled.
"""
def __init__(self):
self.warn_on_cancel = False
def should_cancel(self, event):
return False
+6 -70
View File
@@ -13,11 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from six import iteritems
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
DEFAULT_PER_SHARE_COST = 0.0075 # 0.75 cents per share
DEFAULT_MINIMUM_COST_PER_TRADE = 1.0 # $1 per trade
class PerShare(object):
@@ -26,7 +23,9 @@ class PerShare(object):
share cost with an optional minimum cost per trade.
"""
def __init__(self, cost=0.03, min_trade_cost=None):
def __init__(self,
cost=DEFAULT_PER_SHARE_COST,
min_trade_cost=DEFAULT_MINIMUM_COST_PER_TRADE):
"""
Cost parameter is the cost of a trade per-share. $0.03
means three cents per share, which is a very conservative
@@ -56,27 +55,6 @@ class PerShare(object):
commission = max(commission, self.min_trade_cost)
return abs(commission / transaction.amount), commission
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerShare saved state is too old.")
self.__dict__.update(state)
class PerTrade(object):
"""
@@ -84,7 +62,7 @@ class PerTrade(object):
trade cost.
"""
def __init__(self, cost=5.0):
def __init__(self, cost=DEFAULT_MINIMUM_COST_PER_TRADE):
"""
Cost parameter is the cost of a trade, regardless of
share count. $5.00 per trade is fairly typical of
@@ -104,27 +82,6 @@ class PerTrade(object):
return abs(self.cost / transaction.amount), self.cost
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerTrade saved state is too old.")
self.__dict__.update(state)
class PerDollar(object):
"""
@@ -151,24 +108,3 @@ class PerDollar(object):
"""
cost_per_share = transaction.price * self.cost
return cost_per_share, abs(transaction.amount) * cost_per_share
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerDollar saved state is too old.")
self.__dict__.update(state)
+2 -2
View File
@@ -190,7 +190,7 @@ class MaxOrderSize(TradingControl):
if self.max_shares is not None and abs(amount) > self.max_shares:
self.fail(asset, amount, _algo_datetime)
current_asset_price = algo_current_data[asset].price
current_asset_price = algo_current_data.current(asset, "price")
order_value = amount * current_asset_price
too_much_value = (self.max_notional is not None and
@@ -252,7 +252,7 @@ class MaxPositionSize(TradingControl):
if too_many_shares:
self.fail(asset, amount, algo_datetime)
current_price = algo_current_data[asset].price
current_price = algo_current_data.current(asset, "price")
value_post_order = shares_post_order * current_price
too_much_value = (self.max_notional is not None and
+5 -6
View File
@@ -179,18 +179,17 @@ def check_stoplimit_prices(price, label):
try:
if not isfinite(price):
raise BadOrderParameters(
msg="""Attempted to place an order with a {} price
of {}.""".format(label, price)
msg="Attempted to place an order with a {} price "
"of {}.".format(label, price)
)
# This catches arbitrary objects
except TypeError:
raise BadOrderParameters(
msg="""Attempted to place an order with a {} price
of {}.""".format(label, type(price))
msg="Attempted to place an order with a {} price "
"of {}.".format(label, type(price))
)
if price < 0:
raise BadOrderParameters(
msg="""Can't place a {} order
with a negative price.""".format(label)
msg="Can't place a {} order with a negative price.".format(label)
)
+27 -42
View File
@@ -1,5 +1,5 @@
#
# Copyright 2015 Quantopian, Inc.
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,10 @@ from copy import copy
import math
import uuid
from six import text_type, iteritems
from six import text_type
import zipline.protocol as zp
from zipline.utils.serialization_utils import VERSION_LABEL
from zipline.assets import Asset
from zipline.utils.enum import enum
ORDER_STATUS = enum(
@@ -41,12 +41,14 @@ class Order(object):
commission=None, id=None):
"""
@dt - datetime.datetime that the order was placed
@sid - stock sid of the order
@sid - asset for the order. called sid for historical reasons.
@amount - the number of shares to buy/sell
a positive sign indicates a buy
a negative sign indicates a sell
@filled - how many shares of the order have been filled so far
"""
assert isinstance(sid, Asset)
# get a string representation of the uuid.
self.id = id or self.make_id()
self.dt = dt
@@ -79,23 +81,23 @@ class Order(object):
obj = zp.Order(initial_values=pydict)
return obj
def check_triggers(self, event):
def check_triggers(self, price, dt):
"""
Update internal state based on price triggers and the
trade event's price.
"""
stop_reached, limit_reached, sl_stop_reached = \
self.check_order_triggers(event)
self.check_order_triggers(price)
if (stop_reached, limit_reached) \
!= (self.stop_reached, self.limit_reached):
self.dt = event.dt
self.dt = dt
self.stop_reached = stop_reached
self.limit_reached = limit_reached
if sl_stop_reached:
# Change the STOP LIMIT order into a LIMIT order
self.stop = None
def check_order_triggers(self, event):
def check_order_triggers(self, current_price):
"""
Given an order and a trade event, return a tuple of
(stop_reached, limit_reached).
@@ -129,34 +131,32 @@ class Order(object):
order_type |= LIMIT
if order_type == BUY | STOP | LIMIT:
if event.price >= self.stop:
if current_price >= self.stop:
sl_stop_reached = True
if event.price <= self.limit:
if current_price <= self.limit:
limit_reached = True
elif order_type == SELL | STOP | LIMIT:
if event.price <= self.stop:
if current_price <= self.stop:
sl_stop_reached = True
if event.price >= self.limit:
if current_price >= self.limit:
limit_reached = True
elif order_type == BUY | STOP:
if event.price >= self.stop:
if current_price >= self.stop:
stop_reached = True
elif order_type == SELL | STOP:
if event.price <= self.stop:
if current_price <= self.stop:
stop_reached = True
elif order_type == BUY | LIMIT:
if event.price <= self.limit:
if current_price <= self.limit:
limit_reached = True
elif order_type == SELL | LIMIT:
# This is a SELL LIMIT order
if event.price >= self.limit:
if current_price >= self.limit:
limit_reached = True
return (stop_reached, limit_reached, sl_stop_reached)
def handle_split(self, split_event):
ratio = split_event.ratio
def handle_split(self, ratio):
# update the amount, limit_price, and stop_price
# by the split's ratio
@@ -202,6 +202,14 @@ class Order(object):
def open(self):
return self.status in [ORDER_STATUS.OPEN, ORDER_STATUS.HELD]
@property
def asset(self):
"""
Convenience accessor to hide away a historical API that we'd like to
change at some point.
"""
return self.sid
@property
def triggered(self):
"""
@@ -232,26 +240,3 @@ class Order(object):
Unicode representation for this object.
"""
return text_type(repr(self))
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
state_dict['_status'] = self._status
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Order saved state is too old.")
self.__dict__.update(state)
+9 -49
View File
@@ -88,10 +88,6 @@ from six import itervalues, iteritems
import zipline.protocol as zp
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
log = logbook.Logger('Performance')
TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE
@@ -136,13 +132,19 @@ class PerformancePeriod(object):
self,
starting_cash,
asset_finder,
data_frequency,
data_portal,
period_open=None,
period_close=None,
keep_transactions=True,
keep_orders=False,
serialize_positions=True):
serialize_positions=True,
name=None):
self.asset_finder = asset_finder
self.data_frequency = data_frequency
self._data_portal = data_portal
self.period_open = period_open
self.period_close = period_close
@@ -167,6 +169,8 @@ class PerformancePeriod(object):
self.keep_transactions = keep_transactions
self.keep_orders = keep_orders
self.name = name
# An object to recycle via assigning new values
# when returning portfolio information.
# So as not to avoid creating a new object for each event
@@ -483,47 +487,3 @@ class PerformancePeriod(object):
account.net_liquidation = getattr(self, 'net_liquidation',
period_stats.net_liquidation)
return account
def __getstate__(self):
state_dict = {k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
state_dict['_portfolio_store'] = self._portfolio_store
state_dict['_account_store'] = self._account_store
state_dict['processed_transactions'] = \
dict(self.processed_transactions)
state_dict['orders_by_id'] = \
dict(self.orders_by_id)
state_dict['orders_by_modified'] = \
dict(self.orders_by_modified)
state_dict['_payout_last_sale_prices'] = \
self._payout_last_sale_prices
STATE_VERSION = 3
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 3
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerformancePeriod saved state is too old.")
processed_transactions = {}
processed_transactions.update(state.pop('processed_transactions'))
orders_by_id = OrderedDict()
orders_by_id.update(state.pop('orders_by_id'))
orders_by_modified = {}
orders_by_modified.update(state.pop('orders_by_modified'))
self.processed_transactions = processed_transactions
self.orders_by_id = orders_by_id
self.orders_by_modified = orders_by_modified
self._execution_cash_flow_multipliers = {}
self.__dict__.update(state)
+17 -56
View File
@@ -1,5 +1,5 @@
#
# Copyright 2014 Quantopian, Inc.
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,16 +34,8 @@ Position Tracking
from __future__ import division
from math import copysign
from collections import OrderedDict
from copy import copy
import logbook
import numpy as np
import zipline.protocol as zp
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
import logbook
log = logbook.Logger('Performance')
@@ -64,23 +56,21 @@ class Position(object):
Register the number of shares we held at this dividend's ex date so
that we can pay out the correct amount on the dividend's pay date.
"""
assert dividend['sid'] == self.sid
out = {'id': dividend['id']}
return {
'amount': self.amount * dividend.amount
}
# stock dividend
if dividend['payment_sid']:
out['payment_sid'] = dividend['payment_sid']
out['share_count'] = np.floor(self.amount
* float(dividend['ratio']))
# cash dividend
if dividend['net_amount']:
out['cash_amount'] = self.amount * dividend['net_amount']
elif dividend['gross_amount']:
out['cash_amount'] = self.amount * dividend['gross_amount']
payment_owed = zp.dividend_payment(out)
return payment_owed
def earn_stock_dividend(self, stock_dividend):
"""
Register the number of shares we held at this dividend's ex date so
that we can pay out the correct amount on the dividend's pay date.
"""
return {
'payment_asset': stock_dividend.payment_asset,
'share_count': np.floor(
self.amount * float(stock_dividend.ratio)
)
}
def handle_split(self, sid, ratio):
"""
@@ -92,10 +82,6 @@ class Position(object):
if self.sid != sid:
raise Exception("updating split with the wrong sid!")
log.info("handling split for sid = " + str(sid) +
", ratio = " + str(ratio))
log.info("before split: " + str(self))
# adjust the # of shares by the ratio
# (if we had 100 shares, and the ratio is 3,
# we now have 33 shares)
@@ -114,11 +100,7 @@ class Position(object):
# adjust the cost basis to the nearest cent, e.g., 60.0
new_cost_basis = round(self.cost_basis * ratio, 2)
# adjust the last sale price
new_last_sale_price = round(self.last_sale_price * ratio, 2)
self.cost_basis = new_cost_basis
self.last_sale_price = new_last_sale_price
self.amount = full_share_count
return_cash = round(float(fractional_share_count * new_cost_basis), 2)
@@ -210,28 +192,7 @@ last_sale_price: {last_sale_price}"
'last_sale_price': self.last_sale_price
}
def __getstate__(self):
state_dict = copy(self.__dict__)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Position saved state is too old.")
self.__dict__.update(state)
class positiondict(OrderedDict):
def __missing__(self, key):
pos = Position(key)
self[key] = pos
return pos
return None
+149 -138
View File
@@ -1,5 +1,5 @@
#
# Copyright 2015 Quantopian, Inc.
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,9 +17,11 @@ from __future__ import division
import logbook
import numpy as np
import pandas as pd
from pandas.lib import checknull
from collections import namedtuple
from math import isnan
from zipline.finance.performance.position import Position
from zipline.finance.transaction import Transaction
try:
# optional cython based OrderedDict
from cyordereddict import OrderedDict
@@ -27,11 +29,6 @@ except ImportError:
from collections import OrderedDict
from six import iteritems, itervalues
from zipline.finance.transaction import Transaction
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
import zipline.protocol as zp
from zipline.assets import (
Equity, Future
@@ -122,19 +119,25 @@ def calc_gross_value(long_value, short_value):
class PositionTracker(object):
def __init__(self, asset_finder):
def __init__(self, asset_finder, data_portal, data_frequency):
self.asset_finder = asset_finder
# FIXME really want to avoid storing a data portal here,
# but the path to get to maybe_create_close_position_transaction
# is long and tortuous
self._data_portal = data_portal
# sid => position object
self.positions = positiondict()
# Arrays for quick calculations of positions value
self._position_value_multipliers = OrderedDict()
self._position_exposure_multipliers = OrderedDict()
self._unpaid_dividends = pd.DataFrame(
columns=zp.DIVIDEND_PAYMENT_FIELDS,
)
self._unpaid_dividends = {}
self._unpaid_stock_dividends = {}
self._positions_store = zp.Positions()
self.data_frequency = data_frequency
def _update_asset(self, sid):
try:
self._position_value_multipliers[sid]
@@ -153,21 +156,6 @@ class PositionTracker(object):
self._position_value_multipliers[sid] = 0
self._position_exposure_multipliers[sid] = asset.multiplier
def update_last_sale(self, event):
# NOTE, PerformanceTracker already vetted as TRADE type
sid = event.sid
if sid not in self.positions:
return 0
price = event.price
if checknull(price):
return 0
pos = self.positions[sid]
pos.last_sale_date = event.dt
pos.last_sale_price = price
def update_positions(self, positions):
# update positions in batch
self.positions.update(positions)
@@ -176,24 +164,47 @@ class PositionTracker(object):
def update_position(self, sid, amount=None, last_sale_price=None,
last_sale_date=None, cost_basis=None):
pos = self.positions[sid]
if sid not in self.positions:
position = Position(sid)
self.positions[sid] = position
else:
position = self.positions[sid]
if amount is not None:
pos.amount = amount
position.amount = amount
self._update_asset(sid=sid)
if last_sale_price is not None:
pos.last_sale_price = last_sale_price
position.last_sale_price = last_sale_price
if last_sale_date is not None:
pos.last_sale_date = last_sale_date
position.last_sale_date = last_sale_date
if cost_basis is not None:
pos.cost_basis = cost_basis
position.cost_basis = cost_basis
def execute_transaction(self, txn):
# Update Position
# ----------------
sid = txn.sid
position = self.positions[sid]
if sid not in self.positions:
position = Position(sid)
self.positions[sid] = position
else:
position = self.positions[sid]
position.update(txn)
if position.amount == 0:
# if this position now has 0 shares, remove it from our internal
# bookkeeping.
del self.positions[sid]
try:
# if this position exists in our user-facing dictionary,
# remove it as well.
del self._positions_store[sid]
except KeyError:
pass
self._update_asset(sid)
def handle_commission(self, sid, cost):
@@ -201,102 +212,126 @@ class PositionTracker(object):
if sid in self.positions:
self.positions[sid].adjust_commission_cost_basis(sid, cost)
def handle_split(self, split):
if split.sid in self.positions:
# Make the position object handle the split. It returns the
# leftover cash from a fractional share, if there is any.
position = self.positions[split.sid]
leftover_cash = position.handle_split(split.sid, split.ratio)
self._update_asset(split.sid)
return leftover_cash
def handle_splits(self, splits):
"""
Processes a list of splits by modifying any positions as needed.
def _maybe_earn_dividend(self, dividend):
"""
Take a historical dividend record and return a Series with fields in
zipline.protocol.DIVIDEND_FIELDS (plus an 'id' field) representing
the cash/stock amount we are owed when the dividend is paid.
"""
if dividend['sid'] in self.positions:
return self.positions[dividend['sid']].earn_dividend(dividend)
else:
return zp.dividend_payment()
Parameters
----------
splits: list
A list of splits. Each split is a tuple of (sid, ratio).
def earn_dividends(self, dividend_frame):
Returns
-------
int: The leftover cash from fractional sahres after modifying each
position.
"""
Given a frame of dividends whose ex_dates are all the next trading day,
total_leftover_cash = 0
for split in splits:
sid = split[0]
if sid in self.positions:
# Make the position object handle the split. It returns the
# leftover cash from a fractional share, if there is any.
position = self.positions[sid]
leftover_cash = position.handle_split(sid, split[1])
self._update_asset(split[0])
total_leftover_cash += leftover_cash
return total_leftover_cash
def earn_dividends(self, dividends, stock_dividends):
"""
Given a list of dividends whose ex_dates are all the next trading day,
calculate and store the cash and/or stock payments to be paid on each
dividend's pay date.
Parameters
----------
dividends: iterable of (asset, amount, pay_date) namedtuples
stock_dividends: iterable of (asset, payment_asset, ratio, pay_date)
namedtuples.
"""
earned = dividend_frame.apply(self._maybe_earn_dividend, axis=1)\
.dropna(how='all')
if len(earned) > 0:
for dividend in dividends:
# Store the earned dividends so that they can be paid on the
# dividends' pay_dates.
self._unpaid_dividends = pd.concat(
[self._unpaid_dividends, earned],
)
div_owed = self.positions[dividend.asset].earn_dividend(dividend)
try:
self._unpaid_dividends[dividend.pay_date].append(div_owed)
except KeyError:
self._unpaid_dividends[dividend.pay_date] = [div_owed]
def _maybe_pay_dividend(self, dividend):
for stock_dividend in stock_dividends:
div_owed = \
self.positions[stock_dividend.asset].earn_stock_dividend(
stock_dividend)
try:
self._unpaid_stock_dividends[stock_dividend.pay_date].\
append(div_owed)
except KeyError:
self._unpaid_stock_dividends[stock_dividend.pay_date] = \
[div_owed]
def pay_dividends(self, next_trading_day):
"""
Take a historical dividend record, look up any stored record of
cash/stock we are owed for that dividend, and return a Series
with fields drawn from zipline.protocol.DIVIDEND_PAYMENT_FIELDS.
Returns a cash payment based on the dividends that should be paid out
according to the accumulated bookkeeping of earned, unpaid, and stock
dividends.
"""
net_cash_payment = 0.0
try:
unpaid_dividend = self._unpaid_dividends.loc[dividend['id']]
return unpaid_dividend
payments = self._unpaid_dividends[next_trading_day]
# Mark these dividends as paid by dropping them from our unpaid
del self._unpaid_dividends[next_trading_day]
except KeyError:
return zp.dividend_payment()
payments = []
def pay_dividends(self, dividend_frame):
"""
Given a frame of dividends whose pay_dates are all the next trading
day, grant the cash and/or stock payments that were calculated on the
given dividends' ex dates.
"""
payments = dividend_frame.apply(self._maybe_pay_dividend, axis=1)\
.dropna(how='all')
# Mark these dividends as paid by dropping them from our unpaid
# table.
self._unpaid_dividends.drop(payments.index)
# representing the fact that we're required to reimburse the owner of
# the stock for any dividends paid while borrowing.
for payment in payments:
net_cash_payment += payment['amount']
# Add stock for any stock dividends paid. Again, the values here may
# be negative in the case of short positions.
stock_payments = payments[payments['payment_sid'].notnull()]
for _, row in stock_payments.iterrows():
stock = row['payment_sid']
share_count = row['share_count']
try:
stock_payments = self._unpaid_stock_dividends[next_trading_day]
except:
stock_payments = []
for stock_payment in stock_payments:
payment_asset = stock_payment['payment_asset']
share_count = stock_payment['share_count']
# note we create a Position for stock dividend if we don't
# already own the asset
position = self.positions[stock]
if payment_asset in self.positions:
position = self.positions[payment_asset]
else:
position = self.positions[payment_asset] = \
Position(payment_asset)
position.amount += share_count
self._update_asset(stock)
self._update_asset(payment_asset)
# Add cash equal to the net cash payed from all dividends. Note that
# "negative cash" is effectively paid if we're short an asset,
# representing the fact that we're required to reimburse the owner of
# the stock for any dividends paid while borrowing.
net_cash_payment = payments['cash_amount'].fillna(0).sum()
return net_cash_payment
def maybe_create_close_position_transaction(self, event):
try:
pos = self.positions[event.sid]
amount = pos.amount
if amount == 0:
return None
except KeyError:
def maybe_create_close_position_transaction(self, asset, dt):
if not self.positions.get(asset):
return None
if 'price' in event:
price = event.price
else:
price = pos.last_sale_price
amount = self.positions.get(asset).amount
price = self._data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency)
# Get the last traded price if price is no longer available
if isnan(price):
price = self.positions.get(asset).last_sale_price
txn = Transaction(
sid=event.sid,
amount=(-1 * pos.amount),
dt=event.dt,
sid=asset,
amount=(-1 * amount),
dt=dt,
price=price,
commission=0,
order_id=None,
@@ -326,6 +361,8 @@ class PositionTracker(object):
position.amount = pos.amount
position.cost_basis = pos.cost_basis
position.last_sale_price = pos.last_sale_price
position.last_sale_date = pos.last_sale_date
return positions
def get_positions_list(self):
@@ -335,8 +372,15 @@ class PositionTracker(object):
positions.append(pos.to_dict())
return positions
def get_nonempty_position_sids(self):
return [sid for sid, pos in iteritems(self.positions) if pos.amount]
def sync_last_sale_prices(self, dt):
data_portal = self._data_portal
for asset, position in iteritems(self.positions):
last_sale_price = data_portal.get_spot_value(
asset, 'price', dt, self.data_frequency
)
if not np.isnan(last_sale_price):
position.last_sale_price = last_sale_price
def stats(self):
amounts = []
@@ -380,36 +424,3 @@ class PositionTracker(object):
shorts_count=shorts_count,
net_value=net_value
)
def __getstate__(self):
state_dict = {}
state_dict['asset_finder'] = self.asset_finder
state_dict['positions'] = dict(self.positions)
state_dict['unpaid_dividends'] = self._unpaid_dividends
STATE_VERSION = 4
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 3
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PositionTracker saved state is too old.")
self.asset_finder = state['asset_finder']
self.positions = positiondict()
# note that positions_store is temporary and gets regened from
# .positions
self._positions_store = zp.Positions()
self._unpaid_dividends = state['unpaid_dividends']
# Arrays for quick calculations of positions value
self._position_value_multipliers = OrderedDict()
self._position_exposure_multipliers = OrderedDict()
# Update positions is called without a finder
self.update_positions(state['positions'])
+70 -180
View File
@@ -1,5 +1,5 @@
#
# Copyright 2015 Quantopian, Inc.
# Copyright 2016 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -58,22 +58,17 @@ Performance Tracking
"""
from __future__ import division
import logbook
import pickle
from six import iteritems
from datetime import datetime
import numpy as np
import pandas as pd
from pandas.tseries.tools import normalize_date
import zipline.finance.risk as risk
from . period import PerformancePeriod
from zipline.finance.performance.period import PerformancePeriod
import zipline.finance.risk as risk
from zipline.utils.pandas_utils import sort_values
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
from . position_tracker import PositionTracker
log = logbook.Logger('Performance')
@@ -83,8 +78,7 @@ class PerformanceTracker(object):
"""
Tracks the performance of the algorithm.
"""
def __init__(self, sim_params, env):
def __init__(self, sim_params, env, data_portal):
self.sim_params = sim_params
self.env = env
@@ -107,17 +101,22 @@ class PerformanceTracker(object):
self.trading_days = all_trading_days[mask]
self.dividend_frame = pd.DataFrame()
self._dividend_count = 0
self._data_portal = data_portal
if data_portal is not None:
self._adjustment_reader = data_portal._adjustment_reader
else:
self._adjustment_reader = None
self.position_tracker = PositionTracker(asset_finder=env.asset_finder)
self.position_tracker = PositionTracker(
asset_finder=env.asset_finder,
data_portal=data_portal,
data_frequency=self.sim_params.data_frequency)
if self.emission_rate == 'daily':
self.all_benchmark_returns = pd.Series(
index=self.trading_days)
self.cumulative_risk_metrics = \
risk.RiskMetricsCumulative(self.sim_params, self.env)
elif self.emission_rate == 'minute':
self.all_benchmark_returns = pd.Series(index=pd.date_range(
self.sim_params.first_open, self.sim_params.last_close,
@@ -132,6 +131,8 @@ class PerformanceTracker(object):
self.cumulative_performance = PerformancePeriod(
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the cumulative period will be calculated over the entire test.
period_open=self.period_start,
period_close=self.period_end,
@@ -142,6 +143,7 @@ class PerformanceTracker(object):
# don't serialize positions for cumulative period
serialize_positions=False,
asset_finder=self.env.asset_finder,
name="Cumulative"
)
self.cumulative_performance.position_tracker = self.position_tracker
@@ -149,6 +151,8 @@ class PerformanceTracker(object):
self.todays_performance = PerformancePeriod(
# initial cash is your capital base.
starting_cash=self.capital_base,
data_frequency=self.sim_params.data_frequency,
data_portal=data_portal,
# the daily period will be calculated for the market day
period_open=self.market_open,
period_close=self.market_close,
@@ -156,6 +160,7 @@ class PerformanceTracker(object):
keep_orders=True,
serialize_positions=True,
asset_finder=self.env.asset_finder,
name="Daily"
)
self.todays_performance.position_tracker = self.position_tracker
@@ -185,67 +190,21 @@ class PerformanceTracker(object):
self.saved_dt = date
self.todays_performance.period_close = self.saved_dt
def update_dividends(self, new_dividends):
"""
Update our dividend frame with new dividends. @new_dividends should be
a DataFrame with columns containing at least the entries in
zipline.protocol.DIVIDEND_FIELDS.
"""
# Mark each new dividend with a unique integer id. This ensures that
# we can differentiate dividends whose date/sid fields are otherwise
# identical.
new_dividends['id'] = np.arange(
self._dividend_count,
self._dividend_count + len(new_dividends),
)
self._dividend_count += len(new_dividends)
self.dividend_frame = sort_values(pd.concat(
[self.dividend_frame, new_dividends]
), ['pay_date', 'ex_date']).set_index('id', drop=False)
def initialize_dividends_from_other(self, other):
"""
Helper for copying dividends to a new PerformanceTracker while
preserving dividend count. Useful if a simulation needs to create a
new PerformanceTracker mid-stream and wants to preserve stored dividend
info.
Note that this does not copy unpaid dividends.
"""
self.dividend_frame = other.dividend_frame
self._dividend_count = other._dividend_count
def handle_sid_removed_from_universe(self, sid):
"""
This method handles any behaviors that must occur when a SID leaves the
universe of the TradingAlgorithm.
Parameters
__________
sid : int
The sid of the Asset being removed from the universe.
"""
# Drop any dividends for the sid from the dividends frame
self.dividend_frame = self.dividend_frame[
self.dividend_frame.sid != sid
]
def get_portfolio(self, performance_needs_update, dt):
if performance_needs_update:
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
self.account_needs_update = True
return self.cumulative_performance.as_portfolio()
def update_performance(self):
# calculate performance as of last trade
self.cumulative_performance.calculate_performance()
self.todays_performance.calculate_performance()
def get_portfolio(self, performance_needs_update):
if performance_needs_update:
self.update_performance()
self.account_needs_update = True
return self.cumulative_performance.as_portfolio()
def get_account(self, performance_needs_update):
def get_account(self, performance_needs_update, dt):
if performance_needs_update:
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
self.account_needs_update = True
if self.account_needs_update:
@@ -261,7 +220,6 @@ class PerformanceTracker(object):
Creates a dictionary representing the state of this tracker.
Returns a dict object of the form described in header comments.
"""
# Default to the emission rate of this tracker if no type is provided
if emission_type is None:
emission_type = self.emission_rate
@@ -284,25 +242,14 @@ class PerformanceTracker(object):
return _dict
def _handle_event_price(self, event):
self.position_tracker.update_last_sale(event)
def process_trade(self, event):
self._handle_event_price(event)
def process_transaction(self, event):
self._handle_event_price(event)
def process_transaction(self, transaction):
self.txn_count += 1
self.cumulative_performance.handle_execution(event)
self.todays_performance.handle_execution(event)
self.position_tracker.execute_transaction(event)
self.cumulative_performance.handle_execution(transaction)
self.todays_performance.handle_execution(transaction)
self.position_tracker.execute_transaction(transaction)
def process_dividend(self, dividend):
log.info("Ignoring DIVIDEND event.")
def process_split(self, event):
leftover_cash = self.position_tracker.handle_split(event)
def handle_splits(self, splits):
leftover_cash = self.position_tracker.handle_splits(splits)
if leftover_cash > 0:
self.cumulative_performance.handle_cash_payment(leftover_cash)
self.todays_performance.handle_cash_payment(leftover_cash)
@@ -312,43 +259,16 @@ class PerformanceTracker(object):
self.todays_performance.record_order(event)
def process_commission(self, commission):
sid = commission.sid
cost = commission.cost
sid = commission['sid']
cost = commission['cost']
self.position_tracker.handle_commission(sid, cost)
self.cumulative_performance.handle_commission(cost)
self.todays_performance.handle_commission(cost)
def process_benchmark(self, event):
if self.sim_params.data_frequency == 'minute' and \
self.sim_params.emission_rate == 'daily':
# Minute data benchmarks should have a timestamp of market
# close, so that calculations are triggered at the right time.
# However, risk module uses midnight as the 'day'
# marker for returns, so adjust back to midnight.
midnight = pd.tseries.tools.normalize_date(event.dt)
else:
midnight = event.dt
if midnight not in self.all_benchmark_returns.index:
raise AssertionError(
("Date %s not allocated in all_benchmark_returns. "
"Calendar seems to mismatch with benchmark. "
"Benchmark container is=%s" %
(midnight,
self.all_benchmark_returns.index)))
self.all_benchmark_returns[midnight] = event.returns
def process_close_position(self, event):
# CLOSE_POSITION events that contain prices that must be handled as
# a final trade event
if 'price' in event:
self.process_trade(event)
def process_close_position(self, asset, dt):
txn = self.position_tracker.\
maybe_create_close_position_transaction(event)
maybe_create_close_position_transaction(asset, dt)
if txn:
self.process_transaction(txn)
@@ -362,32 +282,33 @@ class PerformanceTracker(object):
is the next trading day. Apply all such benefits, then recalculate
performance.
"""
if len(self.dividend_frame) == 0:
# We don't currently know about any dividends for this simulation
# period, so bail.
if self._adjustment_reader is None:
return
position_tracker = self.position_tracker
held_sids = set(position_tracker.positions)
# Dividends whose ex_date is the next trading day. We need to check if
# we own any of these stocks so we know to pay them out when the pay
# date comes.
ex_date_mask = (self.dividend_frame['ex_date'] == next_trading_day)
dividends_earnable = self.dividend_frame[ex_date_mask]
# Dividends whose pay date is the next trading day. If we held any of
# these stocks on midnight before the ex_date, we need to pay these out
# now.
pay_date_mask = (self.dividend_frame['pay_date'] == next_trading_day)
dividends_payable = self.dividend_frame[pay_date_mask]
if held_sids:
asset_finder = self.env.asset_finder
position_tracker = self.position_tracker
if len(dividends_earnable):
position_tracker.earn_dividends(dividends_earnable)
cash_dividends = self._adjustment_reader.\
get_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
stock_dividends = self._adjustment_reader.\
get_stock_dividends_with_ex_date(held_sids, next_trading_day,
asset_finder)
if not len(dividends_payable):
position_tracker.earn_dividends(
cash_dividends,
stock_dividends
)
net_cash_payment = position_tracker.pay_dividends(next_trading_day)
if not net_cash_payment:
return
net_cash_payment = position_tracker.pay_dividends(dividends_payable)
self.cumulative_performance.handle_dividends_paid(net_cash_payment)
self.todays_performance.handle_dividends_paid(net_cash_payment)
@@ -408,9 +329,10 @@ class PerformanceTracker(object):
A tuple of the minute perf packet and daily perf packet.
If the market day has not ended, the daily perf packet is None.
"""
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
todays_date = normalize_date(dt)
account = self.get_account(False)
account = self.get_account(False, dt)
bench_returns = self.all_benchmark_returns.loc[todays_date:dt]
# cumulative returns
@@ -426,27 +348,31 @@ class PerformanceTracker(object):
# if this is the close, update dividends for the next day.
# Return the performance tuple
if dt == self.market_close:
return (minute_packet, self._handle_market_close(todays_date))
return minute_packet, self._handle_market_close(todays_date)
else:
return (minute_packet, None)
return minute_packet, None
def handle_market_close_daily(self):
def handle_market_close_daily(self, dt):
"""
Function called after handle_data when running with daily emission
rate.
"""
self.position_tracker.sync_last_sale_prices(dt)
self.update_performance()
completed_date = self.day
account = self.get_account(False)
account = self.get_account(False, dt)
benchmark_value = self.all_benchmark_returns[completed_date]
# update risk metrics for cumulative performance
self.cumulative_risk_metrics.update(
completed_date,
self.todays_performance.returns,
self.all_benchmark_returns[completed_date],
benchmark_value,
account.leverage)
return self._handle_market_close(completed_date)
daily_packet = self._handle_market_close(completed_date)
return daily_packet
def _handle_market_close(self, completed_date):
@@ -514,39 +440,3 @@ class PerformanceTracker(object):
risk_dict = self.risk_report.to_dict()
return risk_dict
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
state_dict['dividend_frame'] = pickle.dumps(self.dividend_frame)
state_dict['_dividend_count'] = self._dividend_count
STATE_VERSION = 4
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 4
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerformanceTracker saved state is too old.")
self.__dict__.update(state)
# Handle the dividend frame specially
self.dividend_frame = pickle.loads(state['dividend_frame'])
# properly setup the perf periods
p_types = ['cumulative', 'todays']
for p_type in p_types:
name = p_type + '_performance'
period = getattr(self, name, None)
if period is None:
continue
period._position_tracker = self.position_tracker
+1 -26
View File
@@ -34,10 +34,6 @@ from . risk import (
sortino_ratio,
)
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
log = logbook.Logger('Risk Cumulative')
@@ -90,8 +86,7 @@ class RiskMetricsCumulative(object):
'information',
)
def __init__(self, sim_params, env,
create_first_day_stats=False):
def __init__(self, sim_params, env, create_first_day_stats=False):
self.treasury_curves = env.treasury_curves
self.start_date = sim_params.period_start.replace(
hour=0, minute=0, second=0, microsecond=0
@@ -454,23 +449,3 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}"
beta = algorithm_covariance / benchmark_variance
return beta
def __getstate__(self):
state_dict = {k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
STATE_VERSION = 3
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 3
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskMetricsCumulative \
saved state is too old.")
self.__dict__.update(state)
-24
View File
@@ -34,10 +34,6 @@ from . risk import (
sortino_ratio,
)
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
log = logbook.Logger('Risk Period')
choose_treasury = functools.partial(risk.choose_treasury,
@@ -323,23 +319,3 @@ class RiskMetricsPeriod(object):
return 0.0
else:
return max(self.algorithm_leverages)
def __getstate__(self):
state_dict = {k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
STATE_VERSION = 3
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 3
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskMetricsPeriod saved state \
is too old.")
self.__dict__.update(state)
-28
View File
@@ -60,14 +60,9 @@ Risk Report
import logbook
import datetime
from dateutil.relativedelta import relativedelta
from six import iteritems
from . period import RiskMetricsPeriod
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
log = logbook.Logger('Risk Report')
@@ -153,26 +148,3 @@ class RiskReport(object):
cur_start = cur_start + relativedelta(months=1)
return ends
def __getstate__(self):
state_dict = \
{k: v for k, v in iteritems(self.__dict__)
if not k.startswith('_')}
if '_dividend_count' in dir(self):
state_dict['_dividend_count'] = self._dividend_count
STATE_VERSION = 2
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 2
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskReport saved state is too old.")
self.__dict__.update(state)
+52 -90
View File
@@ -15,18 +15,10 @@
from __future__ import division
import abc
import math
from copy import copy
from functools import partial
from six import with_metaclass
from zipline.finance.transaction import create_transaction
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
SELL = 1 << 0
BUY = 1 << 1
@@ -34,53 +26,59 @@ STOP = 1 << 2
LIMIT = 1 << 3
def transact_stub(slippage, commission, event, open_orders):
"""
This is intended to be wrapped in a partial, so that the
slippage and commission models can be enclosed.
"""
for order, transaction in slippage(event, open_orders):
if transaction and transaction.amount != 0:
direction = math.copysign(1, transaction.amount)
per_share, total_commission = commission.calculate(transaction)
transaction.price += per_share * direction
transaction.commission = total_commission
yield order, transaction
def transact_partial(slippage, commission):
return partial(transact_stub, slippage, commission)
class LiquidityExceeded(Exception):
pass
DEFAULT_VOLUME_SLIPPAGE_BAR_LIMIT = 0.025
class SlippageModel(with_metaclass(abc.ABCMeta)):
def __init__(self):
self._volume_for_bar = 0
@property
def volume_for_bar(self):
return self._volume_for_bar
@abc.abstractproperty
def process_order(self, event, order):
def process_order(self, data, order):
pass
def simulate(self, event, current_orders):
def simulate(self, data, asset, orders_for_asset):
self._volume_for_bar = 0
volume = data.current(asset, "volume")
for order in current_orders:
if volume == 0:
return
# can use the close price, since we verified there's volume in this
# bar.
price = data.current(asset, "close")
dt = data.current_dt
for order in orders_for_asset:
if order.open_amount == 0:
continue
order.check_triggers(event)
order.check_triggers(price, dt)
if not order.triggered:
continue
txn = None
try:
txn = self.process_order(event, order)
execution_price, execution_volume = \
self.process_order(data, order)
if execution_price is not None:
txn = create_transaction(
order,
data.current_dt,
execution_price,
execution_volume
)
except LiquidityExceeded:
break
@@ -88,19 +86,20 @@ class SlippageModel(with_metaclass(abc.ABCMeta)):
self._volume_for_bar += abs(txn.amount)
yield order, txn
def __call__(self, event, current_orders, **kwargs):
return self.simulate(event, current_orders, **kwargs)
def __call__(self, bar_data, asset, current_orders):
return self.simulate(bar_data, asset, current_orders)
class VolumeShareSlippage(SlippageModel):
def __init__(self,
volume_limit=.25,
def __init__(self, volume_limit=DEFAULT_VOLUME_SLIPPAGE_BAR_LIMIT,
price_impact=0.1):
self.volume_limit = volume_limit
self.price_impact = price_impact
super(VolumeShareSlippage, self).__init__()
def __repr__(self):
return """
{class_name}(
@@ -110,9 +109,10 @@ class VolumeShareSlippage(SlippageModel):
volume_limit=self.volume_limit,
price_impact=self.price_impact)
def process_order(self, event, order):
def process_order(self, data, order):
volume = data.current(order.asset, "volume")
max_volume = self.volume_limit * event.volume
max_volume = self.volume_limit * volume
# price impact accounts for the total volume of transactions
# created against the current minute bar
@@ -126,19 +126,21 @@ class VolumeShareSlippage(SlippageModel):
cur_volume = int(min(remaining_volume, abs(order.open_amount)))
if cur_volume < 1:
return
return None, None
# tally the current amount into our total amount ordered.
# total amount will be used to calculate price impact
total_volume = self.volume_for_bar + cur_volume
volume_share = min(total_volume / event.volume,
volume_share = min(total_volume / volume,
self.volume_limit)
price = data.current(order.asset, "close")
simulated_impact = volume_share ** 2 \
* math.copysign(self.price_impact, order.direction) \
* event.price
impacted_price = event.price + simulated_impact
* price
impacted_price = price + simulated_impact
if order.limit:
# this is tricky! if an order with a limit price has reached
@@ -151,34 +153,13 @@ class VolumeShareSlippage(SlippageModel):
# is less than the limit price
if (order.direction > 0 and impacted_price > order.limit) or \
(order.direction < 0 and impacted_price < order.limit):
return
return None, None
return create_transaction(
event,
order,
return (
impacted_price,
math.copysign(cur_volume, order.direction)
)
def __getstate__(self):
state_dict = copy(self.__dict__)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("VolumeShareSlippage saved state is too old.")
self.__dict__.update(state)
class FixedSlippage(SlippageModel):
@@ -190,29 +171,10 @@ class FixedSlippage(SlippageModel):
"""
self.spread = spread
def process_order(self, event, order):
return create_transaction(
event,
order,
event.price + (self.spread / 2.0 * order.direction),
order.amount,
def process_order(self, data, order):
price = data.current(order.asset, "close")
return (
price + (self.spread / 2.0 * order.direction),
order.amount
)
def __getstate__(self):
state_dict = copy(self.__dict__)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("FixedSlippage saved state is too old.")
self.__dict__.update(state)
+26 -6
View File
@@ -1,5 +1,5 @@
#
# Copyright 2014 Quantopian, Inc.
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,7 @@ from zipline.assets.asset_writer import (
from zipline.errors import (
NoFurtherDataError
)
from zipline.utils.memoize import remember_last, lazyval
log = logbook.Logger('Trading')
@@ -123,6 +123,11 @@ class TradingEnvironment(object):
else:
self.asset_finder = None
@lazyval
def market_minutes(self):
return self.minutes_for_days_in_range(self.first_trading_day,
self.last_trading_day)
def write_data(self,
engine=None,
equities_data=None,
@@ -298,8 +303,11 @@ class TradingEnvironment(object):
return self.trading_days[idx]
def days_in_range(self, start, end):
mask = ((self.trading_days >= start) &
(self.trading_days <= end))
start_date = self.normalize_date(start)
end_date = self.normalize_date(end)
mask = ((self.trading_days >= start_date) &
(self.trading_days <= end_date))
return self.trading_days[mask]
def opens_in_range(self, start, end):
@@ -315,9 +323,20 @@ class TradingEnvironment(object):
start_date = self.normalize_date(start)
end_date = self.normalize_date(end)
o_and_c = self.open_and_closes[
self.open_and_closes.index.slice_indexer(start_date, end_date)]
opens = o_and_c.market_open
closes = o_and_c.market_close
one_min = pd.Timedelta(1, unit='m')
all_minutes = []
for day in self.days_in_range(start_date, end_date):
day_minutes = self.market_minutes_for_day(day)
for i in range(0, len(o_and_c.index)):
market_open = opens[i]
market_close = closes[i]
day_minutes = np.arange(market_open, market_close + one_min,
dtype='datetime64[m]')
all_minutes.append(day_minutes)
# Concatenate all minutes and truncate minutes before start/after end.
@@ -372,6 +391,7 @@ class TradingEnvironment(object):
# then return the open of the *next* trading day.
return self.next_open_and_close(start)[0]
@remember_last
def previous_market_minute(self, start):
"""
Get the next market minute before @start. This is either the immediate

Some files were not shown because too many files have changed in this diff Show More