diff --git a/zipline/data/loader.py b/zipline/data/loader.py index 83597521..f957b247 100644 --- a/zipline/data/loader.py +++ b/zipline/data/loader.py @@ -15,17 +15,33 @@ import os - -if __name__ == "__main__": - import sys - sys.path.append(os.path.abspath('.')) - print sys.path +from os.path import expanduser import msgpack from treasuries import get_treasury_data from benchmarks import get_benchmark_returns +# TODO: Make this path customizable. +DATA_PATH = os.path.join( + expanduser("~"), + '.zipline', + 'data' +) + + +def get_datafile(name, mode='r'): + """ + Returns a handle to data file. + + Creates containing directory, if needed. + """ + + if not os.path.exists(DATA_PATH): + os.makedirs(DATA_PATH) + + return open(os.path.join(DATA_PATH, name), mode) + def dump_treasury_curves(): """ @@ -36,7 +52,6 @@ def dump_treasury_curves(): tr_data = [] for curve in get_treasury_data(): - print curve date_as_tuple = curve['date'].timetuple()[0:6] + \ (curve['date'].microsecond,) # Not ideal but massaging data into expected format @@ -44,10 +59,8 @@ def dump_treasury_curves(): tr = (date_as_tuple, curve) tr_data.append(tr) - tr_path = os.path.join(os.path.dirname(__file__), - "treasury_curves.msgpack") - tr_fp = open(tr_path, "wb") - tr_fp.write(msgpack.dumps(tr_data)) + with get_datafile('treasury_curves.msgpack', mode='wb') as tr_fp: + tr_fp.write(msgpack.dumps(tr_data)) def dump_benchmarks(): @@ -56,16 +69,13 @@ def dump_benchmarks(): Puts source treasury and data into zipline. """ - benchmark_path = os.path.join(os.path.dirname(__file__), - "benchmark.msgpack") - benchmark_fp = open(benchmark_path, "wb") benchmark_data = [] for daily_return in get_benchmark_returns(): - print daily_return date_as_tuple = daily_return.date.timetuple()[0:6] + \ (daily_return.date.microsecond,) # Not ideal but massaging data into expected format benchmark = (date_as_tuple, daily_return.returns) benchmark_data.append(benchmark) - benchmark_fp.write(msgpack.dumps(benchmark_data)) + with get_datafile('benchmark.msgpack', mode='wb') as bmark_fp: + bmark_fp.write(msgpack.dumps(benchmark_data)) diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index ebaa4898..ac2e94b2 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -20,7 +20,6 @@ Factory functions to prepare useful data for tests. import pytz import msgpack import random -from os.path import join, abspath, dirname from operator import attrgetter from collections import OrderedDict @@ -35,27 +34,23 @@ from zipline.utils.protocol_utils import ndict from zipline.sources import SpecificEquityTrades, DataFrameSource from zipline.gens.utils import create_trade from zipline.finance.trading import TradingEnvironment - -from zipline import data - - -# TODO -def data_path(): - data_path = dirname(abspath(data.__file__)) - return data_path +from zipline.data.loader import ( + get_datafile, + dump_benchmarks, + dump_treasury_curves +) def load_market_data(): - benchmark_data_path = join(data_path(), "benchmark.msgpack") try: - fp_bm = open(benchmark_data_path, "rb") + fp_bm = get_datafile('benchmark.msgpack', "rb") except IOError: print """ data msgpacks aren't distribute with source. Fetching data from Yahoo Finance. """.strip() - data.loader.dump_benchmarks() - fp_bm = open(benchmark_data_path, "rb") + dump_benchmarks() + fp_bm = get_datafile('benchmark.msgpack', "rb") bm_list = msgpack.loads(fp_bm.read()) bm_returns = [] @@ -71,20 +66,20 @@ Fetching data from Yahoo Finance. daily_return = risk.DailyReturn(date=event_dt, returns=returns) bm_returns.append(daily_return) + fp_bm.close() + bm_returns = sorted(bm_returns, key=attrgetter('date')) - treasury_data_path = join(data_path(), "treasury_curves.msgpack") try: - fp_bm = open(treasury_data_path, "rb") + fp_tr = get_datafile('treasury_curves.msgpack', "rb") except IOError: print """ data msgpacks aren't distribute with source. Fetching data from data.treasury.gov """.strip() - data.loader.dump_treasury_curves() - fp_bm = open(treasury_data_path, "rb") + dump_treasury_curves() + fp_tr = get_datafile('treasury_curves.msgpack', "rb") - fp_tr = open(join(data_path(), "treasury_curves.msgpack"), "rb") tr_list = msgpack.loads(fp_tr.read()) tr_curves = {} for packed_date, curve in tr_list: @@ -92,6 +87,8 @@ Fetching data from data.treasury.gov #tr_dt = tr_dt.replace(hour=0, minute=0, second=0, tzinfo=pytz.utc) tr_curves[tr_dt] = curve + fp_tr.close() + return bm_returns, tr_curves