mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 10:50:00 +08:00
Changes the location of downloaded external data files.
So that the zipline library can be used when installed to a write-protected location, e.g. the global site-packages, moving the download files to a directory in the user's path, which should be writeable. For now, choosing a ~/.zipline/data location.
This commit is contained in:
+25
-15
@@ -15,17 +15,33 @@
|
||||
|
||||
|
||||
import os
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
print sys.path
|
||||
from os.path import expanduser
|
||||
|
||||
import msgpack
|
||||
|
||||
from treasuries import get_treasury_data
|
||||
from benchmarks import get_benchmark_returns
|
||||
|
||||
# TODO: Make this path customizable.
|
||||
DATA_PATH = os.path.join(
|
||||
expanduser("~"),
|
||||
'.zipline',
|
||||
'data'
|
||||
)
|
||||
|
||||
|
||||
def get_datafile(name, mode='r'):
|
||||
"""
|
||||
Returns a handle to data file.
|
||||
|
||||
Creates containing directory, if needed.
|
||||
"""
|
||||
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.makedirs(DATA_PATH)
|
||||
|
||||
return open(os.path.join(DATA_PATH, name), mode)
|
||||
|
||||
|
||||
def dump_treasury_curves():
|
||||
"""
|
||||
@@ -36,7 +52,6 @@ def dump_treasury_curves():
|
||||
tr_data = []
|
||||
|
||||
for curve in get_treasury_data():
|
||||
print curve
|
||||
date_as_tuple = curve['date'].timetuple()[0:6] + \
|
||||
(curve['date'].microsecond,)
|
||||
# Not ideal but massaging data into expected format
|
||||
@@ -44,10 +59,8 @@ def dump_treasury_curves():
|
||||
tr = (date_as_tuple, curve)
|
||||
tr_data.append(tr)
|
||||
|
||||
tr_path = os.path.join(os.path.dirname(__file__),
|
||||
"treasury_curves.msgpack")
|
||||
tr_fp = open(tr_path, "wb")
|
||||
tr_fp.write(msgpack.dumps(tr_data))
|
||||
with get_datafile('treasury_curves.msgpack', mode='wb') as tr_fp:
|
||||
tr_fp.write(msgpack.dumps(tr_data))
|
||||
|
||||
|
||||
def dump_benchmarks():
|
||||
@@ -56,16 +69,13 @@ def dump_benchmarks():
|
||||
|
||||
Puts source treasury and data into zipline.
|
||||
"""
|
||||
benchmark_path = os.path.join(os.path.dirname(__file__),
|
||||
"benchmark.msgpack")
|
||||
benchmark_fp = open(benchmark_path, "wb")
|
||||
benchmark_data = []
|
||||
for daily_return in get_benchmark_returns():
|
||||
print daily_return
|
||||
date_as_tuple = daily_return.date.timetuple()[0:6] + \
|
||||
(daily_return.date.microsecond,)
|
||||
# Not ideal but massaging data into expected format
|
||||
benchmark = (date_as_tuple, daily_return.returns)
|
||||
benchmark_data.append(benchmark)
|
||||
|
||||
benchmark_fp.write(msgpack.dumps(benchmark_data))
|
||||
with get_datafile('benchmark.msgpack', mode='wb') as bmark_fp:
|
||||
bmark_fp.write(msgpack.dumps(benchmark_data))
|
||||
|
||||
+15
-18
@@ -20,7 +20,6 @@ Factory functions to prepare useful data for tests.
|
||||
import pytz
|
||||
import msgpack
|
||||
import random
|
||||
from os.path import join, abspath, dirname
|
||||
from operator import attrgetter
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -35,27 +34,23 @@ from zipline.utils.protocol_utils import ndict
|
||||
from zipline.sources import SpecificEquityTrades, DataFrameSource
|
||||
from zipline.gens.utils import create_trade
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
|
||||
from zipline import data
|
||||
|
||||
|
||||
# TODO
|
||||
def data_path():
|
||||
data_path = dirname(abspath(data.__file__))
|
||||
return data_path
|
||||
from zipline.data.loader import (
|
||||
get_datafile,
|
||||
dump_benchmarks,
|
||||
dump_treasury_curves
|
||||
)
|
||||
|
||||
|
||||
def load_market_data():
|
||||
benchmark_data_path = join(data_path(), "benchmark.msgpack")
|
||||
try:
|
||||
fp_bm = open(benchmark_data_path, "rb")
|
||||
fp_bm = get_datafile('benchmark.msgpack', "rb")
|
||||
except IOError:
|
||||
print """
|
||||
data msgpacks aren't distribute with source.
|
||||
Fetching data from Yahoo Finance.
|
||||
""".strip()
|
||||
data.loader.dump_benchmarks()
|
||||
fp_bm = open(benchmark_data_path, "rb")
|
||||
dump_benchmarks()
|
||||
fp_bm = get_datafile('benchmark.msgpack', "rb")
|
||||
|
||||
bm_list = msgpack.loads(fp_bm.read())
|
||||
bm_returns = []
|
||||
@@ -71,20 +66,20 @@ Fetching data from Yahoo Finance.
|
||||
daily_return = risk.DailyReturn(date=event_dt, returns=returns)
|
||||
bm_returns.append(daily_return)
|
||||
|
||||
fp_bm.close()
|
||||
|
||||
bm_returns = sorted(bm_returns, key=attrgetter('date'))
|
||||
|
||||
treasury_data_path = join(data_path(), "treasury_curves.msgpack")
|
||||
try:
|
||||
fp_bm = open(treasury_data_path, "rb")
|
||||
fp_tr = get_datafile('treasury_curves.msgpack', "rb")
|
||||
except IOError:
|
||||
print """
|
||||
data msgpacks aren't distribute with source.
|
||||
Fetching data from data.treasury.gov
|
||||
""".strip()
|
||||
data.loader.dump_treasury_curves()
|
||||
fp_bm = open(treasury_data_path, "rb")
|
||||
dump_treasury_curves()
|
||||
fp_tr = get_datafile('treasury_curves.msgpack', "rb")
|
||||
|
||||
fp_tr = open(join(data_path(), "treasury_curves.msgpack"), "rb")
|
||||
tr_list = msgpack.loads(fp_tr.read())
|
||||
tr_curves = {}
|
||||
for packed_date, curve in tr_list:
|
||||
@@ -92,6 +87,8 @@ Fetching data from data.treasury.gov
|
||||
#tr_dt = tr_dt.replace(hour=0, minute=0, second=0, tzinfo=pytz.utc)
|
||||
tr_curves[tr_dt] = curve
|
||||
|
||||
fp_tr.close()
|
||||
|
||||
return bm_returns, tr_curves
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user