Resolved conflicts

This commit is contained in:
Thomas Wiecki
2012-08-23 10:54:02 -04:00
27 changed files with 920 additions and 517 deletions
-1
View File
@@ -1 +0,0 @@
# TODO: move qexec console here
-189
View File
@@ -1,189 +0,0 @@
import uuid
import copy
import atexit
import pickle
from datetime import datetime
from collections import defaultdict
from UserDict import DictMixin
class Snapshot(object, DictMixin):
"""
A snapshot in time of a history container.
"""
def __init__(self, state, version, ts):
self.version = version
self.timestamp = ts
self._state = state
def keys(self):
return self._state.keys()
def values(self):
return self._state.values()
def items(self):
return self._state.items()
def __getitem__(self, key):
return self._state.__getitem__(key)
def has_key(self, key):
return self._state.has_key(key)
def copy(self):
return copy.copy(self._state)
class History(object, DictMixin):
"""
A duck-typed dictionary that tracks its time evolution.
Worth noting this not a particuarly high-performance
data structure due to the copious amount of copying going on.
"""
def __init__(self, default=None):
if default:
initial = defaultdict(default)
else:
initial = {}
self.version = 0
self.changeset = [('CREATE', None)]
self.current = Snapshot(initial, version=self.version, ts=datetime.now())
self._history = [self.current]
def items(self, version=-1):
return self._history[version].items()
def keys(self, version=-1):
return self._history[version].keys()
def rollback(self, version):
pass
def event(self, tup):
self.changeset.append(tup)
def __getitem__(self, key, version=-1):
return self._history[version].__getitem__(key)
def __setitem__(self, key, val):
if self.current.has_key(key):
self.changeset.append(('CHANGE', key))
else:
self.changeset.append(('ADD', key))
state = self.current.copy()
state[key] = val
self.version += 1
self.current = Snapshot(state, self.version, datetime.now())
self._history.append(self.current)
def __delitem__(self, key):
self.changeset.append(('REMOVE', key))
state = self.current.copy()
del state[key]
self.version += 1
self.current = Snapshot(state, self.version, datetime.now())
self._history.append(self.current)
def history(self):
for change in self.changeset:
print change
def __repr__(self):
return ':'.join(['historical', self.current._state.__repr__()])
SocketHistory = History()
ContextHistory = History()
def patch_zmq(_zmq=None):
"""
Monkey patch zeromq to allow for socket tracking.
"""
if _zmq:
zmq = _zmq
else:
import zmq
_Context = zmq.Context
_Socket = zmq.Socket
class TrackedSocket(zmq.Socket):
def __init__(self, context, socket_type):
self.context = context
self.uuid = str(uuid.uuid4())
SocketHistory[self.uuid] = self
_Socket.__init__(self, context, socket_type)
def connect(self, address):
SocketHistory.event(('CONNECT', self.uuid, address))
_Socket.connect(self, address)
def bind(self, address):
SocketHistory.event(('BIND', self.uuid, address))
_Socket.bind(self, address)
def close(self, *args, **kwargs):
del SocketHistory[self.uuid]
_Socket.close(self, *args, **kwargs)
def setsockopt(self, option, optval):
if option == zmq.IDENTITY:
old = SocketHistory[self.uuid]
SocketHistory[optval] = old
del SocketHistory[self.uuid]
self.uuid = optval
_Socket.setsockopt(self, option, optval)
class TrackedContext(zmq.Context):
def __init__(self, *args, **kwargs):
self.sockets = {}
_Context.__init__(self, *args, **kwargs)
self.uuid = str(uuid.uuid4())
ContextHistory[self.uuid] = self
def socket(self, socket_type):
sock = TrackedSocket(self, socket_type)
ContextHistory.event(('EMBED', self.uuid, sock.uuid))
self.sockets[sock.uuid] = sock
return sock
def name(self, name):
"""
Name the context. Is a superset of the vanilla pyzmq
API.
"""
old = ContextHistory[self.context.uuid]
ContextHistory[name] = old
del ContextHistory[self.context.uuid]
self.uuid = name
def term(self, *args, **kwargs):
for uid, sock in self.sockets.iteritems():
if not sock.closed:
del SocketHistory[sock.uuid]
del ContextHistory[self.uuid]
_Context.term(self, *args, **kwargs)
def destroy(self, *args, **kwargs):
ContextHistory.event(('DESTROY', self.uuid))
_Context.destroy(self, *args, **kwargs)
zmq.Context = TrackedContext
zmq.Socket = TrackedSocket
return TrackedContext, TrackedSocket
def track_to_file(f):
def write_track():
pickle.dump(SocketHistory.changeset, file(f, 'wb+'))
atexit.register(write_track)
+4 -3
View File
@@ -4,7 +4,7 @@ python-dateutil==1.5
# Core scientific python
numpy>=1.6.1
pandas>=0.7.0rc1
pandas==0.8.0
scipy>=0.10.0
matplotlib==1.1.0
@@ -12,8 +12,9 @@ matplotlib==1.1.0
numexpr==2.0.1
Cython==0.15.1
#tables>=2.3.1
#scikits.statsmodels>=0.3.1
patsy==0.1.0
statsmodels==0.5.0-tutorial-beta
# ZeroMQ
pyzmq==2.1.11
+53
View File
@@ -0,0 +1,53 @@
import os
from signal import signal, SIGHUP, SIGINT
import time
from types import FrameType
import unittest
from zipline.utils.delayed_signals import delayed_signals
class DelayedSignals(unittest.TestCase):
def handler(self, signum, frame):
print "Got signal " + str(signum)
self.got[signum] = time.time()
self.assertTrue(isinstance(frame, FrameType))
def setUp(self):
signal(SIGHUP, self.handler)
signal(SIGINT, self.handler)
def reset(self):
self.got = {}
def test_delayed_signals(self):
self.reset()
with delayed_signals([SIGHUP]):
os.kill(os.getpid(), SIGHUP)
time.sleep(2)
self.assertTrue(self.got[SIGHUP])
self.assertTrue(time.time() - self.got[SIGHUP] < 2)
def test_immediate_signals(self):
self.reset()
os.kill(os.getpid(), SIGHUP)
time.sleep(2)
self.assertTrue(self.got[SIGHUP])
self.assertTrue(time.time() - self.got[SIGHUP] > 1)
def test_multiple_signals(self):
self.reset()
with delayed_signals([SIGHUP, SIGINT]):
os.kill(os.getpid(), SIGINT)
self.assertFalse(SIGHUP in self.got)
self.assertTrue(SIGINT in self.got)
@delayed_signals([SIGHUP])
def kill_and_sleep(self):
os.kill(os.getpid(), SIGHUP)
time.sleep(2)
def test_decorator(self):
self.reset()
self.kill_and_sleep()
self.assertTrue(SIGHUP in self.got)
self.assertTrue(time.time() - self.got[SIGHUP] < 2)
+47 -1
View File
@@ -3,11 +3,14 @@ import zmq
from unittest2 import TestCase
from collections import defaultdict
from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm
from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm, \
InitializeTimeoutAlgorithm, TooMuchProcessingAlgorithm
from zipline.finance.trading import SIMULATION_STYLE
from zipline.core.devsimulator import AddressAllocator
from zipline.lines import SimulatedTrading
from zipline.gens.transform import StatefulTransform
from zipline.gens.tradesimulation import HEARTBEAT_INTERVAL, \
MAX_HEARTBEAT_INTERVALS
from zipline.utils.test_utils import \
drain_zipline, \
@@ -143,3 +146,46 @@ class ExceptionTestCase(TestCase):
# make sure our path shortening is working
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
def test_initialize_timeout(self):
self.zipline_test_config['algorithm'] = \
InitializeTimeoutAlgorithm(
self.zipline_test_config['sid']
)
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
output, _ = drain_zipline(self, zipline)
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
payload = output[-1]['payload']
self.assertEqual(payload['name'],'Timeout')
self.assertEqual(payload['message'], 'Call to initialize timed out')
def test_heartbeat(self):
self.zipline_test_config['algorithm'] = \
TooMuchProcessingAlgorithm(
self.zipline_test_config['sid']
)
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
output, _ = drain_zipline(self, zipline)
# There should be a message for each hearbeat, plus a message
# for the final timeout.
assert len(output) == MAX_HEARTBEAT_INTERVALS + 1
# Assert that everything but the last message is a heartbeat log.
for message in output[0:-1]:
assert message['prefix'] == 'LOG'
assert message['payload']['func_name'] == 'log_heartbeats'
# Assert that the last message is a timeout exception.
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
payload = output[-1]['payload']
self.assertEqual(payload['name'],'Timeout')
self.assertEqual(payload['message'], 'Too much time spent in handle_data call')
-31
View File
@@ -20,7 +20,6 @@ from zipline.finance.performance import PerformanceTracker
from zipline.utils.protocol_utils import ndict
from zipline.finance.trading import TransactionSimulator
from zipline.utils.test_utils import \
drain_zipline, \
setup_logger, \
teardown_logger,\
assert_single_position
@@ -121,36 +120,6 @@ class FinanceTestCase(TestCase):
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
assert_single_position(self, zipline)
#@timed(DEFAULT_TIMEOUT)
def test_sid_filter(self):
# Ensure the algorithm's filter prevents events from arriving.
# create a test algorithm whose filter will not match any of the
# trade events sourced inside the zipline.
order_amount = 100
order_count = 100
no_match_sid = 222
test_algo = TestAlgorithm(
no_match_sid,
order_amount,
order_count
)
self.zipline_test_config['trade_count'] = 200
self.zipline_test_config['algorithm'] = test_algo
zipline = SimulatedTrading.create_test_zipline(
**self.zipline_test_config
)
output, transaction_count = drain_zipline(self, zipline)
#check that the algorithm received no events
self.assertEqual(
0,
transaction_count,
"The algorithm should not receive any events due to filtering."
)
# TODO: write tests for short sales
# TODO: write a test to do massive buying or shorting.
+40
View File
@@ -1,7 +1,15 @@
import logging
import logbook
import uuid
import zmq
from zipline import ndict
from zipline.utils.logger import configure_logging, tail
from zipline.utils.log_utils import ZeroMQLogHandler
from zipline.utils.test_utils import create_receiver, drain_receiver
from unittest2 import TestCase
@@ -20,3 +28,35 @@ class LoggerTestCase(TestCase):
last_line = tail(logfile, window=1)
logged_msg = last_line.split(" - ")[1]
self.assertEqual(test_msg, logged_msg)
def test_zmq_handler(self):
socket_addr = 'tcp://127.0.0.1:10000'
ctx = zmq.Context()
socket_push = ctx.socket(zmq.PUSH)
socket_push.connect(socket_addr)
recv = create_receiver(socket_addr, ctx)
zmq_out = ZeroMQLogHandler(
socket = socket_push,
filter = lambda r, h: r.channel in ['test zmq logger'],
context=ctx,
#bubble=False
)
log = logbook.Logger('test zmq logger')
x = ndict({})
x.a = 1
ex = example(133)
with zmq_out.threadbound():
log.info(ex.num)
output, _ = drain_receiver(recv, count=1)
self.assertEqual(output[-1]['prefix'], 'LOG')
self.assertTrue(isinstance(output[-1]['payload']['msg'], basestring))
class example(object):
def __init__(self, num):
self.num = num
+49 -1
View File
@@ -1,4 +1,5 @@
import pytz
import numpy
from datetime import timedelta, datetime
from collections import defaultdict
@@ -15,6 +16,7 @@ from zipline.gens.tradegens import SpecificEquityTrades
from zipline.gens.transform import StatefulTransform, EventWindow
from zipline.gens.vwap import VWAP
from zipline.gens.mavg import MovingAverage
from zipline.gens.stddev import MovingStandardDev
from zipline.gens.returns import Returns
import zipline.utils.factory as factory
@@ -70,6 +72,7 @@ class EventWindowTestCase(TestCase):
delta = timedelta(minutes = 5),
days = None
)
now = utcnow()
# 15 dates, increasing in 1 minute increments.
@@ -99,6 +102,7 @@ class EventWindowTestCase(TestCase):
delta = None,
days = 1
)
dates = ([self.pre_open]*3)
dates += ([self.mid_day]*3)
dates += ([self.post_close]*3)
@@ -239,11 +243,12 @@ class FinanceTransformsTestCase(TestCase):
fields = ['price', 'volume'],
delta = timedelta(days = 2),
)
transformed = list(mavg.transform(self.source))
# Output values.
tnfm_prices = [message.tnfm_value.price for message in transformed]
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
# "Hand-calculated" values
expected_prices = [
((10.0) / 1.0),
@@ -264,3 +269,46 @@ class FinanceTransformsTestCase(TestCase):
assert tnfm_prices == expected_prices
assert tnfm_volumes == expected_volumes
def test_moving_stddev(self):
trade_history = factory.create_trade_history(
133,
[10.0, 15.0, 13.0, 12.0],
[100, 100, 100, 100],
timedelta(hours = 1),
self.trading_environment
)
stddev = StatefulTransform(
MovingStandardDev,
market_aware = False,
delta = timedelta(minutes = 150),
)
self.source = SpecificEquityTrades(event_list=trade_history)
transformed = list(stddev.transform(self.source))
vals = [message.tnfm_value for message in transformed]
expected = [
None,
numpy.std([10.0, 15.0], ddof = 1),
numpy.std([10.0, 15.0, 13.0], ddof = 1),
numpy.std([15.0, 13.0, 12.0], ddof = 1),
]
# numpy has odd rounding behavior, cf.
# http://docs.scipy.org/doc/numpy/reference/generated/numpy.std.html
for v1, v2 in zip(vals, expected):
if v1 == None:
assert v2 == None
continue
assert round(v1, 5) == round(v2, 5)
+6 -1
View File
@@ -166,6 +166,7 @@ class PerformanceTracker(object):
self.event_count = 0
self.last_dict = None
self.exceeded_max_loss = False
self.no_more_updates = False
self.compute_risk_metrics = True
@@ -205,9 +206,12 @@ class PerformanceTracker(object):
self.todays_performance.positions[sid] = Position(sid)
def update(self, event):
if event.dt == "DONE":
if self.no_more_updates:
return zp.ndict({'dt':0})
elif event.dt == "DONE":
event.perf_message = self.handle_simulation_end()
del event['TRANSACTION']
self.no_more_updates = True
return event
elif self.exceeded_max_loss:
# in case of max_loss, signal to downstream
@@ -215,6 +219,7 @@ class PerformanceTracker(object):
event.dt = "DONE"
event.perf_message = self.handle_simulation_end()
del event['TRANSACTION']
self.no_more_updates = True
return event
else:
event.perf_message = self.process_event(event)
+14 -16
View File
@@ -307,24 +307,22 @@ class RiskMetrics():
for i in xrange(7):
if(self.treasury_curves.has_key(self.end_date + i * one_day)):
curve = self.treasury_curves[self.end_date + i * one_day]
break
self.treasury_curve = curve
rate = self.treasury_curve[self.treasury_duration]
#1month note data begins in 8/2001, so we can use 3month instead.
if rate == None and self.treasury_duration == '1month':
rate = self.treasury_curve['3month']
if curve:
self.treasury_curve = curve
rate = self.treasury_curve[self.treasury_duration]
#1month note data begins in 8/2001, so we can use 3month instead.
if rate == None and self.treasury_duration == '1month':
rate = self.treasury_curve['3month']
if rate != None:
return rate * (td.days + 1) / 365
if rate != None:
return rate * (td.days + 1) / 365
message = "no rate for end date = {dt} and term = {term}. Check \
that date doesn't exceed treasury history range."
message = message.format(
dt=self.end_date,
term=self.treasury_duration
)
raise Exception(message)
message = "no rate for end date = {dt} and term = {term}. Check \
that date doesn't exceed treasury history range."
message = message.format(
dt=self.end_date,
term=self.treasury_duration
)
raise Exception(message)
+20
View File
@@ -184,6 +184,8 @@ class TradingEnvironment(object):
self.first_open = self.calculate_first_open()
self.last_close = self.calculate_last_close()
self.prior_day_open = self.calculate_prior_day_open()
def calculate_first_open(self):
"""
Finds the first trading day on or after self.period_start.
@@ -197,6 +199,24 @@ class TradingEnvironment(object):
first_open = self.set_NYSE_time(first_open, 9, 30)
return first_open
def calculate_prior_day_open(self):
"""
Finds the first trading day open that falls at least a day
before period_start.
"""
one_day = datetime.timedelta(days=1)
first_open = self.period_start - one_day
if first_open <= self.trading_days[0]:
log.warn("Cannot calculate prior day open.")
return self.period_start
while not self.is_trading_day(first_open):
first_open = first_open - one_day
first_open = self.set_NYSE_time(first_open, 9, 30)
return first_open
def calculate_last_close(self):
"""
Finds the last trading day on or before self.period_end
+7 -3
View File
@@ -1,12 +1,16 @@
"""
Generator version of Feed.
Sorting generator.
"""
import logbook
from collections import deque
from zipline import ndict
from zipline.gens.utils import \
assert_datasource_unframe_protocol, \
assert_sort_protocol
log = logbook.Logger('Sorting')
def date_sort(stream_in, source_ids):
"""
A generator that takes a generator and a list of source_ids. We
@@ -27,7 +31,7 @@ def date_sort(stream_in, source_ids):
# Incoming messages should be the output of DATASOURCE_UNFRAME.
assert_datasource_unframe_protocol(message), \
"Bad message in date_sort: %s" % message
# Only allow messages from sources we expect.
assert message.source_id in sources, "Unexpected source: %s" % message
@@ -40,7 +44,7 @@ def date_sort(stream_in, source_ids):
message = pop_oldest(sources)
assert_sort_protocol(message)
yield message
# We should have only a done message left in each queue.
for queue in sources.itervalues():
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
+100
View File
@@ -0,0 +1,100 @@
from numbers import Number
from datetime import datetime, timedelta
from collections import defaultdict
from math import sqrt
from zipline import ndict
from zipline.gens.transform import EventWindow
class MovingStandardDev(object):
"""
Class that maintains a dicitonary from sids to
MovingStandardDevWindows. For each sid, we maintain a the
standard deviation of all events falling within the specified
window.
"""
def __init__(self, market_aware, days = None, delta = None):
self.market_aware = market_aware
self.delta = delta
self.days = days
# Market-aware mode only works with full-day windows.
if self.market_aware:
assert self.days and not self.delta,\
"Market-aware mode only works with full-day windows."
# Non-market-aware mode requires a timedelta.
else:
assert self.delta and not self.days, \
"Non-market-aware mode requires a timedelta."
# No way to pass arguments to the defaultdict factory, so we
# need to define a method to generate the correct EventWindows.
self.sid_windows = defaultdict(self.create_window)
def create_window(self):
"""
Factory method for self.sid_windows.
"""
return MovingStandardDevWindow(
self.market_aware,
self.days,
self.delta
)
def update(self, event):
"""
Update the event window for this event's sid. Return an ndict
from tracked fields to moving averages.
"""
# This will create a new EventWindow if this is the first
# message for this sid.
window = self.sid_windows[event.sid]
window.update(event)
return window.get_stddev()
class MovingStandardDevWindow(EventWindow):
"""
Iteratively calculates standard deviation for a particular sid
over a given time window. The expected functionality of this
class is to be instantiated inside a MovingStandardDev.
"""
def __init__(self, market_aware, days, delta):
# Call the superclass constructor to set up base EventWindow
# infrastructure.
EventWindow.__init__(self, market_aware, days, delta)
self.sum = 0.0
self.sum_sqr = 0.0
def handle_add(self, event):
assert event.has_key('price')
assert isinstance(event.price, Number)
self.sum += event.price
self.sum_sqr += event.price ** 2
def handle_remove(self, event):
assert event.has_key('price')
assert isinstance(event.price, Number)
self.sum -= event.price
self.sum_sqr -= event.price ** 2
def get_stddev(self):
# Sample standard deviation is undefined for a single event or
# no events.
if len(self) <= 1:
return None
else:
average = self.sum /len(self)
s_squared = (self.sum_sqr - self.sum*average) / (len(self) - 1)
stddev = sqrt(s_squared)
return stddev
+35 -15
View File
@@ -4,24 +4,23 @@ and zipline development
"""
import random
import pytz
from copy import copy
import pandas as pd
from zipline import ndict
from zipline.protocol import DATASOURCE_TYPE
from itertools import chain, cycle, ifilter, izip
from itertools import chain, cycle, ifilter, izip, repeat
from datetime import datetime, timedelta
from zipline.gens.utils import hash_args, create_trade
def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
delta = timedelta(minutes = 1),
count = 100):
count = 100,
repeats = None):
"""
Utility to generate a stream of dates.
"""
return (start + (i * delta) for i in xrange(count))
if repeats:
return (start + (i * delta) for i in xrange(count) for n in xrange(repeats))
else:
return (start + (i * delta) for i in xrange(count))
def mock_prices(count, rand = False):
"""
@@ -101,20 +100,41 @@ class SpecificEquityTrades(object):
def get_hash(self):
return self.__class__.__name__ + "-" + self.arg_string
def update_source_id(self, gen):
for event in gen:
event.source_id = self.get_hash()
yield event
def create_fresh_generator(self):
if self.event_list:
for event in self.event_list:
event['source_id'] = self.get_hash()
unfiltered = (event for event in self.event_list)
event_gen = (event for event in self.event_list)
unfiltered = self.update_source_id(event_gen)
# Set up iterators for each expected field.
else:
dates = date_gen(count=self.count,
start=self.start,
delta=self.delta
)
if self.concurrent:
# in this context the count is the number of
# trades per sid, not the total.
dates = date_gen(
count=self.count,
start=self.start,
delta=self.delta,
repeats=len(self.sids),
)
else:
dates = date_gen(
count=self.count,
start=self.start,
delta=self.delta
)
prices = mock_prices(self.count)
volumes = mock_volumes(self.count)
sids = cycle(self.sids)
# Combine the iterators into a single iterator of arguments
+114 -92
View File
@@ -1,9 +1,12 @@
import signal
from logbook import Logger, Processor
from datetime import datetime, timedelta
from numbers import Integral
from itertools import groupby
from zipline import ndict
from zipline.utils.timeout import timeout, heartbeat, Timeout
from zipline.gens.transform import StatefulTransform
from zipline.finance.trading import TransactionSimulator
@@ -13,10 +16,15 @@ from zipline.gens.utils import hash_args
log = Logger('Trade Simulation')
# TODO: make these arguments rather than global constants
INIT_TIMEOUT = 5
HEARTBEAT_INTERVAL = 1 # seconds
MAX_HEARTBEAT_INTERVALS = 15 #count
class TradeSimulationClient(object):
"""
Generator that takes the expected output of a merge, a user
algorithm, a trading environment, and a simulator style as
Generator-style class that takes the expected output of a merge, a
user algorithm, a trading environment, and a simulator style as
arguments. Pipes the merge stream through a TransactionSimulator
and a PerformanceTracker, which keep track of the current state of
our algorithm's simulated universe. Results are fed to the user's
@@ -24,7 +32,7 @@ class TradeSimulationClient(object):
TransactionSimulator's order book.
TransactionSimulator maintains a dictionary from sids to the
unfulfilled orders placed by the user's algorithm. As trade
as-yet unfilled orders placed by the user's algorithm. As trade
events arrive, if the algorithm has open orders against the
trade's sid, the simulator will fill orders up to 25% of market
cap. Applied transactions are added to a txn field on the event
@@ -40,9 +48,9 @@ class TradeSimulationClient(object):
performance report, which is appended to event's perf_report
field.
Fully processed events are run through a batcher generator, which
batches together events with the same dt field into a single event
to be fed to the algo. The portfolio object is repeatedly
Fully processed events are fed to AlgorithmSimulator, which
batches together events with the same dt field into a single
snapshot to be fed to the algo. The portfolio object is repeatedly
overwritten so that only the most recent snapshot of the universe
is sent to the algo.
"""
@@ -55,9 +63,13 @@ class TradeSimulationClient(object):
self.style = sim_style
self.algo_sim = None
self.warmup_start = self.environment.prior_day_open
self.algo_start = self.environment.first_open
def get_hash(self):
"""
There should only ever be one TSC in the system.
There should only ever be one TSC in the system, so
we don't bother passing args into the hash.
"""
return self.__class__.__name__ + hash_args()
@@ -89,25 +101,31 @@ class TradeSimulationClient(object):
with_portfolio = perf_tracker.transform(with_filled_orders)
# Pass the messages from perf along with the trading client's
# state into the algorithm for simulation. We provide the
# trading client so that the algorithm can place new orders
# into the client's order book.
# state into the algorithm for simulation. We provide a
# pointer to the ordering client's internal state so that the
# algorithm can place new orders into the client's order book.
self.algo_sim = AlgorithmSimulator(
with_portfolio,
ordering_client.state,
self.algo,
self.algo_start
)
# The algorithm will yield a daily_results message (as
# calculated by the performance tracker) at the end of each
# day. It will also yield a risk report at the end of the
# simulation.
for message in self.algo_sim:
yield message
class AlgorithmSimulator(object):
def __init__(self, stream_in, order_book, algo):
def __init__(self,
stream_in,
order_book,
algo,
algo_start):
self.stream_in = stream_in
@@ -121,12 +139,28 @@ class AlgorithmSimulator(object):
self.algo = algo
self.sids = algo.get_sid_filter()
self.algo_start = algo_start
# Monkey patch the user algorithm to place orders in the
# TransactionSimulator's order book.
# TransactionSimulator's order book and use our logger.
self.algo.set_order(self.order)
self.algo.set_logger(Logger("AlgoLog"))
self.algolog = Logger("AlgoLog")
self.algo.set_logger(self.algolog)
# Handler for heartbeats during calls to handle_data.
def log_heartbeats(beat_count, stackframe):
t = beat_count * HEARTBEAT_INTERVAL
warning = "handle_data has been processing for %i seconds" %t
self.algolog.warn(warning)
# Context manager that calls log_heartbeats every HEARTBEAT_INTERVAL
# seconds, raising an exception after MAX_HEARTBEATS
self.heartbeat_monitor = heartbeat(
HEARTBEAT_INTERVAL,
MAX_HEARTBEAT_INTERVALS,
frame_handler=log_heartbeats,
timeout_message="Too much time spent in handle_data call"
)
# ==============
# Snapshot Setup
@@ -142,7 +176,7 @@ class AlgorithmSimulator(object):
# We don't have a datetime for the current snapshot until we
# receive a message.
self.simulation_dt = None
self.this_snapshot_dt = None
self.snapshot_dt = None
# =============
# Logging Setup
@@ -151,7 +185,7 @@ class AlgorithmSimulator(object):
# Processor function for injecting the algo_dt into
# user prints/logs.
def inject_algo_dt(record):
record.extra['algo_dt'] = self.this_snapshot_dt
record.extra['algo_dt'] = self.snapshot_dt
self.processor = Processor(inject_algo_dt)
# This is a class, which is instantiated later
@@ -206,94 +240,62 @@ class AlgorithmSimulator(object):
# Capture any output of this generator to stdout and pipe it
# to a logbook interface. Also inject the current algo
# snapshot time to any log record generated.
with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
# Call the user's initialize method.
self.algo.initialize()
for event in self.stream_in:
# Yield any perf messages received to be relayed back to
# the browser.
# Call user's initialize method with a timeout.
with timeout(INIT_TIMEOUT, message="Call to initialize timed out"):
self.algo.initialize()
if event.perf_message:
yield event.perf_message
del event['perf_message']
# Group together events with the same dt field. This depends on the
# events already being sorted.
for date, snapshot in groupby(self.stream_in, lambda e: e.dt):
if event.dt == "DONE":
if self.this_snapshot_dt:
# stop iteration happened
# mid-snapshot, so we have a universe
# snapshot that is not yet processed
# by the algorithm.
self.simulate_current_snapshot()
break
# This should only happen for the first event we run.
# Set the simulation date to be the first event we see.
# This should only occur once, at the start of the test.
if self.simulation_dt == None:
self.simulation_dt = event.dt
self.simulation_dt = date
# ======================
# Time Compression Logic
# ======================
# Done message has the risk report, so we yield before exiting.
if date == 'DONE':
for event in snapshot:
yield event.perf_message
break
if self.this_snapshot_dt != None:
self.update_current_snapshot(event)
# We're still in the warmup period. Use the event to
# update our universe, but don't yield any perf messages,
# and don't send a snapshot to handle_data.
elif date < self.algo_start:
for event in snapshot:
del event['perf_message']
self.update_universe(event)
# The algorithm has been missing events because it took
# too long processing. Update the universe with data from
# this event, then check if enough time has passed that we
# can start a new snapshot.
# The algo has taken so long to process events that
# its simulated time is later than the event time.
# Update the universe and yield any perf messages
# encountered, but don't call handle_data.
elif date < self.simulation_dt:
for event in snapshot:
# Only yield if we have something interesting to say.
if event.perf_message != None:
yield event.perf_message
# Delete the message before updating so we don't send it
# to the user.
del event['perf_message']
self.update_universe(event)
# Regular snapshot. Update the universe and send a snapshot
# to handle data.
else:
self.update_universe(event)
if event.dt >= self.simulation_dt:
self.this_snapshot_dt = event.dt
for event in snapshot:
# Only yield if we have something interesting to say.
if event.perf_message != None:
yield event.perf_message
del event['perf_message']
self.update_universe(event)
def update_current_snapshot(self, event):
"""
Update our current snapshot of the universe. Call handle_data if
"""
# The new event matches our snapshot dt. Just update the
# universe and move on.
if event.dt == self.this_snapshot_dt:
self.update_universe(event)
# The new event does not match our snapshot.
else:
self.simulate_current_snapshot()
# Once we've finished simulating the old snapshot,
# we can update the universe with the new event.
self.update_universe(event)
# The current event is later than the simulation time,
# which means the algorithm finished quickly enough to
# receive the new event. Start a new snapshot with this
# event's dt.
if event.dt >= self.simulation_dt:
self.this_snapshot_dt = event.dt
# The algorithm spent enough time processing that it
# missed the new event. Wait to start a new snapshot until
# the events catch up to the algo's simulated dt.
else:
self.this_snapshot_dt = None
def simulate_current_snapshot(self):
"""
Run the user's algo against our current snapshot and update the algo's
simulated time.
"""
start_tic = datetime.now()
self.algo.handle_data(self.universe)
stop_tic = datetime.now()
# How long did you take?
delta = stop_tic - start_tic
# Update the simulation time.
self.simulation_dt = self.this_snapshot_dt + delta
# Send the current state of the universe to the user's algo.
self.simulate_snapshot(date)
def update_universe(self, event):
"""
@@ -305,3 +307,23 @@ class AlgorithmSimulator(object):
# Update our knowledge of this event's sid
for field in event.keys():
self.universe[event.sid][field] = event[field]
def simulate_snapshot(self, date):
"""
Run the user's algo against our current snapshot and update
the algo's simulated time.
"""
# Needs to be set so that we inject the proper date into algo
# log/print lines.
self.snapshot_dt = date
start_tic = datetime.now()
with self.heartbeat_monitor:
self.algo.handle_data(self.universe)
stop_tic = datetime.now()
# How long did you take?
delta = stop_tic - start_tic
# Update the simulation time.
self.simulation_dt = date + delta
+11 -4
View File
@@ -3,6 +3,7 @@ Generator versions of transforms.
"""
import types
import pytz
import logbook
from copy import deepcopy
from datetime import datetime, timedelta
@@ -15,6 +16,8 @@ from zipline.utils.tradingcalendar import trading_days_between
from zipline.gens.utils import assert_sort_unframe_protocol, \
assert_transform_protocol, hash_args
log = logbook.Logger('Transform')
class Passthrough(object):
FORWARDER = True
"""
@@ -72,6 +75,7 @@ class StatefulTransform(object):
# Create the string associated with this generator's output.
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
log.info('StatefulTransform [%s] initialized' % self.namestring)
def get_hash(self):
return self.namestring
@@ -82,7 +86,7 @@ class StatefulTransform(object):
def _gen(self, stream_in):
# IMPORTANT: Messages may contain pointers that are shared with
# other streams, so we only manipulate copies.
log.info('Running StatefulTransform [%s]' % self.get_hash())
for message in stream_in:
# allow upstream generators to yield None to avoid
@@ -101,7 +105,8 @@ class StatefulTransform(object):
# FORWARDER flag means we want to keep all original
# values, plus append tnfm_id and tnfm_value. Used for
# preserving the original event fields when our output
# will be fed into a merge.
# will be fed into a merge. Currently only Passthrough
# uses this flag.
if self.forward_all:
out_message = message_copy
out_message.tnfm_id = self.namestring
@@ -143,6 +148,7 @@ class StatefulTransform(object):
out_message.dt = message_copy.dt
yield out_message
log.info('Finished StatefulTransform [%s]' % self.get_hash())
class EventWindow:
"""
Abstract base class for transform classes that calculate iterative
@@ -153,8 +159,9 @@ class EventWindow:
from the window. Subclass these methods along with init(*args,
**kwargs) to calculate metrics over the window.
The market_aware flag is used to toggle whether the eventwindow
calculates
If the market_aware flag is True, the EventWindow drops old events
based on the number of elapsed trading days between newest and oldest.
Otherwise old events are dropped based on a raw timedelta.
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
implementations of moving average and volume-weighted average
+5
View File
@@ -67,12 +67,17 @@ def hash_args(*args, **kwargs):
return hasher.hexdigest()
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
row = ndict({
'source_id' : source_id,
'type' : DATASOURCE_TYPE.TRADE,
'sid' : sid,
'dt' : datetime,
'price' : price,
'close' : price,
'open' : price,
'low' : price * .95,
'high' : price * 1.05,
'volume' : amount
})
return row
+19 -122
View File
@@ -63,30 +63,20 @@ import sys
import zmq
import os
from signal import SIGHUP, SIGINT
import datetime
import pytz
import pandas as pd
import numpy as np
import multiprocessing
from setproctitle import setproctitle
from zipline.test_algorithms import TestAlgorithm
from zipline.finance.trading import SIMULATION_STYLE
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
from zipline.utils.log_utils import ZeroMQLogHandler
from zipline.utils import factory
from zipline.utils.factory import create_trading_environment
from zipline.gens.tradegens import SpecificEquityTrades
from zipline import ndict
from zipline.protocol import DATASOURCE_TYPE
from zipline.test_algorithms import TestAlgorithm
from zipline.gens.composites import \
date_sorted_sources, merged_transforms, sequential_transforms
from zipline.gens.transform import Passthrough, StatefulTransform
from zipline.gens.composites import (
date_sorted_sources,
sequential_transforms
)
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
from logbook import Logger, NestedSetup, Processor
from logbook import Logger
import zipline.protocol as zp
@@ -181,6 +171,8 @@ class SimulatedTrading(object):
def close(self):
log.info("Closing Simulation: {id}".format(id=self.sim_id))
if self.results_socket:
self.results_socket.close()
if self.proc and self.send_sighup:
ppid = os.getppid()
if self.success:
@@ -253,17 +245,10 @@ class SimulatedTrading(object):
else:
return []
def __iter__(self):
return self
def next(self):
return self.gen.next()
@staticmethod
def create_test_zipline(**config):
"""
:param config: A configuration object that is a dict with
(all optional):
:param config: A configuration object that is a dict with:
- environment - a \
:py:class:`zipline.finance.trading.TradingEnvironment`
@@ -285,7 +270,12 @@ class SimulatedTrading(object):
of StatefulTransform objects.
"""
assert isinstance(config, dict)
sid = config.get('sid', 133)
sid_list = config.get('sid_list')
if not sid_list:
sid = config.get('sid')
sid_list = [sid]
concurrent_trades = config.get('concurrent_trades', False)
#--------------------
# Trading Environment
@@ -329,11 +319,13 @@ class SimulatedTrading(object):
trade_source = config['trade_source']
else:
trade_source = factory.create_daily_trade_source(
sids,
sid_list,
trade_count,
trading_environment
trading_environment,
concurrent=concurrent_trades
)
#-------------------
# Transforms
#-------------------
@@ -367,98 +359,3 @@ class SimulatedTrading(object):
#-------------------
return sim
def create_sp_source(start_dt=None, end_dt=None):
if start_dt is None:
start_dt = datetime.datetime(2002, 1, 1, tzinfo=pytz.utc)
if end_dt is None:
end_dt = datetime.datetime(2008, 1, 1, tzinfo=pytz.utc)
sp_events, _ = factory.load_market_data()
sp_transformed = []
for event in sp_events:
transformed = ndict(event.to_dict())
if (transformed.dt < start_dt) or (transformed.dt > end_dt):
continue
transformed['sid'] = 0
transformed['price'] = transformed['returns']
transformed['type'] = DATASOURCE_TYPE.TRADE
sp_transformed.append(transformed)
source = SpecificEquityTrades(event_list=sp_transformed)
return source
class Zipline(object):
def __init__(self, **kwargs):
algorithm = kwargs.get('algorithm', TestAlgorithm)
source_descrs = kwargs.get('sources', ['S&P'])
if isinstance(source_descrs, str):
source_descrs = [source_descrs]
sources = []
for source_descr in source_descrs:
if isinstance(source_descr, str):
if source_descr == 'S&P':
source = create_sp_source()
else:
raise NotImplementedError, "Source with name {source_descr} not known.".format(source_descr=source_descr)
else:
source = source_descr
sources.append(source)
environment = kwargs.get('environment', create_trading_environment())
try:
transform_descrs = kwargs.get('transforms', algorithm.registered_transforms)
except:
print "Couldn't load any registered_transforms."
transform_descrs = {}
# Create transforms by wrapping them into StatefulTransforms
transforms = []
for namestring, trans_descr in transform_descrs.iteritems():
sf = StatefulTransform(
trans_descr['class'],
*trans_descr['args'],
**trans_descr['kwargs']
)
sf.namestring = namestring
transforms.append(sf)
results_socket_uri = None
context = None
sim_id = None
style = SIMULATION_STYLE.FIXED_SLIPPAGE
self.simulated_trading = SimulatedTrading(
sources,
transforms,
algorithm,
environment,
style,
results_socket_uri,
context,
sim_id)
def run(self):
# drain simulated_trading
perfs = [perf for perf in self.simulated_trading]
# create daily stats dataframe
daily_perfs = []
cum_perfs = []
for perf in perfs:
if 'daily_perf' in perf:
daily_perfs.append(perf['daily_perf'])
else:
cum_perfs.append(perf)
daily_dts = [np.datetime64(perf['period_close'], utc=True) for perf in daily_perfs]
daily_stats = pd.DataFrame(daily_perfs, index=daily_dts)
return daily_stats
+135 -25
View File
@@ -1,5 +1,8 @@
from zipline.lines import Zipline
import pandas as pd
import pandas.io.data as dt
from pandas.io.data import DataReader
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
@@ -9,24 +12,28 @@ from zipline.gens.mavg import MovingAverage
from zipline.optimize.algorithms import TradingAlgorithm
from datetime import timedelta
from mpi4py_map import map
#from mpi4py_map import map
# Inherits from Algorithm base class
class DMA(TradingAlgorithm):
"""Dual Moving Average algorithm.
"""
def __init__(self, sid, amount=100, short_window=20, long_window=40):
self.sids = [sid]
def __init__(self, sids, amount=100, short_window=20, long_window=40):
self.sids = sids
self.amount = amount
self.done = False
self.order = None
self.frame_count = 0
self.portfolio = None
self.orders = []
self.market_entered = False
self.prices = []
self.events = 0
self.invested = {}
for sid in self.sids:
self.invested[sid] = False
self.add_transform(MovingAverage, 'short_mavg', ['price'],
market_aware=False,
delta=timedelta(days=int(short_window)))
@@ -37,36 +44,139 @@ class DMA(TradingAlgorithm):
def handle_data(self, data):
self.events += 1
sid = self.sids[0]
# access transforms via their user-defined tag
if (data[sid].short_mavg > data[sid].long_mavg) and not self.market_entered:
self.order(sid, 100)
self.market_entered = True
elif (data[sid].short_mavg < data[sid].long_mavg) and self.market_entered:
self.order(sid, -100)
self.market_entered = False
for sid in self.sids:
# access transforms via their user-defined tag
if (data[sid].short_mavg['price'] > data[sid].long_mavg['price']) and not self.invested[sid]:
self.order(sid, self.amount)
self.invested[sid] = True
elif (data[sid].short_mavg['price'] < data[sid].long_mavg['price']) and self.invested[sid]:
self.order(sid, -self.amount)
self.invested[sid] = False
class DanVWAP(TradingAlgorithm):
"""Dual Moving Average algorithm.
"""
def __init__(self, sids, amount=100, short_window=20, long_window=40):
self.sids = sids
self.amount = amount
self.done = False
self.order = None
self.frame_count = 0
self.portfolio = None
self.orders = []
self.prices = []
self.port = 0
self.add_transform(MovingAverage, 'short_mavg', ['price'],
market_aware=False,
delta=timedelta(days=int(short_window)))
self.add_transform(MovingAverage, 'long_mavg', ['price'],
market_aware=False,
delta=timedelta(days=int(long_window)))
def handle_data(self, data):
for sid in self.sids:
average=data[sid].vwap(5)
price=data[sid].price
if price>average*1.05:
self.order(sid, self.amount)
def load_close_px(indexes=None, stocks=None):
if indexes is None:
indexes = {'SPX' : '^GSPC'}
if stocks is None:
stocks = ['AAPL', 'GE', 'IBM', 'MSFT', 'XOM', 'AA', 'JNJ', 'PEP']
start = pd.datetime(1990, 1, 1)
end = pd.datetime.today()
data = {}
for stock in stocks:
print stock
stkd = DataReader(stock, 'yahoo', start, end).sort_index()
data[stock] = stkd
for name, ticker in indexes.iteritems():
print name
stkd = DataReader(ticker, 'yahoo', start, end).sort_index()
data[name] = stkd
df = pd.DataFrame({key: d['Close'] for key, d in data.iteritems()})
return df
def run((short_window, long_window)):
data = pd.DataFrame.from_csv('SP500.csv')
myalgo = DMA(sid=0, amount=100, short_window=short_window, long_window=long_window)
myalgo = DMA([0], amount=100, short_window=short_window, long_window=long_window)
stats = myalgo.run(data, compute_risk_metrics=False)
stats['sw'] = short_window
stats['lw'] = long_window
return stats
sws, lws = np.mgrid[50:80:5, 100:140:5]
def explore_params():
sws, lws = np.mgrid[10:20:5, 10:20:5]
stats_all = map(run, zip(sws.flatten(), lws.flatten()))
stats_all = map(run, zip(sws.flatten(), lws.flatten()))
stats = pd.concat(stats_all)
returns = stats.groupby(['sw', 'lw']).sum()
# for sw, lw in zip(sws.flatten(), lws.flatten()):
# stats = run(short_window=sw, long_window=lw)
# stats_all.append(stats)
plt.contourf(sws, lws, returns.returns.reshape(sws.shape))
plt.xlabel('Short window length')
plt.ylabel('Long window length')
plt.savefig('DMA_contour.png')
plt.show()
stats = pd.concat(stats_all)
returns = stats.groupby(['sw', 'lw']).sum()
plt.contourf(sws, lws, returns.returns.reshape(sws.shape))
plt.xlabel('Short window length')
plt.ylabel('Long window length')
plt.savefig('DMA_contour.png')
plt.show()
#stats = run((10, 50))
def get_opt_holdings_qp(univ_rets, track_rets):
from cvxopt import matrix
from cvxopt.solvers import qp
# set up the QP for CVXOPT
# .5 x' P x + q'x
# P = 2 * R'R
# q = - 2 * bmk'R
R = univ_rets.values
b = track_rets.values
P = matrix(2 * np.dot(R.T, R))
q = matrix(-2 * np.dot(R.T, b))
result = qp(P, q)
if result['status'] != 'optimal':
raise Exception('optimum not reached by QP')
return pd.Series(np.array(result['x']).ravel(), index=univ_rets.columns)
def opt_portfolio(cov, budget, min_return):
from cvxopt import matrix
from cvxopt.solvers import qp
n = len(cov)
cov = matrix(2 * cov)
q = matrix(np.zeros(n))
h = matrix(budget) # G*x < h
# coneqp
result = qp(cov, q, h=h)
if result['status'] != 'optimal':
raise Exception('optimum not reached by QP')
return pd.Series(np.array(result['x']).ravel())
def calc_te(weights, univ_rets, track_rets):
port_rets = (univ_rets * weights).sum(1)
return (port_rets - track_rets).std()
def plot_returns(port_returns, bmk_returns):
plt.figure()
cum_port = ((1 + port_returns).cumprod() - 1)
cum_bmk = ((1 + bmk_returns).cumprod() - 1)
# cum_port = port_returns.cumsum()
# cum_bmk = bmk_returns.cumsum()
cum_port.plot(label='Portfolio returns')
cum_bmk.plot(label='Benchmark')
plt.title('Portfolio performance')
plt.legend(loc='best')
+1 -1
View File
@@ -133,6 +133,6 @@ def create_predictable_zipline(config, offset=0, simulate=True):
zipline = SimulatedTrading.create_test_zipline(**config)
if simulate:
zipline.simulate(blocking=True)
zipline.drain_zipline(blocking=True)
return zipline, config
+22 -3
View File
@@ -132,6 +132,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now
# -----------------------
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG']
PRICE_FIELDS = ['price', 'open', 'close', 'high', 'low']
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
@@ -428,21 +429,26 @@ def TRADE_FRAME(event):
assert isinstance(event, ndict)
assert event.type == DATASOURCE_TYPE.TRADE
assert isinstance(event.sid, int)
assert isinstance(event.price, numbers.Real)
for field in PRICE_FIELDS:
assert isinstance(event[field], numbers.Real)
assert isinstance(event.volume, numbers.Integral)
PACK_DATE(event)
return msgpack.dumps(tuple([
event.sid,
event.price,
event.open,
event.close,
event.high,
event.low,
event.volume,
event.dt,
event.type,
event.type
]))
def TRADE_UNFRAME(msg):
try:
packed = msgpack.loads(msg)
sid, price, volume, dt, source_type = packed
sid, price, open, close, high, low, volume, dt, source_type = packed
assert isinstance(sid, int)
assert isinstance(price, numbers.Real)
@@ -450,6 +456,10 @@ def TRADE_UNFRAME(msg):
rval = ndict({
'sid' : sid,
'price' : price,
'open' : open,
'close' : close,
'high' : high,
'low' : low,
'volume' : volume,
'dt' : dt,
'type' : source_type
@@ -654,7 +664,13 @@ def tuple_to_date(date_tuple):
dt = dt.replace(microsecond = micros, tzinfo = pytz.utc)
return dt
# Datasource type should completely determine the other fields of a
# message with its type.
DATASOURCE_TYPE = Enum(
'AS_TRADED_EQUITY',
'MERGER',
'SPLIT',
'DIVIDEND',
'TRADE',
'EMPTY',
'DONE'
@@ -720,6 +736,9 @@ def LOG_FRAME(payload):
assert payload.has_key('msg'),\
"LOG_FRAME with no message"
# truncation will only work with strings and msgpack will
# preserve primitives.
payload['msg'] = str(payload['msg'])
return BT_UPDATE_FRAME('LOG', payload)
+50
View File
@@ -237,6 +237,56 @@ class DivByZeroAlgorithm():
def get_sid_filter(self):
return [self.sid]
class InitializeTimeoutAlgorithm():
def __init__(self, sid):
self.sid = sid
self.incr = 0
def initialize(self):
import time
from zipline.gens.tradesimulation import INIT_TIMEOUT
time.sleep(INIT_TIMEOUT + 1)
def set_order(self, order_callable):
pass
def set_logger(self, logger):
pass
def set_portfolio(self, portfolio):
pass
def handle_data(self, data):
pass
def get_sid_filter(self):
return [self.sid]
class TooMuchProcessingAlgorithm():
def __init__(self, sid):
self.sid = sid
def initialize(self):
pass
def set_order(self, order_callable):
pass
def set_logger(self, logger):
pass
def set_portfolio(self, portfolio):
pass
def handle_data(self, data):
# Unless we're running on some sort of
# supercomputer this will hit timeout.
for i in xrange(1000000000):
self.foo = i
def get_sid_filter(self):
return [self.sid]
class TimeoutAlgorithm():
def __init__(self, sid):
+40
View File
@@ -0,0 +1,40 @@
from functools import wraps
from signal import signal
class delayed_signals(object):
"""
Utility to temporary intercept one or more signals while a function or code
block is executed, restore their signal handlers at the end of execution,
and invoke them if the signals were in fact received during execution.
Can be used either as a decorator or a context manager.
Pass in an iterable of signals to intercept.
"""
def handler(self, signum, frame=None):
self.got.append({'signum': signum, 'frame': frame})
def __init__(self, signals):
self.signals = signals
self.handlers = {}
self.got = []
def __enter__(self):
for signum in self.signals:
# signal() returns the old signal handler
self.handlers[signum] = signal(signum, self.handler)
def __exit__(self, time, value, traceback):
for signum, handler in self.handlers.items():
signal(signum, handler)
for signum, frame in ((i['signum'], i['frame']) for i in self.got):
self.handlers[signum](signum, frame)
def __call__(self, fn):
@wraps(fn)
def call_fn(*args, **kwargs):
with self:
outval = fn(*args, **kwargs)
return outval
return call_fn
+9 -6
View File
@@ -174,7 +174,7 @@ def create_random_trade_source(sid, trade_count, trading_environment):
return source
def create_daily_trade_source(sids, trade_count, trading_environment):
def create_daily_trade_source(sids, trade_count, trading_environment, concurrent=False):
"""
creates trade_count trades for each sid in sids list.
@@ -189,11 +189,12 @@ def create_daily_trade_source(sids, trade_count, trading_environment):
sids,
trade_count,
timedelta(days=1),
trading_environment
trading_environment,
concurrent=concurrent
)
def create_minutely_trade_source(sids, trade_count, trading_environment):
def create_minutely_trade_source(sids, trade_count, trading_environment, concurrent=False):
"""
creates trade_count trades for each sid in sids list.
@@ -208,10 +209,11 @@ def create_minutely_trade_source(sids, trade_count, trading_environment):
sids,
trade_count,
timedelta(minutes=1),
trading_environment
trading_environment,
concurrent=concurrent
)
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment):
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment, concurrent=False):
args = tuple()
kwargs = {
@@ -219,7 +221,8 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ
'sids' : sids,
'start' : trading_environment.first_open,
'delta' : trade_time_increment,
'filter' : sids
'filter' : sids,
'concurrent' : concurrent
}
source = SpecificEquityTrades(*args, **kwargs)
+1
View File
@@ -89,6 +89,7 @@ class ZeroMQLogHandler(Handler):
def __init__(self, socket=None, level=NOTSET, filter=None, bubble=False,
context=None, fds = LOG_FIELDS, extra_fds = LOG_EXTRA_FIELDS):
Handler.__init__(self, level, filter, bubble)
try:
import zmq
except ImportError:
+10 -3
View File
@@ -15,6 +15,7 @@ def setup_logger(test, path='/var/log/zipline/zipline.log'):
def teardown_logger(test):
test.log_handler.pop_application()
test.log_handler.close()
def check_list(test, a, b, label):
test.assertTrue(isinstance(a, (list, blist.blist)))
@@ -91,11 +92,13 @@ def create_receiver(socket_addr, ctx):
return receiver
def drain_receiver(receiver):
def drain_receiver(receiver, count=None):
output = []
transaction_count = 0
msg_counter = 0
while True:
msg = receiver.recv()
msg_counter += 1
update = zp.BT_UPDATE_UNFRAME(msg)
output.append(update)
if update['prefix'] == 'PERF':
@@ -106,14 +109,18 @@ def drain_receiver(receiver):
elif update['prefix'] == 'DONE':
break
if count and msg_counter >= count:
break
receiver.close()
del receiver
return output, transaction_count
def assert_single_position(test, zipline):
output, transaction_count = drain_zipline(test, zipline)
def assert_single_position(test, zipline, blocking=False):
output, transaction_count = drain_zipline(test, zipline, p_blocking=blocking)
test.assertEqual(output[-1]['prefix'], 'DONE')
test.assertEqual(
test.zipline_test_config['order_count'],
+128
View File
@@ -0,0 +1,128 @@
import signal
from functools import wraps
from pprint import pprint as pp
from numbers import Number
from logbook import Logger
class Timeout(Exception):
def __init__(self, frame, message=''):
self.frame = frame
self.message = message
class timeout(object):
"""
Utility to make a function raise TimeoutException if it spends
more than a specified number of seconds executing. Can be used
as a decorator to apply a static timeout to a function, or as
a context manager to dynamically add a timeout to a code block.
"""
def __init__(self, seconds, message=''):
self.seconds = seconds
self.message = message
assert isinstance(seconds, Number), "Failed to specify a timeout."
assert seconds > 0, "Timeout must be greater than 0"
def handler(self, signum, frame):
raise Timeout(frame, self.message)
def __call__(self, fn):
@wraps(fn)
def call_fn_with_timeout(*args, **kwargs):
# Set the alarm.
signal.signal(signal.SIGALRM, self.handler)
signal.setitimer(signal.ITIMER_REAL, self.seconds, 0)
try:
outval = fn(*args, **kwargs)
# Deactivate the alarm once we're done so that the
# decorator doesn't have unexpected side-effects later.
# Note that this will still raise Timeout if the
# call to fn takes too long.
finally:
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, signal.SIG_DFL)
# Return the value of fn if it finished before the alarm. This
# won't execute if the Timeout was raised.
return outval
return call_fn_with_timeout
def __enter__(self):
# Set the alarm on entrance.
signal.signal(signal.SIGALRM, self.handler)
signal.setitimer(signal.ITIMER_REAL, self.seconds, 0)
def __exit__(self, type, value, traceback):
# Deactivate the alarm on exit. This will re-raise
# any exceptions raised inside the with block.
signal.signal(signal.SIGALRM, self.handler)
signal.setitimer(signal.ITIMER_REAL, 0, 0)
class heartbeat(object):
"""
Utility to perform pseudo-heartbeat checks on a single-threaded
function. Calls frame_handler on the current stack frame of the
wrapped function every ``interval`` seconds. After ``max_interval``
intervals, raises Timeout. Can be used either as a decorator or
a context manager.
"""
def __init__(self,
interval,
max_intervals,
frame_handler=None,
timeout_message=''):
self.interval = interval
self.max_intervals = max_intervals
self.frame_handler = frame_handler
self.timeout_message = timeout_message
self.count = 0
def handler(self, signum, frame):
self.count += 1
if self.frame_handler:
self.frame_handler(self.count, frame)
if self.count >= self.max_intervals:
raise Timeout(frame, self.timeout_message)
def __call__(self, fn):
@wraps(fn)
def call_fn_with_heartbeat(*args, **kwargs):
# Set a timer to call our handler every ``interval`` seconds.
signal.signal(signal.SIGALRM, self.handler)
signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval)
try:
outval = fn(*args, **kwargs)
finally:
# Deactivate the timer once we're done so that the
# decorator doesn't have unexpected side-effects later.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, signal.SIG_DFL)
self.count = 0
# Return the value of fn if it finished without tripping
# an exception. This won't execute if the Timeout or any
# other exception was raised by self.handle.
return outval
return call_fn_with_heartbeat
def __enter__(self):
# Set a timer to call our handler every N seconds.
self.count = 0
signal.signal(signal.SIGALRM, self.handler)
signal.setitimer(signal.ITIMER_REAL, self.interval, self.interval)
def __exit__(self, type, value, traceback):
# Turn off the timer on exit. This will re-raise any exception raised
# during execution of the with-block
self.count = 0
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, signal.SIG_DFL)