mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 19:15:14 +08:00
Resolved conflicts
This commit is contained in:
@@ -1 +0,0 @@
|
||||
# TODO: move qexec console here
|
||||
-189
@@ -1,189 +0,0 @@
|
||||
import uuid
|
||||
import copy
|
||||
import atexit
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
from UserDict import DictMixin
|
||||
|
||||
class Snapshot(object, DictMixin):
|
||||
"""
|
||||
A snapshot in time of a history container.
|
||||
"""
|
||||
|
||||
def __init__(self, state, version, ts):
|
||||
self.version = version
|
||||
self.timestamp = ts
|
||||
self._state = state
|
||||
|
||||
def keys(self):
|
||||
return self._state.keys()
|
||||
|
||||
def values(self):
|
||||
return self._state.values()
|
||||
|
||||
def items(self):
|
||||
return self._state.items()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._state.__getitem__(key)
|
||||
|
||||
def has_key(self, key):
|
||||
return self._state.has_key(key)
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self._state)
|
||||
|
||||
class History(object, DictMixin):
|
||||
"""
|
||||
A duck-typed dictionary that tracks its time evolution.
|
||||
|
||||
Worth noting this not a particuarly high-performance
|
||||
data structure due to the copious amount of copying going on.
|
||||
"""
|
||||
|
||||
def __init__(self, default=None):
|
||||
if default:
|
||||
initial = defaultdict(default)
|
||||
else:
|
||||
initial = {}
|
||||
|
||||
self.version = 0
|
||||
self.changeset = [('CREATE', None)]
|
||||
self.current = Snapshot(initial, version=self.version, ts=datetime.now())
|
||||
self._history = [self.current]
|
||||
|
||||
def items(self, version=-1):
|
||||
return self._history[version].items()
|
||||
|
||||
def keys(self, version=-1):
|
||||
return self._history[version].keys()
|
||||
|
||||
def rollback(self, version):
|
||||
pass
|
||||
|
||||
def event(self, tup):
|
||||
self.changeset.append(tup)
|
||||
|
||||
def __getitem__(self, key, version=-1):
|
||||
return self._history[version].__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
if self.current.has_key(key):
|
||||
self.changeset.append(('CHANGE', key))
|
||||
else:
|
||||
self.changeset.append(('ADD', key))
|
||||
|
||||
state = self.current.copy()
|
||||
state[key] = val
|
||||
|
||||
self.version += 1
|
||||
self.current = Snapshot(state, self.version, datetime.now())
|
||||
self._history.append(self.current)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.changeset.append(('REMOVE', key))
|
||||
|
||||
state = self.current.copy()
|
||||
del state[key]
|
||||
|
||||
self.version += 1
|
||||
self.current = Snapshot(state, self.version, datetime.now())
|
||||
self._history.append(self.current)
|
||||
|
||||
def history(self):
|
||||
for change in self.changeset:
|
||||
print change
|
||||
|
||||
def __repr__(self):
|
||||
return ':'.join(['historical', self.current._state.__repr__()])
|
||||
|
||||
SocketHistory = History()
|
||||
ContextHistory = History()
|
||||
|
||||
def patch_zmq(_zmq=None):
|
||||
"""
|
||||
Monkey patch zeromq to allow for socket tracking.
|
||||
"""
|
||||
if _zmq:
|
||||
zmq = _zmq
|
||||
else:
|
||||
import zmq
|
||||
|
||||
_Context = zmq.Context
|
||||
_Socket = zmq.Socket
|
||||
|
||||
class TrackedSocket(zmq.Socket):
|
||||
|
||||
def __init__(self, context, socket_type):
|
||||
self.context = context
|
||||
self.uuid = str(uuid.uuid4())
|
||||
SocketHistory[self.uuid] = self
|
||||
_Socket.__init__(self, context, socket_type)
|
||||
|
||||
def connect(self, address):
|
||||
SocketHistory.event(('CONNECT', self.uuid, address))
|
||||
_Socket.connect(self, address)
|
||||
|
||||
def bind(self, address):
|
||||
SocketHistory.event(('BIND', self.uuid, address))
|
||||
_Socket.bind(self, address)
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
del SocketHistory[self.uuid]
|
||||
_Socket.close(self, *args, **kwargs)
|
||||
|
||||
def setsockopt(self, option, optval):
|
||||
if option == zmq.IDENTITY:
|
||||
old = SocketHistory[self.uuid]
|
||||
SocketHistory[optval] = old
|
||||
del SocketHistory[self.uuid]
|
||||
self.uuid = optval
|
||||
|
||||
_Socket.setsockopt(self, option, optval)
|
||||
|
||||
class TrackedContext(zmq.Context):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.sockets = {}
|
||||
_Context.__init__(self, *args, **kwargs)
|
||||
self.uuid = str(uuid.uuid4())
|
||||
ContextHistory[self.uuid] = self
|
||||
|
||||
def socket(self, socket_type):
|
||||
sock = TrackedSocket(self, socket_type)
|
||||
ContextHistory.event(('EMBED', self.uuid, sock.uuid))
|
||||
self.sockets[sock.uuid] = sock
|
||||
return sock
|
||||
|
||||
def name(self, name):
|
||||
"""
|
||||
Name the context. Is a superset of the vanilla pyzmq
|
||||
API.
|
||||
"""
|
||||
old = ContextHistory[self.context.uuid]
|
||||
ContextHistory[name] = old
|
||||
del ContextHistory[self.context.uuid]
|
||||
self.uuid = name
|
||||
|
||||
def term(self, *args, **kwargs):
|
||||
for uid, sock in self.sockets.iteritems():
|
||||
if not sock.closed:
|
||||
del SocketHistory[sock.uuid]
|
||||
del ContextHistory[self.uuid]
|
||||
_Context.term(self, *args, **kwargs)
|
||||
|
||||
def destroy(self, *args, **kwargs):
|
||||
ContextHistory.event(('DESTROY', self.uuid))
|
||||
_Context.destroy(self, *args, **kwargs)
|
||||
|
||||
zmq.Context = TrackedContext
|
||||
zmq.Socket = TrackedSocket
|
||||
return TrackedContext, TrackedSocket
|
||||
|
||||
def track_to_file(f):
|
||||
def write_track():
|
||||
pickle.dump(SocketHistory.changeset, file(f, 'wb+'))
|
||||
atexit.register(write_track)
|
||||
@@ -4,7 +4,7 @@ python-dateutil==1.5
|
||||
|
||||
# Core scientific python
|
||||
numpy>=1.6.1
|
||||
pandas>=0.7.0rc1
|
||||
pandas==0.8.0
|
||||
scipy>=0.10.0
|
||||
|
||||
matplotlib==1.1.0
|
||||
@@ -12,8 +12,9 @@ matplotlib==1.1.0
|
||||
|
||||
numexpr==2.0.1
|
||||
Cython==0.15.1
|
||||
#tables>=2.3.1
|
||||
#scikits.statsmodels>=0.3.1
|
||||
patsy==0.1.0
|
||||
statsmodels==0.5.0-tutorial-beta
|
||||
|
||||
|
||||
# ZeroMQ
|
||||
pyzmq==2.1.11
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
from signal import signal, SIGHUP, SIGINT
|
||||
import time
|
||||
from types import FrameType
|
||||
import unittest
|
||||
|
||||
from zipline.utils.delayed_signals import delayed_signals
|
||||
|
||||
class DelayedSignals(unittest.TestCase):
|
||||
def handler(self, signum, frame):
|
||||
print "Got signal " + str(signum)
|
||||
self.got[signum] = time.time()
|
||||
self.assertTrue(isinstance(frame, FrameType))
|
||||
|
||||
def setUp(self):
|
||||
signal(SIGHUP, self.handler)
|
||||
signal(SIGINT, self.handler)
|
||||
|
||||
def reset(self):
|
||||
self.got = {}
|
||||
|
||||
def test_delayed_signals(self):
|
||||
self.reset()
|
||||
with delayed_signals([SIGHUP]):
|
||||
os.kill(os.getpid(), SIGHUP)
|
||||
time.sleep(2)
|
||||
self.assertTrue(self.got[SIGHUP])
|
||||
self.assertTrue(time.time() - self.got[SIGHUP] < 2)
|
||||
|
||||
def test_immediate_signals(self):
|
||||
self.reset()
|
||||
os.kill(os.getpid(), SIGHUP)
|
||||
time.sleep(2)
|
||||
self.assertTrue(self.got[SIGHUP])
|
||||
self.assertTrue(time.time() - self.got[SIGHUP] > 1)
|
||||
|
||||
def test_multiple_signals(self):
|
||||
self.reset()
|
||||
with delayed_signals([SIGHUP, SIGINT]):
|
||||
os.kill(os.getpid(), SIGINT)
|
||||
self.assertFalse(SIGHUP in self.got)
|
||||
self.assertTrue(SIGINT in self.got)
|
||||
|
||||
@delayed_signals([SIGHUP])
|
||||
def kill_and_sleep(self):
|
||||
os.kill(os.getpid(), SIGHUP)
|
||||
time.sleep(2)
|
||||
|
||||
def test_decorator(self):
|
||||
self.reset()
|
||||
self.kill_and_sleep()
|
||||
self.assertTrue(SIGHUP in self.got)
|
||||
self.assertTrue(time.time() - self.got[SIGHUP] < 2)
|
||||
@@ -3,11 +3,14 @@ import zmq
|
||||
from unittest2 import TestCase
|
||||
from collections import defaultdict
|
||||
|
||||
from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm
|
||||
from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm, \
|
||||
InitializeTimeoutAlgorithm, TooMuchProcessingAlgorithm
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.lines import SimulatedTrading
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
from zipline.gens.tradesimulation import HEARTBEAT_INTERVAL, \
|
||||
MAX_HEARTBEAT_INTERVALS
|
||||
|
||||
from zipline.utils.test_utils import \
|
||||
drain_zipline, \
|
||||
@@ -143,3 +146,46 @@ class ExceptionTestCase(TestCase):
|
||||
# make sure our path shortening is working
|
||||
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
|
||||
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
|
||||
|
||||
def test_initialize_timeout(self):
|
||||
|
||||
self.zipline_test_config['algorithm'] = \
|
||||
InitializeTimeoutAlgorithm(
|
||||
self.zipline_test_config['sid']
|
||||
)
|
||||
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
|
||||
payload = output[-1]['payload']
|
||||
self.assertEqual(payload['name'],'Timeout')
|
||||
self.assertEqual(payload['message'], 'Call to initialize timed out')
|
||||
|
||||
def test_heartbeat(self):
|
||||
|
||||
self.zipline_test_config['algorithm'] = \
|
||||
TooMuchProcessingAlgorithm(
|
||||
self.zipline_test_config['sid']
|
||||
)
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
|
||||
# There should be a message for each hearbeat, plus a message
|
||||
# for the final timeout.
|
||||
assert len(output) == MAX_HEARTBEAT_INTERVALS + 1
|
||||
|
||||
# Assert that everything but the last message is a heartbeat log.
|
||||
for message in output[0:-1]:
|
||||
assert message['prefix'] == 'LOG'
|
||||
assert message['payload']['func_name'] == 'log_heartbeats'
|
||||
|
||||
# Assert that the last message is a timeout exception.
|
||||
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
|
||||
payload = output[-1]['payload']
|
||||
self.assertEqual(payload['name'],'Timeout')
|
||||
self.assertEqual(payload['message'], 'Too much time spent in handle_data call')
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from zipline.finance.performance import PerformanceTracker
|
||||
from zipline.utils.protocol_utils import ndict
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
from zipline.utils.test_utils import \
|
||||
drain_zipline, \
|
||||
setup_logger, \
|
||||
teardown_logger,\
|
||||
assert_single_position
|
||||
@@ -121,36 +120,6 @@ class FinanceTestCase(TestCase):
|
||||
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
|
||||
assert_single_position(self, zipline)
|
||||
|
||||
#@timed(DEFAULT_TIMEOUT)
|
||||
def test_sid_filter(self):
|
||||
# Ensure the algorithm's filter prevents events from arriving.
|
||||
# create a test algorithm whose filter will not match any of the
|
||||
# trade events sourced inside the zipline.
|
||||
order_amount = 100
|
||||
order_count = 100
|
||||
no_match_sid = 222
|
||||
test_algo = TestAlgorithm(
|
||||
no_match_sid,
|
||||
order_amount,
|
||||
order_count
|
||||
)
|
||||
|
||||
self.zipline_test_config['trade_count'] = 200
|
||||
self.zipline_test_config['algorithm'] = test_algo
|
||||
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, transaction_count = drain_zipline(self, zipline)
|
||||
|
||||
#check that the algorithm received no events
|
||||
self.assertEqual(
|
||||
0,
|
||||
transaction_count,
|
||||
"The algorithm should not receive any events due to filtering."
|
||||
)
|
||||
|
||||
|
||||
# TODO: write tests for short sales
|
||||
# TODO: write a test to do massive buying or shorting.
|
||||
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import logging
|
||||
import logbook
|
||||
import uuid
|
||||
import zmq
|
||||
|
||||
from zipline import ndict
|
||||
|
||||
from zipline.utils.logger import configure_logging, tail
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler
|
||||
|
||||
from zipline.utils.test_utils import create_receiver, drain_receiver
|
||||
|
||||
from unittest2 import TestCase
|
||||
|
||||
|
||||
@@ -20,3 +28,35 @@ class LoggerTestCase(TestCase):
|
||||
last_line = tail(logfile, window=1)
|
||||
logged_msg = last_line.split(" - ")[1]
|
||||
self.assertEqual(test_msg, logged_msg)
|
||||
|
||||
|
||||
def test_zmq_handler(self):
|
||||
socket_addr = 'tcp://127.0.0.1:10000'
|
||||
ctx = zmq.Context()
|
||||
socket_push = ctx.socket(zmq.PUSH)
|
||||
socket_push.connect(socket_addr)
|
||||
recv = create_receiver(socket_addr, ctx)
|
||||
zmq_out = ZeroMQLogHandler(
|
||||
socket = socket_push,
|
||||
filter = lambda r, h: r.channel in ['test zmq logger'],
|
||||
context=ctx,
|
||||
#bubble=False
|
||||
)
|
||||
|
||||
log = logbook.Logger('test zmq logger')
|
||||
x = ndict({})
|
||||
x.a = 1
|
||||
ex = example(133)
|
||||
with zmq_out.threadbound():
|
||||
log.info(ex.num)
|
||||
|
||||
|
||||
output, _ = drain_receiver(recv, count=1)
|
||||
self.assertEqual(output[-1]['prefix'], 'LOG')
|
||||
self.assertTrue(isinstance(output[-1]['payload']['msg'], basestring))
|
||||
|
||||
|
||||
class example(object):
|
||||
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytz
|
||||
import numpy
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from collections import defaultdict
|
||||
@@ -15,6 +16,7 @@ from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import StatefulTransform, EventWindow
|
||||
from zipline.gens.vwap import VWAP
|
||||
from zipline.gens.mavg import MovingAverage
|
||||
from zipline.gens.stddev import MovingStandardDev
|
||||
from zipline.gens.returns import Returns
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
@@ -70,6 +72,7 @@ class EventWindowTestCase(TestCase):
|
||||
delta = timedelta(minutes = 5),
|
||||
days = None
|
||||
)
|
||||
|
||||
now = utcnow()
|
||||
|
||||
# 15 dates, increasing in 1 minute increments.
|
||||
@@ -99,6 +102,7 @@ class EventWindowTestCase(TestCase):
|
||||
delta = None,
|
||||
days = 1
|
||||
)
|
||||
|
||||
dates = ([self.pre_open]*3)
|
||||
dates += ([self.mid_day]*3)
|
||||
dates += ([self.post_close]*3)
|
||||
@@ -239,11 +243,12 @@ class FinanceTransformsTestCase(TestCase):
|
||||
fields = ['price', 'volume'],
|
||||
delta = timedelta(days = 2),
|
||||
)
|
||||
|
||||
transformed = list(mavg.transform(self.source))
|
||||
# Output values.
|
||||
tnfm_prices = [message.tnfm_value.price for message in transformed]
|
||||
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
|
||||
|
||||
|
||||
# "Hand-calculated" values
|
||||
expected_prices = [
|
||||
((10.0) / 1.0),
|
||||
@@ -264,3 +269,46 @@ class FinanceTransformsTestCase(TestCase):
|
||||
|
||||
assert tnfm_prices == expected_prices
|
||||
assert tnfm_volumes == expected_volumes
|
||||
|
||||
def test_moving_stddev(self):
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 15.0, 13.0, 12.0],
|
||||
[100, 100, 100, 100],
|
||||
timedelta(hours = 1),
|
||||
self.trading_environment
|
||||
)
|
||||
|
||||
stddev = StatefulTransform(
|
||||
MovingStandardDev,
|
||||
market_aware = False,
|
||||
delta = timedelta(minutes = 150),
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
transformed = list(stddev.transform(self.source))
|
||||
|
||||
vals = [message.tnfm_value for message in transformed]
|
||||
|
||||
expected = [
|
||||
None,
|
||||
numpy.std([10.0, 15.0], ddof = 1),
|
||||
numpy.std([10.0, 15.0, 13.0], ddof = 1),
|
||||
numpy.std([15.0, 13.0, 12.0], ddof = 1),
|
||||
]
|
||||
|
||||
# numpy has odd rounding behavior, cf.
|
||||
# http://docs.scipy.org/doc/numpy/reference/generated/numpy.std.html
|
||||
for v1, v2 in zip(vals, expected):
|
||||
|
||||
if v1 == None:
|
||||
assert v2 == None
|
||||
continue
|
||||
assert round(v1, 5) == round(v2, 5)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ class PerformanceTracker(object):
|
||||
self.event_count = 0
|
||||
self.last_dict = None
|
||||
self.exceeded_max_loss = False
|
||||
self.no_more_updates = False
|
||||
|
||||
self.compute_risk_metrics = True
|
||||
|
||||
@@ -205,9 +206,12 @@ class PerformanceTracker(object):
|
||||
self.todays_performance.positions[sid] = Position(sid)
|
||||
|
||||
def update(self, event):
|
||||
if event.dt == "DONE":
|
||||
if self.no_more_updates:
|
||||
return zp.ndict({'dt':0})
|
||||
elif event.dt == "DONE":
|
||||
event.perf_message = self.handle_simulation_end()
|
||||
del event['TRANSACTION']
|
||||
self.no_more_updates = True
|
||||
return event
|
||||
elif self.exceeded_max_loss:
|
||||
# in case of max_loss, signal to downstream
|
||||
@@ -215,6 +219,7 @@ class PerformanceTracker(object):
|
||||
event.dt = "DONE"
|
||||
event.perf_message = self.handle_simulation_end()
|
||||
del event['TRANSACTION']
|
||||
self.no_more_updates = True
|
||||
return event
|
||||
else:
|
||||
event.perf_message = self.process_event(event)
|
||||
|
||||
+14
-16
@@ -307,24 +307,22 @@ class RiskMetrics():
|
||||
for i in xrange(7):
|
||||
if(self.treasury_curves.has_key(self.end_date + i * one_day)):
|
||||
curve = self.treasury_curves[self.end_date + i * one_day]
|
||||
break
|
||||
self.treasury_curve = curve
|
||||
rate = self.treasury_curve[self.treasury_duration]
|
||||
#1month note data begins in 8/2001, so we can use 3month instead.
|
||||
if rate == None and self.treasury_duration == '1month':
|
||||
rate = self.treasury_curve['3month']
|
||||
|
||||
if curve:
|
||||
self.treasury_curve = curve
|
||||
rate = self.treasury_curve[self.treasury_duration]
|
||||
#1month note data begins in 8/2001, so we can use 3month instead.
|
||||
if rate == None and self.treasury_duration == '1month':
|
||||
rate = self.treasury_curve['3month']
|
||||
if rate != None:
|
||||
return rate * (td.days + 1) / 365
|
||||
if rate != None:
|
||||
return rate * (td.days + 1) / 365
|
||||
|
||||
message = "no rate for end date = {dt} and term = {term}. Check \
|
||||
that date doesn't exceed treasury history range."
|
||||
message = message.format(
|
||||
dt=self.end_date,
|
||||
term=self.treasury_duration
|
||||
)
|
||||
raise Exception(message)
|
||||
message = "no rate for end date = {dt} and term = {term}. Check \
|
||||
that date doesn't exceed treasury history range."
|
||||
message = message.format(
|
||||
dt=self.end_date,
|
||||
term=self.treasury_duration
|
||||
)
|
||||
raise Exception(message)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -184,6 +184,8 @@ class TradingEnvironment(object):
|
||||
self.first_open = self.calculate_first_open()
|
||||
self.last_close = self.calculate_last_close()
|
||||
|
||||
self.prior_day_open = self.calculate_prior_day_open()
|
||||
|
||||
def calculate_first_open(self):
|
||||
"""
|
||||
Finds the first trading day on or after self.period_start.
|
||||
@@ -197,6 +199,24 @@ class TradingEnvironment(object):
|
||||
first_open = self.set_NYSE_time(first_open, 9, 30)
|
||||
return first_open
|
||||
|
||||
def calculate_prior_day_open(self):
|
||||
"""
|
||||
Finds the first trading day open that falls at least a day
|
||||
before period_start.
|
||||
"""
|
||||
one_day = datetime.timedelta(days=1)
|
||||
first_open = self.period_start - one_day
|
||||
|
||||
if first_open <= self.trading_days[0]:
|
||||
log.warn("Cannot calculate prior day open.")
|
||||
return self.period_start
|
||||
|
||||
while not self.is_trading_day(first_open):
|
||||
first_open = first_open - one_day
|
||||
|
||||
first_open = self.set_NYSE_time(first_open, 9, 30)
|
||||
return first_open
|
||||
|
||||
def calculate_last_close(self):
|
||||
"""
|
||||
Finds the last trading day on or before self.period_end
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
"""
|
||||
Generator version of Feed.
|
||||
Sorting generator.
|
||||
"""
|
||||
import logbook
|
||||
|
||||
from collections import deque
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import \
|
||||
assert_datasource_unframe_protocol, \
|
||||
assert_sort_protocol
|
||||
|
||||
log = logbook.Logger('Sorting')
|
||||
|
||||
def date_sort(stream_in, source_ids):
|
||||
"""
|
||||
A generator that takes a generator and a list of source_ids. We
|
||||
@@ -27,7 +31,7 @@ def date_sort(stream_in, source_ids):
|
||||
# Incoming messages should be the output of DATASOURCE_UNFRAME.
|
||||
assert_datasource_unframe_protocol(message), \
|
||||
"Bad message in date_sort: %s" % message
|
||||
|
||||
|
||||
# Only allow messages from sources we expect.
|
||||
assert message.source_id in sources, "Unexpected source: %s" % message
|
||||
|
||||
@@ -40,7 +44,7 @@ def date_sort(stream_in, source_ids):
|
||||
message = pop_oldest(sources)
|
||||
assert_sort_protocol(message)
|
||||
yield message
|
||||
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
from numbers import Number
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from math import sqrt
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.transform import EventWindow
|
||||
|
||||
class MovingStandardDev(object):
|
||||
"""
|
||||
Class that maintains a dicitonary from sids to
|
||||
MovingStandardDevWindows. For each sid, we maintain a the
|
||||
standard deviation of all events falling within the specified
|
||||
window.
|
||||
"""
|
||||
|
||||
def __init__(self, market_aware, days = None, delta = None):
|
||||
|
||||
self.market_aware = market_aware
|
||||
|
||||
self.delta = delta
|
||||
self.days = days
|
||||
|
||||
# Market-aware mode only works with full-day windows.
|
||||
if self.market_aware:
|
||||
assert self.days and not self.delta,\
|
||||
"Market-aware mode only works with full-day windows."
|
||||
|
||||
# Non-market-aware mode requires a timedelta.
|
||||
else:
|
||||
assert self.delta and not self.days, \
|
||||
"Non-market-aware mode requires a timedelta."
|
||||
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
def create_window(self):
|
||||
"""
|
||||
Factory method for self.sid_windows.
|
||||
"""
|
||||
return MovingStandardDevWindow(
|
||||
self.market_aware,
|
||||
self.days,
|
||||
self.delta
|
||||
)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Return an ndict
|
||||
from tracked fields to moving averages.
|
||||
"""
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
return window.get_stddev()
|
||||
|
||||
class MovingStandardDevWindow(EventWindow):
|
||||
"""
|
||||
Iteratively calculates standard deviation for a particular sid
|
||||
over a given time window. The expected functionality of this
|
||||
class is to be instantiated inside a MovingStandardDev.
|
||||
"""
|
||||
|
||||
def __init__(self, market_aware, days, delta):
|
||||
|
||||
# Call the superclass constructor to set up base EventWindow
|
||||
# infrastructure.
|
||||
EventWindow.__init__(self, market_aware, days, delta)
|
||||
|
||||
self.sum = 0.0
|
||||
self.sum_sqr = 0.0
|
||||
|
||||
def handle_add(self, event):
|
||||
assert event.has_key('price')
|
||||
assert isinstance(event.price, Number)
|
||||
|
||||
self.sum += event.price
|
||||
self.sum_sqr += event.price ** 2
|
||||
|
||||
def handle_remove(self, event):
|
||||
assert event.has_key('price')
|
||||
assert isinstance(event.price, Number)
|
||||
|
||||
self.sum -= event.price
|
||||
self.sum_sqr -= event.price ** 2
|
||||
|
||||
def get_stddev(self):
|
||||
|
||||
# Sample standard deviation is undefined for a single event or
|
||||
# no events.
|
||||
if len(self) <= 1:
|
||||
return None
|
||||
|
||||
else:
|
||||
average = self.sum /len(self)
|
||||
s_squared = (self.sum_sqr - self.sum*average) / (len(self) - 1)
|
||||
stddev = sqrt(s_squared)
|
||||
return stddev
|
||||
+35
-15
@@ -4,24 +4,23 @@ and zipline development
|
||||
"""
|
||||
import random
|
||||
import pytz
|
||||
from copy import copy
|
||||
|
||||
import pandas as pd
|
||||
from zipline import ndict
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
|
||||
from itertools import chain, cycle, ifilter, izip
|
||||
from itertools import chain, cycle, ifilter, izip, repeat
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.gens.utils import hash_args, create_trade
|
||||
|
||||
def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
|
||||
delta = timedelta(minutes = 1),
|
||||
count = 100):
|
||||
count = 100,
|
||||
repeats = None):
|
||||
"""
|
||||
Utility to generate a stream of dates.
|
||||
"""
|
||||
return (start + (i * delta) for i in xrange(count))
|
||||
if repeats:
|
||||
return (start + (i * delta) for i in xrange(count) for n in xrange(repeats))
|
||||
else:
|
||||
return (start + (i * delta) for i in xrange(count))
|
||||
|
||||
def mock_prices(count, rand = False):
|
||||
"""
|
||||
@@ -101,20 +100,41 @@ class SpecificEquityTrades(object):
|
||||
def get_hash(self):
|
||||
return self.__class__.__name__ + "-" + self.arg_string
|
||||
|
||||
def update_source_id(self, gen):
|
||||
for event in gen:
|
||||
event.source_id = self.get_hash()
|
||||
yield event
|
||||
|
||||
def create_fresh_generator(self):
|
||||
|
||||
if self.event_list:
|
||||
for event in self.event_list:
|
||||
event['source_id'] = self.get_hash()
|
||||
unfiltered = (event for event in self.event_list)
|
||||
event_gen = (event for event in self.event_list)
|
||||
unfiltered = self.update_source_id(event_gen)
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
if self.concurrent:
|
||||
# in this context the count is the number of
|
||||
# trades per sid, not the total.
|
||||
dates = date_gen(
|
||||
count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta,
|
||||
repeats=len(self.sids),
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
||||
dates = date_gen(
|
||||
count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
|
||||
sids = cycle(self.sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
|
||||
+114
-92
@@ -1,9 +1,12 @@
|
||||
import signal
|
||||
from logbook import Logger, Processor
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from numbers import Integral
|
||||
from itertools import groupby
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.utils.timeout import timeout, heartbeat, Timeout
|
||||
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
@@ -13,10 +16,15 @@ from zipline.gens.utils import hash_args
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
# TODO: make these arguments rather than global constants
|
||||
INIT_TIMEOUT = 5
|
||||
HEARTBEAT_INTERVAL = 1 # seconds
|
||||
MAX_HEARTBEAT_INTERVALS = 15 #count
|
||||
|
||||
class TradeSimulationClient(object):
|
||||
"""
|
||||
Generator that takes the expected output of a merge, a user
|
||||
algorithm, a trading environment, and a simulator style as
|
||||
Generator-style class that takes the expected output of a merge, a
|
||||
user algorithm, a trading environment, and a simulator style as
|
||||
arguments. Pipes the merge stream through a TransactionSimulator
|
||||
and a PerformanceTracker, which keep track of the current state of
|
||||
our algorithm's simulated universe. Results are fed to the user's
|
||||
@@ -24,7 +32,7 @@ class TradeSimulationClient(object):
|
||||
TransactionSimulator's order book.
|
||||
|
||||
TransactionSimulator maintains a dictionary from sids to the
|
||||
unfulfilled orders placed by the user's algorithm. As trade
|
||||
as-yet unfilled orders placed by the user's algorithm. As trade
|
||||
events arrive, if the algorithm has open orders against the
|
||||
trade's sid, the simulator will fill orders up to 25% of market
|
||||
cap. Applied transactions are added to a txn field on the event
|
||||
@@ -40,9 +48,9 @@ class TradeSimulationClient(object):
|
||||
performance report, which is appended to event's perf_report
|
||||
field.
|
||||
|
||||
Fully processed events are run through a batcher generator, which
|
||||
batches together events with the same dt field into a single event
|
||||
to be fed to the algo. The portfolio object is repeatedly
|
||||
Fully processed events are fed to AlgorithmSimulator, which
|
||||
batches together events with the same dt field into a single
|
||||
snapshot to be fed to the algo. The portfolio object is repeatedly
|
||||
overwritten so that only the most recent snapshot of the universe
|
||||
is sent to the algo.
|
||||
"""
|
||||
@@ -55,9 +63,13 @@ class TradeSimulationClient(object):
|
||||
self.style = sim_style
|
||||
self.algo_sim = None
|
||||
|
||||
self.warmup_start = self.environment.prior_day_open
|
||||
self.algo_start = self.environment.first_open
|
||||
|
||||
def get_hash(self):
|
||||
"""
|
||||
There should only ever be one TSC in the system.
|
||||
There should only ever be one TSC in the system, so
|
||||
we don't bother passing args into the hash.
|
||||
"""
|
||||
return self.__class__.__name__ + hash_args()
|
||||
|
||||
@@ -89,25 +101,31 @@ class TradeSimulationClient(object):
|
||||
with_portfolio = perf_tracker.transform(with_filled_orders)
|
||||
|
||||
# Pass the messages from perf along with the trading client's
|
||||
# state into the algorithm for simulation. We provide the
|
||||
# trading client so that the algorithm can place new orders
|
||||
# into the client's order book.
|
||||
# state into the algorithm for simulation. We provide a
|
||||
# pointer to the ordering client's internal state so that the
|
||||
# algorithm can place new orders into the client's order book.
|
||||
self.algo_sim = AlgorithmSimulator(
|
||||
with_portfolio,
|
||||
ordering_client.state,
|
||||
self.algo,
|
||||
self.algo_start
|
||||
)
|
||||
|
||||
# The algorithm will yield a daily_results message (as
|
||||
# calculated by the performance tracker) at the end of each
|
||||
# day. It will also yield a risk report at the end of the
|
||||
# simulation.
|
||||
|
||||
for message in self.algo_sim:
|
||||
yield message
|
||||
|
||||
class AlgorithmSimulator(object):
|
||||
|
||||
def __init__(self, stream_in, order_book, algo):
|
||||
def __init__(self,
|
||||
stream_in,
|
||||
order_book,
|
||||
algo,
|
||||
algo_start):
|
||||
|
||||
self.stream_in = stream_in
|
||||
|
||||
@@ -121,12 +139,28 @@ class AlgorithmSimulator(object):
|
||||
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
self.algo_start = algo_start
|
||||
|
||||
# Monkey patch the user algorithm to place orders in the
|
||||
# TransactionSimulator's order book.
|
||||
# TransactionSimulator's order book and use our logger.
|
||||
self.algo.set_order(self.order)
|
||||
self.algo.set_logger(Logger("AlgoLog"))
|
||||
self.algolog = Logger("AlgoLog")
|
||||
self.algo.set_logger(self.algolog)
|
||||
|
||||
# Handler for heartbeats during calls to handle_data.
|
||||
def log_heartbeats(beat_count, stackframe):
|
||||
t = beat_count * HEARTBEAT_INTERVAL
|
||||
warning = "handle_data has been processing for %i seconds" %t
|
||||
self.algolog.warn(warning)
|
||||
|
||||
# Context manager that calls log_heartbeats every HEARTBEAT_INTERVAL
|
||||
# seconds, raising an exception after MAX_HEARTBEATS
|
||||
self.heartbeat_monitor = heartbeat(
|
||||
HEARTBEAT_INTERVAL,
|
||||
MAX_HEARTBEAT_INTERVALS,
|
||||
frame_handler=log_heartbeats,
|
||||
timeout_message="Too much time spent in handle_data call"
|
||||
)
|
||||
|
||||
# ==============
|
||||
# Snapshot Setup
|
||||
@@ -142,7 +176,7 @@ class AlgorithmSimulator(object):
|
||||
# We don't have a datetime for the current snapshot until we
|
||||
# receive a message.
|
||||
self.simulation_dt = None
|
||||
self.this_snapshot_dt = None
|
||||
self.snapshot_dt = None
|
||||
|
||||
# =============
|
||||
# Logging Setup
|
||||
@@ -151,7 +185,7 @@ class AlgorithmSimulator(object):
|
||||
# Processor function for injecting the algo_dt into
|
||||
# user prints/logs.
|
||||
def inject_algo_dt(record):
|
||||
record.extra['algo_dt'] = self.this_snapshot_dt
|
||||
record.extra['algo_dt'] = self.snapshot_dt
|
||||
self.processor = Processor(inject_algo_dt)
|
||||
|
||||
# This is a class, which is instantiated later
|
||||
@@ -206,94 +240,62 @@ class AlgorithmSimulator(object):
|
||||
# Capture any output of this generator to stdout and pipe it
|
||||
# to a logbook interface. Also inject the current algo
|
||||
# snapshot time to any log record generated.
|
||||
|
||||
with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
|
||||
# Call the user's initialize method.
|
||||
self.algo.initialize()
|
||||
|
||||
for event in self.stream_in:
|
||||
# Yield any perf messages received to be relayed back to
|
||||
# the browser.
|
||||
# Call user's initialize method with a timeout.
|
||||
with timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
|
||||
self.algo.initialize()
|
||||
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
# Group together events with the same dt field. This depends on the
|
||||
# events already being sorted.
|
||||
for date, snapshot in groupby(self.stream_in, lambda e: e.dt):
|
||||
|
||||
if event.dt == "DONE":
|
||||
if self.this_snapshot_dt:
|
||||
# stop iteration happened
|
||||
# mid-snapshot, so we have a universe
|
||||
# snapshot that is not yet processed
|
||||
# by the algorithm.
|
||||
self.simulate_current_snapshot()
|
||||
break
|
||||
|
||||
# This should only happen for the first event we run.
|
||||
# Set the simulation date to be the first event we see.
|
||||
# This should only occur once, at the start of the test.
|
||||
if self.simulation_dt == None:
|
||||
self.simulation_dt = event.dt
|
||||
self.simulation_dt = date
|
||||
|
||||
# ======================
|
||||
# Time Compression Logic
|
||||
# ======================
|
||||
# Done message has the risk report, so we yield before exiting.
|
||||
if date == 'DONE':
|
||||
for event in snapshot:
|
||||
yield event.perf_message
|
||||
break
|
||||
|
||||
if self.this_snapshot_dt != None:
|
||||
self.update_current_snapshot(event)
|
||||
# We're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
# and don't send a snapshot to handle_data.
|
||||
elif date < self.algo_start:
|
||||
for event in snapshot:
|
||||
del event['perf_message']
|
||||
self.update_universe(event)
|
||||
|
||||
# The algorithm has been missing events because it took
|
||||
# too long processing. Update the universe with data from
|
||||
# this event, then check if enough time has passed that we
|
||||
# can start a new snapshot.
|
||||
# The algo has taken so long to process events that
|
||||
# its simulated time is later than the event time.
|
||||
# Update the universe and yield any perf messages
|
||||
# encountered, but don't call handle_data.
|
||||
elif date < self.simulation_dt:
|
||||
for event in snapshot:
|
||||
# Only yield if we have something interesting to say.
|
||||
if event.perf_message != None:
|
||||
yield event.perf_message
|
||||
# Delete the message before updating so we don't send it
|
||||
# to the user.
|
||||
del event['perf_message']
|
||||
self.update_universe(event)
|
||||
|
||||
# Regular snapshot. Update the universe and send a snapshot
|
||||
# to handle data.
|
||||
else:
|
||||
self.update_universe(event)
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
for event in snapshot:
|
||||
# Only yield if we have something interesting to say.
|
||||
if event.perf_message != None:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
|
||||
self.update_universe(event)
|
||||
|
||||
|
||||
def update_current_snapshot(self, event):
|
||||
"""
|
||||
Update our current snapshot of the universe. Call handle_data if
|
||||
"""
|
||||
# The new event matches our snapshot dt. Just update the
|
||||
# universe and move on.
|
||||
if event.dt == self.this_snapshot_dt:
|
||||
self.update_universe(event)
|
||||
|
||||
# The new event does not match our snapshot.
|
||||
else:
|
||||
self.simulate_current_snapshot()
|
||||
|
||||
# Once we've finished simulating the old snapshot,
|
||||
# we can update the universe with the new event.
|
||||
self.update_universe(event)
|
||||
|
||||
# The current event is later than the simulation time,
|
||||
# which means the algorithm finished quickly enough to
|
||||
# receive the new event. Start a new snapshot with this
|
||||
# event's dt.
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
|
||||
# The algorithm spent enough time processing that it
|
||||
# missed the new event. Wait to start a new snapshot until
|
||||
# the events catch up to the algo's simulated dt.
|
||||
else:
|
||||
self.this_snapshot_dt = None
|
||||
|
||||
def simulate_current_snapshot(self):
|
||||
"""
|
||||
Run the user's algo against our current snapshot and update the algo's
|
||||
simulated time.
|
||||
"""
|
||||
start_tic = datetime.now()
|
||||
self.algo.handle_data(self.universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
# How long did you take?
|
||||
delta = stop_tic - start_tic
|
||||
|
||||
# Update the simulation time.
|
||||
self.simulation_dt = self.this_snapshot_dt + delta
|
||||
# Send the current state of the universe to the user's algo.
|
||||
self.simulate_snapshot(date)
|
||||
|
||||
def update_universe(self, event):
|
||||
"""
|
||||
@@ -305,3 +307,23 @@ class AlgorithmSimulator(object):
|
||||
# Update our knowledge of this event's sid
|
||||
for field in event.keys():
|
||||
self.universe[event.sid][field] = event[field]
|
||||
|
||||
def simulate_snapshot(self, date):
|
||||
"""
|
||||
Run the user's algo against our current snapshot and update
|
||||
the algo's simulated time.
|
||||
"""
|
||||
# Needs to be set so that we inject the proper date into algo
|
||||
# log/print lines.
|
||||
self.snapshot_dt = date
|
||||
|
||||
start_tic = datetime.now()
|
||||
with self.heartbeat_monitor:
|
||||
self.algo.handle_data(self.universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
# How long did you take?
|
||||
delta = stop_tic - start_tic
|
||||
|
||||
# Update the simulation time.
|
||||
self.simulation_dt = date + delta
|
||||
|
||||
@@ -3,6 +3,7 @@ Generator versions of transforms.
|
||||
"""
|
||||
import types
|
||||
import pytz
|
||||
import logbook
|
||||
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
@@ -15,6 +16,8 @@ from zipline.utils.tradingcalendar import trading_days_between
|
||||
from zipline.gens.utils import assert_sort_unframe_protocol, \
|
||||
assert_transform_protocol, hash_args
|
||||
|
||||
log = logbook.Logger('Transform')
|
||||
|
||||
class Passthrough(object):
|
||||
FORWARDER = True
|
||||
"""
|
||||
@@ -72,6 +75,7 @@ class StatefulTransform(object):
|
||||
|
||||
# Create the string associated with this generator's output.
|
||||
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
log.info('StatefulTransform [%s] initialized' % self.namestring)
|
||||
|
||||
def get_hash(self):
|
||||
return self.namestring
|
||||
@@ -82,7 +86,7 @@ class StatefulTransform(object):
|
||||
def _gen(self, stream_in):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
|
||||
log.info('Running StatefulTransform [%s]' % self.get_hash())
|
||||
for message in stream_in:
|
||||
|
||||
# allow upstream generators to yield None to avoid
|
||||
@@ -101,7 +105,8 @@ class StatefulTransform(object):
|
||||
# FORWARDER flag means we want to keep all original
|
||||
# values, plus append tnfm_id and tnfm_value. Used for
|
||||
# preserving the original event fields when our output
|
||||
# will be fed into a merge.
|
||||
# will be fed into a merge. Currently only Passthrough
|
||||
# uses this flag.
|
||||
if self.forward_all:
|
||||
out_message = message_copy
|
||||
out_message.tnfm_id = self.namestring
|
||||
@@ -143,6 +148,7 @@ class StatefulTransform(object):
|
||||
out_message.dt = message_copy.dt
|
||||
yield out_message
|
||||
|
||||
log.info('Finished StatefulTransform [%s]' % self.get_hash())
|
||||
class EventWindow:
|
||||
"""
|
||||
Abstract base class for transform classes that calculate iterative
|
||||
@@ -153,8 +159,9 @@ class EventWindow:
|
||||
from the window. Subclass these methods along with init(*args,
|
||||
**kwargs) to calculate metrics over the window.
|
||||
|
||||
The market_aware flag is used to toggle whether the eventwindow
|
||||
calculates
|
||||
If the market_aware flag is True, the EventWindow drops old events
|
||||
based on the number of elapsed trading days between newest and oldest.
|
||||
Otherwise old events are dropped based on a raw timedelta.
|
||||
|
||||
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
|
||||
implementations of moving average and volume-weighted average
|
||||
|
||||
@@ -67,12 +67,17 @@ def hash_args(*args, **kwargs):
|
||||
return hasher.hexdigest()
|
||||
|
||||
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
|
||||
|
||||
row = ndict({
|
||||
'source_id' : source_id,
|
||||
'type' : DATASOURCE_TYPE.TRADE,
|
||||
'sid' : sid,
|
||||
'dt' : datetime,
|
||||
'price' : price,
|
||||
'close' : price,
|
||||
'open' : price,
|
||||
'low' : price * .95,
|
||||
'high' : price * 1.05,
|
||||
'volume' : amount
|
||||
})
|
||||
return row
|
||||
|
||||
+19
-122
@@ -63,30 +63,20 @@ import sys
|
||||
import zmq
|
||||
import os
|
||||
from signal import SIGHUP, SIGINT
|
||||
import datetime
|
||||
import pytz
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
import multiprocessing
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler
|
||||
from zipline.utils import factory
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline import ndict
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
from zipline.gens.composites import \
|
||||
date_sorted_sources, merged_transforms, sequential_transforms
|
||||
from zipline.gens.transform import Passthrough, StatefulTransform
|
||||
from zipline.gens.composites import (
|
||||
date_sorted_sources,
|
||||
sequential_transforms
|
||||
)
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
from logbook import Logger, NestedSetup, Processor
|
||||
from logbook import Logger
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
@@ -181,6 +171,8 @@ class SimulatedTrading(object):
|
||||
|
||||
def close(self):
|
||||
log.info("Closing Simulation: {id}".format(id=self.sim_id))
|
||||
if self.results_socket:
|
||||
self.results_socket.close()
|
||||
if self.proc and self.send_sighup:
|
||||
ppid = os.getppid()
|
||||
if self.success:
|
||||
@@ -253,17 +245,10 @@ class SimulatedTrading(object):
|
||||
else:
|
||||
return []
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
return self.gen.next()
|
||||
|
||||
@staticmethod
|
||||
def create_test_zipline(**config):
|
||||
"""
|
||||
:param config: A configuration object that is a dict with
|
||||
(all optional):
|
||||
:param config: A configuration object that is a dict with:
|
||||
|
||||
- environment - a \
|
||||
:py:class:`zipline.finance.trading.TradingEnvironment`
|
||||
@@ -285,7 +270,12 @@ class SimulatedTrading(object):
|
||||
of StatefulTransform objects.
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
sid = config.get('sid', 133)
|
||||
sid_list = config.get('sid_list')
|
||||
if not sid_list:
|
||||
sid = config.get('sid')
|
||||
sid_list = [sid]
|
||||
|
||||
concurrent_trades = config.get('concurrent_trades', False)
|
||||
|
||||
#--------------------
|
||||
# Trading Environment
|
||||
@@ -329,11 +319,13 @@ class SimulatedTrading(object):
|
||||
trade_source = config['trade_source']
|
||||
else:
|
||||
trade_source = factory.create_daily_trade_source(
|
||||
sids,
|
||||
sid_list,
|
||||
trade_count,
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent_trades
|
||||
)
|
||||
|
||||
|
||||
#-------------------
|
||||
# Transforms
|
||||
#-------------------
|
||||
@@ -367,98 +359,3 @@ class SimulatedTrading(object):
|
||||
#-------------------
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def create_sp_source(start_dt=None, end_dt=None):
|
||||
if start_dt is None:
|
||||
start_dt = datetime.datetime(2002, 1, 1, tzinfo=pytz.utc)
|
||||
if end_dt is None:
|
||||
end_dt = datetime.datetime(2008, 1, 1, tzinfo=pytz.utc)
|
||||
|
||||
sp_events, _ = factory.load_market_data()
|
||||
sp_transformed = []
|
||||
for event in sp_events:
|
||||
transformed = ndict(event.to_dict())
|
||||
if (transformed.dt < start_dt) or (transformed.dt > end_dt):
|
||||
continue
|
||||
transformed['sid'] = 0
|
||||
transformed['price'] = transformed['returns']
|
||||
transformed['type'] = DATASOURCE_TYPE.TRADE
|
||||
sp_transformed.append(transformed)
|
||||
|
||||
source = SpecificEquityTrades(event_list=sp_transformed)
|
||||
|
||||
return source
|
||||
|
||||
class Zipline(object):
|
||||
def __init__(self, **kwargs):
|
||||
algorithm = kwargs.get('algorithm', TestAlgorithm)
|
||||
source_descrs = kwargs.get('sources', ['S&P'])
|
||||
if isinstance(source_descrs, str):
|
||||
source_descrs = [source_descrs]
|
||||
|
||||
sources = []
|
||||
for source_descr in source_descrs:
|
||||
if isinstance(source_descr, str):
|
||||
if source_descr == 'S&P':
|
||||
source = create_sp_source()
|
||||
else:
|
||||
raise NotImplementedError, "Source with name {source_descr} not known.".format(source_descr=source_descr)
|
||||
else:
|
||||
source = source_descr
|
||||
|
||||
sources.append(source)
|
||||
|
||||
environment = kwargs.get('environment', create_trading_environment())
|
||||
|
||||
try:
|
||||
transform_descrs = kwargs.get('transforms', algorithm.registered_transforms)
|
||||
except:
|
||||
print "Couldn't load any registered_transforms."
|
||||
transform_descrs = {}
|
||||
|
||||
# Create transforms by wrapping them into StatefulTransforms
|
||||
transforms = []
|
||||
for namestring, trans_descr in transform_descrs.iteritems():
|
||||
sf = StatefulTransform(
|
||||
trans_descr['class'],
|
||||
*trans_descr['args'],
|
||||
**trans_descr['kwargs']
|
||||
)
|
||||
sf.namestring = namestring
|
||||
|
||||
transforms.append(sf)
|
||||
|
||||
results_socket_uri = None
|
||||
context = None
|
||||
sim_id = None
|
||||
style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
self.simulated_trading = SimulatedTrading(
|
||||
sources,
|
||||
transforms,
|
||||
algorithm,
|
||||
environment,
|
||||
style,
|
||||
results_socket_uri,
|
||||
context,
|
||||
sim_id)
|
||||
|
||||
|
||||
def run(self):
|
||||
# drain simulated_trading
|
||||
perfs = [perf for perf in self.simulated_trading]
|
||||
|
||||
# create daily stats dataframe
|
||||
daily_perfs = []
|
||||
cum_perfs = []
|
||||
for perf in perfs:
|
||||
if 'daily_perf' in perf:
|
||||
daily_perfs.append(perf['daily_perf'])
|
||||
else:
|
||||
cum_perfs.append(perf)
|
||||
|
||||
daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs]
|
||||
daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
|
||||
|
||||
return daily_stats
|
||||
|
||||
+135
-25
@@ -1,5 +1,8 @@
|
||||
from zipline.lines import Zipline
|
||||
import pandas as pd
|
||||
import pandas.io.data as dt
|
||||
from pandas.io.data import DataReader
|
||||
|
||||
import numpy as np
|
||||
#from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
@@ -9,24 +12,28 @@ from zipline.gens.mavg import MovingAverage
|
||||
from zipline.optimize.algorithms import TradingAlgorithm
|
||||
from datetime import timedelta
|
||||
|
||||
from mpi4py_map import map
|
||||
#from mpi4py_map import map
|
||||
|
||||
# Inherits from Algorithm base class
|
||||
class DMA(TradingAlgorithm):
|
||||
"""Dual Moving Average algorithm.
|
||||
"""
|
||||
def __init__(self, sid, amount=100, short_window=20, long_window=40):
|
||||
self.sids = [sid]
|
||||
def __init__(self, sids, amount=100, short_window=20, long_window=40):
|
||||
self.sids = sids
|
||||
self.amount = amount
|
||||
self.done = False
|
||||
self.order = None
|
||||
self.frame_count = 0
|
||||
self.portfolio = None
|
||||
self.orders = []
|
||||
self.market_entered = False
|
||||
|
||||
self.prices = []
|
||||
self.events = 0
|
||||
|
||||
self.invested = {}
|
||||
for sid in self.sids:
|
||||
self.invested[sid] = False
|
||||
|
||||
self.add_transform(MovingAverage, 'short_mavg', ['price'],
|
||||
market_aware=False,
|
||||
delta=timedelta(days=int(short_window)))
|
||||
@@ -37,36 +44,139 @@ class DMA(TradingAlgorithm):
|
||||
|
||||
def handle_data(self, data):
|
||||
self.events += 1
|
||||
sid = self.sids[0]
|
||||
# access transforms via their user-defined tag
|
||||
if (data[sid].short_mavg > data[sid].long_mavg) and not self.market_entered:
|
||||
self.order(sid, 100)
|
||||
self.market_entered = True
|
||||
elif (data[sid].short_mavg < data[sid].long_mavg) and self.market_entered:
|
||||
self.order(sid, -100)
|
||||
self.market_entered = False
|
||||
|
||||
for sid in self.sids:
|
||||
# access transforms via their user-defined tag
|
||||
if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]:
|
||||
self.order(sid, self.amount)
|
||||
self.invested[sid] = True
|
||||
elif (data[sid].short_mavg['price'] < data[sid].long_mavg['price']) and self.invested[sid]:
|
||||
self.order(sid, -self.amount)
|
||||
self.invested[sid] = False
|
||||
|
||||
|
||||
class DanVWAP(TradingAlgorithm):
|
||||
"""Dual Moving Average algorithm.
|
||||
"""
|
||||
def __init__(self, sids, amount=100, short_window=20, long_window=40):
|
||||
self.sids = sids
|
||||
self.amount = amount
|
||||
self.done = False
|
||||
self.order = None
|
||||
self.frame_count = 0
|
||||
self.portfolio = None
|
||||
self.orders = []
|
||||
|
||||
self.prices = []
|
||||
self.port = 0
|
||||
|
||||
self.add_transform(MovingAverage, 'short_mavg', ['price'],
|
||||
market_aware=False,
|
||||
delta=timedelta(days=int(short_window)))
|
||||
|
||||
self.add_transform(MovingAverage, 'long_mavg', ['price'],
|
||||
market_aware=False,
|
||||
delta=timedelta(days=int(long_window)))
|
||||
|
||||
def handle_data(self, data):
|
||||
for sid in self.sids:
|
||||
average=data[sid].vwap(5)
|
||||
price=data[sid].price
|
||||
|
||||
if price>average*1.05:
|
||||
self.order(sid, self.amount)
|
||||
|
||||
|
||||
def load_close_px(indexes=None, stocks=None):
|
||||
if indexes is None:
|
||||
indexes = {'SPX' : '^GSPC'}
|
||||
if stocks is None:
|
||||
stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
|
||||
|
||||
start = pd.datetime(1990, 1, 1)
|
||||
end = pd.datetime.today()
|
||||
|
||||
data = {}
|
||||
for stock in stocks:
|
||||
print stock
|
||||
stkd = DataReader(stock, 'yahoo', start, end).sort_index()
|
||||
data[stock] = stkd
|
||||
|
||||
for name, ticker in indexes.iteritems():
|
||||
print name
|
||||
stkd = DataReader(ticker, 'yahoo', start, end).sort_index()
|
||||
data[name] = stkd
|
||||
|
||||
df = pd.DataFrame({key: d['Close'] for key, d in data.iteritems()})
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def run((short_window, long_window)):
|
||||
data = pd.DataFrame.from_csv('SP500.csv')
|
||||
myalgo = DMA(sid=0, amount=100, short_window=short_window, long_window=long_window)
|
||||
myalgo = DMA([0], amount=100, short_window=short_window, long_window=long_window)
|
||||
stats = myalgo.run(data, compute_risk_metrics=False)
|
||||
stats['sw'] = short_window
|
||||
stats['lw'] = long_window
|
||||
return stats
|
||||
|
||||
sws, lws = np.mgrid[50:80:5, 100:140:5]
|
||||
def explore_params():
|
||||
sws, lws = np.mgrid[10:20:5, 10:20:5]
|
||||
|
||||
stats_all = map(run, zip(sws.flatten(), lws.flatten()))
|
||||
stats_all = map(run, zip(sws.flatten(), lws.flatten()))
|
||||
stats = pd.concat(stats_all)
|
||||
returns = stats.groupby(['sw', 'lw']).sum()
|
||||
|
||||
# for sw, lw in zip(sws.flatten(), lws.flatten()):
|
||||
# stats = run(short_window=sw, long_window=lw)
|
||||
# stats_all.append(stats)
|
||||
plt.contourf(sws, lws, returns.returns.reshape(sws.shape))
|
||||
plt.xlabel('Short window length')
|
||||
plt.ylabel('Long window length')
|
||||
plt.savefig('DMA_contour.png')
|
||||
plt.show()
|
||||
|
||||
stats = pd.concat(stats_all)
|
||||
returns = stats.groupby(['sw', 'lw']).sum()
|
||||
plt.contourf(sws, lws, returns.returns.reshape(sws.shape))
|
||||
plt.xlabel('Short window length')
|
||||
plt.ylabel('Long window length')
|
||||
plt.savefig('DMA_contour.png')
|
||||
plt.show()
|
||||
#stats = run((10, 50))
|
||||
|
||||
def get_opt_holdings_qp(univ_rets, track_rets):
|
||||
from cvxopt import matrix
|
||||
from cvxopt.solvers import qp
|
||||
# set up the QP for CVXOPT
|
||||
# .5 x' P x + q'x
|
||||
# P = 2 * R'R
|
||||
# q = - 2 * bmk'R
|
||||
R = univ_rets.values
|
||||
b = track_rets.values
|
||||
P = matrix(2 * np.dot(R.T, R))
|
||||
q = matrix(-2 * np.dot(R.T, b))
|
||||
result = qp(P, q)
|
||||
if result['status'] != 'optimal':
|
||||
raise Exception('optimum not reached by QP')
|
||||
return pd.Series(np.array(result['x']).ravel(), index=univ_rets.columns)
|
||||
|
||||
def opt_portfolio(cov, budget, min_return):
|
||||
from cvxopt import matrix
|
||||
from cvxopt.solvers import qp
|
||||
n = len(cov)
|
||||
cov = matrix(2 * cov)
|
||||
q = matrix(np.zeros(n))
|
||||
|
||||
h = matrix(budget) # G*x < h
|
||||
# coneqp
|
||||
result = qp(cov, q, h=h)
|
||||
if result['status'] != 'optimal':
|
||||
raise Exception('optimum not reached by QP')
|
||||
|
||||
return pd.Series(np.array(result['x']).ravel())
|
||||
|
||||
def calc_te(weights, univ_rets, track_rets):
|
||||
port_rets = (univ_rets * weights).sum(1)
|
||||
return (port_rets - track_rets).std()
|
||||
|
||||
def plot_returns(port_returns, bmk_returns):
|
||||
plt.figure()
|
||||
cum_port = ((1 + port_returns).cumprod() - 1)
|
||||
cum_bmk = ((1 + bmk_returns).cumprod() - 1)
|
||||
# cum_port = port_returns.cumsum()
|
||||
# cum_bmk = bmk_returns.cumsum()
|
||||
cum_port.plot(label='Portfolio returns')
|
||||
cum_bmk.plot(label='Benchmark')
|
||||
plt.title('Portfolio performance')
|
||||
plt.legend(loc='best')
|
||||
|
||||
@@ -133,6 +133,6 @@ def create_predictable_zipline(config, offset=0, simulate=True):
|
||||
zipline = SimulatedTrading.create_test_zipline(**config)
|
||||
|
||||
if simulate:
|
||||
zipline.simulate(blocking=True)
|
||||
zipline.drain_zipline(blocking=True)
|
||||
|
||||
return zipline, config
|
||||
|
||||
+22
-3
@@ -132,6 +132,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now
|
||||
# -----------------------
|
||||
|
||||
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG']
|
||||
PRICE_FIELDS = ['price', 'open', 'close', 'high', 'low']
|
||||
|
||||
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
|
||||
|
||||
@@ -428,21 +429,26 @@ def TRADE_FRAME(event):
|
||||
assert isinstance(event, ndict)
|
||||
assert event.type == DATASOURCE_TYPE.TRADE
|
||||
assert isinstance(event.sid, int)
|
||||
assert isinstance(event.price, numbers.Real)
|
||||
for field in PRICE_FIELDS:
|
||||
assert isinstance(event[field], numbers.Real)
|
||||
assert isinstance(event.volume, numbers.Integral)
|
||||
PACK_DATE(event)
|
||||
return msgpack.dumps(tuple([
|
||||
event.sid,
|
||||
event.price,
|
||||
event.open,
|
||||
event.close,
|
||||
event.high,
|
||||
event.low,
|
||||
event.volume,
|
||||
event.dt,
|
||||
event.type,
|
||||
event.type
|
||||
]))
|
||||
|
||||
def TRADE_UNFRAME(msg):
|
||||
try:
|
||||
packed = msgpack.loads(msg)
|
||||
sid, price, volume, dt, source_type = packed
|
||||
sid, price, open, close, high, low, volume, dt, source_type = packed
|
||||
|
||||
assert isinstance(sid, int)
|
||||
assert isinstance(price, numbers.Real)
|
||||
@@ -450,6 +456,10 @@ def TRADE_UNFRAME(msg):
|
||||
rval = ndict({
|
||||
'sid' : sid,
|
||||
'price' : price,
|
||||
'open' : open,
|
||||
'close' : close,
|
||||
'high' : high,
|
||||
'low' : low,
|
||||
'volume' : volume,
|
||||
'dt' : dt,
|
||||
'type' : source_type
|
||||
@@ -654,7 +664,13 @@ def tuple_to_date(date_tuple):
|
||||
dt = dt.replace(microsecond = micros, tzinfo = pytz.utc)
|
||||
return dt
|
||||
|
||||
# Datasource type should completely determine the other fields of a
|
||||
# message with its type.
|
||||
DATASOURCE_TYPE = Enum(
|
||||
'AS_TRADED_EQUITY',
|
||||
'MERGER',
|
||||
'SPLIT',
|
||||
'DIVIDEND',
|
||||
'TRADE',
|
||||
'EMPTY',
|
||||
'DONE'
|
||||
@@ -720,6 +736,9 @@ def LOG_FRAME(payload):
|
||||
assert payload.has_key('msg'),\
|
||||
"LOG_FRAME with no message"
|
||||
|
||||
# truncation will only work with strings and msgpack will
|
||||
# preserve primitives.
|
||||
payload['msg'] = str(payload['msg'])
|
||||
|
||||
return BT_UPDATE_FRAME('LOG', payload)
|
||||
|
||||
|
||||
@@ -237,6 +237,56 @@ class DivByZeroAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
class InitializeTimeoutAlgorithm():
|
||||
def __init__(self, sid):
|
||||
self.sid = sid
|
||||
self.incr = 0
|
||||
|
||||
def initialize(self):
|
||||
import time
|
||||
from zipline.gens.tradesimulation import INIT_TIMEOUT
|
||||
time.sleep(INIT_TIMEOUT + 1)
|
||||
|
||||
def set_order(self, order_callable):
|
||||
pass
|
||||
|
||||
def set_logger(self, logger):
|
||||
pass
|
||||
|
||||
def set_portfolio(self, portfolio):
|
||||
pass
|
||||
|
||||
def handle_data(self, data):
|
||||
pass
|
||||
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
class TooMuchProcessingAlgorithm():
|
||||
def __init__(self, sid):
|
||||
self.sid = sid
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def set_order(self, order_callable):
|
||||
pass
|
||||
|
||||
def set_logger(self, logger):
|
||||
pass
|
||||
|
||||
def set_portfolio(self, portfolio):
|
||||
pass
|
||||
|
||||
def handle_data(self, data):
|
||||
# Unless we're running on some sort of
|
||||
# supercomputer this will hit timeout.
|
||||
for i in xrange(1000000000):
|
||||
self.foo = i
|
||||
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
class TimeoutAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from functools import wraps
|
||||
from signal import signal
|
||||
|
||||
class delayed_signals(object):
|
||||
"""
|
||||
Utility to temporary intercept one or more signals while a function or code
|
||||
block is executed, restore their signal handlers at the end of execution,
|
||||
and invoke them if the signals were in fact received during execution.
|
||||
|
||||
Can be used either as a decorator or a context manager.
|
||||
|
||||
Pass in an iterable of signals to intercept.
|
||||
"""
|
||||
|
||||
def handler(self, signum, frame=None):
|
||||
self.got.append({'signum': signum, 'frame': frame})
|
||||
|
||||
def __init__(self, signals):
|
||||
self.signals = signals
|
||||
self.handlers = {}
|
||||
self.got = []
|
||||
|
||||
def __enter__(self):
|
||||
for signum in self.signals:
|
||||
# signal() returns the old signal handler
|
||||
self.handlers[signum] = signal(signum, self.handler)
|
||||
|
||||
def __exit__(self, time, value, traceback):
|
||||
for signum, handler in self.handlers.items():
|
||||
signal(signum, handler)
|
||||
for signum, frame in ((i['signum'], i['frame']) for i in self.got):
|
||||
self.handlers[signum](signum, frame)
|
||||
|
||||
def __call__(self, fn):
|
||||
@wraps(fn)
|
||||
def call_fn(*args, **kwargs):
|
||||
with self:
|
||||
outval = fn(*args, **kwargs)
|
||||
return outval
|
||||
return call_fn
|
||||
@@ -174,7 +174,7 @@ def create_random_trade_source(sid, trade_count, trading_environment):
|
||||
|
||||
return source
|
||||
|
||||
def create_daily_trade_source(sids, trade_count, trading_environment):
|
||||
def create_daily_trade_source(sids, trade_count, trading_environment, concurrent=False):
|
||||
|
||||
"""
|
||||
creates trade_count trades for each sid in sids list.
|
||||
@@ -189,11 +189,12 @@ def create_daily_trade_source(sids, trade_count, trading_environment):
|
||||
sids,
|
||||
trade_count,
|
||||
timedelta(days=1),
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent
|
||||
)
|
||||
|
||||
|
||||
def create_minutely_trade_source(sids, trade_count, trading_environment):
|
||||
def create_minutely_trade_source(sids, trade_count, trading_environment, concurrent=False):
|
||||
|
||||
"""
|
||||
creates trade_count trades for each sid in sids list.
|
||||
@@ -208,10 +209,11 @@ def create_minutely_trade_source(sids, trade_count, trading_environment):
|
||||
sids,
|
||||
trade_count,
|
||||
timedelta(minutes=1),
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent
|
||||
)
|
||||
|
||||
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment):
|
||||
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment, concurrent=False):
|
||||
|
||||
args = tuple()
|
||||
kwargs = {
|
||||
@@ -219,7 +221,8 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ
|
||||
'sids' : sids,
|
||||
'start' : trading_environment.first_open,
|
||||
'delta' : trade_time_increment,
|
||||
'filter' : sids
|
||||
'filter' : sids,
|
||||
'concurrent' : concurrent
|
||||
}
|
||||
source = SpecificEquityTrades(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class ZeroMQLogHandler(Handler):
|
||||
def __init__(self, socket=None, level=NOTSET, filter=None, bubble=False,
|
||||
context=None, fds = LOG_FIELDS, extra_fds = LOG_EXTRA_FIELDS):
|
||||
Handler.__init__(self, level, filter, bubble)
|
||||
|
||||
try:
|
||||
import zmq
|
||||
except ImportError:
|
||||
|
||||
@@ -15,6 +15,7 @@ def setup_logger(test, path='/var/log/zipline/zipline.log'):
|
||||
|
||||
def teardown_logger(test):
|
||||
test.log_handler.pop_application()
|
||||
test.log_handler.close()
|
||||
|
||||
def check_list(test, a, b, label):
|
||||
test.assertTrue(isinstance(a, (list, blist.blist)))
|
||||
@@ -91,11 +92,13 @@ def create_receiver(socket_addr, ctx):
|
||||
|
||||
return receiver
|
||||
|
||||
def drain_receiver(receiver):
|
||||
def drain_receiver(receiver, count=None):
|
||||
output = []
|
||||
transaction_count = 0
|
||||
msg_counter = 0
|
||||
while True:
|
||||
msg = receiver.recv()
|
||||
msg_counter += 1
|
||||
update = zp.BT_UPDATE_UNFRAME(msg)
|
||||
output.append(update)
|
||||
if update['prefix'] == 'PERF':
|
||||
@@ -106,14 +109,18 @@ def drain_receiver(receiver):
|
||||
elif update['prefix'] == 'DONE':
|
||||
break
|
||||
|
||||
if count and msg_counter >= count:
|
||||
break
|
||||
|
||||
receiver.close()
|
||||
del receiver
|
||||
|
||||
return output, transaction_count
|
||||
|
||||
|
||||
def assert_single_position(test, zipline):
|
||||
output, transaction_count = drain_zipline(test, zipline)
|
||||
def assert_single_position(test, zipline, blocking=False):
|
||||
output, transaction_count = drain_zipline(test, zipline, p_blocking=blocking)
|
||||
test.assertEqual(output[-1]['prefix'], 'DONE')
|
||||
|
||||
test.assertEqual(
|
||||
test.zipline_test_config['order_count'],
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
import signal
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from pprint import pprint as pp
|
||||
from numbers import Number
|
||||
from logbook import Logger
|
||||
|
||||
class Timeout(Exception):
|
||||
|
||||
def __init__(self, frame, message=''):
|
||||
self.frame = frame
|
||||
self.message = message
|
||||
|
||||
class timeout(object):
|
||||
"""
|
||||
Utility to make a function raise TimeoutException if it spends
|
||||
more than a specified number of seconds executing. Can be used
|
||||
as a decorator to apply a static timeout to a function, or as
|
||||
a context manager to dynamically add a timeout to a code block.
|
||||
"""
|
||||
|
||||
def __init__(self, seconds, message=''):
|
||||
self.seconds = seconds
|
||||
self.message = message
|
||||
assert isinstance(seconds, Number), "Failed to specify a timeout."
|
||||
assert seconds > 0, "Timeout must be greater than 0"
|
||||
|
||||
def handler(self, signum, frame):
|
||||
raise Timeout(frame, self.message)
|
||||
|
||||
def __call__(self, fn):
|
||||
|
||||
@wraps(fn)
|
||||
def call_fn_with_timeout(*args, **kwargs):
|
||||
# Set the alarm.
|
||||
signal.signal(signal.SIGALRM, self.handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, self.seconds, 0)
|
||||
try:
|
||||
outval = fn(*args, **kwargs)
|
||||
|
||||
# Deactivate the alarm once we're done so that the
|
||||
# decorator doesn't have unexpected side-effects later.
|
||||
# Note that this will still raise Timeout if the
|
||||
# call to fn takes too long.
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, signal.SIG_DFL)
|
||||
|
||||
# Return the value of fn if it finished before the alarm. This
|
||||
# won't execute if the Timeout was raised.
|
||||
return outval
|
||||
return call_fn_with_timeout
|
||||
|
||||
def __enter__(self):
|
||||
# Set the alarm on entrance.
|
||||
signal.signal(signal.SIGALRM, self.handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, self.seconds, 0)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
# Deactivate the alarm on exit. This will re-raise
|
||||
# any exceptions raised inside the with block.
|
||||
signal.signal(signal.SIGALRM, self.handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
|
||||
class heartbeat(object):
|
||||
"""
|
||||
Utility to perform pseudo-heartbeat checks on a single-threaded
|
||||
function. Calls frame_handler on the current stack frame of the
|
||||
wrapped function every ``interval`` seconds. After ``max_interval``
|
||||
intervals, raises Timeout. Can be used either as a decorator or
|
||||
a context manager.
|
||||
"""
|
||||
def __init__(self,
|
||||
interval,
|
||||
max_intervals,
|
||||
frame_handler=None,
|
||||
timeout_message=''):
|
||||
|
||||
self.interval = interval
|
||||
self.max_intervals = max_intervals
|
||||
self.frame_handler = frame_handler
|
||||
self.timeout_message = timeout_message
|
||||
self.count = 0
|
||||
|
||||
def handler(self, signum, frame):
|
||||
self.count += 1
|
||||
if self.frame_handler:
|
||||
self.frame_handler(self.count, frame)
|
||||
|
||||
if self.count >= self.max_intervals:
|
||||
raise Timeout(frame, self.timeout_message)
|
||||
|
||||
def __call__(self, fn):
|
||||
|
||||
@wraps(fn)
|
||||
def call_fn_with_heartbeat(*args, **kwargs):
|
||||
# Set a timer to call our handler every ``interval`` seconds.
|
||||
signal.signal(signal.SIGALRM, self.handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval)
|
||||
try:
|
||||
outval = fn(*args, **kwargs)
|
||||
|
||||
finally:
|
||||
# Deactivate the timer once we're done so that the
|
||||
# decorator doesn't have unexpected side-effects later.
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, signal.SIG_DFL)
|
||||
self.count = 0
|
||||
|
||||
# Return the value of fn if it finished without tripping
|
||||
# an exception. This won't execute if the Timeout or any
|
||||
# other exception was raised by self.handle.
|
||||
return outval
|
||||
return call_fn_with_heartbeat
|
||||
|
||||
def __enter__(self):
|
||||
# Set a timer to call our handler every N seconds.
|
||||
self.count = 0
|
||||
signal.signal(signal.SIGALRM, self.handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
# Turn off the timer on exit. This will re-raise any exception raised
|
||||
# during execution of the with-block
|
||||
self.count = 0
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, signal.SIG_DFL)
|
||||
Reference in New Issue
Block a user