ENH: Loading of benchmark data now path independent.

BUG: Correctly close file descriptor after use.
ENH: Use attrgetter instead of lambda function to sort keys.
ENH: Added package_data so that benchmark datasets will get installed.
This commit is contained in:
Thomas Wiecki
2012-04-10 10:18:18 -04:00
parent 8ce5159e91
commit 013e23383c
2 changed files with 55 additions and 48 deletions
+3
View File
@@ -39,6 +39,7 @@ version='dev'
install_requires = parse_requirements('./etc/requirements.txt') + parse_requirements('./etc/requirements_sci.txt')
tests_require = install_requires + parse_requirements('./etc/requirements_dev.txt')
options(
sphinx=Bunch(
builddir="_build",
@@ -48,6 +49,8 @@ options(
version = version,
classifiers = [],
packages = find_packages(),
package_data = find_package_data("zipline", package="zipline",
only_in_packages=False),
install_requires = install_requires,
tests_require = tests_require,
test_suite = 'nose.collector',
+52 -48
View File
@@ -4,40 +4,45 @@ Factory functions to prepare useful data for tests.
import pytz
import msgpack
import random
from os.path import join
from operator import attrgetter
from datetime import datetime, timedelta
import zipline.util as qutil
import zipline
import zipline.finance.risk as risk
import zipline.protocol as zp
from zipline.sources import SpecificEquityTrades, RandomEquityTrades
from zipline.finance.trading import TradingEnvironment
def load_market_data():
fp_bm = open("./zipline/test/benchmark.msgpack", "rb")
bm_list = msgpack.loads(fp_bm.read())
data_path = join(zipline.__path__[0], "test")
with open(join(data_path, "benchmark.msgpack"), "rb") as fp_bm:
bm_list = msgpack.loads(fp_bm.read())
bm_returns = []
for packed_date, returns in bm_list:
event_dt = zp.tuple_to_date(packed_date)
#event_dt = event_dt.replace(
# hour=0,
# minute=0,
# second=0,
# hour=0,
# minute=0,
# second=0,
# tzinfo=pytz.utc
#)
daily_return = risk.DailyReturn(date=event_dt, returns=returns)
bm_returns.append(daily_return)
bm_returns = sorted(bm_returns, key=lambda(x): x.date)
fp_tr = open("./zipline/test/treasury_curves.msgpack", "rb")
tr_list = msgpack.loads(fp_tr.read())
bm_returns = sorted(bm_returns, key=attrgetter('date'))
with open(join(data_path, "treasury_curves.msgpack"), "rb") as fp_tr:
tr_list = msgpack.loads(fp_tr.read())
tr_curves = {}
for packed_date, curve in tr_list:
tr_dt = zp.tuple_to_date(packed_date)
#tr_dt = tr_dt.replace(hour=0, minute=0, second=0, tzinfo=pytz.utc)
tr_curves[tr_dt] = curve
return bm_returns, tr_curves
def create_trading_environment():
"""Construct a complete environment with reasonable defaults"""
benchmark_returns, treasury_curves = load_market_data()
@@ -51,7 +56,7 @@ def create_trading_environment():
period_end = end,
capital_base = 100000.0
)
return trading_environment
def create_trade(sid, price, amount, datetime):
row = zp.namedict({
@@ -70,7 +75,7 @@ def get_next_trading_dt(current, interval, trading_calendar):
next = next + interval
if trading_calendar.is_trading_day(next):
break
return next
def create_trade_history(sid, prices, amounts, interval, trading_calendar):
@@ -78,7 +83,7 @@ def create_trade_history(sid, prices, amounts, interval, trading_calendar):
current = trading_calendar.first_open
for price, amount in zip(prices, amounts):
current = get_next_trading_dt(current, interval, trading_calendar)
trade = create_trade(sid, price, amount, current)
trades.append(trade)
@@ -89,9 +94,9 @@ def create_trade_history(sid, prices, amounts, interval, trading_calendar):
def create_txn(sid, price, amount, datetime, btrid=None):
txn = zp.namedict({
'sid':sid,
'amount':amount,
'amount':amount,
'dt':datetime,
'price':price,
'price':price,
})
return txn
@@ -115,15 +120,15 @@ def create_returns(daycount, trading_calendar):
test_range = []
current = trading_calendar.first_open
one_day = timedelta(days = 1)
for day in range(daycount):
for day in range(daycount):
current = current + one_day
if trading_calendar.is_trading_day(current):
r = risk.DailyReturn(current, random.random())
test_range.append(r)
return test_range
def create_returns_from_range(trading_calendar):
current = trading_calendar.first_open
@@ -134,47 +139,47 @@ def create_returns_from_range(trading_calendar):
r = risk.DailyReturn(current, random.random())
test_range.append(r)
current = get_next_trading_dt(current, one_day, trading_calendar)
return test_range
def create_returns_from_list(returns, trading_calendar):
current = trading_calendar.first_open
one_day = timedelta(days = 1)
test_range = []
#sometimes the range starts with a non-trading day.
if not trading_calendar.is_trading_day(current):
current = get_next_trading_dt(current, one_day, trading_calendar)
for return_val in returns:
for return_val in returns:
r = risk.DailyReturn(current, return_val)
test_range.append(r)
current = get_next_trading_dt(current, one_day, trading_calendar)
return test_range
def create_random_trade_source(sid, trade_count, trading_environment):
# create the source
source = RandomEquityTrades(sid, "rand-"+str(sid), trade_count)
# make the period_end of trading_environment match
cur = trading_environment.first_open
one_day = timedelta(days = 1)
for i in range(trade_count + 2):
cur = get_next_trading_dt(cur, one_day, trading_environment)
trading_environment.period_end = cur
return source
def create_daily_trade_source(sids, trade_count, trading_environment):
"""
creates trade_count trades for each sid in sids list.
first trade will be on trading_environment.period_start, and daily
thereafter for each sid. Thus, two sids should result in two trades per
day.
creates trade_count trades for each sid in sids list.
first trade will be on trading_environment.period_start, and daily
thereafter for each sid. Thus, two sids should result in two trades per
day.
Important side-effect: trading_environment.period_end will be modified
to match the day of the final trade.
to match the day of the final trade.
"""
trade_history = []
for sid in sids:
@@ -183,22 +188,21 @@ def create_daily_trade_source(sids, trade_count, trading_environment):
start_date = trading_environment.first_open
trade_time_increment = timedelta(days=1)
generated_trades = create_trade_history(
sid,
price,
volume,
trade_time_increment,
trading_environment
generated_trades = create_trade_history(
sid,
price,
volume,
trade_time_increment,
trading_environment
)
trade_history.extend(generated_trades)
trade_history = sorted(trade_history, key=lambda(x): x.dt)
trade_history = sorted(trade_history, key=attrgetter('dt'))
#set the trading environment's end to same dt as the last trade in the
#history.
trading_environment.period_end = trade_history[-1].dt
source = SpecificEquityTrades("flat", trade_history)
return source