mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 17:33:01 +08:00
intersticial commit to show realdiehl the dummy module.
This commit is contained in:
@@ -40,3 +40,6 @@ nosetests.xml
|
||||
|
||||
# Built documentation
|
||||
docs/_build/*
|
||||
|
||||
# credentials and other uncheckinables
|
||||
host_settings.py
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import datetime
|
||||
import sys
|
||||
import zipline.util as qutil
|
||||
from zipline.finance.data import DataLoader
|
||||
|
||||
def print_usage():
|
||||
print """
|
||||
Usage is:
|
||||
python loaddata.py (pt | lt | lh | ld | ei | bm | si | help)
|
||||
|
||||
pt - purge trade collection from the db
|
||||
lt - load trades (minute bars) to the db
|
||||
lh - load trades (hour bars) to the db
|
||||
ld - load trades (daily close) to the db
|
||||
ei - ensure all indexes on all collections in tick and algo db
|
||||
tr - load treasury rates
|
||||
bm - load benchmark data
|
||||
si - load security info (sid, symbol, qualifier)
|
||||
help - display this message
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) == 2:
|
||||
qutil.configure_logging()
|
||||
operation = sys.argv[1]
|
||||
if(operation not in['pt','lt','lh','ld','ei','si', 'tr','bm'] or operation == 'help'):
|
||||
print_usage()
|
||||
else:
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
pidfile = "/tmp/loaddata-{stamp}.pid".format(stamp=ts)
|
||||
daemon = DataLoader(pidfile,operation)
|
||||
qutil.LOGGER.info("DataLoader starting.")
|
||||
daemon.run()
|
||||
sys.exit(0)
|
||||
else:
|
||||
print_usage()
|
||||
sys.exit(2)
|
||||
@@ -3,3 +3,4 @@ pyzmq==2.1.11
|
||||
gevent-zeromq==0.2.2
|
||||
msgpack-python==0.1.12
|
||||
humanhash==0.0.1
|
||||
pymongo==2.1.1
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Daemon class, based on the excellent article:
|
||||
http://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/
|
||||
"""
|
||||
|
||||
import sys, os, time, atexit
|
||||
from signal import SIGTERM, SIGINT
|
||||
|
||||
class Daemon:
|
||||
"""
|
||||
A generic daemon class.
|
||||
|
||||
Usage: subclass the Daemon class and override the run() method
|
||||
"""
|
||||
def __init__(self, pidfile, stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'):
|
||||
self.stdin = stdin
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.pidfile = pidfile
|
||||
|
||||
def daemonize(self):
|
||||
"""
|
||||
do the UNIX double-fork magic, see Stevens' "Advanced
|
||||
Programming in the UNIX Environment" for details (ISBN 0201563177)
|
||||
http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16
|
||||
"""
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
# exit first parent
|
||||
sys.exit(0)
|
||||
except OSError, e:
|
||||
sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
|
||||
sys.exit(1)
|
||||
|
||||
# decouple from parent environment
|
||||
os.chdir("/")
|
||||
os.setsid()
|
||||
os.umask(0)
|
||||
|
||||
# do second fork
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
# exit from second parent
|
||||
sys.exit(0)
|
||||
except OSError, e:
|
||||
sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror))
|
||||
sys.exit(1)
|
||||
|
||||
# redirect standard file descriptors
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
si = file(self.stdin, 'r')
|
||||
so = file(self.stdout, 'a+')
|
||||
se = file(self.stderr, 'a+', 0)
|
||||
os.dup2(si.fileno(), sys.stdin.fileno())
|
||||
os.dup2(so.fileno(), sys.stdout.fileno())
|
||||
os.dup2(se.fileno(), sys.stderr.fileno())
|
||||
|
||||
# write pidfile
|
||||
atexit.register(self.delpid)
|
||||
pid = str(os.getpid())
|
||||
file(self.pidfile,'w+').write("%s\n" % pid)
|
||||
|
||||
def delpid(self):
|
||||
os.remove(self.pidfile)
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the daemon
|
||||
"""
|
||||
# Check for a pidfile to see if the daemon already runs
|
||||
try:
|
||||
pf = file(self.pidfile,'r')
|
||||
pid = int(pf.read().strip())
|
||||
pf.close()
|
||||
except IOError:
|
||||
pid = None
|
||||
|
||||
if pid:
|
||||
message = "pidfile %s already exist. Daemon already running?\n"
|
||||
sys.stderr.write(message % self.pidfile)
|
||||
sys.exit(1)
|
||||
|
||||
# Start the daemon
|
||||
self.daemonize()
|
||||
try:
|
||||
signal.signal(signal.SIGINT, self.handle_kill)
|
||||
except Exception, err:
|
||||
print "Problem with sigint signup " + str(err)
|
||||
self.run()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the daemon
|
||||
"""
|
||||
# Get the pid from the pidfile
|
||||
try:
|
||||
pf = file(self.pidfile,'r')
|
||||
pid = int(pf.read().strip())
|
||||
pf.close()
|
||||
except IOError:
|
||||
pid = None
|
||||
|
||||
if not pid:
|
||||
message = "pidfile %s does not exist. Daemon not running?\n"
|
||||
sys.stderr.write(message % self.pidfile)
|
||||
return # not an error in a restart
|
||||
|
||||
# First signal the process that we need to interrupt, so it can do things like close child procs
|
||||
try:
|
||||
os.kill(pid, SIGINT)
|
||||
time.sleep(2.0) #Give the process some time to kill...
|
||||
except OSError, err:
|
||||
print "Error trying to sigint the process" + str(err)
|
||||
|
||||
# Try killing the daemon process
|
||||
try:
|
||||
while 1:
|
||||
os.kill(pid, SIGTERM)
|
||||
time.sleep(0.1)
|
||||
except OSError, err:
|
||||
err = str(err)
|
||||
if err.find("No such process") > 0:
|
||||
if os.path.exists(self.pidfile):
|
||||
os.remove(self.pidfile)
|
||||
else:
|
||||
print str(err)
|
||||
sys.exit(1)
|
||||
|
||||
def restart(self):
|
||||
"""
|
||||
Restart the daemon
|
||||
"""
|
||||
self.stop()
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
You should override this method when you subclass Daemon. It will be called after the process has been
|
||||
daemonized by start() or restart().
|
||||
"""
|
||||
@@ -0,0 +1,76 @@
|
||||
import atexit
|
||||
import pymongo
|
||||
import zipline.util as qutil
|
||||
|
||||
class MongoOptions(object):
|
||||
|
||||
def __init__(self, host, port, dbname, user, password):
|
||||
self.mongodb_host = host
|
||||
self.mongodb_port = port
|
||||
self.mongodb_dbname = dbname
|
||||
self.mongodb_user = user
|
||||
self.mongodb_password = password
|
||||
|
||||
class NoDatabase(Exception):
|
||||
def __repr__(self):
|
||||
return 'The database has not been set up yet.'
|
||||
|
||||
def setup_db(credentials):
|
||||
"""
|
||||
Setup the database. Has global side effects.
|
||||
"""
|
||||
qutil.LOGGER.info(dir(DbConnection))
|
||||
if not DbConnection.initd:
|
||||
connector = connect_db(credentials)
|
||||
DbConnection.set(*connector)
|
||||
|
||||
def connect_db(options):
|
||||
"""
|
||||
Connect to pymongo, return a connection and database instance
|
||||
as a tuple.
|
||||
"""
|
||||
|
||||
connection = pymongo.Connection(options.mongodb_host, options.mongodb_port)
|
||||
|
||||
db = connection[options.mongodb_dbname]
|
||||
db.authenticate(options.mongodb_user, options.mongodb_password)
|
||||
|
||||
def _gc_connection(): # pragma: no cover
|
||||
connection.close()
|
||||
|
||||
atexit.register(_gc_connection)
|
||||
return connection, db
|
||||
|
||||
class DbConnection(object):
|
||||
"""
|
||||
Hold the shared state of the database connection.
|
||||
"""
|
||||
|
||||
initd = False
|
||||
__shared = {}
|
||||
|
||||
def __init__(self):
|
||||
self.__dict__ = self.__shared
|
||||
|
||||
@staticmethod
|
||||
def set(conn, db):
|
||||
DbConnection.__shared['conn'] = conn
|
||||
DbConnection.__shared['db'] = db
|
||||
DbConnection.initd = True
|
||||
|
||||
@staticmethod
|
||||
def get():
|
||||
return (
|
||||
DbConnection.__shared['conn'],
|
||||
DbConnection.__shared['db']
|
||||
)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if not DbConnection.__shared.get('initd'):
|
||||
raise NoDatabase()
|
||||
else:
|
||||
return DbConnection.__shared.get(key)
|
||||
|
||||
def destory(self): # pragma: no cover
|
||||
DbConnection.__shared['initd'] = False
|
||||
self.conn.close()
|
||||
@@ -0,0 +1,497 @@
|
||||
import sys
|
||||
import logging
|
||||
import datetime
|
||||
import sys
|
||||
import os
|
||||
import pymongo
|
||||
import csv
|
||||
import re
|
||||
import copy
|
||||
import datetime
|
||||
import time
|
||||
import pytz
|
||||
import shutil
|
||||
import urllib
|
||||
import subprocess
|
||||
from pymongo import ASCENDING, DESCENDING
|
||||
from zipline.daemon import Daemon
|
||||
import zipline.util as qutil
|
||||
import zipline.db as db
|
||||
import zipline.host_settings
|
||||
|
||||
class FinancialDataLoader():
|
||||
"""
|
||||
Load trade and quote data from tickdata extracts into the db.
|
||||
Dates and times in the extracts must be in GMT.
|
||||
|
||||
All data extract files are expected to be in $HOME/fdl/. The expected directory layout is::
|
||||
/benchmark.csv -- this will be created from yahoo data each time load_bench_marks is run
|
||||
/interest_rates.csv --
|
||||
"""
|
||||
BATCH_SIZE = 100
|
||||
|
||||
def __init__(self):
|
||||
self.conn, self.db = db.DbConnection.get()
|
||||
self.data_file_path = os.environ['HOME'] + "/fdl/"
|
||||
subprocess.call("mkdir {data_dir}".format(data_dir=self.data_file_path), shell=True)
|
||||
self.last_bm_close = None
|
||||
|
||||
def load_bench_marks(self):
|
||||
"""Fetches the S&P end of day pricing history from yahoo, loads it to db.bench_marks"""
|
||||
start = time.time()
|
||||
start_date = datetime.datetime(year=1950, month=1, day=3)
|
||||
end_date = datetime.datetime.utcnow()
|
||||
file_path = self.data_file_path + "benchmark.csv"
|
||||
fp = open(file_path + ".tmp", "wb")
|
||||
|
||||
#create benchmark files
|
||||
#^GSPC 19500103
|
||||
query = {}
|
||||
query['s'] = "^GSPC" #the s&p 500
|
||||
query['d'] = end_date.month - 1 # end_date month, zero indexed
|
||||
query['e'] = end_date.day # end_date day str(int(todate[6:8])) #day
|
||||
query['f'] = end_date.year #end_date year str(int(todate[0:4]))
|
||||
query['g'] = "d" #daily frequency
|
||||
query['a'] = start_date.month - 1 #start_date month, zero indexed
|
||||
query['b'] = start_date.day #start_date day
|
||||
query['c'] = start_date.year #start_date year
|
||||
|
||||
#print query
|
||||
params = urllib.urlencode(query)
|
||||
params += "&ignore=.csv"
|
||||
|
||||
url = "http://ichart.yahoo.com/table.csv?%s" % params
|
||||
qutil.LOGGER.info("fetching {url}".format(url=url))
|
||||
f = urllib.urlopen(url)
|
||||
fp.write(f.read())
|
||||
fp.close()
|
||||
qutil.LOGGER.info("fetched {url} Reversing.".format(url=url))
|
||||
|
||||
tmp_file = file_path + ".tmp"
|
||||
reversed_tmp_file = file_path + ".rev"
|
||||
|
||||
rcode = subprocess.call("tac {oldfile} > {newfile}".format(oldfile=tmp_file, newfile=reversed_tmp_file), shell=True)
|
||||
#on mac, there is no tac command, so use tail -r (which isn't available on debian)
|
||||
if rcode != 0:
|
||||
rcode = subprocess.call("tail -r {oldfile} > {newfile}".format(oldfile=tmp_file, newfile=reversed_tmp_file), shell=True)
|
||||
|
||||
#tail -1 benchmark.csv.rev > benchmark.csv
|
||||
subprocess.call("echo \"date,open,high,low,close,volume,adj_close\" > {result}".format(newfile=reversed_tmp_file, result=self.data_file_path), shell=True)
|
||||
#sed '$d' < ~/fdl/benchmark.csv.rev >> ~/fdl/benchmark.csv
|
||||
subprocess.call("sed '$d' < {newfile} >> {result}".format(newfile=reversed_tmp_file, result=self.data_file_path), shell=True)
|
||||
#clean up working files
|
||||
subprocess.call("rm {tmp} {reversed}".format(tmp=tmp_file, reversed=reversed_tmp_file), shell=True)
|
||||
|
||||
#load the records into mongodb
|
||||
self.db.bench_marks.drop()
|
||||
qutil.LOGGER.info("processing benchmark info")
|
||||
self.parse_file(self.db.bench_marks,
|
||||
self.bench_mark_cb,
|
||||
file_path,
|
||||
['date','open','high','low','close','volume','adj_close'],
|
||||
None,
|
||||
0)
|
||||
qutil.LOGGER.info("benchmark info complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to load benchmark history" % total)
|
||||
|
||||
def load_treasuries(self):
|
||||
"""fetches data from the treasury.gov yield curve website, and populates the treasury_curves table.
|
||||
|
||||
to explore data available from the treasury:
|
||||
http://www.treasury.gov/resource-center/data-chart-center/interest-rates/Pages/TextView.aspx?data=yield
|
||||
|
||||
to fetch xml of all daily yield curves:
|
||||
http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData
|
||||
"""
|
||||
|
||||
from xml.dom.minidom import parse
|
||||
self.db.treasury_curves.drop()
|
||||
path = os.path.join(self.data_file_path + "all_treasury_rates.xml")
|
||||
#download all data to local filesystem
|
||||
subprocess.call("curl http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData > {path}".format(path=path), shell=True)
|
||||
dom = parse(path)
|
||||
|
||||
|
||||
entries = dom.getElementsByTagName("entry")
|
||||
for entry in entries:
|
||||
curve = {}
|
||||
curve['tid'] = self.get_node_value(entry, "d:Id")
|
||||
|
||||
curve['date'] = self.get_treasury_date(self.get_node_value(entry, "d:NEW_DATE"))
|
||||
curve['1month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_1MONTH"))
|
||||
curve['3month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_3MONTH"))
|
||||
curve['6month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_6MONTH"))
|
||||
curve['1year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_1YEAR"))
|
||||
curve['2year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_2YEAR"))
|
||||
curve['3year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_3YEAR"))
|
||||
curve['5year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_5YEAR"))
|
||||
curve['7year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_7YEAR"))
|
||||
curve['10year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_10YEAR"))
|
||||
curve['20year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_20YEAR"))
|
||||
curve['30year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_30YEAR"))
|
||||
self.db.treasury_curves.insert(curve, True)
|
||||
|
||||
def get_treasury_date(self, dstring):
|
||||
return datetime.datetime.strptime(dstring.split("T")[0], '%Y-%m-%d')
|
||||
|
||||
def get_treasury_rate(self, string_val):
|
||||
val = self.guarded_conversion(float, string_val, None)
|
||||
if val != None:
|
||||
val = round(val / 100.0, 4)
|
||||
return val
|
||||
def get_node_value(self, entry_node, tag_name):
|
||||
return self.get_xml_text(entry_node.getElementsByTagName(tag_name)[0].childNodes)
|
||||
|
||||
def get_xml_text(self, nodelist):
|
||||
rc = []
|
||||
for node in nodelist:
|
||||
if node.nodeType == node.TEXT_NODE:
|
||||
rc.append(node.data)
|
||||
|
||||
return ''.join(rc)
|
||||
|
||||
def purge_quotes(self):
|
||||
self.db.equity.quotes.drop()
|
||||
|
||||
def purge_trades(self):
|
||||
self.db.equity.trades.drop()
|
||||
|
||||
def load_quotes(self):
|
||||
start = time.time()
|
||||
qutil.LOGGER.info("processing equity quotes")
|
||||
self.load_events(self.db.equity.quotes,
|
||||
self.quoteRowCallback,
|
||||
self.data_file_path + "2008/Quotes/DATA",
|
||||
['trade_date', 'trade_time','exchange_code','bid_price','ask_price', 'bid_size','ask_size'])
|
||||
qutil.LOGGER.info("quotes complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to update equity quotes" % total)
|
||||
|
||||
|
||||
def load_trades(self):
|
||||
start = time.time()
|
||||
qutil.LOGGER.info("processing equity minute bars")
|
||||
self.load_events(self.db.equity.trades.minute,
|
||||
self.trade_cb,
|
||||
os.path.join(self.data_file_path, "2008/Trades/MINUTE_DATA"),
|
||||
['trade_date','trade_time','price', 'volume'])
|
||||
qutil.LOGGER.info("minute trades complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
|
||||
|
||||
def load_hourly_trades(self):
|
||||
start = time.time()
|
||||
qutil.LOGGER.info("processing equity hour bars")
|
||||
self.load_events(self.db.equity.trades.hourly,
|
||||
self.trade_cb,
|
||||
os.path.join(self.data_file_path, "2008/Trades/HOURLY_DATA"),
|
||||
['trade_date','trade_time','price','volume'])
|
||||
qutil.LOGGER.info("hourly trades complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
|
||||
|
||||
|
||||
def load_daily_close(self):
|
||||
start = time.time()
|
||||
qutil.LOGGER.info("processing equity daily close")
|
||||
self.load_events(self.db.equity.trades.daily,
|
||||
self.trade_cb,
|
||||
os.path.join(self.data_file_path, "2008/Trades/DAILY_DATA"),
|
||||
['trade_date','price', 'volume'])
|
||||
qutil.LOGGER.info("daily close complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
|
||||
|
||||
def ensure_indexes(self):
|
||||
|
||||
#ensure indexes on minute trades
|
||||
qutil.LOGGER.info("ensuring (+datetime, +sid) index on trades.minute")
|
||||
self.db.equity.trades.minute.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+datetime, +sid) index on trades.minute ready")
|
||||
|
||||
#ensure indexes for hourly trades
|
||||
qutil.LOGGER.info("ensuring (sid, +datetime) index on trades.hourly")
|
||||
self.db.equity.trades.hourly.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(sid, +datetime) index on trades.hourly ready")
|
||||
|
||||
#ensure indexes for daily trades
|
||||
qutil.LOGGER.info("ensuring (+datetime,+sid) index on trades.daily")
|
||||
self.db.equity.trades.daily.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+datetime,+sid) index on trades.daily ready")
|
||||
|
||||
#ensure indexes for orders and transactions
|
||||
qutil.LOGGER.info("ensuring (+backtestid) index on orders")
|
||||
self.db.orders.ensure_index([("back_test_run_id",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+backtestid) index on orders ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (+backtestid, +datetime) index on orders")
|
||||
self.db.orders.ensure_index([("back_test_run_id",ASCENDING),("dt",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+backtestid, +datetime) index on orders ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (+backtestid) index on orders")
|
||||
self.db.transactions.ensure_index([("back_test_run_id",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+backtestid) index on orders ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (+backtestid) index on transactions")
|
||||
self.db.transactions.ensure_index([("back_test_run_id",ASCENDING),("dt",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info("(+backtestid) index on transactions ready")
|
||||
|
||||
#indexes for benchmarks and treasuries
|
||||
qutil.LOGGER.info("ensuring (+date) index on treasury_curves")
|
||||
self.db.treasury_curves.ensure_index([("date",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info(" (+date) index on treasury_curves ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (-date) index on treasury_curves")
|
||||
self.db.treasury_curves.ensure_index([("date",DESCENDING)],background=True)
|
||||
qutil.LOGGER.info(" (-date) index on treasury_curves ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (+date) index on bench_marks")
|
||||
self.db.bench_marks.ensure_index([("date",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info(" (+date) index on bench_marks ready")
|
||||
|
||||
qutil.LOGGER.info("ensuring (+symbol, +date) index on bench_marks")
|
||||
self.db.bench_marks.ensure_index([("symbol",ASCENDING),("date",ASCENDING)],background=True)
|
||||
qutil.LOGGER.info(" (+symbol, +date) index on bench_marks ready")
|
||||
|
||||
def load_security_info(self):
|
||||
start = time.time()
|
||||
qutil.LOGGER.info("processing company info")
|
||||
|
||||
sourceFile = os.path.join(self.data_file_path, "2008/Trades/MINUTE_DATA/CompanyInfo/CompanyInfo.asc")
|
||||
self.db.securities.drop()
|
||||
self.parse_file(self.db.securities,
|
||||
self.security_cb,
|
||||
sourceFile,
|
||||
['symbol','file name','company name','CUSIP','exchange','industry code','first date','last date','company id'],
|
||||
None,
|
||||
0)
|
||||
qutil.LOGGER.info("company info complete")
|
||||
total = time.time() - start
|
||||
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
|
||||
|
||||
|
||||
|
||||
def load_events(self, collection, rowCallBack, dataDirectory, csvFields):
|
||||
id_counter = 0
|
||||
listing = os.listdir(dataDirectory)
|
||||
processedDir = os.path.join(dataDirectory,"processed")
|
||||
if not os.path.exists(processedDir):
|
||||
os.mkdir(processedDir)
|
||||
for curFile in listing:
|
||||
if os.path.isdir(os.path.join(dataDirectory,curFile)):
|
||||
continue
|
||||
start = time.time()
|
||||
if id_counter == 0: #this is the first file we are processing, so we want to ensure we don't duplicate records
|
||||
minDateTime = self.get_latest_entry_for_sid(self.get_sid_from_filename(curFile),collection)
|
||||
else:
|
||||
minDateTime = None #this isn't the first file, so don't bother querying
|
||||
rowCount, totalCount = self.parse_file(collection, rowCallBack, os.path.join(dataDirectory,curFile), csvFields, minDateTime, id_counter)
|
||||
id_counter = id_counter + rowCount
|
||||
parseTime = time.time() - start
|
||||
qutil.LOGGER.info("{time} seconds to parse and load {rowCount} records of {totalCount} from {file}. {rate} records/second".
|
||||
format(time = parseTime, rowCount=rowCount, totalCount=totalCount, file=curFile, rate = rowCount/parseTime))
|
||||
#we successfully processed the file without an exception, move it to the processed folder
|
||||
#qutil.LOGGER.info("moving data file to {newpath}".format(newpath=os.path.join(processedDir,curFile)))
|
||||
shutil.move(os.path.join(dataDirectory,curFile),os.path.join(processedDir,curFile))
|
||||
|
||||
def parse_file(self, collection, rowCallBack, curFile, pFieldnames, minDateTime, id_counter):
|
||||
"""Parses the given file into the collection. Returns tuple of the rows committed, rows in csvfile"""
|
||||
|
||||
qutil.LOGGER.debug("processing {fn}".format(fn=curFile))
|
||||
cur_id = id_counter
|
||||
rowCount = 0
|
||||
csvRowCount = 0
|
||||
with open(curFile, 'rb') as f:
|
||||
reader = csv.DictReader(f,fieldnames=pFieldnames)
|
||||
header = False
|
||||
|
||||
if csv.Sniffer().has_header(f.read(1024)):
|
||||
header = True
|
||||
f.seek(0)
|
||||
|
||||
if header:
|
||||
reader.next()
|
||||
try:
|
||||
rows = []
|
||||
for row in reader:
|
||||
#row['_id'] = cur_id
|
||||
cur_id = cur_id + 1
|
||||
csvRowCount += 1
|
||||
utcDT, dt = self.get_event_datetime(row)
|
||||
#only add rows that are after the mindate for the current sid.
|
||||
if(minDateTime != None and dt <= minDateTime):
|
||||
continue
|
||||
if(dt != None):
|
||||
row['dt'] = dt
|
||||
if('company id' not in pFieldnames):
|
||||
company_id = self.get_sid_from_filename(curFile)
|
||||
if(company_id):
|
||||
row['sid'] = int(company_id)
|
||||
if not rowCallBack(curFile, row):
|
||||
continue
|
||||
rows.append(row)
|
||||
rowCount+=1
|
||||
if(len(rows) >= self.BATCH_SIZE):
|
||||
collection.insert(rows, safe=True)
|
||||
rows = []
|
||||
if(len(rows) > 0):
|
||||
collection.insert(rows, safe=True)
|
||||
rows = None
|
||||
except csv.Error, e:
|
||||
sys.exit('file %s, line %d: %s' % (curFile, reader.line_num, e))
|
||||
return rowCount, csvRowCount
|
||||
|
||||
def trade_cb(self, curFile, row):
|
||||
row['price'] = self.guarded_conversion(float,row['price'])
|
||||
row['volume'] = self.guarded_conversion(self.safe_int,row['volume'])
|
||||
return True
|
||||
|
||||
def bench_mark_cb(self, curFile, row):
|
||||
row['symbol'] = "GSPC"
|
||||
row['volume'] = self.guarded_conversion(int,row['volume'])
|
||||
row['open'] = self.guarded_conversion(float,row['open'])
|
||||
row['high'] = self.guarded_conversion(float,row['high'])
|
||||
row['low'] = self.guarded_conversion(float,row['low'])
|
||||
row['close'] = self.guarded_conversion(float,row['close'])
|
||||
row['adj_close'] = self.guarded_conversion(float,row['adj_close'])
|
||||
row['date'] = datetime.datetime.strptime(row['date'], '%Y-%m-%d')
|
||||
if self.last_bm_close == None:
|
||||
row['returns'] = (row['close'] - row['open'])/row['open']
|
||||
else:
|
||||
row['returns'] = (row['close'] - self.last_bm_close) / self.last_bm_close
|
||||
self.last_bm_close = row['close']
|
||||
return True
|
||||
|
||||
def security_cb(self, curFile, row):
|
||||
"""source columns: ['symbol','file name','company name','CUSIP','exchange','industry code','first date','last date','company id']"""
|
||||
row['sid'] = self.guarded_conversion(int,row['company id'])
|
||||
del(row['company id'])
|
||||
row['start_date'] = self.guarded_conversion(self.date_conversion, row['first date'])
|
||||
del(row['first date'])
|
||||
row['end_date'] = self.guarded_conversion(self.date_conversion, row['last date'])
|
||||
del(row['last date'])
|
||||
row['symbol'] = self.verify_symbol_in_filename(row['symbol'], row['file name'])
|
||||
del(row['file name'])
|
||||
row['company_name'] = row['company name']
|
||||
del(row['company name'])
|
||||
return True
|
||||
|
||||
def guarded_conversion(self, conversion, strVal, default = None):
|
||||
if(strVal == None or strVal == ""):
|
||||
return default
|
||||
return conversion(strVal)
|
||||
|
||||
def safe_int(self,str):
|
||||
"""casts the string to a float to handle the occassionaly decimal point in int fields from data providers."""
|
||||
f = float(str)
|
||||
i = int(f)
|
||||
return i
|
||||
|
||||
def date_conversion(self, dateStr):
|
||||
dt = datetime.datetime.strptime(dateStr, '%m/%d/%Y')
|
||||
dt = dt.replace (tzinfo = pytz.utc)
|
||||
return dt
|
||||
|
||||
def verify_symbol_in_filename(self, symbol, file_name):
|
||||
if(symbol == file_name):
|
||||
return symbol
|
||||
|
||||
parts = file_name.split('_')
|
||||
if(len(parts) == 2):
|
||||
return file_name
|
||||
else:
|
||||
raise Exception("found a mismatch between symbol and filename, but no underscore.")
|
||||
|
||||
def get_event_datetime(self, row):
|
||||
"""python 2.5 doesn't support %f for setting the microseconds, so this override is necessary.
|
||||
a significant side effect - the trade date and trade time elements are removed from this dictionary. done to
|
||||
avoid storing the source fields in the db.
|
||||
"""
|
||||
if row.has_key('trade_date') and row.has_key('trade_time'):
|
||||
value = row['trade_date'] + "-" + row['trade_time']
|
||||
dt = datetime.datetime.strptime(value.split(".")[0], '%m/%d/%Y-%H:%M:%S')
|
||||
dt = dt.replace(microsecond=int(value.split(".")[1]+"000"))
|
||||
del row['trade_date']
|
||||
del row['trade_time']
|
||||
elif row.has_key('trade_date'):
|
||||
dt = datetime.datetime.strptime(row['trade_date'],'%m/%d/%Y')
|
||||
del row['trade_date']
|
||||
else:
|
||||
return None, None
|
||||
|
||||
utcDT = quantoenv.getUTCFromExchangeTime(dt) #store everything in UTC
|
||||
return utcDT, dt
|
||||
|
||||
def get_sid_from_filename(self, filename):
|
||||
|
||||
regexp = r"(?P<company_id>[0-9]+)([.]csv)"
|
||||
result = re.search(regexp,filename)
|
||||
if(result):
|
||||
companyID = int(result.group('company_id'))
|
||||
return companyID
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_latest_entry_for_sid(self, sid, collection):
|
||||
"""checks given collection for the most recent record for the given sid."""
|
||||
results = collection.find(fields=["dt"],
|
||||
spec={"sid":sid},
|
||||
sort=[("dt",DESCENDING)],
|
||||
limit=1,
|
||||
as_class=quantoenv.DocWrap)
|
||||
|
||||
if(results.count() > 0):
|
||||
return results[0].dt
|
||||
else:
|
||||
return datetime.datetime.min
|
||||
|
||||
|
||||
|
||||
class DataLoader(Daemon):
|
||||
"""A daemon process that manages the data in the finance database."""
|
||||
|
||||
def __init__(self, pidfile, operation):
|
||||
self.operation = operation
|
||||
self.pidfile = pidfile
|
||||
self.stdin = '/dev/null'
|
||||
self.stdout = '/dev/null'
|
||||
self.stderr = '/dev/null'
|
||||
|
||||
def run(self):
|
||||
qutil.LOGGER.info("running operation: {op}".format(op=self.operation))
|
||||
try:
|
||||
fdl = FinancialDataLoader()
|
||||
if(self.operation == 'pt'):
|
||||
qutil.LOGGER.info("Purging trades from database!")
|
||||
fdl.purge_trades()
|
||||
elif(self.operation == 'ei'):
|
||||
qutil.LOGGER.info("Ensuring indexes.")
|
||||
fdl.ensure_indexes()
|
||||
elif(self.operation == 'lt'):
|
||||
qutil.LOGGER.info("Loading trades into database.")
|
||||
fdl.loadTrades()
|
||||
elif(self.operation == 'lh'):
|
||||
qutil.LOGGER.info("Loading trades into database.")
|
||||
fdl.load_hourly_trades()
|
||||
elif(self.operation == 'ld'):
|
||||
qutil.LOGGER.info("Loading trades into database.")
|
||||
fdl.load_daily_close()
|
||||
elif(self.operation == 'si'):
|
||||
qutil.LOGGER.info("Loading security info into database.")
|
||||
fdl.load_security_info()
|
||||
elif(self.operation == 'tr'):
|
||||
qutil.LOGGER.info("Loading US Treasury rates into database.")
|
||||
fdl.load_treasuries()
|
||||
elif(self.operation == 'bm'):
|
||||
qutil.LOGGER.info("loading benchmark data into database.")
|
||||
fdl.load_bench_marks()
|
||||
else:
|
||||
qutil.LOGGER.warning("Unknown command for load data: {op}.".format(op=self.operation))
|
||||
qutil.LOGGER.info("Finished.")
|
||||
except:
|
||||
qutil.LOGGER.exception("exiting load_data due to unexpected exception.")
|
||||
finally:
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
+20
-18
@@ -1,9 +1,10 @@
|
||||
import datetime
|
||||
import quantoenv
|
||||
import math
|
||||
import pytz
|
||||
import numpy as np
|
||||
import numpy.linalg as la
|
||||
import zipline.util as qutil
|
||||
import zipline.db as db
|
||||
from pymongo import ASCENDING, DESCENDING
|
||||
|
||||
class daily_return():
|
||||
@@ -14,6 +15,7 @@ class daily_return():
|
||||
|
||||
class periodmetrics():
|
||||
def __init__(self, start_date, end_date, returns, benchmark_returns):
|
||||
self.db = db.DbConnection.get()[1]
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.trading_calendar = trading_calendar
|
||||
@@ -45,7 +47,7 @@ class periodmetrics():
|
||||
|
||||
def calculate_period_returns(self, daily_returns):
|
||||
returns = [x.returns for x in daily_returns if x.date >= self.start_date and x.date <= self.end_date and self.trading_calendar.is_trading_day(x.date)]
|
||||
#quantoenv.qlogger.debug("using {count} daily returns out of {total}".format(count=len(returns),total=len(daily_returns)))
|
||||
#qutil.LOGGER.debug("using {count} daily returns out of {total}".format(count=len(returns),total=len(daily_returns)))
|
||||
period_returns = 1.0
|
||||
for r in returns:
|
||||
period_returns = period_returns * (1.0 + r)
|
||||
@@ -53,14 +55,14 @@ class periodmetrics():
|
||||
return period_returns, returns
|
||||
|
||||
def calculate_volatility(self, daily_returns):
|
||||
#quantoenv.qlogger.debug("trading days {td}".format(td=self.trading_days))
|
||||
#qutil.LOGGER.debug("trading days {td}".format(td=self.trading_days))
|
||||
return np.std(daily_returns, ddof=1) * math.sqrt(self.trading_days)
|
||||
|
||||
def calculate_sharpe(self):
|
||||
return (self.algorithm_period_returns - self.treasury_period_return) / self.algorithm_volatility
|
||||
|
||||
def calculate_beta(self):
|
||||
#quantoenv.qlogger.debug("algorithm has {acount} days, benchmark has {bmcount} days".format(acount=len(self.algorithm_returns), bmcount=len(self.benchmark_returns)))
|
||||
#qutil.LOGGER.debug("algorithm has {acount} days, benchmark has {bmcount} days".format(acount=len(self.algorithm_returns), bmcount=len(self.benchmark_returns)))
|
||||
#it doesn't make much sense to calculate beta for less than two days, so return none.
|
||||
if len(self.algorithm_returns) < 2:
|
||||
return 0.0, 0.0, 0.0, 0.0, []
|
||||
@@ -71,7 +73,7 @@ class periodmetrics():
|
||||
algorithm_covariance = C[0][1]
|
||||
benchmark_variance = C[1][1]
|
||||
beta = C[0][1] / C[1][1]
|
||||
#quantoenv.qlogger.debug("bm variance is {bmv}, returns matrix is {rm}, covariance is {c}, beta is {beta}".format(rm=returns_matrix, bmv=C[1][1], c=C, beta=beta))
|
||||
#qutil.LOGGER.debug("bm variance is {bmv}, returns matrix is {rm}, covariance is {c}, beta is {beta}".format(rm=returns_matrix, bmv=C[1][1], c=C, beta=beta))
|
||||
|
||||
return beta, algorithm_covariance, benchmark_variance, condition_number, eigen_values
|
||||
|
||||
@@ -86,11 +88,11 @@ class periodmetrics():
|
||||
cur_return = math.log(1.0 + r) + cur_return
|
||||
#this is a guard for a single day returning -100%
|
||||
else:
|
||||
quantoenv.qlogger.warn("negative 100 percent return, zeroing the returns")
|
||||
qutil.LOGGER.warn("negative 100 percent return, zeroing the returns")
|
||||
cur_return = 0.0
|
||||
compounded_returns.append(cur_return)
|
||||
|
||||
#quantoenv.qlogger.debug("compounded returns are {cr}".format(cr=compounded_returns))
|
||||
#qutil.LOGGER.debug("compounded returns are {cr}".format(cr=compounded_returns))
|
||||
cur_max = None
|
||||
max_drawdown = None
|
||||
for cur in compounded_returns:
|
||||
@@ -101,7 +103,7 @@ class periodmetrics():
|
||||
if max_drawdown == None or drawdown < max_drawdown:
|
||||
max_drawdown = drawdown
|
||||
|
||||
#quantoenv.qlogger.debug("max drawdown is: {dd}".format(dd=max_drawdown))
|
||||
#qutil.LOGGER.debug("max drawdown is: {dd}".format(dd=max_drawdown))
|
||||
if max_drawdown == None:
|
||||
return 0.0
|
||||
|
||||
@@ -131,7 +133,7 @@ class periodmetrics():
|
||||
else:
|
||||
self.treasury_duration = '30year'
|
||||
|
||||
treasuryQS = quantoenv.getTickDB().treasury_curves.find(
|
||||
treasuryQS = self.db.treasury_curves.find(
|
||||
spec={"date" : {"$lte" : self.end_date}},
|
||||
sort=[("date",DESCENDING)],
|
||||
limit=3,
|
||||
@@ -154,11 +156,11 @@ class riskmetrics():
|
||||
|
||||
def __init__(self, algorithm_returns):
|
||||
"""algorithm_returns needs to be a list of daily_return objects sorted in date ascending order"""
|
||||
self.db = quantoenv.getTickDB()
|
||||
self.db = db.DbConnection.get()[1]
|
||||
self.algorithm_returns = algorithm_returns
|
||||
self.bm_returns = [x for x in benchmark_returns if x.date >= self.algorithm_returns[0].date and x.date <= self.algorithm_returns[-1].date]
|
||||
|
||||
quantoenv.qlogger.debug("#### {start} thru {end} with {count} trading_days of {total} possible".format(start=self.algorithm_returns[0].date,
|
||||
qutil.LOGGER.debug("#### {start} thru {end} with {count} trading_days of {total} possible".format(start=self.algorithm_returns[0].date,
|
||||
end=self.algorithm_returns[-1].date,
|
||||
count=len(self.bm_returns),
|
||||
total=len(benchmark_returns)))
|
||||
@@ -187,7 +189,7 @@ class riskmetrics():
|
||||
cur_end = advance_by_months(cur_start, months_per) - one_day
|
||||
if(cur_end > the_end):
|
||||
break
|
||||
#quantoenv.qlogger.debug("start: {start}, end: {end}".format(start=cur_start, end=cur_end))
|
||||
#qutil.LOGGER.debug("start: {start}, end: {end}".format(start=cur_start, end=cur_end))
|
||||
cur_period_metrics = periodmetrics(start_date=cur_start, end_date=cur_end, returns=self.algorithm_returns, benchmark_returns=self.bm_returns)
|
||||
ends.append(cur_period_metrics)
|
||||
cur_start = advance_by_months(cur_start, 1)
|
||||
@@ -195,7 +197,7 @@ class riskmetrics():
|
||||
return ends
|
||||
|
||||
def store_to_db(self, back_test_run_id):
|
||||
col = quantoenv.getTickDB().risk_metrics
|
||||
col = self.db.risk_metrics
|
||||
for period in self.month_periods:
|
||||
for metric in ["algorithm_period_returns", "benchmark_period_returns", "excess_return", "trading_days", "benchmark_volatility", "algorithm_volatility", "sharpe", "beta", "alpha", "max_drawdown"]:
|
||||
record = {'back_test_run_id':back_test_run_id}
|
||||
@@ -203,7 +205,7 @@ class riskmetrics():
|
||||
record['metric_name'] = metric
|
||||
for dur in ["month", "three_month", "six_month", "year", "three_year", "five_year"]:
|
||||
record[dur] = self.find_metric_by_end(period.end_date, dur, metric)
|
||||
#quantoenv.qlogger.debug("storing {val} for {metric} and {dur}".format(val=record[dur], metric=metric, dur=dur))
|
||||
#qutil.LOGGER.debug("storing {val} for {metric} and {dur}".format(val=record[dur], metric=metric, dur=dur))
|
||||
col.insert(record, safe=True)
|
||||
|
||||
def find_metric_by_end(self, end_date, duration, metric):
|
||||
@@ -253,13 +255,13 @@ class TradingCalendar(object):
|
||||
|
||||
|
||||
def get_benchmark_data():
|
||||
bmQS = quantoenv.getTickDB().bench_marks.find(
|
||||
bmQS = db.DbConnection.get()[1].bench_marks.find(
|
||||
spec={"symbol" : "GSPC",
|
||||
"date":{"$gte": quantoenv.getUTCFromExchangeTime(datetime.datetime.strptime('01/01/1990','%m/%d/%Y')),
|
||||
"$lte": quantoenv.getUTCFromExchangeTime(datetime.datetime.strptime('12/31/2010','%m/%d/%Y'))}},
|
||||
"date":{"$gte": datetime.datetime.strptime('01/01/1990','%m/%d/%Y').replace(tzinfo = pytz.utc),
|
||||
"$lte": datetime.datetime.strptime('12/31/2010','%m/%d/%Y').replace(tzinfo = pytz.utc)}},
|
||||
sort=[("date",ASCENDING)],
|
||||
slave_ok=True,
|
||||
as_class=quantoenv.DocWrap)
|
||||
as_class=qutil.DocWrap)
|
||||
bm_returns = []
|
||||
for bm in bmQS:
|
||||
bm_r = daily_return(date=bm.date.replace(tzinfo=pytz.utc), returns=bm.returns)
|
||||
|
||||
@@ -2,9 +2,11 @@ import json
|
||||
import zipline.util as qutil
|
||||
import zipline.messaging as qmsg
|
||||
|
||||
from zipline.finance.trading import TradeSimulationClient
|
||||
from zipline.protocol import CONTROL_PROTOCOL
|
||||
|
||||
class TestClient(qmsg.Component):
|
||||
"""no-op client - Just connects to the merge and counts messages. compares received message count to the expected count."""
|
||||
|
||||
def __init__(self, utest, expected_msg_count=0):
|
||||
qmsg.Component.__init__(self)
|
||||
@@ -44,3 +46,12 @@ class TestClient(qmsg.Component):
|
||||
self.prev_dt = event['dt']
|
||||
if(self.received_count % 100 == 0):
|
||||
qutil.LOGGER.info("received {n} messages".format(n=self.received_count))
|
||||
|
||||
class TestTradingClient(TradeSimulationClient):
|
||||
|
||||
|
||||
def handle_events(self, event_queue):
|
||||
#place an order for 100 shares of sid:133
|
||||
self.order(133,100)
|
||||
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Dummy simulator for test/development on Zipline.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import mock
|
||||
from collections import defaultdict
|
||||
from zipline.monitor import Controller
|
||||
from zipline.messaging import ComponentHost
|
||||
import zipline.util as qutil
|
||||
|
||||
class DummyAllocator(object):
|
||||
|
||||
def __init__(self, ns):
|
||||
self.idx = 0
|
||||
self.sockets = [
|
||||
'tcp://127.0.0.1:%s' % (10000 + n)
|
||||
for n in xrange(ns)
|
||||
]
|
||||
|
||||
def lease(self, n):
|
||||
sockets = self.sockets[self.idx:self.idx+n]
|
||||
self.idx += n
|
||||
return sockets
|
||||
|
||||
def reaquire(self, *conn):
|
||||
pass
|
||||
|
||||
class SimulatorBase(ComponentHost):
|
||||
"""
|
||||
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
|
||||
"""
|
||||
|
||||
def __init__(self, addresses, gevent_needed=False):
|
||||
"""
|
||||
"""
|
||||
ComponentHost.__init__(self, addresses, gevent_needed)
|
||||
|
||||
def simulate(self):
|
||||
self.run()
|
||||
|
||||
def get_id(self):
|
||||
return "Simulator"
|
||||
|
||||
class ThreadSimulator(SimulatorBase):
|
||||
|
||||
def __init__(self, addresses):
|
||||
SimulatorBase.__init__(self, addresses)
|
||||
|
||||
def launch_controller(self):
|
||||
thread = threading.Thread(target=self.controller.run)
|
||||
thread.start()
|
||||
self.cuc = thread
|
||||
return thread
|
||||
|
||||
def launch_component(self, component):
|
||||
thread = threading.Thread(target=component.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
class ExecutorMixinBase(object):
|
||||
"""Abstract base to allow mixin for tests that need a dummy simulator."""
|
||||
leased_sockets = defaultdict(list)
|
||||
|
||||
def setUp(self):
|
||||
self.setup_logging()
|
||||
|
||||
# TODO: how to make Nose use this cross-process????
|
||||
self.setup_allocator()
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
#self.unallocate_sockets()
|
||||
|
||||
# Assert the sockets were properly cleaned up
|
||||
#self.assertEmpty(self.leased_sockets[self.id()].values())
|
||||
|
||||
# Assert they were returned to the heap
|
||||
#self.allocator.socketheap.assert
|
||||
|
||||
def get_simulator(self):
|
||||
"""
|
||||
Return a new simulator instance to be tested.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_controller(self):
|
||||
"""
|
||||
Return a new controler for simulator instance to be tested.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_allocator(self):
|
||||
"""
|
||||
Setup the socket allocator for this test case.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def allocate_sockets(self, n):
|
||||
"""
|
||||
Allocate sockets local to this test case, track them so
|
||||
we can gc after test run.
|
||||
"""
|
||||
|
||||
assert isinstance(n, int)
|
||||
assert n > 0
|
||||
|
||||
leased = self.allocator.lease(n)
|
||||
|
||||
self.leased_sockets[self.id()].extend(leased)
|
||||
return leased
|
||||
|
||||
def unallocate_sockets(self):
|
||||
self.allocator.reaquire(*self.leased_sockets[self.id()])
|
||||
|
||||
class ThreadPoolExecutorMixin(ExecutorMixinBase):
|
||||
"""Dummy server using threads."""
|
||||
allocator = DummyAllocator(100)
|
||||
|
||||
def setup_logging(self):
|
||||
qutil.configure_logging()
|
||||
|
||||
# lazy import by design
|
||||
self.logger = mock.Mock()
|
||||
|
||||
def setup_allocator(self):
|
||||
pass
|
||||
|
||||
def get_simulator(self, addresses):
|
||||
return ThreadSimulator(addresses)
|
||||
|
||||
def get_controller(self):
|
||||
# Allocate two more sockets
|
||||
controller_sockets = self.allocate_sockets(2)
|
||||
|
||||
return Controller(
|
||||
controller_sockets[0],
|
||||
controller_sockets[1],
|
||||
logging = self.logger,
|
||||
)
|
||||
|
||||
|
||||
|
||||
+8
-11
@@ -1,9 +1,6 @@
|
||||
import datetime
|
||||
import pytz
|
||||
from algorithm.quantoenv import *
|
||||
from algorithm.quantomodels import *
|
||||
from algorithm.hostedalgorithm import *
|
||||
from algorithm.risk import *
|
||||
import zipline.finance.risk as risk
|
||||
|
||||
def createReturns(daycount, start):
|
||||
i = 0
|
||||
@@ -15,7 +12,7 @@ def createReturns(daycount, start):
|
||||
r = daily_return(current, random.random())
|
||||
test_range.append(r)
|
||||
current = current + one_day
|
||||
return [ x for x in test_range if(trading_calendar.is_trading_day(x.date)) ]
|
||||
return [ x for x in test_range if(risk.trading_calendar.is_trading_day(x.date)) ]
|
||||
|
||||
def createReturnsFromRange(start, end):
|
||||
current = start.replace(tzinfo=pytz.utc)
|
||||
@@ -25,7 +22,7 @@ def createReturnsFromRange(start, end):
|
||||
i = 0
|
||||
while current <= end:
|
||||
current = current + one_day
|
||||
if(not trading_calendar.is_trading_day(current)):
|
||||
if(not risk.trading_calendar.is_trading_day(current)):
|
||||
continue
|
||||
r = daily_return(current, random.random())
|
||||
i += 1
|
||||
@@ -38,7 +35,7 @@ def createReturnsFromList(returns, start):
|
||||
test_range = []
|
||||
i = 0
|
||||
while len(test_range) < len(returns):
|
||||
if(trading_calendar.is_trading_day(current)):
|
||||
if(risk.trading_calendar.is_trading_day(current)):
|
||||
r = daily_return(current, returns[i])
|
||||
i += 1
|
||||
test_range.append(r)
|
||||
@@ -61,7 +58,7 @@ def getCodeFromFile(filename):
|
||||
return rVal
|
||||
|
||||
|
||||
def createTrade(sid, price, amount, datetime):
|
||||
def create_trade(sid, price, amount, datetime):
|
||||
row = {}
|
||||
row['sid'] = sid
|
||||
row['dt'] = datetime
|
||||
@@ -79,8 +76,8 @@ def create_trade_history(sid, prices, amounts, start_time, interval):
|
||||
trades = []
|
||||
current = start_time
|
||||
while i < len(prices):
|
||||
if(trading_calendar.is_trading_day(current)):
|
||||
trades.append(createTrade(sid, priceList[i], amtList[i], current))
|
||||
if(risk.trading_calendar.is_trading_day(current)):
|
||||
trades.append(create_trade(sid, priceList[i], amtList[i], current))
|
||||
current = current + interval
|
||||
i += 1
|
||||
else:
|
||||
@@ -98,7 +95,7 @@ def createTxnHistory(sid, priceList, amtList, startTime, interval):
|
||||
txns = []
|
||||
current = startTime
|
||||
while i < len(priceList):
|
||||
if(trading_calendar.is_trading_day(current)):
|
||||
if(risk.trading_calendar.is_trading_day(current)):
|
||||
txns.append(createTxn(sid,priceList[i],amtList[i], current))
|
||||
current = current + interval
|
||||
i += 1
|
||||
|
||||
@@ -1,87 +1,10 @@
|
||||
"""
|
||||
Dummy simulator backported from Qexec for development on Zipline.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import mock
|
||||
from unittest2 import TestCase
|
||||
|
||||
from zipline.test.test_messaging import SimulatorTestCase
|
||||
from zipline.monitor import Controller
|
||||
from zipline.messaging import ComponentHost
|
||||
import zipline.util as qutil
|
||||
from zipline.test.dummy import ThreadPoolExecutorMixin
|
||||
|
||||
class DummyAllocator(object):
|
||||
|
||||
def __init__(self, ns):
|
||||
self.idx = 0
|
||||
self.sockets = [
|
||||
'tcp://127.0.0.1:%s' % (10000 + n)
|
||||
for n in xrange(ns)
|
||||
]
|
||||
|
||||
def lease(self, n):
|
||||
sockets = self.sockets[self.idx:self.idx+n]
|
||||
self.idx += n
|
||||
return sockets
|
||||
|
||||
def reaquire(self, *conn):
|
||||
pass
|
||||
|
||||
class SimulatorBase(ComponentHost):
|
||||
"""
|
||||
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
|
||||
"""
|
||||
|
||||
def __init__(self, addresses, gevent_needed=False):
|
||||
"""
|
||||
"""
|
||||
ComponentHost.__init__(self, addresses, gevent_needed)
|
||||
|
||||
def simulate(self):
|
||||
self.run()
|
||||
|
||||
def get_id(self):
|
||||
return "Simulator"
|
||||
|
||||
class ThreadSimulator(SimulatorBase):
|
||||
|
||||
def __init__(self, addresses):
|
||||
SimulatorBase.__init__(self, addresses)
|
||||
|
||||
def launch_controller(self):
|
||||
thread = threading.Thread(target=self.controller.run)
|
||||
thread.start()
|
||||
self.cuc = thread
|
||||
return thread
|
||||
|
||||
def launch_component(self, component):
|
||||
thread = threading.Thread(target=component.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
class ThreadPoolExecutor(SimulatorTestCase, TestCase):
|
||||
|
||||
allocator = DummyAllocator(100)
|
||||
|
||||
def setup_logging(self):
|
||||
qutil.configure_logging()
|
||||
|
||||
# lazy import by design
|
||||
self.logger = mock.Mock()
|
||||
|
||||
def setup_allocator(self):
|
||||
pass
|
||||
|
||||
def get_simulator(self, addresses):
|
||||
return ThreadSimulator(addresses)
|
||||
|
||||
def get_controller(self):
|
||||
# Allocate two more sockets
|
||||
controller_sockets = self.allocate_sockets(2)
|
||||
|
||||
return Controller(
|
||||
controller_sockets[0],
|
||||
controller_sockets[1],
|
||||
logging = self.logger,
|
||||
)
|
||||
|
||||
def test_universe(self):
|
||||
# first order logic is working today. Yay!
|
||||
self.assertTrue(True != False)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Tests for the zipline.finance package"""
|
||||
import mock
|
||||
import zipline.host_settings
|
||||
from unittest2 import TestCase
|
||||
|
||||
from zipline.test.test_devsimulator import ThreadSimulator, DummyAllocator
|
||||
from zipline.test.test_messaging import SimulatorTestCase
|
||||
import zipline.test.factory as factory
|
||||
from zipline.monitor import Controller
|
||||
from zipline.messaging import DataSource
|
||||
import zipline.util as qutil
|
||||
import zipline.db as db
|
||||
import zipline.host_settings
|
||||
|
||||
class ThreadPoolExecutor(SimulatorTestCase, TestCase):
|
||||
class FinanceTestCase(SimulatorTestCase, TestCase):
|
||||
|
||||
allocator = DummyAllocator(100)
|
||||
|
||||
@@ -57,17 +61,14 @@ class ThreadPoolExecutor(SimulatorTestCase, TestCase):
|
||||
# Simulation Components
|
||||
# ---------------------
|
||||
|
||||
ret1 = SpecificEquityTrades("flat-133",factory.create_trade_history(133,
|
||||
[10.0,11.0,10.0,10.0],
|
||||
set1 = SpecificEquityTrades("flat-133",factory.create_trade_history(133,
|
||||
[10.0,10.0,10.0,10.0],
|
||||
[100,100,100,100],
|
||||
datetime.datetime.utcnow(),
|
||||
datetime.timedelta(days=1)))
|
||||
ret2 = RandomEquityTrades(134, "ret2", 5000)
|
||||
mavg1 = MovingAverage("mavg1", 30)
|
||||
mavg2 = MovingAverage("mavg2", 60)
|
||||
client = TestClient(self, expected_msg_count=10000)
|
||||
client = TestTradingClient(self, expected_msg_count=4)
|
||||
|
||||
sim.register_components([ret1, ret2, mavg1, mavg2, client])
|
||||
sim.register_components([set1, client])
|
||||
sim.register_controller( con )
|
||||
|
||||
# Simulation
|
||||
|
||||
@@ -3,72 +3,17 @@ Test suite for the messaging infrastructure of Zipline.
|
||||
"""
|
||||
#don't worry about excessive public methods pylint: disable=R0904
|
||||
|
||||
from collections import defaultdict
|
||||
import zipline.messaging as qmsg
|
||||
|
||||
from zipline.transforms.technical import MovingAverage
|
||||
from zipline.sources import RandomEquityTrades
|
||||
|
||||
from zipline.test.dummy import ThreadPoolExecutorMixin
|
||||
from zipline.test.client import TestClient
|
||||
|
||||
|
||||
# Should not inherit form TestCase since test runners will pick
|
||||
# it up as a test. Its a Mixin of sorts at this point.
|
||||
class SimulatorTestCase(object):
|
||||
|
||||
leased_sockets = defaultdict(list)
|
||||
|
||||
def setUp(self):
|
||||
self.setup_logging()
|
||||
|
||||
# TODO: how to make Nose use this cross-process????
|
||||
self.setup_allocator()
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
#self.unallocate_sockets()
|
||||
|
||||
# Assert the sockets were properly cleaned up
|
||||
#self.assertEmpty(self.leased_sockets[self.id()].values())
|
||||
|
||||
# Assert they were returned to the heap
|
||||
#self.allocator.socketheap.assert
|
||||
|
||||
def get_simulator(self):
|
||||
"""
|
||||
Return a new simulator instance to be tested.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_controller(self):
|
||||
"""
|
||||
Return a new controler for simulator instance to be tested.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_allocator(self):
|
||||
"""
|
||||
Setup the socket allocator for this test case.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def allocate_sockets(self, n):
|
||||
"""
|
||||
Allocate sockets local to this test case, track them so
|
||||
we can gc after test run.
|
||||
"""
|
||||
|
||||
assert isinstance(n, int)
|
||||
assert n > 0
|
||||
|
||||
leased = self.allocator.lease(n)
|
||||
|
||||
self.leased_sockets[self.id()].extend(leased)
|
||||
return leased
|
||||
|
||||
def unallocate_sockets(self):
|
||||
self.allocator.reaquire(*self.leased_sockets[self.id()])
|
||||
|
||||
class SimulatorTestCase(ThreadPoolExecutorMixin):
|
||||
|
||||
# -------
|
||||
# Cases
|
||||
# -------
|
||||
|
||||
+28
-1
@@ -6,8 +6,9 @@ and other common operations.
|
||||
import datetime
|
||||
import pytz
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
||||
LOGGER = logging.getLogger('QSimLogger')
|
||||
LOGGER = logging.getLogger('ZiplineLogger')
|
||||
|
||||
def configure_logging(loglevel=logging.DEBUG):
|
||||
"""
|
||||
@@ -49,3 +50,29 @@ def format_date(dt):
|
||||
dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000)
|
||||
return dt_str
|
||||
|
||||
class DocWrap():
|
||||
def __init__(self, store=None):
|
||||
if(store == None):
|
||||
self.store = {}
|
||||
else:
|
||||
self.store = store.copy()
|
||||
if(self.store.has_key('_id')):
|
||||
self.store['id'] = self.store['_id']
|
||||
del(self.store['_id'])
|
||||
|
||||
def __setitem__(self,key,value):
|
||||
if(key == '_id'):
|
||||
self.store['id'] = value
|
||||
else:
|
||||
self.store[key] = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.store.has_key(key):
|
||||
return self.store[key]
|
||||
|
||||
def __getattr__(self,attrname):
|
||||
if self.store.has_key(attrname):
|
||||
return self.store[attrname]
|
||||
else:
|
||||
raise AttributeError("No attribute named {name}".format(name=attrname))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user