mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 16:47:49 +08:00
@@ -0,0 +1,381 @@
|
||||
import zmq
|
||||
import pytz
|
||||
from pprint import pformat as pf
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from unittest2 import TestCase
|
||||
from collections import defaultdict
|
||||
from zipline.gens.composites import date_sorted_sources, merged_transforms
|
||||
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.gens.transform import Passthrough, StatefulTransform
|
||||
from zipline.gens.mavg import MovingAverage
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
|
||||
from zipline.utils.test_utils import (
|
||||
setup_logger,
|
||||
teardown_logger,
|
||||
create_monitor,
|
||||
launch_monitor
|
||||
)
|
||||
|
||||
from zipline.core import Component
|
||||
from zipline.protocol import (
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME,
|
||||
MERGE_FRAME,
|
||||
MERGE_UNFRAME,
|
||||
SIMULATION_STYLE,
|
||||
PERF_FRAME,
|
||||
BT_UPDATE_UNFRAME
|
||||
)
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
|
||||
import logbook
|
||||
log = logbook.Logger('ComponentTestCase')
|
||||
|
||||
allocator = AddressAllocator(1000)
|
||||
|
||||
|
||||
class ComponentTestCase(TestCase):
|
||||
|
||||
leased_sockets = defaultdict(list)
|
||||
|
||||
def setUp(self):
|
||||
self.zipline_test_config = {
|
||||
'allocator' : allocator,
|
||||
'sid' : 133,
|
||||
'devel' : False,
|
||||
'results_socket' : allocator.lease(1)[0],
|
||||
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
}
|
||||
self.ctx = zmq.Context()
|
||||
setup_logger(self)
|
||||
|
||||
count = 250
|
||||
filter = [2,3]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 2*count,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2002,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
self.source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : count,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2002,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
self.source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
self.environment = create_trading_environment(year = 2002)
|
||||
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
teardown_logger(self)
|
||||
|
||||
def test_source(self):
|
||||
monitor = create_monitor(allocator)
|
||||
socket_uri = allocator.lease(1)[0]
|
||||
count = 100
|
||||
|
||||
filter = [1,2,3,4]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
|
||||
trade_gen = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
|
||||
comp_a = Component(
|
||||
trade_gen,
|
||||
monitor,
|
||||
socket_uri,
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
"source_a"
|
||||
)
|
||||
|
||||
mon_proc = launch_monitor(monitor)
|
||||
|
||||
for event in comp_a:
|
||||
log.info(event)
|
||||
|
||||
# wait for the sending process to exit
|
||||
comp_a.proc.join()
|
||||
mon_proc.join()
|
||||
|
||||
def test_sort(self):
|
||||
monitor = create_monitor(allocator)
|
||||
socket_uris = allocator.lease(3)
|
||||
count = 100
|
||||
|
||||
filter = [1,2,3,4]
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0,tzinfo=pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
trade_gen_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
trade_gen_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'sids' : [3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter,
|
||||
'count' : count
|
||||
}
|
||||
|
||||
trade_gen_c = SpecificEquityTrades(*args_c, **kwargs_c)
|
||||
|
||||
|
||||
comp_a = Component(
|
||||
trade_gen_a,
|
||||
monitor,
|
||||
socket_uris[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
trade_gen_a.get_hash()
|
||||
)
|
||||
|
||||
comp_b = Component(
|
||||
trade_gen_b,
|
||||
monitor,
|
||||
socket_uris[1],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
trade_gen_b.get_hash()
|
||||
)
|
||||
|
||||
comp_c = Component(
|
||||
trade_gen_c,
|
||||
monitor,
|
||||
socket_uris[2],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
trade_gen_c.get_hash()
|
||||
)
|
||||
|
||||
sources = [comp_a, comp_b, comp_c]
|
||||
|
||||
sorted_out = date_sorted_sources(*sources)
|
||||
|
||||
mon_proc = launch_monitor(monitor)
|
||||
|
||||
prev = None
|
||||
sort_count = 0
|
||||
for msg in sorted_out:
|
||||
if prev:
|
||||
self.assertTrue(msg.dt >= prev.dt, \
|
||||
"Messages should be in date ascending order")
|
||||
prev = msg
|
||||
sort_count += 1
|
||||
|
||||
self.assertEqual(count*3, sort_count)
|
||||
|
||||
# wait for processes to finish
|
||||
comp_a.proc.join()
|
||||
comp_b.proc.join()
|
||||
comp_c.proc.join()
|
||||
mon_proc.join()
|
||||
|
||||
|
||||
def test_full(self):
|
||||
monitor = create_monitor(allocator)
|
||||
|
||||
# ------------------------
|
||||
# Run sources in dedicated processes
|
||||
comp_a = Component(
|
||||
self.source_a,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
self.source_a.get_hash()
|
||||
)
|
||||
|
||||
comp_b = Component(
|
||||
self.source_b,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
DATASOURCE_FRAME,
|
||||
DATASOURCE_UNFRAME,
|
||||
self.source_b.get_hash()
|
||||
)
|
||||
|
||||
# Date sort the sources, and run the sort in a dedicated
|
||||
# process
|
||||
sources = [comp_a, comp_b]
|
||||
|
||||
sorted_out = date_sorted_sources(*sources)
|
||||
|
||||
sorted = Component(
|
||||
sorted_out,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME,
|
||||
"sort"
|
||||
)
|
||||
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
['price'],
|
||||
market_aware = False,
|
||||
delta=timedelta(minutes = 20)
|
||||
)
|
||||
|
||||
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
merged = Component(
|
||||
merged_gen,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
MERGE_FRAME,
|
||||
MERGE_UNFRAME,
|
||||
"merge"
|
||||
)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
|
||||
style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
trading_client = tsc(algo, self.environment, style)
|
||||
tsc_gen = trading_client.simulate(merged)
|
||||
|
||||
tsc_comp = Component(
|
||||
tsc_gen,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
PERF_FRAME,
|
||||
BT_UPDATE_UNFRAME,
|
||||
"tsc"
|
||||
)
|
||||
mon_proc = launch_monitor(monitor)
|
||||
for message in tsc_comp:
|
||||
log.info(pf(message))
|
||||
|
||||
|
||||
# wait for processes to finish
|
||||
comp_a.proc.join()
|
||||
comp_b.proc.join()
|
||||
sorted.proc.join()
|
||||
merged.proc.join()
|
||||
tsc_comp.proc.join()
|
||||
mon_proc.join()
|
||||
return
|
||||
|
||||
def test_single_thread(self):
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
|
||||
sorted = date_sorted_sources(self.source_a, self.source_b)
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
['price'],
|
||||
market_aware=False,
|
||||
delta=timedelta(minutes = 20),
|
||||
)
|
||||
|
||||
merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
trading_client = tsc(algo, self.environment, style)
|
||||
for message in trading_client.simulate(merged):
|
||||
log.info(pf(message))
|
||||
|
||||
def test_compound(self):
|
||||
monitor = create_monitor(allocator)
|
||||
|
||||
sorted_out = date_sorted_sources(self.source_a, self.source_b)
|
||||
|
||||
sorted = Component(
|
||||
sorted_out,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
FEED_FRAME,
|
||||
FEED_UNFRAME,
|
||||
"feed"
|
||||
)
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
['price'],
|
||||
market_aware = False,
|
||||
delta=timedelta(minutes = 20)
|
||||
)
|
||||
|
||||
merged_gen = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
merged = Component(
|
||||
merged_gen,
|
||||
monitor,
|
||||
allocator.lease(1)[0],
|
||||
MERGE_FRAME,
|
||||
MERGE_UNFRAME,
|
||||
"merge"
|
||||
)
|
||||
|
||||
algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
trading_client = tsc(algo, self.environment, style)
|
||||
tsc_gen = trading_client.simulate(merged)
|
||||
|
||||
|
||||
mon_proc = launch_monitor(monitor)
|
||||
for message in tsc_gen:
|
||||
log.info(pf(message))
|
||||
|
||||
|
||||
# wait for processes to finish
|
||||
sorted.proc.join()
|
||||
merged.proc.join()
|
||||
mon_proc.join()
|
||||
return
|
||||
@@ -7,30 +7,30 @@ from zipline.test_algorithms import ExceptionAlgorithm, DivByZeroAlgorithm
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.lines import SimulatedTrading
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
|
||||
from zipline.utils.test_utils import \
|
||||
drain_zipline, \
|
||||
check, \
|
||||
setup_logger, \
|
||||
teardown_logger
|
||||
teardown_logger, \
|
||||
ExceptionSource, \
|
||||
ExceptionTransform
|
||||
|
||||
DEFAULT_TIMEOUT = 15 # seconds
|
||||
EXTENDED_TIMEOUT = 90
|
||||
|
||||
allocator = AddressAllocator(1000)
|
||||
|
||||
|
||||
class ExceptionTestCase(TestCase):
|
||||
|
||||
leased_sockets = defaultdict(list)
|
||||
|
||||
def setUp(self):
|
||||
self.zipline_test_config = {
|
||||
'allocator' : allocator,
|
||||
'sid' : 133,
|
||||
'devel' : False,
|
||||
'results_socket' : allocator.lease(1)[0],
|
||||
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
'sid' : 133,
|
||||
'results_socket_uri' : allocator.lease(1)[0],
|
||||
'simulation_style' : SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
}
|
||||
self.ctx = zmq.Context()
|
||||
setup_logger(self)
|
||||
@@ -39,6 +39,39 @@ class ExceptionTestCase(TestCase):
|
||||
self.ctx.term()
|
||||
teardown_logger(self)
|
||||
|
||||
def test_datasource_exception(self):
|
||||
self.zipline_test_config['trade_source'] = ExceptionSource()
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
assert len(output) == 1
|
||||
assert output[0]['prefix'] == 'EXCEPTION'
|
||||
message = output[0]['payload']
|
||||
for field in ['date', 'message', 'name', 'stack']:
|
||||
assert field in message.keys()
|
||||
|
||||
assert message['message'] == 'integer division or modulo by zero'
|
||||
assert message['name'] == 'ZeroDivisionError'
|
||||
|
||||
def test_tranform_exception(self):
|
||||
exc_tnfm = StatefulTransform(ExceptionTransform)
|
||||
self.zipline_test_config['transforms'] = [exc_tnfm]
|
||||
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
assert len(output) == 1
|
||||
assert output[0]['prefix'] == 'EXCEPTION'
|
||||
message = output[0]['payload']
|
||||
for field in ['date', 'message', 'name', 'stack']:
|
||||
assert field in message.keys()
|
||||
|
||||
assert message['message'] == 'An assertion message'
|
||||
assert message['name'] == 'AssertionError'
|
||||
|
||||
|
||||
def test_exception_in_init(self):
|
||||
# Simulation
|
||||
# ----------
|
||||
@@ -52,19 +85,17 @@ class ExceptionTestCase(TestCase):
|
||||
**self.zipline_test_config
|
||||
)
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
self.assertEqual(len(output), 1)
|
||||
|
||||
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
|
||||
payload = output[-1]['payload']
|
||||
self.assertTrue(payload['date'])
|
||||
del payload['date']
|
||||
check(self, payload, INITIALIZE_TB)
|
||||
|
||||
self.assertTrue(zipline.sim.ready())
|
||||
self.assertFalse(zipline.sim.exception)
|
||||
|
||||
self.assertEqual(payload['message'],'Algo exception in initialize')
|
||||
self.assertEqual(payload['name'],'Exception')
|
||||
# make sure our path shortening is working
|
||||
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
|
||||
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
|
||||
|
||||
def test_exception_in_handle_data(self):
|
||||
|
||||
# Simulation
|
||||
# ----------
|
||||
self.zipline_test_config['algorithm'] = \
|
||||
@@ -78,16 +109,15 @@ class ExceptionTestCase(TestCase):
|
||||
)
|
||||
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
|
||||
self.assertEqual(len(output), 1)
|
||||
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
|
||||
payload = output[-1]['payload']
|
||||
self.assertTrue(payload['date'])
|
||||
del payload['date']
|
||||
check(self, payload, HANDLE_DATA_TB)
|
||||
|
||||
self.assertTrue(zipline.sim.ready())
|
||||
self.assertFalse(zipline.sim.exception)
|
||||
self.assertEqual(payload['message'],'Algo exception in handle_data')
|
||||
self.assertEqual(payload['name'],'Exception')
|
||||
# make sure our path shortening is working
|
||||
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
|
||||
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
|
||||
|
||||
def test_zerodivision_exception_in_handle_data(self):
|
||||
|
||||
@@ -103,95 +133,13 @@ class ExceptionTestCase(TestCase):
|
||||
)
|
||||
|
||||
output, _ = drain_zipline(self, zipline)
|
||||
self.assertEqual(len(output), 5)
|
||||
|
||||
self.assertEqual(output[-1]['prefix'], 'EXCEPTION')
|
||||
payload = output[-1]['payload']
|
||||
self.assertTrue(payload['date'])
|
||||
del payload['date']
|
||||
check(self, payload, ZERO_DIV_TB)
|
||||
|
||||
self.assertTrue(zipline.sim.ready())
|
||||
self.assertFalse(zipline.sim.exception)
|
||||
|
||||
# TODO:
|
||||
# - define more zipline failure modes: exception in other
|
||||
# components, exception in Monitor, etc. write tests
|
||||
# for those scenarios.
|
||||
|
||||
|
||||
|
||||
INITIALIZE_TB =\
|
||||
{'message': 'Algo exception in initialize',
|
||||
'name': 'Exception',
|
||||
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.initialize_algo()',
|
||||
'lineno': 97,
|
||||
'method': 'do_work'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.do_op(self.algorithm.initialize)',
|
||||
'lineno': 80,
|
||||
'method': 'initialize_algo'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'callable_op(*args, **kwargs)',
|
||||
'lineno': 210,
|
||||
'method': 'do_op'},
|
||||
{'filename': '/zipline/test_algorithms.py',
|
||||
'line': 'raise Exception("Algo exception in initialize")',
|
||||
'lineno': 166,
|
||||
'method': 'initialize'}]}
|
||||
|
||||
HANDLE_DATA_TB =\
|
||||
{
|
||||
'message': 'Algo exception in handle_data',
|
||||
'name': 'Exception',
|
||||
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.process_event(event)',
|
||||
'lineno': 116,
|
||||
'method': 'do_work'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.run_algorithm()',
|
||||
'lineno': 164,
|
||||
'method': 'process_event'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.do_op(self.algorithm.handle_data, data)',
|
||||
'lineno': 186,
|
||||
'method': 'run_algorithm'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'callable_op(*args, **kwargs)',
|
||||
'lineno': 210,
|
||||
'method': 'do_op'},
|
||||
{'filename': '/zipline/test_algorithms.py',
|
||||
'line': 'raise Exception("Algo exception in handle_data")',
|
||||
'lineno': 187,
|
||||
'method': 'handle_data'}]}
|
||||
|
||||
|
||||
ZERO_DIV_TB= \
|
||||
{'message': 'integer division or modulo by zero',
|
||||
'name': 'ZeroDivisionError',
|
||||
'stack': [{'filename': '/zipline/core/component.py', 'line': 'self._run()', 'lineno': 204, 'method': 'run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.loop()', 'lineno': 195, 'method': '_run'},
|
||||
{'filename': '/zipline/core/component.py', 'line': 'self.do_work()', 'lineno': 235, 'method': 'loop'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.process_event(event)',
|
||||
'lineno': 116,
|
||||
'method': 'do_work'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.run_algorithm()',
|
||||
'lineno': 164,
|
||||
'method': 'process_event'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'self.do_op(self.algorithm.handle_data, data)',
|
||||
'lineno': 186,
|
||||
'method': 'run_algorithm'},
|
||||
{'filename': '/zipline/components/tradesimulation.py',
|
||||
'line': 'callable_op(*args, **kwargs)',
|
||||
'lineno': 210,
|
||||
'method': 'do_op'},
|
||||
{'filename': '/zipline/test_algorithms.py', 'line': '5/0', 'lineno': 218, 'method': 'handle_data'}]}
|
||||
self.assertEqual(payload['message'],'integer division or modulo by zero')
|
||||
self.assertEqual(payload['name'],'ZeroDivisionError')
|
||||
# make sure our path shortening is working
|
||||
self.assertEqual(payload['stack'][0]['filename'], '/zipline/lines.py')
|
||||
self.assertEqual(payload['stack'][-1]['filename'], '/zipline/test_algorithms.py')
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
from unittest2 import TestCase
|
||||
from itertools import cycle, chain
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.sort import \
|
||||
date_sort, \
|
||||
ready, \
|
||||
done, \
|
||||
queue_is_ready,\
|
||||
queue_is_done
|
||||
from zipline.gens.utils import hash_args, alternate
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
class HelperTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_individual_queue_logic(self):
|
||||
queue = deque()
|
||||
# Empty queues are neither done nor ready.
|
||||
assert not queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
queue.append(to_dt('foo'))
|
||||
assert queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
queue.appendleft(to_dt('DONE'))
|
||||
assert queue_is_ready(queue)
|
||||
|
||||
# Checking done when we have a message after done will trip an assert.
|
||||
self.assertRaises(AssertionError, queue_is_done, queue)
|
||||
|
||||
queue.pop()
|
||||
assert queue_is_ready(queue)
|
||||
assert queue_is_done(queue)
|
||||
|
||||
def test_pop_logic(self):
|
||||
sources = {}
|
||||
ids = ['a', 'b', 'c']
|
||||
for id in ids:
|
||||
sources[id] = deque()
|
||||
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
# All sources must have a message to be ready/done
|
||||
sources['a'].append(to_dt("datetime"))
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
sources['a'].pop()
|
||||
|
||||
for id in ids:
|
||||
sources[id].append(to_dt("datetime"))
|
||||
|
||||
assert ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].appendleft(to_dt("DONE"))
|
||||
|
||||
# ["DONE", message] will trip an assert in queue_is_done.
|
||||
assert ready(sources)
|
||||
self.assertRaises(AssertionError, done, sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].pop()
|
||||
|
||||
assert ready(sources)
|
||||
assert done(sources)
|
||||
|
||||
class DateSortTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def run_date_sort(self, events, expected, source_ids):
|
||||
"""
|
||||
Take a list of events, their source_ids, and an expected sorting.
|
||||
Assert that date_sort's output agrees with expected.
|
||||
"""
|
||||
sort_gen = date_sort(events, source_ids)
|
||||
l = list(sort_gen)
|
||||
assert l == expected
|
||||
|
||||
def test_single_source(self):
|
||||
source_ids = ['a']
|
||||
# 100 events, increasing by a minute at a time.
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
dates = list(date_gen(count = 100))
|
||||
dates.append("DONE")
|
||||
|
||||
# [('a', date1, type), ('a', date2, type), ... ('a', "DONE", type)]
|
||||
event_args = zip(cycle(source_ids), iter(dates), cycle([type]))
|
||||
|
||||
# Turn event_args into proper events.
|
||||
events = [mock_data_unframe(*args) for args in event_args]
|
||||
|
||||
# We don't expected Feed to yield the last event.
|
||||
expected = events[:-1]
|
||||
|
||||
event_gen = (e for e in events)
|
||||
|
||||
self.run_date_sort(event_gen, expected, source_ids)
|
||||
|
||||
def test_multi_source(self):
|
||||
source_ids = ['a', 'b']
|
||||
type = zp.DATASOURCE_TYPE.TRADE
|
||||
|
||||
# Set up source 'a'. Outputs 20 events with 2 minute deltas.
|
||||
delta_a = timedelta(minutes = 2)
|
||||
dates_a = list(date_gen(delta = delta_a, count = 20))
|
||||
dates_a.append("DONE")
|
||||
|
||||
events_a_args = zip(cycle(['a']), iter(dates_a), cycle([type]))
|
||||
events_a = [mock_data_unframe(*args) for args in events_a_args]
|
||||
|
||||
# Set up source 'b'. Outputs 10 events with 1 minute deltas.
|
||||
delta_b = timedelta(minutes = 1)
|
||||
dates_b = list(date_gen(delta = delta_b, count = 10))
|
||||
dates_b.append("DONE")
|
||||
|
||||
events_b_args = zip(cycle(['b']), iter(dates_b), cycle([type]))
|
||||
events_b = [mock_data_unframe(*args) for args in events_b_args]
|
||||
|
||||
# The expected output is all non-DONE events in both a and b,
|
||||
# sorted first by dt and then by source_id.
|
||||
non_dones = events_a[:-1] + events_b[:-1]
|
||||
expected = sorted(non_dones, compare_by_dt_source_id)
|
||||
|
||||
# Alternating between a and b.
|
||||
interleaved = alternate(iter(events_a), iter(events_b))
|
||||
self.run_date_sort(interleaved, expected, source_ids)
|
||||
|
||||
# All of a, then all of b.
|
||||
|
||||
sequential = chain(iter(events_a), iter(events_b))
|
||||
self.run_date_sort(sequential, expected, source_ids)
|
||||
|
||||
def test_sorted_sources(self):
|
||||
|
||||
filter = [1,2]
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
# Set up source d. This should produce no events because the
|
||||
# internal sids don't match the filter.
|
||||
args_d = tuple()
|
||||
kwargs_d = {'sids' : [3,4],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
|
||||
sources = (SpecificEquityTrades,) * 4
|
||||
source_args = (args_a, args_b, args_c, args_d)
|
||||
source_kwargs = (kwargs_a, kwargs_b, kwargs_c, kwargs_d)
|
||||
|
||||
# Generate our expected source_ids.
|
||||
zip_args = zip(source_args, source_kwargs)
|
||||
expected_ids = ["SpecificEquityTrades" + hash_args(*args, **kwargs)
|
||||
for args, kwargs in zip_args]
|
||||
|
||||
# Pipe our sources into sort.
|
||||
sort_out = date_sorted_sources(sources, source_args, source_kwargs)
|
||||
|
||||
# Read all the values from sort and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(sort_out)
|
||||
copy = to_list[:]
|
||||
for e in to_list:
|
||||
# All events should match one of our expected source_ids.
|
||||
assert e.source_id in expected_ids
|
||||
# But none of them should match source_d.
|
||||
assert e.source_id != hash_args(*args_d, **kwargs_d)
|
||||
|
||||
expected = sorted(copy, compare_by_dt_source_id)
|
||||
assert to_list == expected
|
||||
|
||||
def mock_data_unframe(source_id, dt, type):
|
||||
event = ndict()
|
||||
event.source_id = source_id
|
||||
event.dt = dt
|
||||
event.type = type
|
||||
return event
|
||||
|
||||
def to_dt(val):
|
||||
return ndict({'dt': val})
|
||||
|
||||
def compare_by_dt_source_id(x,y):
|
||||
if x.dt < y.dt:
|
||||
return -1
|
||||
elif x.dt > y.dt:
|
||||
return 1
|
||||
|
||||
elif x.source_id < y.source_id:
|
||||
return -1
|
||||
elif x.source_id > y.source_id:
|
||||
return 1
|
||||
|
||||
else:
|
||||
return 0
|
||||
+8
-17
@@ -11,7 +11,6 @@ from collections import defaultdict
|
||||
from nose.tools import timed
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
@@ -19,10 +18,9 @@ from zipline.core.devsimulator import AddressAllocator
|
||||
from zipline.lines import SimulatedTrading
|
||||
from zipline.finance.performance import PerformanceTracker
|
||||
from zipline.utils.protocol_utils import ndict
|
||||
from zipline.finance.trading import TransactionSimulator, SIMULATION_STYLE
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
from zipline.utils.test_utils import \
|
||||
drain_zipline, \
|
||||
check, \
|
||||
setup_logger, \
|
||||
teardown_logger,\
|
||||
assert_single_position
|
||||
@@ -39,10 +37,8 @@ class FinanceTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.zipline_test_config = {
|
||||
'allocator' : allocator,
|
||||
'sid' : 133,
|
||||
#'devel' : True,
|
||||
'results_socket' : allocator.lease(1)[0]
|
||||
'sid' : 133,
|
||||
'results_socket_uri' : allocator.lease(1)[0]
|
||||
}
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
@@ -60,7 +56,7 @@ class FinanceTestCase(TestCase):
|
||||
trading_environment
|
||||
)
|
||||
prev = None
|
||||
for trade in trade_source.event_list:
|
||||
for trade in trade_source:
|
||||
if prev:
|
||||
self.assertTrue(trade.dt > prev.dt)
|
||||
prev = trade
|
||||
@@ -123,7 +119,6 @@ class FinanceTestCase(TestCase):
|
||||
self.zipline_test_config['order_count'] = 100
|
||||
self.zipline_test_config['trade_count'] = 200
|
||||
zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config)
|
||||
|
||||
assert_single_position(self, zipline)
|
||||
|
||||
#@timed(DEFAULT_TIMEOUT)
|
||||
@@ -148,9 +143,6 @@ class FinanceTestCase(TestCase):
|
||||
)
|
||||
output, transaction_count = drain_zipline(self, zipline)
|
||||
|
||||
self.assertTrue(zipline.sim.ready())
|
||||
self.assertFalse(zipline.sim.exception)
|
||||
|
||||
#check that the algorithm received no events
|
||||
self.assertEqual(
|
||||
0,
|
||||
@@ -301,12 +293,12 @@ class FinanceTestCase(TestCase):
|
||||
# if present, expect transaction amounts to match orders exactly.
|
||||
complete_fill = params.get('complete_fill')
|
||||
|
||||
sid = 1
|
||||
trading_environment = factory.create_trading_environment()
|
||||
trade_sim = TransactionSimulator()
|
||||
trade_sim = TransactionSimulator([sid])
|
||||
price = [10.1] * trade_count
|
||||
volume = [100] * trade_count
|
||||
start_date = trading_environment.first_open
|
||||
sid = 1
|
||||
|
||||
generated_trades = factory.create_trade_history(
|
||||
sid,
|
||||
@@ -330,7 +322,7 @@ class FinanceTestCase(TestCase):
|
||||
'dt' : order_date
|
||||
})
|
||||
|
||||
trade_sim.add_open_order(order)
|
||||
trade_sim.place_order(order)
|
||||
|
||||
order_date = order_date + order_interval
|
||||
# move after market orders to just after market next
|
||||
@@ -353,14 +345,13 @@ class FinanceTestCase(TestCase):
|
||||
self.assertEqual(order.amount, order_amount * alternator**i)
|
||||
|
||||
|
||||
tracker = PerformanceTracker(trading_environment)
|
||||
tracker = PerformanceTracker(trading_environment, [sid])
|
||||
|
||||
# this approximates the loop inside TradingSimulationClient
|
||||
transactions = []
|
||||
for trade in generated_trades:
|
||||
if trade_delay:
|
||||
trade.dt = trade.dt + trade_delay
|
||||
|
||||
txn = trade_sim.apply_trade_to_open_orders(trade)
|
||||
if txn:
|
||||
transactions.append(txn)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from zipline.utils.test_utils import setup_logger, teardown_logger
|
||||
from unittest2 import TestCase, skip
|
||||
|
||||
from zipline.core.monitor import Controller
|
||||
from zipline.core.monitor import Monitor
|
||||
|
||||
class TestMonitor(TestCase):
|
||||
def setUp(self):
|
||||
@@ -14,13 +14,15 @@ class TestMonitor(TestCase):
|
||||
def test_init(self):
|
||||
pub_socket = 'tcp://127.0.0.1:5000'
|
||||
route_socket = 'tcp://127.0.0.1:5001'
|
||||
exception_socket = 'tcp://127.0.0.1:5002'
|
||||
|
||||
con = Controller(pub_socket, route_socket)
|
||||
con.manage([])
|
||||
mon = Monitor(pub_socket, route_socket, exception_socket)
|
||||
mon.manage([])
|
||||
|
||||
def test_init_topology(self):
|
||||
pub_socket = 'tcp://127.0.0.1:5000'
|
||||
route_socket = 'tcp://127.0.0.1:5001'
|
||||
exception_socket = 'tcp://127.0.0.1:5002'
|
||||
|
||||
con = Controller(pub_socket, route_socket, )
|
||||
con.manage([ 'a', 'b', 'c', 'd' ])
|
||||
mon = Monitor(pub_socket, route_socket, exception_socket)
|
||||
mon.manage([ 'a', 'b', 'c', 'd' ])
|
||||
|
||||
@@ -543,7 +543,10 @@ shares in position"
|
||||
self.trading_environment.capital_base = 1000.0
|
||||
self.trading_environment.frame_index = ['sid', 'volume', 'dt', \
|
||||
'price', 'changed']
|
||||
perf_tracker = perf.PerformanceTracker(self.trading_environment)
|
||||
perf_tracker = perf.PerformanceTracker(
|
||||
self.trading_environment,
|
||||
[sid, sid2]
|
||||
)
|
||||
|
||||
for event in trade_history:
|
||||
#create a transaction for all but
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""
|
||||
Test the FRAME/UNFRAME functions in the sequence expected from ziplines.
|
||||
"""
|
||||
import pytz
|
||||
|
||||
from unittest2 import TestCase
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
@@ -10,10 +8,8 @@ from collections import defaultdict
|
||||
from nose.tools import timed
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
from zipline.utils import logger
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.finance.sources import SpecificEquityTrades
|
||||
|
||||
DEFAULT_TIMEOUT = 5 # seconds
|
||||
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
import pytz
|
||||
|
||||
from unittest2 import TestCase
|
||||
from itertools import cycle, chain, izip, izip_longest
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.sort import \
|
||||
date_sort, \
|
||||
ready, \
|
||||
done, \
|
||||
queue_is_ready,\
|
||||
queue_is_done
|
||||
from zipline.gens.utils import hash_args, alternate, done_message
|
||||
from zipline.gens.tradegens import date_gen, SpecificEquityTrades
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
class HelperTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_individual_queue_logic(self):
|
||||
queue = deque()
|
||||
# Empty queues are neither done nor ready.
|
||||
assert not queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
queue.append(to_dt('foo'))
|
||||
assert queue_is_ready(queue)
|
||||
assert not queue_is_done(queue)
|
||||
|
||||
|
||||
queue.appendleft(to_dt('DONE'))
|
||||
assert queue_is_ready(queue)
|
||||
|
||||
# Checking done when we have a message after done will trip an assert.
|
||||
self.assertRaises(AssertionError, queue_is_done, queue)
|
||||
|
||||
queue.pop()
|
||||
assert queue_is_ready(queue)
|
||||
assert queue_is_done(queue)
|
||||
|
||||
def test_pop_logic(self):
|
||||
sources = {}
|
||||
ids = ['a', 'b', 'c']
|
||||
for id in ids:
|
||||
sources[id] = deque()
|
||||
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
# All sources must have a message to be ready/done
|
||||
sources['a'].append(to_dt("datetime"))
|
||||
assert not ready(sources)
|
||||
assert not done(sources)
|
||||
sources['a'].pop()
|
||||
|
||||
for id in ids:
|
||||
sources[id].append(to_dt("datetime"))
|
||||
|
||||
assert ready(sources)
|
||||
assert not done(sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].appendleft(to_dt("DONE"))
|
||||
|
||||
# ["DONE", message] will trip an assert in queue_is_done.
|
||||
assert ready(sources)
|
||||
self.assertRaises(AssertionError, done, sources)
|
||||
|
||||
for id in ids:
|
||||
sources[id].pop()
|
||||
|
||||
assert ready(sources)
|
||||
assert done(sources)
|
||||
|
||||
class DateSortTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def run_date_sort(self, event_stream, expected, source_ids):
|
||||
"""
|
||||
Take a list of events, their source_ids, and an expected sorting.
|
||||
Assert that date_sort's output agrees with expected.
|
||||
"""
|
||||
sort_out = date_sort(event_stream, source_ids)
|
||||
for m1, m2 in izip_longest(sort_out, expected):
|
||||
assert m1 == m2
|
||||
|
||||
def test_single_source(self):
|
||||
|
||||
# Just using the built-in defaults. See
|
||||
# zipline/gens/tradegens.py
|
||||
source = SpecificEquityTrades()
|
||||
expected = list(source)
|
||||
source.rewind()
|
||||
# The raw source doesn't handle done messaging, so we need to
|
||||
# append a done message for sort to work properly.
|
||||
with_done = chain(source, [done_message(source.get_hash())])
|
||||
self.run_date_sort(with_done, expected, [source.get_hash()])
|
||||
|
||||
def test_multi_source(self):
|
||||
|
||||
filter = [2,3]
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 100,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 100,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
all_events = list(chain(source_a, source_b))
|
||||
|
||||
# The expected output is all events, sorted by dt with
|
||||
# source_id as a tiebreaker.
|
||||
expected = sorted(all_events, comp)
|
||||
source_ids = [source_a.get_hash(), source_b.get_hash()]
|
||||
|
||||
# Generating the events list consumes the sources. Rewind them
|
||||
# for testing.
|
||||
source_a.rewind()
|
||||
source_b.rewind()
|
||||
|
||||
# Append a done message to each source.
|
||||
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
||||
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
||||
|
||||
interleaved = alternate(with_done_a, with_done_b)
|
||||
|
||||
# Test sort with alternating messages from source_a and
|
||||
# source_b.
|
||||
self.run_date_sort(interleaved, expected, source_ids)
|
||||
|
||||
source_a.rewind()
|
||||
source_b.rewind()
|
||||
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
||||
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
||||
|
||||
sequential = chain(with_done_a, with_done_b)
|
||||
|
||||
# Test sort with all messages from a, followed by all messages
|
||||
# from b.
|
||||
|
||||
self.run_date_sort(sequential, expected, source_ids)
|
||||
|
||||
|
||||
def test_sort_composite(self):
|
||||
|
||||
filter = [1,2]
|
||||
|
||||
#Set up source a. One hour between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 100,
|
||||
'sids' : [1],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. One day between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 50,
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(days = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
#Set up source c. One minute between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'count' : 150,
|
||||
'sids' : [1,2],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_c = SpecificEquityTrades(*args_c, **kwargs_c)
|
||||
# Set up source d. This should produce no events because the
|
||||
# internal sids don't match the filter.
|
||||
args_d = tuple()
|
||||
kwargs_d = {
|
||||
'count' : 50,
|
||||
'sids' : [3],
|
||||
'start' : datetime(2012,6,6,0),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
source_d = SpecificEquityTrades(*args_d, **kwargs_d)
|
||||
sources = [source_a, source_b, source_c, source_d]
|
||||
hashes = [source.get_hash() for source in sources]
|
||||
|
||||
sort_out = date_sorted_sources(*sources)
|
||||
|
||||
# Read all the values from sort and assert that they arrive in
|
||||
# the correct sorting with the expected hash values.
|
||||
to_list = list(sort_out)
|
||||
copy = to_list[:]
|
||||
|
||||
# We should have 300 events (100 from a, 150 from b, 50 from c)
|
||||
assert len(to_list) == 300
|
||||
|
||||
for e in to_list:
|
||||
# All events should match one of our expected source_ids.
|
||||
assert e.source_id in hashes
|
||||
# But none of them should match source_d.
|
||||
assert e.source_id != source_d.get_hash()
|
||||
|
||||
# The events should be sorted by dt, with source_id as tiebreaker.
|
||||
expected = sorted(copy, comp)
|
||||
|
||||
assert to_list == expected
|
||||
|
||||
def compare_by_dt_source_id(x,y):
|
||||
if x.dt < y.dt:
|
||||
return -1
|
||||
elif x.dt > y.dt:
|
||||
return 1
|
||||
|
||||
elif x.source_id < y.source_id:
|
||||
return -1
|
||||
elif x.source_id > y.source_id:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
#Alias for ease of use
|
||||
comp = compare_by_dt_source_id
|
||||
|
||||
def to_dt(msg):
|
||||
return ndict({'dt': msg})
|
||||
+221
-57
@@ -1,102 +1,266 @@
|
||||
from datetime import timedelta
|
||||
import pytz
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from collections import defaultdict
|
||||
from unittest2 import TestCase
|
||||
|
||||
from zipline import ndict
|
||||
|
||||
from zipline.lines import SimulatedTrading
|
||||
|
||||
from zipline.utils.test_utils import setup_logger, teardown_logger
|
||||
from zipline.utils.date_utils import utcnow
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import StatefulTransform, EventWindow
|
||||
from zipline.gens.vwap import VWAP
|
||||
from zipline.gens.mavg import MovingAverage
|
||||
from zipline.gens.returns import Returns
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
from zipline.finance.vwap import DailyVWAP, VWAPTransform
|
||||
from zipline.finance.returns import ReturnsFromPriorClose
|
||||
from zipline.finance.movingaverage import MovingAverage
|
||||
from zipline.lines import SimulatedTrading
|
||||
from zipline.core.devsimulator import AddressAllocator
|
||||
|
||||
allocator = AddressAllocator(1000)
|
||||
def to_dt(msg):
|
||||
return ndict({'dt': msg})
|
||||
|
||||
class ZiplineWithTransformsTestCase(TestCase):
|
||||
leased_sockets = defaultdict(list)
|
||||
class NoopEventWindow(EventWindow):
|
||||
"""
|
||||
A no-op EventWindow subclass for testing the base EventWindow logic.
|
||||
Keeps lists of all added and dropped events.
|
||||
"""
|
||||
def __init__(self, market_aware, days, delta):
|
||||
EventWindow.__init__(self, market_aware, days, delta)
|
||||
|
||||
self.added = []
|
||||
self.removed = []
|
||||
|
||||
def handle_add(self, event):
|
||||
self.added.append(event)
|
||||
|
||||
def handle_remove(self, event):
|
||||
self.removed.append(event)
|
||||
|
||||
class EventWindowTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# skip ahead 100 spots
|
||||
allocator.lease(100)
|
||||
self.trading_environment = factory.create_trading_environment()
|
||||
self.zipline_test_config = {
|
||||
'allocator' : allocator,
|
||||
'sid' : 133,
|
||||
'devel' : True
|
||||
}
|
||||
setup_logger(self, '/var/log/qexec/qexed.log')
|
||||
setup_logger(self)
|
||||
|
||||
# Constants calling before open, during the day, and after
|
||||
# close on a valid trading day.
|
||||
self.pre_open = datetime(2012, 8, 7, 13, tzinfo = pytz.utc)
|
||||
self.mid_day = datetime(2012, 8, 7, 15, tzinfo = pytz.utc)
|
||||
self.post_close = datetime(2012, 8, 7, 22, tzinfo = pytz.utc)
|
||||
|
||||
def tearDown(self):
|
||||
teardown_logger(self)
|
||||
# Constants calling before open, during the day, and after
|
||||
# close on a saturday.
|
||||
self.pre_open_saturday = datetime(2012, 8, 11, 13, tzinfo = pytz.utc)
|
||||
self.mid_day_saturday = datetime(2012, 8, 11, 15, tzinfo = pytz.utc)
|
||||
self.post_close_saturday = datetime(2012, 8, 11, 22, tzinfo = pytz.utc)
|
||||
|
||||
def test_vwap_tnfm(self):
|
||||
zipline = SimulatedTrading.create_test_zipline(
|
||||
**self.zipline_test_config
|
||||
# Constants calling before open, during the day, and after
|
||||
# close on a holiday.
|
||||
self.pre_open_holiday = datetime(2012, 12, 25, 13, tzinfo = pytz.utc)
|
||||
self.mid_day_holiday = datetime(2012, 12, 25, tzinfo = pytz.utc)
|
||||
self.post_close_holiday = datetime(2012, 12, 25, 22, tzinfo = pytz.utc)
|
||||
|
||||
def test_event_window_with_timedelta(self):
|
||||
|
||||
# Keep all events within a 5 minute window.
|
||||
window = NoopEventWindow(
|
||||
market_aware = False,
|
||||
delta = timedelta(minutes = 5),
|
||||
days = None
|
||||
)
|
||||
vwap = VWAPTransform("vwap_10", daycount=10)
|
||||
zipline.add_transform(vwap)
|
||||
now = utcnow()
|
||||
|
||||
zipline.simulate(blocking=True)
|
||||
# 15 dates, increasing in 1 minute increments.
|
||||
dates = [now + i * timedelta(minutes = 1)
|
||||
for i in xrange(15)]
|
||||
|
||||
self.assertTrue(zipline.sim.ready())
|
||||
self.assertFalse(zipline.sim.exception)
|
||||
# Turn the dates into the format required by EventWindow.
|
||||
dt_messages = [to_dt(date) for date in dates]
|
||||
|
||||
# Run all messages through the window and assert that we're adding
|
||||
# and removing messages appropriately. We start the enumeration at 1
|
||||
# for convenience.
|
||||
for num, message in enumerate(dt_messages, 1):
|
||||
window.update(message)
|
||||
|
||||
# Assert that we've added the correct number of events.
|
||||
assert len(window.added) == num
|
||||
|
||||
# Assert that we removed only events that fall outside (or
|
||||
# on the boundary of) the delta.
|
||||
for dropped in window.removed:
|
||||
assert message.dt - dropped.dt >= timedelta(minutes = 5)
|
||||
|
||||
def test_market_aware_window(self):
|
||||
window = NoopEventWindow(
|
||||
market_aware = True,
|
||||
delta = None,
|
||||
days = 1
|
||||
)
|
||||
dates = ([self.pre_open]*3)
|
||||
dates += ([self.mid_day]*3)
|
||||
dates += ([self.post_close]*3)
|
||||
dates += [self.pre_open + timedelta(days = 1, seconds = 1)]
|
||||
events = [to_dt(date) for date in dates]
|
||||
|
||||
# Run the events.
|
||||
for event in events:
|
||||
window.update(event)
|
||||
|
||||
# We should have removed the pre_open events on the first day.
|
||||
# The rest should be intact.
|
||||
|
||||
assert window.added == events
|
||||
assert window.removed == events[0:3]
|
||||
assert list(window.ticks) == events[3:]
|
||||
|
||||
def test_market_aware_window_weekend(self):
|
||||
window = NoopEventWindow(
|
||||
market_aware = True,
|
||||
delta = None,
|
||||
days = 2
|
||||
)
|
||||
dates = [self.pre_open_saturday - timedelta(days = 1, seconds=1)]
|
||||
dates += [self.mid_day_saturday - timedelta(days = 1, seconds=1)]
|
||||
dates += [self.post_close_saturday - timedelta(days = 1, seconds=1)]
|
||||
dates += [self.mid_day_saturday + timedelta(days = 1)]
|
||||
|
||||
events = [to_dt(date) for date in dates]
|
||||
|
||||
# Run the events.
|
||||
for event in events:
|
||||
window.update(event)
|
||||
|
||||
# We shouldn't remove any events.
|
||||
assert window.added == events
|
||||
assert window.removed == []
|
||||
assert list(window.ticks) == events
|
||||
|
||||
extra = to_dt(self.mid_day_saturday + timedelta(days = 2))
|
||||
window.update(extra)
|
||||
|
||||
# We should remove only the first event.
|
||||
assert window.removed == [events[0]]
|
||||
assert list(window.ticks) == events[1:] + [extra]
|
||||
|
||||
def tearDown(self):
|
||||
setup_logger(self)
|
||||
|
||||
class FinanceTransformsTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.trading_environment = factory.create_trading_environment()
|
||||
setup_logger(self, '/var/log/qexec/qexec.log')
|
||||
setup_logger(self)
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 10.0, 11.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(days=1),
|
||||
self.trading_environment
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
def tearDown(self):
|
||||
self.log_handler.pop_application()
|
||||
|
||||
def test_vwap(self):
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 10.0, 10.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(days=1),
|
||||
self.trading_environment
|
||||
vwap = StatefulTransform(
|
||||
VWAP,
|
||||
market_aware = False,
|
||||
delta = timedelta(days = 2)
|
||||
)
|
||||
transformed = list(vwap.transform(self.source))
|
||||
|
||||
vwap = DailyVWAP(days=2)
|
||||
for trade in trade_history:
|
||||
vwap.update(trade)
|
||||
|
||||
self.assertEqual(vwap.vwap, 10.75)
|
||||
# Output values
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
# "Hand calculated" values.
|
||||
expected = [
|
||||
(10.0 * 100) / 100.0,
|
||||
((10.0 * 100) + (10.0 * 100)) / (200.0),
|
||||
# We should drop the first event here.
|
||||
((10.0 * 100) + (11.0 * 100)) / (200.0),
|
||||
# We should drop the second event here.
|
||||
((11.0 * 100) + (11.0 * 300)) / (400.0)
|
||||
]
|
||||
|
||||
# Output should match the expected.
|
||||
assert tnfm_vals == expected
|
||||
|
||||
def test_returns(self):
|
||||
# Daily returns.
|
||||
returns = StatefulTransform(Returns, 1)
|
||||
|
||||
transformed = list(returns.transform(self.source))
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
|
||||
# No returns for the first event because we don't have a
|
||||
# previous close.
|
||||
expected = [0.0, 0.0, 0.1, 0.0]
|
||||
|
||||
assert tnfm_vals == expected
|
||||
|
||||
# Two-day returns. An extra kink here is that the
|
||||
# factory will automatically skip a weekend for the
|
||||
# last event. Results shouldn't notice this blip.
|
||||
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 10.0, 10.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
[10.0, 15.0, 13.0, 12.0, 13.0],
|
||||
[100, 100, 100, 300, 100],
|
||||
timedelta(days=1),
|
||||
self.trading_environment
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
returns = ReturnsFromPriorClose()
|
||||
for trade in trade_history:
|
||||
returns.update(trade)
|
||||
returns = StatefulTransform(Returns, 2)
|
||||
|
||||
transformed = list(returns.transform(self.source))
|
||||
tnfm_vals = [message.tnfm_value for message in transformed]
|
||||
|
||||
self.assertEqual(returns.returns, .1)
|
||||
expected = [
|
||||
0.0,
|
||||
0.0,
|
||||
(13.0 - 10.0) / 10.0,
|
||||
(12.0 - 15.0) / 15.0,
|
||||
(13.0 - 13.0) / 13.0
|
||||
]
|
||||
|
||||
assert tnfm_vals == expected
|
||||
|
||||
def test_moving_average(self):
|
||||
trade_history = factory.create_trade_history(
|
||||
133,
|
||||
[10.0, 10.0, 10.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(days=1),
|
||||
self.trading_environment
|
||||
|
||||
mavg = StatefulTransform(
|
||||
MovingAverage,
|
||||
market_aware = False,
|
||||
fields = ['price', 'volume'],
|
||||
delta = timedelta(days = 2),
|
||||
)
|
||||
transformed = list(mavg.transform(self.source))
|
||||
# Output values.
|
||||
tnfm_prices = [message.tnfm_value.price for message in transformed]
|
||||
tnfm_volumes = [message.tnfm_value.volume for message in transformed]
|
||||
|
||||
ma = MovingAverage(days=2)
|
||||
for trade in trade_history:
|
||||
ma.update(trade)
|
||||
# "Hand-calculated" values
|
||||
expected_prices = [
|
||||
((10.0) / 1.0),
|
||||
((10.0 + 10.0) / 2.0),
|
||||
# First event should get dropped here.
|
||||
((10.0 + 11.0) / 2.0),
|
||||
# Second event should get dropped here.
|
||||
((11.0 + 11.0) / 2.0)
|
||||
]
|
||||
expected_volumes = [
|
||||
((100.0) / 1.0),
|
||||
((100.0 + 100.0) / 2.0),
|
||||
# First event should get dropped here.
|
||||
((100.0 + 100.0) / 2.0),
|
||||
# Second event should get dropped here.
|
||||
((100.0 + 300.0) / 2.0)
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(ma.average, 10.5)
|
||||
assert tnfm_prices == expected_prices
|
||||
assert tnfm_volumes == expected_volumes
|
||||
|
||||
@@ -6,15 +6,9 @@ Zipline
|
||||
# it is a place to expose the public interfaces.
|
||||
|
||||
import protocol # namespace
|
||||
from core.monitor import Controller
|
||||
from lines import SimulatedTrading
|
||||
from core.host import ComponentHost
|
||||
from utils.protocol_utils import ndict
|
||||
|
||||
__all__ = [
|
||||
SimulatedTrading,
|
||||
Controller,
|
||||
ComponentHost,
|
||||
protocol,
|
||||
ndict
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@ from zipline.utils.protocol_utils import ndict
|
||||
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
|
||||
|
||||
from logbook import Logger, NestedSetup, Processor, queues
|
||||
from logbook import Logger, NestedSetup, Processor
|
||||
|
||||
|
||||
log = logbook.Logger('TradeSimulation')
|
||||
@@ -20,7 +20,7 @@ log = logbook.Logger('TradeSimulation')
|
||||
|
||||
class TradeSimulationClient(Component):
|
||||
|
||||
def init(self, trading_environment, sim_style, results_socket):
|
||||
def init(self, trading_environment, sim_style, results_socket, algorithm):
|
||||
self.received_count = 0
|
||||
self.prev_dt = None
|
||||
self.event_queue = None
|
||||
@@ -29,13 +29,20 @@ class TradeSimulationClient(Component):
|
||||
self.trading_environment = trading_environment
|
||||
self.current_dt = trading_environment.period_start
|
||||
self.last_iteration_dur = datetime.timedelta(seconds=0)
|
||||
self.algorithm = None
|
||||
self.algorithm = algorithm
|
||||
self.algorithm.set_order(self.order)
|
||||
self.max_wait = datetime.timedelta(seconds=60)
|
||||
self.last_msg_dt = datetime.datetime.utcnow()
|
||||
self.txn_sim = TransactionSimulator(sim_style)
|
||||
self.txn_sim = TransactionSimulator(
|
||||
open_orders={},
|
||||
style=sim_style
|
||||
)
|
||||
|
||||
self.event_data = ndict()
|
||||
self.perf = perf.PerformanceTracker(self.trading_environment)
|
||||
self.perf = perf.PerformanceTracker(
|
||||
self.trading_environment,
|
||||
self.algorithm.get_sid_filter()
|
||||
)
|
||||
self.zmq_out = None
|
||||
self.results_socket = results_socket
|
||||
self.algo_initialized = False
|
||||
@@ -44,19 +51,6 @@ class TradeSimulationClient(Component):
|
||||
def get_id(self):
|
||||
return str(zp.FINANCE_COMPONENT.TRADING_CLIENT)
|
||||
|
||||
def set_algorithm(self, algorithm):
|
||||
"""
|
||||
:param algorithm: must implement the algorithm protocol. See
|
||||
:py:mod:`zipline.test.algorithm`
|
||||
"""
|
||||
self.algorithm = algorithm
|
||||
# register the client's order method with the algorithm
|
||||
self.algorithm.set_order(self.order)
|
||||
# we need to provide the performance tracker with the
|
||||
# sids referenced in the algorithm, so portfolio can
|
||||
# initialize with all possible sids.
|
||||
self.perf.set_sids(self.algorithm.get_sid_filter())
|
||||
|
||||
def open(self):
|
||||
self.result_feed = self.connect_result()
|
||||
if self.results_socket:
|
||||
@@ -185,16 +179,6 @@ class TradeSimulationClient(Component):
|
||||
# LOG_EXTRA_FIELDS in zipline/protocol.py
|
||||
self.do_op(self.algorithm.handle_data, data)
|
||||
|
||||
def exception_callback(self, exc_type, exc_value, exc_traceback):
|
||||
if self.results_socket:
|
||||
log.info("Sending exception frame")
|
||||
msg = zp.EXCEPTION_FRAME(
|
||||
exc_traceback,
|
||||
exc_type.__name__,
|
||||
exc_value.message
|
||||
)
|
||||
self.out_socket.send(msg)
|
||||
|
||||
def do_op(self, callable_op, *args, **kwargs):
|
||||
""" Wrap a callable operation with the zmq logbook
|
||||
handler if it exits."""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from host import ComponentHost
|
||||
from component import Component
|
||||
from monitor import Controller
|
||||
from monitor import Monitor
|
||||
|
||||
__all__ = [
|
||||
Component,
|
||||
Controller,
|
||||
Monitor,
|
||||
ComponentHost
|
||||
]
|
||||
|
||||
+361
-495
File diff suppressed because it is too large
Load Diff
@@ -103,7 +103,7 @@ class ComponentHost(object):
|
||||
|
||||
|
||||
log.info('== Roll Call ==')
|
||||
log.info('Controller')
|
||||
log.info('Monitor')
|
||||
|
||||
self.launch_controller()
|
||||
|
||||
|
||||
+79
-55
@@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
import os
|
||||
import zmq
|
||||
import sys
|
||||
@@ -10,8 +9,13 @@ from signal import SIGHUP, SIGINT
|
||||
|
||||
from collections import OrderedDict, Counter
|
||||
|
||||
from zipline.protocol import CONTROL_PROTOCOL, CONTROL_FRAME, \
|
||||
CONTROL_UNFRAME, CONTROL_STATES, INVALID_CONTROL_FRAME \
|
||||
from zipline.protocol import (
|
||||
CONTROL_PROTOCOL,
|
||||
CONTROL_FRAME,
|
||||
CONTROL_UNFRAME,
|
||||
CONTROL_STATES,
|
||||
INVALID_CONTROL_FRAME
|
||||
)
|
||||
|
||||
from zipline.utils.protocol_utils import ndict
|
||||
|
||||
@@ -34,7 +38,7 @@ class UnknownChatter(Exception):
|
||||
return """Component calling itself "%s" talking on unexpected channel""" % self.named
|
||||
|
||||
|
||||
log = logbook.Logger('Controller')
|
||||
log = logbook.Logger('Monitor')
|
||||
|
||||
# The scalars determining the timing of the monitor behavior for
|
||||
# the system.
|
||||
@@ -52,7 +56,7 @@ PARAMETERS = ndict(dict(
|
||||
SYSTEM_TIMEOUT = 50,
|
||||
))
|
||||
|
||||
class Controller(object):
|
||||
class Monitor(object):
|
||||
"""
|
||||
A N to M messaging system for inter component communication.
|
||||
|
||||
@@ -69,7 +73,12 @@ class Controller(object):
|
||||
debug = True
|
||||
period = PARAMETERS.GENERATIONAL_PERIOD
|
||||
|
||||
def __init__(self, pub_socket, route_socket, send_sighup=False):
|
||||
def __init__(
|
||||
self,
|
||||
pub_socket,
|
||||
route_socket,
|
||||
exception_socket,
|
||||
send_sighup=False):
|
||||
|
||||
self.nosignals = False
|
||||
self.context = None
|
||||
@@ -90,13 +99,15 @@ class Controller(object):
|
||||
|
||||
self.associated = []
|
||||
|
||||
self.pub_socket = pub_socket
|
||||
self.route_socket = route_socket
|
||||
|
||||
self.error_replay = OrderedDict()
|
||||
self.pub_socket = pub_socket
|
||||
self.route_socket = route_socket
|
||||
self.exception_socket = exception_socket
|
||||
|
||||
self.missed_beats = Counter()
|
||||
|
||||
# start with an empty topology
|
||||
self.topology = set([])
|
||||
|
||||
self.send_sighup = send_sighup
|
||||
if self.send_sighup:
|
||||
log.info("Request to send sighup/sigint")
|
||||
@@ -108,6 +119,17 @@ class Controller(object):
|
||||
self.zmq_poller = self.zmq.Poller
|
||||
return
|
||||
|
||||
def add_to_topology(self, component_id):
|
||||
add = set([component_id, "FORK-" + component_id])
|
||||
self.topology.update(add)
|
||||
|
||||
def freeze_topology(self):
|
||||
if isinstance(self.topology, frozenset):
|
||||
return
|
||||
# we've been incrementally adding components.
|
||||
# time to freeze.
|
||||
self.manage(self.topology)
|
||||
|
||||
def manage(self, topology):
|
||||
"""
|
||||
Give the controller a set set of components to manage and
|
||||
@@ -139,6 +161,7 @@ class Controller(object):
|
||||
raise RuntimeError("Invalid State Transition : %s -> %s" %(old, new))
|
||||
|
||||
def run(self):
|
||||
self.freeze_topology()
|
||||
self.running = True
|
||||
self.init_zmq()
|
||||
setproctitle('Monitor')
|
||||
@@ -243,6 +266,11 @@ class Controller(object):
|
||||
self.router.bind(self.route_socket)
|
||||
self.router.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
# -- Exception Out --
|
||||
# ===================
|
||||
self.ex_out = self.context.socket(self.zmq.PUSH)
|
||||
self.ex_out.connect(self.exception_socket)
|
||||
|
||||
poller = self.zmq.Poller()
|
||||
poller.register(self.router, self.zmq.POLLIN)
|
||||
#poller.register(self.cancel, self.zmq.POLLIN)
|
||||
@@ -312,6 +340,11 @@ class Controller(object):
|
||||
log.info("breaking out of initial heartbeat")
|
||||
break
|
||||
|
||||
# Break out if the entire topology told us its DONE
|
||||
if len(self.finished) == len(self.topology):
|
||||
break
|
||||
|
||||
|
||||
# ================
|
||||
# Heartbeat Stats
|
||||
# ================
|
||||
@@ -374,6 +407,7 @@ class Controller(object):
|
||||
"""
|
||||
if not self.send_sighup:
|
||||
log.warning("Skipping SIGINT")
|
||||
return
|
||||
ppid = os.getpid()
|
||||
log.warning("Sending SIGINT")
|
||||
os.kill(ppid, SIGINT)
|
||||
@@ -403,28 +437,22 @@ class Controller(object):
|
||||
bad = self.tracked - good - self.finished
|
||||
new = self.responses - good - self.finished
|
||||
|
||||
missing = self.topology - self.tracked - self.finished
|
||||
|
||||
for component in new:
|
||||
self.new(component)
|
||||
|
||||
if self.debug:
|
||||
log.info('New component %r' % component)
|
||||
|
||||
for component in bad:
|
||||
self.fail(component)
|
||||
self.timed_out(component)
|
||||
|
||||
missing = self.topology - self.tracked - self.finished
|
||||
|
||||
for component in missing:
|
||||
|
||||
if self.debug:
|
||||
log.info('Missing component %r' % component)
|
||||
|
||||
if self.debug:
|
||||
|
||||
for component in self.tracked:
|
||||
if component not in self.topology:
|
||||
log.info('Uninvited component %r' % component)
|
||||
for component in self.tracked:
|
||||
if component not in self.topology:
|
||||
log.info('Uninvited component %r' % component)
|
||||
|
||||
# --------------
|
||||
# Init Handlers
|
||||
@@ -460,22 +488,16 @@ class Controller(object):
|
||||
# Epic Fail Handling
|
||||
# ------------------
|
||||
|
||||
def fail_universal(self):
|
||||
# TODO: this requires higher order functionality
|
||||
log.error('Component missed heartbeat, Monitor shutting down system')
|
||||
self.kill()
|
||||
|
||||
def fail(self, component):
|
||||
def timed_out(self, component):
|
||||
if self.state is CONTROL_STATES.TERMINATE:
|
||||
return
|
||||
|
||||
universal = self.fail_universal
|
||||
fail_handlers = { }
|
||||
|
||||
if component in (self.topology - self.finished) or self.freeform:
|
||||
log.warning('Component "%s" missed heartbeat' % component)
|
||||
self.tracked.remove(component)
|
||||
fail_handlers.get(component, universal)()
|
||||
# we treat a time out as a severe failure, and
|
||||
# conduct a rapid shutdown
|
||||
self.kill()
|
||||
|
||||
|
||||
# -------------------
|
||||
# Completion Handling
|
||||
@@ -489,26 +511,15 @@ class Controller(object):
|
||||
# --------------
|
||||
# Error Handling
|
||||
# --------------
|
||||
|
||||
def exception_universal(self):
|
||||
"""
|
||||
Shutdown the system on failure.
|
||||
"""
|
||||
log.error('System in exception state, shutting down')
|
||||
def exception(self, component, exception_data):
|
||||
log.error('Component in exception state: %s. Shutting down system and sending exception data to listeners.'\
|
||||
% component)
|
||||
# Send the exception message out to listeners.
|
||||
self.ex_out.send(exception_data)
|
||||
# An exception in one component is treated as a hard
|
||||
# failure, and we conduct a rapid shutdown.
|
||||
self.kill()
|
||||
|
||||
def exception(self, component, failure):
|
||||
universal = self.exception_universal
|
||||
exception_handlers = { }
|
||||
|
||||
if component in self.topology or self.freeform:
|
||||
self.error_replay[(component, time.time())] = failure
|
||||
log.error('Component in exception state: %s' % component)
|
||||
|
||||
exception_handlers.get(component, universal)()
|
||||
else:
|
||||
raise UnknownChatter(component)
|
||||
|
||||
# -----------------
|
||||
# Protocol Handling
|
||||
# -----------------
|
||||
@@ -555,7 +566,19 @@ class Controller(object):
|
||||
|
||||
# A component is telling us it failed, and how
|
||||
if id is CONTROL_PROTOCOL.EXCEPTION:
|
||||
self.exception(identity, status)
|
||||
# status should be a msgpack emitted from
|
||||
# EXCEPTION_FRAME
|
||||
try:
|
||||
exception_data = status
|
||||
self.exception(identity, exception_data)
|
||||
except:
|
||||
# if an exception occurs when we try to handle
|
||||
# the exception, signal the parent that we need
|
||||
# to go down
|
||||
# TODO: should we attempt to call self.exception?
|
||||
log.exception("Unexpected exception sending exception data")
|
||||
self.kill()
|
||||
|
||||
return
|
||||
|
||||
# A component is telling us its done with work and won't
|
||||
@@ -605,11 +628,9 @@ class Controller(object):
|
||||
self.associated.append(s)
|
||||
return s
|
||||
|
||||
def do_error_replay(self):
|
||||
for (component, time), error in self.error_replay.iteritems():
|
||||
log.info('Component Log for -- %s --:\n%s' % (component, error))
|
||||
|
||||
def kill(self):
|
||||
"""Aggressively exit the whole zipline.
|
||||
"""
|
||||
if self.state is CONTROL_STATES.TERMINATE:
|
||||
return
|
||||
|
||||
@@ -617,6 +638,9 @@ class Controller(object):
|
||||
self.send_hardkill()
|
||||
self.state = CONTROL_STATES.TERMINATE
|
||||
self.alive = False
|
||||
# send burrito an interrupt, instructing it to kill all
|
||||
# child processes assocated with this zipline.
|
||||
time.sleep(3)
|
||||
self.signal_interrupt()
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
@@ -40,13 +40,13 @@ class ProcessSimulator(ComponentHost):
|
||||
# invoked by the host's open()
|
||||
|
||||
def launch_controller(self):
|
||||
proc = multiprocessing.Process(target=self.controller.run)
|
||||
proc = multiprocessing.Process(target=self.monitor.run)
|
||||
proc.start()
|
||||
self.con = proc
|
||||
|
||||
# Process specific
|
||||
self.controller_process = proc
|
||||
self.mapping[proc.pid] = 'Controller'
|
||||
self.monitor_process = proc
|
||||
self.mapping[proc.pid] = 'Monitor'
|
||||
|
||||
def launch_component(self, component):
|
||||
proc = multiprocessing.Process(target=component.run)
|
||||
@@ -81,7 +81,7 @@ class ProcessSimulator(ComponentHost):
|
||||
process.join(timeout=1)
|
||||
process.terminate()
|
||||
|
||||
self.controller.shutdown(soft=True)
|
||||
self.monitor.shutdown(soft=True)
|
||||
self.running = False
|
||||
|
||||
self.con.terminate()
|
||||
|
||||
@@ -133,6 +133,7 @@ import zipline.finance.risk as risk
|
||||
log = logbook.Logger('Performance')
|
||||
|
||||
class PerformanceTracker(object):
|
||||
UPDATER = True
|
||||
"""
|
||||
Tracks the performance of the zipline as it is running in
|
||||
the simulator, relays this out to the Deluge broker and then
|
||||
@@ -144,7 +145,7 @@ class PerformanceTracker(object):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, trading_environment):
|
||||
def __init__(self, trading_environment, sid_list):
|
||||
|
||||
self.trading_environment = trading_environment
|
||||
self.trading_day = datetime.timedelta(hours = 6, minutes = 30)
|
||||
@@ -164,7 +165,6 @@ class PerformanceTracker(object):
|
||||
self.txn_count = 0
|
||||
self.event_count = 0
|
||||
self.last_dict = None
|
||||
self.order_log = []
|
||||
self.exceeded_max_loss = False
|
||||
|
||||
self.results_socket = None
|
||||
@@ -197,10 +197,21 @@ class PerformanceTracker(object):
|
||||
# save the transactions for the daily periods
|
||||
keep_transactions = True
|
||||
)
|
||||
|
||||
def set_sids(self, sid_list):
|
||||
|
||||
for sid in sid_list:
|
||||
self.cumulative_performance.positions[sid] = Position(sid)
|
||||
self.todays_performance.positions[sid] = Position(sid)
|
||||
|
||||
def update(self, event):
|
||||
if event.dt == "DONE":
|
||||
event.perf_message = self.handle_simulation_end()
|
||||
del event['TRANSACTION']
|
||||
return event
|
||||
else:
|
||||
event.perf_message = self.process_event(event)
|
||||
event.portfolio = self.get_portfolio()
|
||||
del event['TRANSACTION']
|
||||
return event
|
||||
|
||||
def get_portfolio(self):
|
||||
return self.cumulative_performance.as_portfolio()
|
||||
@@ -238,10 +249,10 @@ class PerformanceTracker(object):
|
||||
'cumulative_risk_metrics' : self.cumulative_risk_metrics.to_dict()
|
||||
}
|
||||
|
||||
def log_order(self, order):
|
||||
self.order_log.append(order)
|
||||
|
||||
def process_event(self, event):
|
||||
|
||||
message = None
|
||||
|
||||
if self.exceeded_max_loss:
|
||||
return
|
||||
@@ -250,7 +261,7 @@ class PerformanceTracker(object):
|
||||
self.event_count += 1
|
||||
|
||||
if(event.dt >= self.market_close):
|
||||
self.handle_market_close()
|
||||
message = self.handle_market_close()
|
||||
|
||||
if event.TRANSACTION:
|
||||
self.txn_count += 1
|
||||
@@ -264,9 +275,12 @@ class PerformanceTracker(object):
|
||||
#calculate performance as of last trade
|
||||
self.cumulative_performance.calculate_performance()
|
||||
self.todays_performance.calculate_performance()
|
||||
|
||||
|
||||
return message
|
||||
|
||||
def handle_market_close(self):
|
||||
|
||||
|
||||
# add the return results from today to the list of DailyReturn objects.
|
||||
todays_date = self.market_close.replace(hour=0, minute=0, second=0)
|
||||
todays_return_obj = risk.DailyReturn(
|
||||
@@ -288,12 +302,10 @@ class PerformanceTracker(object):
|
||||
# calculate progress of test
|
||||
self.progress = self.day_count / self.total_days
|
||||
|
||||
# Output results
|
||||
if self.results_socket:
|
||||
msg = zp.PERF_FRAME(self.to_dict())
|
||||
self.results_socket.send(msg)
|
||||
# Take a snapshot of our current peformance to return to the
|
||||
# browser.
|
||||
daily_update = self.to_dict()
|
||||
|
||||
#
|
||||
if self.trading_environment.max_drawdown:
|
||||
returns = self.todays_performance.returns
|
||||
max_dd = -1 * self.trading_environment.max_drawdown
|
||||
@@ -304,7 +316,7 @@ class PerformanceTracker(object):
|
||||
# so it shows up in the update, but don't end the test
|
||||
# here. Let the update go out before stopping
|
||||
self.exceeded_max_loss = True
|
||||
return
|
||||
return daily_update
|
||||
|
||||
|
||||
#move the market day markers forward
|
||||
@@ -326,6 +338,8 @@ class PerformanceTracker(object):
|
||||
self.market_close,
|
||||
keep_transactions = True
|
||||
)
|
||||
|
||||
return daily_update
|
||||
|
||||
def handle_simulation_end(self):
|
||||
"""
|
||||
@@ -349,12 +363,8 @@ class PerformanceTracker(object):
|
||||
exceeded_max_loss = self.exceeded_max_loss
|
||||
)
|
||||
|
||||
if self.results_socket:
|
||||
log.info("about to stream the risk report...")
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
|
||||
msg = zp.RISK_FRAME(risk_dict)
|
||||
self.results_socket.send(msg)
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
return risk_dict
|
||||
|
||||
class Position(object):
|
||||
|
||||
@@ -362,8 +372,8 @@ class Position(object):
|
||||
self.sid = sid
|
||||
self.amount = 0
|
||||
self.cost_basis = 0.0 ##per share
|
||||
self.last_sale_price = None
|
||||
self.last_sale_date = None
|
||||
self.last_sale_price = 0.0
|
||||
self.last_sale_date = 0.0
|
||||
|
||||
def update(self, txn):
|
||||
if(self.sid != txn.sid):
|
||||
@@ -584,7 +594,6 @@ class PerformancePeriod(object):
|
||||
|
||||
return positions
|
||||
|
||||
#
|
||||
def get_positions_list(self):
|
||||
positions = []
|
||||
for sid, pos in self.positions.iteritems():
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from zipline.transforms.base import BaseTransform
|
||||
|
||||
class ReturnsTransform(BaseTransform):
|
||||
|
||||
def init(self, name):
|
||||
self.state = {}
|
||||
self.state['name'] = name
|
||||
self.by_sid = defaultdict(self._create)
|
||||
|
||||
@property
|
||||
def get_id(self):
|
||||
return self.state['name']
|
||||
|
||||
|
||||
def transform(self, event):
|
||||
cur = self.by_sid[event.sid]
|
||||
cur.update(event)
|
||||
self.state['value'] = cur.returns
|
||||
return self.state
|
||||
|
||||
def _create(self):
|
||||
return ReturnsFromPriorClose()
|
||||
|
||||
class ReturnsFromPriorClose(object):
|
||||
"""
|
||||
Calculates a security's returns since the previous close, using the
|
||||
current price.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_close = None
|
||||
self.last_event = None
|
||||
self.returns = 0.0
|
||||
|
||||
def update(self, event):
|
||||
if self.last_close:
|
||||
change = event.price - self.last_close.price
|
||||
self.returns = change / self.last_close.price
|
||||
|
||||
if self.last_event:
|
||||
if self.last_event.dt.day != event.dt.day:
|
||||
# the current event is from the day after
|
||||
# the last event. Therefore the last event was
|
||||
# the last close
|
||||
self.last_close = self.last_event
|
||||
|
||||
# the current event is now the last_event
|
||||
self.last_event = event
|
||||
+15
-23
@@ -9,10 +9,10 @@ from zipline.protocol import SIMULATION_STYLE
|
||||
log = logbook.Logger('Transaction Simulator')
|
||||
|
||||
class TransactionSimulator(object):
|
||||
UPDATER = True
|
||||
|
||||
def __init__(self, style=SIMULATION_STYLE.PARTIAL_VOLUME):
|
||||
def __init__(self, sid_filter, style=SIMULATION_STYLE.PARTIAL_VOLUME):
|
||||
self.open_orders = {}
|
||||
self.order_count = 0
|
||||
self.txn_count = 0
|
||||
self.trade_window = datetime.timedelta(seconds=30)
|
||||
self.orderTTL = datetime.timedelta(days=1)
|
||||
@@ -27,27 +27,20 @@ class TransactionSimulator(object):
|
||||
elif style == SIMULATION_STYLE.NOOP:
|
||||
self.apply_trade_to_open_orders = self.simulate_noop
|
||||
|
||||
def add_open_order(self, event):
|
||||
# Orders are captured in a buffer by sid. No calculations are done here.
|
||||
# Amount is explicitly converted to an int.
|
||||
# Orders of amount zero are ignored.
|
||||
for sid in sid_filter:
|
||||
self.open_orders[sid] = []
|
||||
|
||||
self.order_count += 1
|
||||
event.amount = int(event.amount)
|
||||
def place_order(self, order):
|
||||
# initialized filled field.
|
||||
order.filled = 0
|
||||
self.open_orders[order.sid].append(order)
|
||||
|
||||
if event.amount == 0:
|
||||
log = "requested to trade zero shares of {sid}".format(
|
||||
sid=event.sid
|
||||
)
|
||||
log.debug(log)
|
||||
return
|
||||
|
||||
if not self.open_orders.has_key(event.sid):
|
||||
self.open_orders[event.sid] = []
|
||||
|
||||
# set the filled property to zero
|
||||
event.filled = 0
|
||||
self.open_orders[event.sid].append(event)
|
||||
def update(self, event):
|
||||
event.TRANSACTION = None
|
||||
# We only fill transactions on trade events.
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
event.TRANSACTION = self.apply_trade_to_open_orders(event)
|
||||
return event
|
||||
|
||||
def simulate_buy_all(self, event):
|
||||
txn = self.create_transaction(
|
||||
@@ -81,7 +74,7 @@ class TransactionSimulator(object):
|
||||
txn = self.create_transaction(
|
||||
event.sid,
|
||||
amount,
|
||||
event.price + 0.10,
|
||||
event.price + 0.10, # Magic constant?
|
||||
event.dt,
|
||||
direction
|
||||
)
|
||||
@@ -162,7 +155,6 @@ class TransactionSimulator(object):
|
||||
}
|
||||
return zp.ndict(txn)
|
||||
|
||||
|
||||
class TradingEnvironment(object):
|
||||
|
||||
def __init__(
|
||||
|
||||
+78
-68
@@ -1,91 +1,101 @@
|
||||
import datetime
|
||||
from itertools import tee, starmap
|
||||
from itertools import tee, starmap, chain
|
||||
from collections import namedtuple
|
||||
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import roundrobin, hash_args
|
||||
from zipline.gens.utils import roundrobin, hash_args, done_message
|
||||
from zipline.gens.sort import date_sort
|
||||
from zipline.gens.merge import merge
|
||||
from zipline.gens.transform import stateful_transform
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
|
||||
SortBundle = namedtuple("SortBundle", ['source', 'args', 'kwargs'])
|
||||
MergeBundle = namedtuple("MergeBundle", ['stream', 'tnfm', 'args', 'kwargs'])
|
||||
|
||||
def date_sorted_sources(sources, source_args, source_kwargs):
|
||||
def date_sorted_sources(*sources):
|
||||
"""
|
||||
Takes a list of generator functions, a list of tuples of positional arguments,
|
||||
and a list of dictionaries of keyword arguments. Packages up all arguments
|
||||
and passes them into a date_sort.
|
||||
Takes an iterable of sources, generating namestrings and
|
||||
piping their output into date_sort.
|
||||
"""
|
||||
assert len(sources) == len(source_args) == len(source_kwargs)
|
||||
# Package up sources and arguments.
|
||||
|
||||
# Create a generator of SortBundle objects to be turned into
|
||||
# namestrings and generator objects.
|
||||
bundle_gen = starmap(SortBundle, zip(sources, source_args, source_kwargs))
|
||||
|
||||
# Load the results of the generator into a tuple so that the
|
||||
# results can be used twice (once in namestring comprehension,
|
||||
# once in the generator comprehension for intialized sources.
|
||||
bundles = tuple(bundle_gen)
|
||||
for source in sources:
|
||||
assert iter(source), "Source %s not iterable" % source
|
||||
assert source.__class__.__dict__.has_key('get_hash'), "No get_hash"
|
||||
|
||||
# Calculate namestring hashes to pass to date_sort.
|
||||
names = [bundle.source.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
# Pass each source its arguments.
|
||||
initialized = [bundle.source(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
# Get name hashes to pass to date_sort.
|
||||
names = [source.get_hash() for source in sources]
|
||||
|
||||
# Convert the list of generators into a flat stream by pulling
|
||||
# one element at a time from each.
|
||||
stream_in = roundrobin(*initialized)
|
||||
|
||||
# Guarantee the flat stream will be sorted by date, using source_id as
|
||||
# tie-breaker, which is fully deterministic (given deterministic string
|
||||
# representation for all args/kwargs)
|
||||
stream_in = roundrobin(sources, names)
|
||||
|
||||
# Guarantee the flat stream will be sorted by date, using
|
||||
# source_id as tie-breaker, which is fully deterministic (given
|
||||
# deterministic string representation for all args/kwargs)
|
||||
|
||||
return date_sort(stream_in, names)
|
||||
|
||||
|
||||
def merged_transforms(sorted_stream, tnfms, tnfm_args, tnfm_kwargs):
|
||||
def merged_transforms(sorted_stream, *transforms):
|
||||
"""
|
||||
A generator that takes the expected output of a date_sort, pipes it
|
||||
through a given set of transforms, and runs the results throught a
|
||||
merge to output a unified stream. tnfms should be a list of
|
||||
pointers to generator functions. tnfm_args should be a list of
|
||||
tuples, representing the arguments to be passed to each transform.
|
||||
tnfm_kwargs should be a list of dictionaries representing keyword
|
||||
arguments to each transform.
|
||||
A generator that takes the expected output of a date_sort, pipes
|
||||
it through a given set of transforms, and runs the results
|
||||
through a merge to output a unified stream. tnfms should be a
|
||||
list of pointers to generator functions. tnfm_args should be a
|
||||
list of tuples, representing the arguments to be passed to each
|
||||
transform. tnfm_kwargs should be a list of dictionaries
|
||||
representing keyword arguments to each transform.
|
||||
"""
|
||||
|
||||
# We should have as many sets of args as we have transforms.
|
||||
assert len(tnfms) == len(tnfm_args) == len(tnfm_kwargs)
|
||||
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(sorted_stream, len(tnfms))
|
||||
|
||||
# Package each transform with a stream copy and set of args. Use a list
|
||||
# so that we can re-use this for calculating hashes.
|
||||
bundle_gen = starmap(MergeBundle, zip(split, tnfms, tnfm_args, tnfm_kwargs))
|
||||
|
||||
bundles = tuple(bundle_gen)
|
||||
# list comprehension to create transform generators from
|
||||
# bundles
|
||||
tnfm_gens = [
|
||||
stateful_transform(
|
||||
bundle.stream,
|
||||
bundle.tnfm,
|
||||
*bundle.args,
|
||||
**bundle.kwargs
|
||||
)
|
||||
for bundle in bundles]
|
||||
for transform in transforms:
|
||||
assert isinstance(transform, StatefulTransform)
|
||||
|
||||
# Generate expected hashes for each transform
|
||||
hashes = [bundle.tnfm.__name__ + hash_args(*bundle.args, **bundle.kwargs)
|
||||
for bundle in bundles]
|
||||
namestrings = [tnfm.get_hash() for tnfm in transforms]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat stream.
|
||||
to_merge = roundrobin(*tnfm_gens)
|
||||
# Create a copy of the stream for each transform.
|
||||
split = tee(sorted_stream, len(transforms))
|
||||
|
||||
# Package a stream copy with each StatefulTransform instance.
|
||||
bundles = zip(transforms, split)
|
||||
|
||||
# Convert the copies into transform streams.
|
||||
tnfm_gens = [tnfm.transform(stream) for tnfm, stream in bundles]
|
||||
|
||||
# Roundrobin the outputs of our transforms to create a single flat
|
||||
# stream.
|
||||
to_merge = roundrobin(tnfm_gens, namestrings)
|
||||
|
||||
# Pipe the stream into merge.
|
||||
merged = merge(to_merge, hashes)
|
||||
return merged_transforms
|
||||
merged = merge(to_merge, namestrings)
|
||||
|
||||
dt_aliased = alias_dt(merged)
|
||||
# Return the merged events.
|
||||
return add_done(dt_aliased)
|
||||
|
||||
def sequential_transforms(stream_in, *transforms):
|
||||
"""
|
||||
Apply each transform in transforms sequentially to each event in stream_in.
|
||||
Each transform application will add a new entry indexed to the transform's
|
||||
hash string.
|
||||
"""
|
||||
|
||||
assert isinstance(transforms, (list, tuple))
|
||||
for tnfm in transforms:
|
||||
tnfm.forward_all = False
|
||||
tnfm.update_in_place = False
|
||||
tnfm.append_value = True
|
||||
|
||||
# Recursively apply all transforms to the stream.
|
||||
stream_out = reduce(lambda stream, tnfm: tnfm.transform(stream),
|
||||
transforms,
|
||||
stream_in)
|
||||
|
||||
dt_aliased = alias_dt(stream_out)
|
||||
return add_done(dt_aliased)
|
||||
|
||||
def alias_dt(stream_in):
|
||||
"""
|
||||
Alias the dt field to datetime on each message.
|
||||
"""
|
||||
for message in stream_in:
|
||||
message['datetime'] = message['dt']
|
||||
yield message
|
||||
|
||||
# Add a done message to a stream.
|
||||
def add_done(stream_in):
|
||||
return chain(stream_in, [done_message('Composite')])
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import pytz
|
||||
import time
|
||||
|
||||
from time import sleep
|
||||
from pprint import pprint as pp
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import izip
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
from zipline.gens.composites import SourceBundle, TransformBundle, \
|
||||
date_sorted_sources, merged_transforms, sequential_transforms
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.transform import MovingAverage, Passthrough, StatefulTransform
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
filter = [2,3]
|
||||
#Set up source a. Six minutes between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'count' : 1000,
|
||||
'sids' : [1,2,3],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 6),
|
||||
'filter' : filter
|
||||
}
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
source_a_prime = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Five minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'count' : 1000,
|
||||
'sids' : [2,3,4],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 5),
|
||||
'filter' : filter
|
||||
}
|
||||
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
source_b_prime = SpecificEquityTrades(*args_b, **kwargs_b)
|
||||
|
||||
sorted = date_sorted_sources(source_a, source_b)
|
||||
sorted_prime = date_sorted_sources(
|
||||
source_a_prime,
|
||||
source_b_prime
|
||||
)
|
||||
|
||||
passthrough = StatefulTransform(Passthrough)
|
||||
mavg_price = StatefulTransform(
|
||||
MovingAverage,
|
||||
timedelta(minutes = 20),
|
||||
['price']
|
||||
)
|
||||
|
||||
passthrough_prime = StatefulTransform(Passthrough)
|
||||
mavg_price_prime = StatefulTransform(
|
||||
MovingAverage,
|
||||
timedelta(minutes = 20),
|
||||
['price']
|
||||
)
|
||||
|
||||
merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
start = time.time()
|
||||
for message in merged:
|
||||
assert 1 + 1 == 2
|
||||
stop = time.time()
|
||||
merge_time = stop - start
|
||||
print "Merge time: %s" % str(merge_time)
|
||||
|
||||
sequential = sequential_transforms(
|
||||
sorted_prime,
|
||||
passthrough_prime,
|
||||
mavg_price_prime
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
for message in sequential:
|
||||
assert 1 + 1 == 2
|
||||
stop = time.time()
|
||||
seq_time = stop - start
|
||||
print "Sequential time: %s" % str(seq_time)
|
||||
print "Merge/Seq: %s" % (str(merge_time/seq_time))
|
||||
|
||||
|
||||
# merged = merged_transforms(sorted, passthrough, mavg_price)
|
||||
|
||||
# algo = TestAlgorithm(2, 10, 100, sid_filter = [2,3])
|
||||
# environment = create_trading_environment(year = 2012)
|
||||
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
# trading_client = tsc(algo, environment, style)
|
||||
|
||||
# for message in trading_client.simulate(merged):
|
||||
# pp(message)
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
from numbers import Number
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.transform import EventWindow
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to
|
||||
MovingAverageEventWindows. For each sid, we maintain moving
|
||||
averages over any number of distinct fields (For example, we can
|
||||
maintain a sid's average volume as well as its average price.)
|
||||
"""
|
||||
|
||||
def __init__(self, fields, market_aware, days = None, delta = None):
|
||||
|
||||
self.fields = fields
|
||||
self.market_aware = market_aware
|
||||
|
||||
self.delta = delta
|
||||
self.days = days
|
||||
|
||||
# Market-aware mode only works with full-day windows.
|
||||
if self.market_aware:
|
||||
assert self.days and not self.delta,\
|
||||
"Market-aware mode only works with full-day windows."
|
||||
|
||||
# Non-market-aware mode requires a timedelta.
|
||||
else:
|
||||
assert self.delta and not self.days, \
|
||||
"Non-market-aware mode requires a timedelta."
|
||||
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
def create_window(self):
|
||||
"""
|
||||
Factory method for self.sid_windows.
|
||||
"""
|
||||
return MovingAverageEventWindow(
|
||||
self.fields,
|
||||
self.market_aware,
|
||||
self.days,
|
||||
self.delta
|
||||
)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Return an ndict
|
||||
from tracked fields to moving averages.
|
||||
"""
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
return window.get_averages()
|
||||
|
||||
class MovingAverageEventWindow(EventWindow):
|
||||
"""
|
||||
Iteratively calculates moving averages for a particular sid over a
|
||||
given time window. We can maintain averages for arbitrarily many
|
||||
fields on a single sid. (For example, we might track average
|
||||
price as well as average volume for a single sid.) The expected
|
||||
functionality of this class is to be instantiated inside a
|
||||
MovingAverage transform.
|
||||
"""
|
||||
|
||||
def __init__(self, fields, market_aware, days, delta):
|
||||
|
||||
# Call the superclass constructor to set up base EventWindow
|
||||
# infrastructure.
|
||||
EventWindow.__init__(self, market_aware, days, delta)
|
||||
|
||||
# We maintain a dictionary of totals for each of our tracked
|
||||
# fields.
|
||||
self.fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
|
||||
# Subclass customization for adding new events.
|
||||
def handle_add(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
# Increment our running totals with data from the event.
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
|
||||
# Subclass customization for removing expired events.
|
||||
def handle_remove(self, event):
|
||||
# Decrement our running totals with data from the event.
|
||||
for field in self.fields:
|
||||
self.totals[field] -= event[field]
|
||||
|
||||
def average(self, field):
|
||||
"""
|
||||
Calculate the average value of our ticks over a single field.
|
||||
"""
|
||||
# Sanity check.
|
||||
assert field in self.fields
|
||||
|
||||
# Averages are None by convention if we have no ticks.
|
||||
if len(self.ticks) == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate and return the average. len(self.ticks) is O(1).
|
||||
else:
|
||||
return self.totals[field] / len(self.ticks)
|
||||
|
||||
def get_averages(self):
|
||||
"""
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
return out
|
||||
|
||||
def assert_required_fields(self, event):
|
||||
"""
|
||||
We only allow events with all of our tracked fields.
|
||||
"""
|
||||
for field in self.fields:
|
||||
assert event.has_key(field), \
|
||||
"Event missing [%s] in MovingAverageEventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in MovingAverageEventWindow" % (event[field], field)
|
||||
+28
-16
@@ -6,13 +6,13 @@ from collections import deque
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.utils import hash_args, \
|
||||
assert_merge_protocol
|
||||
|
||||
assert_merge_protocol, done_message
|
||||
from itertools import repeat
|
||||
|
||||
def merge(stream_in, tnfm_ids):
|
||||
"""
|
||||
A generator that takes a generator and a list of source_ids. We
|
||||
maintain an internal queue for each id in source_ids. Once we
|
||||
A generator that takes a generator and a list of transform ids. We
|
||||
maintain an internal queue for each id in tnfm_ids. Once we
|
||||
have a message from every queue, we pop an event from each queue
|
||||
and merge them together into an event. We raise an error if we
|
||||
do not receive the same number of events from all sources.
|
||||
@@ -28,22 +28,22 @@ def merge(stream_in, tnfm_ids):
|
||||
|
||||
# Process incoming streams.
|
||||
for message in stream_in:
|
||||
assert isinstance(message, tuple), \
|
||||
"Bad message in merge: %s" %message
|
||||
assert len(message) == 2
|
||||
id, value = message
|
||||
assert isinstance(message, ndict)
|
||||
assert message.has_key('tnfm_id')
|
||||
assert message.has_key('tnfm_value')
|
||||
assert message.has_key('dt')
|
||||
|
||||
id = message.tnfm_id
|
||||
assert id in tnfm_ids, \
|
||||
"Message from unexpected tnfm: %s, %s" % (id, tnfm_ids)
|
||||
assert isinstance(value, ndict), "Bad message in merge: %s" %message
|
||||
|
||||
tnfms[id].append(value)
|
||||
tnfms[id].append(message)
|
||||
|
||||
# Only pop messages when we have a pending message from
|
||||
# all datasources. Stop if all sources have signalled done.
|
||||
|
||||
while ready(tnfms) and not done(tnfms):
|
||||
message = merge_one(tnfms)
|
||||
assert_merge_protocol(tnfm_ids, message)
|
||||
yield message
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
@@ -51,13 +51,25 @@ def merge(stream_in, tnfm_ids):
|
||||
assert len(queue) == 1, "Bad queue in merge on exit: %s" % queue
|
||||
assert queue[0].dt == "DONE", \
|
||||
"Bad last message in merge on exit: %s" % queue
|
||||
|
||||
|
||||
def merge_one(sources):
|
||||
output = ndict()
|
||||
|
||||
event_fields = ndict()
|
||||
for key, queue in sources.iteritems():
|
||||
new_xform = ndict({key: queue.popleft()})
|
||||
output.merge(new_xform)
|
||||
return output
|
||||
|
||||
# Add transform value to the transforms dict.
|
||||
message = queue.popleft()
|
||||
event_fields[message.tnfm_id] = message.tnfm_value
|
||||
del message['tnfm_id']
|
||||
del message['tnfm_value']
|
||||
|
||||
# Merge any remaining fields into the event dict.
|
||||
event_fields.merge(message)
|
||||
|
||||
# alias dt with datetime, per algoscript api
|
||||
event_fields['datetime'] = event_fields['dt']
|
||||
|
||||
return event_fields
|
||||
|
||||
|
||||
#TODO: This is replicated in sort. Probably should be one source file.
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
from collections import defaultdict, deque
|
||||
|
||||
class Returns(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to the sid's
|
||||
closing price N trading days ago.
|
||||
"""
|
||||
def __init__(self, days):
|
||||
self.days = days
|
||||
self.mapping = defaultdict(self._create)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update and return the calculated returns for this event's sid.
|
||||
"""
|
||||
assert event.has_key('dt')
|
||||
assert event.has_key('price')
|
||||
tracker = self.mapping[event.sid]
|
||||
tracker.update(event)
|
||||
|
||||
return tracker.get_returns()
|
||||
|
||||
def _create(self):
|
||||
return ReturnsFromPriorClose(self.days)
|
||||
|
||||
class ReturnsFromPriorClose(object):
|
||||
"""
|
||||
Records the last N closing events for a given security as well as the
|
||||
last event for the security. When we get an event for a new day, we
|
||||
treat the last event seen as the close for the previous day.
|
||||
"""
|
||||
|
||||
def __init__(self, days):
|
||||
self.closes = deque()
|
||||
self.last_event = None
|
||||
self.returns = 0.0
|
||||
self.days = days
|
||||
|
||||
def get_returns(self):
|
||||
return self.returns
|
||||
|
||||
def update(self, event):
|
||||
|
||||
if self.last_event:
|
||||
|
||||
# Day has changed since the last event we saw. Treat
|
||||
# the last event as the closing price for its day and
|
||||
# clear out the oldest close if it has expired.
|
||||
if self.last_event.dt.date() != event.dt.date():
|
||||
|
||||
self.closes.append(self.last_event)
|
||||
|
||||
# We keep an event for the end of each trading day, so
|
||||
# if the number of stored events is greater than the
|
||||
# number of days we want to track, the oldest close
|
||||
# is expired and should be discarded.
|
||||
while len(self.closes) > self.days:
|
||||
# Pop the oldest event.
|
||||
self.closes.popleft()
|
||||
|
||||
# We only generate a return value once we've seen enough days
|
||||
# to give a sensible value. Would be nice if we could query
|
||||
# db for closes prior to our initial event, but that would
|
||||
# require giving this transform database creds, which we want
|
||||
# to avoid.
|
||||
|
||||
if len(self.closes) == self.days:
|
||||
last_close = self.closes[0].price
|
||||
change = event.price - last_close
|
||||
self.returns = change / last_close
|
||||
|
||||
|
||||
# the current event is now the last_event
|
||||
self.last_event = event
|
||||
@@ -14,7 +14,6 @@ def date_sort(stream_in, source_ids):
|
||||
have messages pending from all sources, we pull the earliest
|
||||
message and yield it.
|
||||
"""
|
||||
|
||||
assert isinstance(source_ids, (list, tuple))
|
||||
|
||||
# Set up an internal queue for each expected source.
|
||||
@@ -28,7 +27,7 @@ def date_sort(stream_in, source_ids):
|
||||
# Incoming messages should be the output of DATASOURCE_UNFRAME.
|
||||
assert_datasource_unframe_protocol(message), \
|
||||
"Bad message in date_sort: %s" % message
|
||||
|
||||
|
||||
# Only allow messages from sources we expect.
|
||||
assert message.source_id in sources, "Unexpected source: %s" % message
|
||||
|
||||
@@ -41,7 +40,7 @@ def date_sort(stream_in, source_ids):
|
||||
message = pop_oldest(sources)
|
||||
assert_sort_protocol(message)
|
||||
yield message
|
||||
|
||||
|
||||
# We should have only a done message left in each queue.
|
||||
for queue in sources.itervalues():
|
||||
assert len(queue) == 1, "Bad queue in date_sort on exit: %s" % queue
|
||||
|
||||
+93
-60
@@ -3,13 +3,14 @@ Tools to generate trade events without a backing store. Useful for testing
|
||||
and zipline development
|
||||
"""
|
||||
import random
|
||||
import pytz
|
||||
|
||||
from itertools import chain, cycle, ifilter, izip
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trade
|
||||
from zipline.gens.utils import hash_args, mock_done
|
||||
from zipline.gens.utils import hash_args, create_trade
|
||||
|
||||
def date_gen(start = datetime(2012, 6, 6, 0),
|
||||
def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
|
||||
delta = timedelta(minutes = 1),
|
||||
count = 100):
|
||||
"""
|
||||
@@ -25,9 +26,9 @@ def mock_prices(count, rand = False):
|
||||
"""
|
||||
|
||||
if rand:
|
||||
return (random.uniform(0.0, 10.0) for i in xrange(count))
|
||||
return (random.uniform(1.0, 10.0) for i in xrange(count))
|
||||
else:
|
||||
return (float(i % 11) for i in xrange(1,count+1))
|
||||
return (float(i % 10) + 1.0 for i in xrange(count))
|
||||
|
||||
def mock_volumes(count, rand = False):
|
||||
"""
|
||||
@@ -49,72 +50,104 @@ def fuzzy_dates(count = 500):
|
||||
for date in date_gen(count = count):
|
||||
yield date + timedelta(seconds = random.randint(-10, 10))
|
||||
|
||||
def SpecificEquityTrades(*args, **config):
|
||||
class SpecificEquityTrades(object):
|
||||
"""
|
||||
Yields all events in event_list that match the given sid_filter.
|
||||
If no event_list is specified, generates an internal stream of events
|
||||
to filter. Returns all events if filter is None.
|
||||
|
||||
Configuration options:
|
||||
|
||||
count : integer representing number of trades
|
||||
sids : list of values representing simulated internal sids
|
||||
start : start date
|
||||
delta : timedelta between internal events
|
||||
filter : filter to remove the sids
|
||||
"""
|
||||
# We shouldn't get any positional arguments.
|
||||
assert args == ()
|
||||
|
||||
# Unpack config dictionary with default values.
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1, 2])
|
||||
start = config.get('start', datetime(2012, 6, 6, 0))
|
||||
delta = config.get('delta', timedelta(minutes = 1))
|
||||
def __init__(self, *args, **kwargs):
|
||||
# We shouldn't get any positional arguments.
|
||||
assert len(args) == 0
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
event_list = config.get('event_list')
|
||||
filter = config.get('filter')
|
||||
# Unpack config dictionary with default values.
|
||||
self.count = kwargs.get('count', 500)
|
||||
self.sids = kwargs.get('sids', [1, 2])
|
||||
self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc))
|
||||
self.delta = kwargs.get('delta', timedelta(minutes = 1))
|
||||
|
||||
arg_string = hash_args(*args, **config)
|
||||
namestring = "SpecificEquityTrades" + arg_string
|
||||
# If we have an event_list, ignore the other arguments and use the list.
|
||||
# TODO: still append our namestring?
|
||||
if event_list:
|
||||
unfiltered = (event for event in event_list)
|
||||
# Default to None for event_list and filter.
|
||||
self.event_list = kwargs.get('event_list')
|
||||
self.filter = kwargs.get('filter')
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count = count, start = start, delta = delta)
|
||||
prices = mock_prices(count)
|
||||
volumes = mock_volumes(count)
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(*args, **kwargs)
|
||||
|
||||
self.generator = self.create_fresh_generator()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
return self.generator.next()
|
||||
|
||||
def rewind(self):
|
||||
self.generator = self.create_fresh_generator()
|
||||
|
||||
def get_hash(self):
|
||||
return self.__class__.__name__ + "-" + self.arg_string
|
||||
|
||||
def create_fresh_generator(self):
|
||||
|
||||
if self.event_list:
|
||||
unfiltered = (event for event in self.event_list)
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
sids = cycle(self.sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = self.get_hash())
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't
|
||||
# match the filter.
|
||||
if self.filter:
|
||||
filtered = ifilter(lambda event: event.sid in self.filter, unfiltered)
|
||||
|
||||
# Otherwise just use all events.
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Return the filtered event stream.
|
||||
return filtered
|
||||
|
||||
|
||||
# !!!!!!! Deprecated for now !!!!!!!!!
|
||||
|
||||
def RandomEquityTrades(object):
|
||||
|
||||
def __init__(self):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
self.count = config.get('count', 500)
|
||||
self.sids = config.get('sids', [1,2])
|
||||
self.filter = config.get('filter')
|
||||
|
||||
dates = fuzzy_dates(count)
|
||||
prices = mock_prices(count, rand = True)
|
||||
volumes = mock_volumes(count, rand = True)
|
||||
sids = cycle(sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = namestring)
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't match the filter.
|
||||
if filter:
|
||||
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
|
||||
|
||||
# Otherwise just use all events.
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Add a done message to the end of the stream. For a live
|
||||
# datasource this would be handled by the containing Component.
|
||||
out = chain(filtered, [mock_done(namestring)])
|
||||
return out
|
||||
|
||||
def RandomEquityTrades(*args, **config):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1,2])
|
||||
filter = config.get('filter')
|
||||
|
||||
dates = fuzzy_dates(count)
|
||||
prices = mock_prices(count, rand = True)
|
||||
volumes = mock_volumes(count, rand = True)
|
||||
sids = cycle(sids)
|
||||
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
unfiltered = (create_trade(*args) for args in arg_gen)
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
from logbook import Logger, Processor
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from numbers import Integral
|
||||
|
||||
from zipline import ndict
|
||||
|
||||
from zipline.gens.transform import StatefulTransform
|
||||
from zipline.finance.trading import TransactionSimulator
|
||||
from zipline.finance.performance import PerformanceTracker
|
||||
from zipline.utils.log_utils import stdout_only_pipe
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
class TradeSimulationClient(object):
|
||||
"""
|
||||
Generator that takes the expected output of a merge, a user
|
||||
algorithm, a trading environment, and a simulator style as
|
||||
arguments. Pipes the merge stream through a TransactionSimulator
|
||||
and a PerformanceTracker, which keep track of the current state of
|
||||
our algorithm's simulated universe. Results are fed to the user's
|
||||
algorithm, which directly inserts transactions into the
|
||||
TransactionSimulator's order book.
|
||||
|
||||
TransactionSimulator maintains a dictionary from sids to the
|
||||
unfulfilled orders placed by the user's algorithm. As trade
|
||||
events arrive, if the algorithm has open orders against the
|
||||
trade's sid, the simulator will fill orders up to 25% of market
|
||||
cap. Applied transactions are added to a txn field on the event
|
||||
and forwarded to PerformanceTracker. The txn field is set to None
|
||||
on non-trade events and events that do not match any open orders.
|
||||
|
||||
PerformanceTracker receives the updated event messages from
|
||||
TransactionSimulator, maintaining a set of daily and cumulative
|
||||
performance metrics for the algorithm. The tracker removes the
|
||||
txn field from each event it receives, replacing it with a
|
||||
portfolio field to be fed into the user algo. At the end of each
|
||||
trading day, the PerformanceTracker also generates a daily
|
||||
performance report, which is appended to event's perf_report
|
||||
field.
|
||||
|
||||
Fully processed events are run through a batcher generator, which
|
||||
batches together events with the same dt field into a single event
|
||||
to be fed to the algo. The portfolio object is repeatedly
|
||||
overwritten so that only the most recent snapshot of the universe
|
||||
is sent to the algo.
|
||||
"""
|
||||
|
||||
def __init__(self, algo, environment, sim_style):
|
||||
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
self.environment = environment
|
||||
self.style = sim_style
|
||||
self.algo_sim = None
|
||||
|
||||
def get_hash(self):
|
||||
"""
|
||||
There should only ever be one TSC in the system.
|
||||
"""
|
||||
return self.__class__.__name__ + hash_args()
|
||||
|
||||
def simulate(self, stream_in):
|
||||
"""
|
||||
Main generator work loop.
|
||||
"""
|
||||
|
||||
# Simulate filling any open orders made by the previous run of
|
||||
# the user's algorithm. Sets the txn field to true on any
|
||||
# event that results in a filled order.
|
||||
ordering_client = StatefulTransform(
|
||||
TransactionSimulator,
|
||||
self.sids,
|
||||
style = self.style
|
||||
)
|
||||
with_filled_orders = ordering_client.transform(stream_in)
|
||||
|
||||
# Pipe the events with transactions to perf. This will remove
|
||||
# the txn field added by TransactionSimulator and replace it
|
||||
# with a portfolio object to be passed to the user's
|
||||
# algorithm. Also adds a perf_message field which is usually
|
||||
# none, but contains an update message once per day.
|
||||
perf_tracker = StatefulTransform(
|
||||
PerformanceTracker,
|
||||
self.environment,
|
||||
self.sids
|
||||
)
|
||||
with_portfolio = perf_tracker.transform(with_filled_orders)
|
||||
|
||||
# Pass the messages from perf along with the trading client's
|
||||
# state into the algorithm for simulation. We provide the
|
||||
# trading client so that the algorithm can place new orders
|
||||
# into the client's order book.
|
||||
self.algo_sim = AlgorithmSimulator(
|
||||
with_portfolio,
|
||||
ordering_client.state,
|
||||
self.algo,
|
||||
)
|
||||
|
||||
# The algorithm will yield a daily_results message (as
|
||||
# calculated by the performance tracker) at the end of each
|
||||
# day. It will also yield a risk report at the end of the
|
||||
# simulation.
|
||||
for message in self.algo_sim:
|
||||
yield message
|
||||
|
||||
class AlgorithmSimulator(object):
|
||||
|
||||
def __init__(self, stream_in, order_book, algo):
|
||||
|
||||
self.stream_in = stream_in
|
||||
|
||||
# ==========
|
||||
# Algo Setup
|
||||
# ==========
|
||||
|
||||
# We extract the order book from the txn client so that
|
||||
# the algo can place new orders.
|
||||
self.order_book = order_book
|
||||
|
||||
self.algo = algo
|
||||
self.sids = algo.get_sid_filter()
|
||||
|
||||
# Monkey patch the user algorithm to place orders in the
|
||||
# TransactionSimulator's order book.
|
||||
self.algo.set_order(self.order)
|
||||
self.algo.set_logger(Logger("AlgoLog"))
|
||||
|
||||
|
||||
# ==============
|
||||
# Snapshot Setup
|
||||
# ==============
|
||||
|
||||
# The algorithm's universe as of our most recent event.
|
||||
self.universe = ndict()
|
||||
|
||||
for sid in self.sids:
|
||||
self.universe[sid] = ndict()
|
||||
self.universe.portfolio = None
|
||||
|
||||
# We don't have a datetime for the current snapshot until we
|
||||
# receive a message.
|
||||
self.simulation_dt = None
|
||||
self.this_snapshot_dt = None
|
||||
|
||||
# =============
|
||||
# Logging Setup
|
||||
# =============
|
||||
|
||||
# Processor function for injecting the algo_dt into
|
||||
# user prints/logs.
|
||||
def inject_algo_dt(record):
|
||||
record.extra['algo_dt'] = self.this_snapshot_dt
|
||||
self.processor = Processor(inject_algo_dt)
|
||||
|
||||
# This is a class, which is instantiated later
|
||||
# in run_algorithm. The class provides a generator.
|
||||
self.stdout_capture = stdout_only_pipe
|
||||
|
||||
self.__generator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.__generator:
|
||||
return self.__generator.next()
|
||||
else:
|
||||
self.__generator = self._gen()
|
||||
return self.__generator.next()
|
||||
|
||||
def order(self, sid, amount):
|
||||
"""
|
||||
Closure to pass into the user's algo to allow placing orders
|
||||
into the txn_sim's dict of open orders.
|
||||
"""
|
||||
assert sid in self.sids, "Order on invalid sid: %i" % sid
|
||||
order = ndict({
|
||||
'dt' : self.simulation_dt,
|
||||
'sid' : sid,
|
||||
'amount' : int(amount),
|
||||
'filled' : 0
|
||||
})
|
||||
|
||||
# Tell the user if they try to buy 0 shares of something.
|
||||
if order.amount == 0:
|
||||
zero_message = "Requested to trade zero shares of {sid}".format(
|
||||
sid=order.sid
|
||||
)
|
||||
log.debug(zero_message)
|
||||
# Don't bother placing orders for 0 shares.
|
||||
return
|
||||
|
||||
# Add non-zero orders to the order book.
|
||||
# !!!IMPORTANT SIDE-EFFECT!!!
|
||||
# This modifies the internal state of the transaction
|
||||
# simulator so that it can fill the placed order when it
|
||||
# receives its next message.
|
||||
self.order_book.place_order(order)
|
||||
|
||||
def _gen(self):
|
||||
"""
|
||||
Internal generator work loop.
|
||||
"""
|
||||
# Capture any output of this generator to stdout and pipe it
|
||||
# to a logbook interface. Also inject the current algo
|
||||
# snapshot time to any log record generated.
|
||||
|
||||
with self.processor.threadbound(), self.stdout_capture(Logger('Print'),''):
|
||||
# Call the user's initialize method.
|
||||
self.algo.initialize()
|
||||
|
||||
for event in self.stream_in:
|
||||
# Yield any perf messages received to be relayed back to
|
||||
# the browser.
|
||||
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
|
||||
if event.dt == "DONE":
|
||||
if self.this_snapshot_dt:
|
||||
# stop iteration happened
|
||||
# mid-snapshot, so we have a universe
|
||||
# snapshot that is not yet processed
|
||||
# by the algorithm.
|
||||
self.simulate_current_snapshot()
|
||||
break
|
||||
|
||||
# This should only happen for the first event we run.
|
||||
if self.simulation_dt == None:
|
||||
self.simulation_dt = event.dt
|
||||
|
||||
# ======================
|
||||
# Time Compression Logic
|
||||
# ======================
|
||||
|
||||
if self.this_snapshot_dt != None:
|
||||
self.update_current_snapshot(event)
|
||||
|
||||
# The algorithm has been missing events because it took
|
||||
# too long processing. Update the universe with data from
|
||||
# this event, then check if enough time has passed that we
|
||||
# can start a new snapshot.
|
||||
else:
|
||||
self.update_universe(event)
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
|
||||
|
||||
|
||||
def update_current_snapshot(self, event):
|
||||
"""
|
||||
Update our current snapshot of the universe. Call handle_data if
|
||||
"""
|
||||
# The new event matches our snapshot dt. Just update the
|
||||
# universe and move on.
|
||||
if event.dt == self.this_snapshot_dt:
|
||||
self.update_universe(event)
|
||||
|
||||
# The new event does not match our snapshot.
|
||||
else:
|
||||
self.simulate_current_snapshot()
|
||||
|
||||
# Once we've finished simulating the old snapshot,
|
||||
# we can update the universe with the new event.
|
||||
self.update_universe(event)
|
||||
|
||||
# The current event is later than the simulation time,
|
||||
# which means the algorithm finished quickly enough to
|
||||
# receive the new event. Start a new snapshot with this
|
||||
# event's dt.
|
||||
if event.dt >= self.simulation_dt:
|
||||
self.this_snapshot_dt = event.dt
|
||||
|
||||
# The algorithm spent enough time processing that it
|
||||
# missed the new event. Wait to start a new snapshot until
|
||||
# the events catch up to the algo's simulated dt.
|
||||
else:
|
||||
self.this_snapshot_dt = None
|
||||
|
||||
def simulate_current_snapshot(self):
|
||||
"""
|
||||
Run the user's algo against our current snapshot and update the algo's
|
||||
simulated time.
|
||||
"""
|
||||
start_tic = datetime.now()
|
||||
self.algo.handle_data(self.universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
# How long did you take?
|
||||
delta = stop_tic - start_tic
|
||||
|
||||
# Update the simulation time.
|
||||
self.simulation_dt = self.this_snapshot_dt + delta
|
||||
|
||||
def update_universe(self, event):
|
||||
"""
|
||||
Update the universe with new event information.
|
||||
"""
|
||||
# Update our portfolio.
|
||||
self.universe.portfolio = event.portfolio
|
||||
|
||||
# Update our knowledge of this event's sid
|
||||
for field in event.keys():
|
||||
self.universe[event.sid][field] = event[field]
|
||||
+168
-106
@@ -2,16 +2,21 @@
|
||||
Generator versions of transforms.
|
||||
"""
|
||||
import types
|
||||
import pytz
|
||||
|
||||
from datetime import datetime
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque, defaultdict
|
||||
from numbers import Number
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.utils.tradingcalendar import trading_days_between
|
||||
from zipline.gens.utils import assert_sort_unframe_protocol, \
|
||||
assert_transform_protocol, hash_args
|
||||
|
||||
class Passthrough(object):
|
||||
FORWARDER = True
|
||||
"""
|
||||
Trivial class for forwarding events.
|
||||
"""
|
||||
@@ -19,11 +24,9 @@ class Passthrough(object):
|
||||
pass
|
||||
|
||||
def update(self, event):
|
||||
assert isinstance(event, ndict),"Bad event in Passthrough: %s" % event
|
||||
assert event.has_key('sid'), "No sid in Passthrough: %s" % event
|
||||
assert event.has_key('dt'), "No dt in Passthorughz: %s" % event
|
||||
return event
|
||||
pass
|
||||
|
||||
# Deprecated
|
||||
def functional_transform(stream_in, func, *args, **kwargs):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
@@ -40,131 +43,195 @@ def functional_transform(stream_in, func, *args, **kwargs):
|
||||
assert_transform_protocol(out_value)
|
||||
yield(namestring, out_value)
|
||||
|
||||
def stateful_transform(stream_in, tnfm_class, *args, **kwargs):
|
||||
class StatefulTransform(object):
|
||||
"""
|
||||
Generic transform generator that takes each message from an in-stream
|
||||
and sorts it to a state class. For each call to update, the state
|
||||
class must produce a message to be fed downstream.
|
||||
Generic transform generator that takes each message from an
|
||||
in-stream and passes it to a state object. For each call to
|
||||
update, the state class must produce a message to be fed
|
||||
downstream. Any transform class with the FORWARDER class variable
|
||||
set to true will forward all fields in the original message.
|
||||
Otherwise only dt, tnfm_id, and tnfm_value are forwarded.
|
||||
"""
|
||||
|
||||
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
|
||||
def __init__(self, tnfm_class, *args, **kwargs):
|
||||
assert isinstance(tnfm_class, (types.ObjectType, types.ClassType)), \
|
||||
"Stateful transform requires a class."
|
||||
assert tnfm_class.__dict__.has_key('update'), \
|
||||
assert tnfm_class.__dict__.has_key('update'), \
|
||||
"Stateful transform requires the class to have an update method"
|
||||
|
||||
# Create an instance of our transform class.
|
||||
state = tnfm_class(*args, **kwargs)
|
||||
self.forward_all = tnfm_class.__dict__.get('FORWARDER', False)
|
||||
self.update_in_place = tnfm_class.__dict__.get('UPDATER', False)
|
||||
self.append_value = tnfm_class.__dict__.get('APPENDER', False)
|
||||
|
||||
# Generate the string associated with this generator's output.
|
||||
namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
# You only one special behavior mode can be set.
|
||||
assert sum(map(int, [self.forward_all,
|
||||
self.update_in_place,
|
||||
self.append_value])) <= 1
|
||||
|
||||
for message in stream_in:
|
||||
assert_sort_unframe_protocol(message)
|
||||
out_value = state.update(message)
|
||||
assert_transform_protocol(out_value)
|
||||
yield (namestring, out_value)
|
||||
# Create an instance of our transform class.
|
||||
self.state = tnfm_class(*args, **kwargs)
|
||||
|
||||
class MovingAverage(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to EventWindows
|
||||
Upon receipt of each message we update the
|
||||
corresponding window and return the calculated average.
|
||||
# Create the string associated with this generator's output.
|
||||
self.namestring = tnfm_class.__name__ + hash_args(*args, **kwargs)
|
||||
|
||||
def get_hash(self):
|
||||
return self.namestring
|
||||
|
||||
def transform(self, stream_in):
|
||||
return self._gen(stream_in)
|
||||
|
||||
def _gen(self, stream_in):
|
||||
# IMPORTANT: Messages may contain pointers that are shared with
|
||||
# other streams, so we only manipulate copies.
|
||||
|
||||
for message in stream_in:
|
||||
|
||||
# allow upstream generators to yield None to avoid
|
||||
# blocking.
|
||||
if message == None:
|
||||
continue
|
||||
|
||||
#TODO: refactor this to avoid unnecessary copying.
|
||||
|
||||
assert_sort_unframe_protocol(message)
|
||||
message_copy = deepcopy(message)
|
||||
|
||||
# Same shared pointer issue here as above.
|
||||
tnfm_value = self.state.update(deepcopy(message_copy))
|
||||
|
||||
# FORWARDER flag means we want to keep all original
|
||||
# values, plus append tnfm_id and tnfm_value. Used for
|
||||
# preserving the original event fields when our output
|
||||
# will be fed into a merge.
|
||||
if self.forward_all:
|
||||
out_message = message_copy
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# UPDATER flag should be used for transforms that
|
||||
# side-effectfully modify the event they are passed.
|
||||
# Updated messages are passed along exactly as they are
|
||||
# returned to use by our state class. Useful for chaining
|
||||
# specific transforms that won't be fed to a merge. (See
|
||||
# the implementation of TradeSimulationClient for example
|
||||
# usage of this flag with PerformanceTracker and
|
||||
# TransactionSimulator.
|
||||
elif self.update_in_place:
|
||||
yield tnfm_value
|
||||
|
||||
# APPENDER flag should be used to add a single new
|
||||
# key-value pair to the event. The new key is this
|
||||
# transform's namestring, and it's value is the value
|
||||
# returned by state.update(event). This is almost
|
||||
# identical to the behavior of FORWARDER, except we
|
||||
# compress the two calculated values (tnfm_id, and
|
||||
# tnfm_value) into a single field. This mode is used by
|
||||
# the sequential_transforms composite.
|
||||
elif self.append_value:
|
||||
out_message = message_copy
|
||||
out_message[self.namestring] = tnfm_value
|
||||
yield out_message
|
||||
|
||||
# If no flags are set, we create a new message containing
|
||||
# just the tnfm_id, the event's datetime, and the
|
||||
# calculated tnfm_value. This is the default behavior for
|
||||
# a transform being fed into a merge.
|
||||
else:
|
||||
out_message = ndict()
|
||||
out_message.tnfm_id = self.namestring
|
||||
out_message.tnfm_value = tnfm_value
|
||||
out_message.dt = message_copy.dt
|
||||
yield out_message
|
||||
|
||||
class EventWindow:
|
||||
"""
|
||||
Abstract base class for transform classes that calculate iterative
|
||||
metrics on events within a given timedelta. Maintains a list of
|
||||
events that are within a certain timedelta of the most recent
|
||||
tick. Calls self.handle_add(event) for each event added to the
|
||||
window. Calls self.handle_remove(event) for each event removed
|
||||
from the window. Subclass these methods along with init(*args,
|
||||
**kwargs) to calculate metrics over the window.
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
The market_aware flag is used to toggle whether the eventwindow
|
||||
calculates
|
||||
|
||||
See zipline/gens/mavg.py and zipline/gens/vwap.py for example
|
||||
implementations of moving average and volume-weighted average
|
||||
price.
|
||||
"""
|
||||
# Mark this as an abstract base class.
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, market_aware, days = None, delta = None):
|
||||
|
||||
self.market_aware = market_aware
|
||||
self.days = days
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
def create_window(self):
|
||||
"""Factory method for self.sid_windows."""
|
||||
return EventWindow(self.delta, self.fields)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Return an ndict from
|
||||
tracked fields to averages.
|
||||
"""
|
||||
|
||||
assert isinstance(event, ndict),"Bad event in MovingAverage: %s" % event
|
||||
assert event.has_key('sid'), "No sid in MovingAverage: %s" % event
|
||||
assert event.has_key('dt'), "No dt in MovingAverage: %s" % event
|
||||
|
||||
output = ndict({'sid': event.sid, 'dt': event.dt})
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
averages = window.get_averages()
|
||||
|
||||
# Return the calculated averages along with
|
||||
output.merge(averages)
|
||||
return output
|
||||
|
||||
class EventWindow(object):
|
||||
"""
|
||||
Maintains a list of events that are within a certain timedelta
|
||||
of the most recent tick. The expected use of this class is to
|
||||
track events associated with a single sid. We provide simple
|
||||
functionality for averages, but anything more complicated
|
||||
should be handled by a containing class.
|
||||
"""
|
||||
|
||||
def __init__(self, delta, fields):
|
||||
self.ticks = deque()
|
||||
self.delta = delta
|
||||
self.fields = fields
|
||||
self.totals = defaultdict(float)
|
||||
|
||||
# Market-aware mode only works with full-day windows.
|
||||
if self.market_aware:
|
||||
assert self.days and not self.delta,\
|
||||
"Market-aware mode only works with full-day windows."
|
||||
|
||||
# Non-market-aware mode requires a timedelta.
|
||||
else:
|
||||
assert self.delta and not self.days, \
|
||||
"Non-market-aware mode requires a timedelta."
|
||||
|
||||
# Set the behavior for dropping events from the back of the
|
||||
# event window.
|
||||
if self.market_aware:
|
||||
self.drop_condition = self.out_of_market_window
|
||||
else:
|
||||
self.drop_condition = self.out_of_delta
|
||||
|
||||
@abstractmethod
|
||||
def handle_add(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def handle_remove(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ticks)
|
||||
|
||||
def update(self, event):
|
||||
self.assert_well_formed(event)
|
||||
|
||||
# Add new event and increment totals.
|
||||
self.ticks.append(event)
|
||||
for field in self.fields:
|
||||
self.totals[field] += event[field]
|
||||
|
||||
# We return a list of all out-of-range events we removed.
|
||||
out_of_range = []
|
||||
# Subclasses should override handle_add to define behavior for
|
||||
# adding new ticks.
|
||||
self.handle_add(event)
|
||||
|
||||
# Clear out expired events, decrementing totals.
|
||||
# newest oldest
|
||||
# | |
|
||||
# V V
|
||||
# Clear out any expired events. drop_condition changes depending
|
||||
# on whether or not we are running in market_aware mode.
|
||||
#
|
||||
# oldest newest
|
||||
# | |
|
||||
# V V
|
||||
while self.drop_condition(self.ticks[0].dt, self.ticks[-1].dt):
|
||||
|
||||
while (self.ticks[-1].dt - self.ticks[0].dt) >= self.delta:
|
||||
# popleft removes and returns ticks[0]
|
||||
# popleft removes and returns the oldest tick in self.ticks
|
||||
popped = self.ticks.popleft()
|
||||
# Decrement totals
|
||||
for field in self.fields:
|
||||
self.totals[field] -= popped[field]
|
||||
# Add the popped element to the list of dropped events.
|
||||
out_of_range.append(popped)
|
||||
|
||||
return out_of_range
|
||||
# Subclasses should override handle_remove to define
|
||||
# behavior for removing ticks.
|
||||
self.handle_remove(popped)
|
||||
|
||||
def average(self, field):
|
||||
assert field in self.fields
|
||||
if len(self.ticks) == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return self.totals[field] / len(self.ticks)
|
||||
def out_of_market_window(self, oldest, newest):
|
||||
return trading_days_between(oldest, newest) >= self.days
|
||||
|
||||
def get_averages(self):
|
||||
"""
|
||||
Return an ndict of all our tracked averages.
|
||||
"""
|
||||
out = ndict()
|
||||
# out.ticks = len(self.ticks)
|
||||
for field in self.fields:
|
||||
out[field] = self.average(field)
|
||||
return out
|
||||
def out_of_delta(self, oldest, newest):
|
||||
return (newest - oldest) >= self.delta
|
||||
|
||||
# All event windows expect to receive events with datetime fields
|
||||
# that arrive in sorted order.
|
||||
def assert_well_formed(self, event):
|
||||
assert isinstance(event, ndict), "Bad event in EventWindow:%s" % event
|
||||
assert event.has_key('dt'), "Missing dt in EventWindow:%s" % event
|
||||
@@ -173,8 +240,3 @@ class EventWindow(object):
|
||||
# Something is wrong if new event is older than previous.
|
||||
assert event.dt >= self.ticks[-1].dt, \
|
||||
"Events arrived out of order in EventWindow: %s -> %s" % (event, self.ticks[0])
|
||||
for field in self.fields:
|
||||
assert event.has_key(field), \
|
||||
"Event missing [%s] in EventWindow" % field
|
||||
assert isinstance(event[field], Number), \
|
||||
"Got %s for %s in EventWindow" % (event[field], field)
|
||||
|
||||
+44
-7
@@ -1,6 +1,7 @@
|
||||
import pytz
|
||||
import numbers
|
||||
|
||||
from collections import OrderedDict
|
||||
from hashlib import md5
|
||||
from datetime import datetime
|
||||
from itertools import izip_longest
|
||||
@@ -16,8 +17,16 @@ def mock_raw_event(sid, dt):
|
||||
}
|
||||
return event
|
||||
|
||||
def mock_done(source_id):
|
||||
return ndict({'dt': "DONE", "source_id" : source_id, 'type' : 0})
|
||||
def mock_done(id):
|
||||
return ndict({
|
||||
'dt' : "DONE",
|
||||
"source_id" : id,
|
||||
'tnfm_id' : id,
|
||||
'tnfm_value': None,
|
||||
'type' : DATASOURCE_TYPE.DONE
|
||||
})
|
||||
|
||||
done_message = mock_done
|
||||
|
||||
def alternate(g1, g2):
|
||||
"""Specialized version of roundrobin for just 2 generators."""
|
||||
@@ -27,16 +36,25 @@ def alternate(g1, g2):
|
||||
if e2 != None:
|
||||
yield e2
|
||||
|
||||
def roundrobin(*args):
|
||||
def roundrobin(sources, namestrings):
|
||||
"""
|
||||
Takes N generators, pulling one element off each until all inputs
|
||||
are empty.
|
||||
"""
|
||||
for elem_tuple in izip_longest(*args):
|
||||
for value in elem_tuple:
|
||||
if value != None:
|
||||
yield value
|
||||
assert len(sources) == len(namestrings)
|
||||
mapping = OrderedDict(zip(namestrings, sources))
|
||||
|
||||
# While our generators have not been exhausted, pull elements
|
||||
while mapping.keys() != []:
|
||||
for namestring, source in mapping.iteritems():
|
||||
try:
|
||||
message = source.next()
|
||||
# allow sources to yield None to avoid blocking.
|
||||
if message:
|
||||
yield message
|
||||
except StopIteration:
|
||||
yield done_message(namestring)
|
||||
del mapping[namestring]
|
||||
|
||||
def hash_args(*args, **kwargs):
|
||||
"""Define a unique string for any set of representable args."""
|
||||
@@ -48,6 +66,25 @@ def hash_args(*args, **kwargs):
|
||||
hasher.update(combined)
|
||||
return hasher.hexdigest()
|
||||
|
||||
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
|
||||
row = ndict({
|
||||
'source_id' : source_id,
|
||||
'type' : DATASOURCE_TYPE.TRADE,
|
||||
'sid' : sid,
|
||||
'dt' : datetime,
|
||||
'price' : price,
|
||||
'volume' : amount
|
||||
})
|
||||
return row
|
||||
|
||||
def sum_true(bool_iterable):
|
||||
"""
|
||||
Takes an iterable of boolean values and returns the number of
|
||||
those values that are True.
|
||||
"""
|
||||
return sum(map(int, bool_iterable))
|
||||
|
||||
|
||||
def assert_datasource_protocol(event):
|
||||
"""Assert that an event meets the protocol for datasource outputs."""
|
||||
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
from numbers import Number
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from zipline import ndict
|
||||
from zipline.gens.transform import EventWindow
|
||||
|
||||
class VWAP(object):
|
||||
"""
|
||||
Class that maintains a dictionary from sids to VWAPEventWindows.
|
||||
"""
|
||||
def __init__(self, market_aware, delta=None, days=None):
|
||||
|
||||
self.market_aware = market_aware
|
||||
self.delta = delta
|
||||
self.days = days
|
||||
|
||||
# Market-aware mode only works with full-day windows.
|
||||
if self.market_aware:
|
||||
assert self.days and not self.delta,\
|
||||
"Market-aware mode only works with full-day windows."
|
||||
|
||||
# Non-market-aware mode requires a timedelta.
|
||||
else:
|
||||
assert self.delta and not self.days, \
|
||||
"Non-market-aware mode requires a timedelta."
|
||||
|
||||
# No way to pass arguments to the defaultdict factory, so we
|
||||
# need to define a method to generate the correct EventWindows.
|
||||
self.sid_windows = defaultdict(self.create_window)
|
||||
|
||||
def create_window(self):
|
||||
"""Factory method for self.sid_windows."""
|
||||
return VWAPEventWindow(
|
||||
self.market_aware,
|
||||
days = self.days,
|
||||
delta = self.delta
|
||||
)
|
||||
|
||||
def update(self, event):
|
||||
"""
|
||||
Update the event window for this event's sid. Returns the
|
||||
current vwap for the sid.
|
||||
"""
|
||||
# This will create a new EventWindow if this is the first
|
||||
# message for this sid.
|
||||
window = self.sid_windows[event.sid]
|
||||
window.update(event)
|
||||
return window.get_vwap()
|
||||
|
||||
class VWAPEventWindow(EventWindow):
|
||||
"""
|
||||
Iteratively maintains a vwap for a single sid over a given
|
||||
timedelta.
|
||||
"""
|
||||
def __init__(self, market_aware, days=None, delta=None):
|
||||
EventWindow.__init__(self, market_aware, days, delta)
|
||||
self.flux = 0.0
|
||||
self.totalvolume = 0.0
|
||||
|
||||
# Subclass customization for adding new events.
|
||||
def handle_add(self, event):
|
||||
# Sanity check on the event.
|
||||
self.assert_required_fields(event)
|
||||
self.flux += event.volume * event.price
|
||||
self.totalvolume += event.volume
|
||||
|
||||
# Subclass customization for removing expired events.
|
||||
def handle_remove(self, event):
|
||||
self.flux -= event.volume * event.price
|
||||
self.totalvolume -= event.volume
|
||||
|
||||
def get_vwap(self):
|
||||
"""
|
||||
Return the calculated vwap for this sid.
|
||||
"""
|
||||
# By convention, vwap is None if we have no events.
|
||||
if len(self.ticks) == 0:
|
||||
return None
|
||||
else:
|
||||
return (self.flux / self.totalvolume)
|
||||
|
||||
# We need numerical price and volume to calculate a vwap.
|
||||
def assert_required_fields(self, event):
|
||||
assert isinstance(event.price, Number)
|
||||
assert isinstance(event.volume, Number)
|
||||
@@ -2,15 +2,17 @@ import zmq
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
def gen_from_zmq(poller, unframe):
|
||||
def gen_from_zmq(poller, unframe, namestring):
|
||||
"""
|
||||
A generator that takes an initialized zmq poller and yields
|
||||
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
|
||||
"""
|
||||
while True:
|
||||
message = poller.recv()
|
||||
if message = zp.CONTROL_PROTOCOL.DONE:
|
||||
yield "DONE"
|
||||
# Done protocol should now be a message type so that
|
||||
# done messages can also have source_ids.
|
||||
if message.type == zp.CONTROL_PROTOCOL.DONE:
|
||||
yield done_message(message.source_id)
|
||||
break
|
||||
else:
|
||||
yield unframe(message)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import zmq
|
||||
import zipline.protocol as zp
|
||||
|
||||
def gen_from_pull_socket(socket_uri, context, unframe):
|
||||
"""
|
||||
A generator that takes a socket_uri, and yields
|
||||
messages from the poller until it gets a zp.CONTROL_PROTOCOL.DONE.
|
||||
"""
|
||||
pull_socket = context.socket(zmq.PULL)
|
||||
pull_socket.connect(socket_uri)
|
||||
poller = zmq.Poller()
|
||||
poller.register(pull_socket, zmq.POLLIN)
|
||||
|
||||
return gen_from_poller(poller, pull_socket, unframe)
|
||||
|
||||
|
||||
# this generator needs to know about the source_ids coming in via
|
||||
# the poller, and need to yield DONE messages for each
|
||||
# source_id.
|
||||
+188
-233
@@ -59,106 +59,190 @@ before invoking simulate.
|
||||
| __init__. |
|
||||
+---------------------------------+
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logbook
|
||||
import zipline.utils.factory as factory
|
||||
|
||||
from zipline.components import DataSource
|
||||
from zipline.transforms import BaseTransform
|
||||
import sys
|
||||
import zmq
|
||||
import os
|
||||
from signal import SIGHUP, SIGINT
|
||||
import multiprocessing
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
from zipline.components import TradeSimulationClient
|
||||
from zipline.core.process import ProcessSimulator
|
||||
from zipline.core.monitor import Controller
|
||||
from zipline.finance.trading import SIMULATION_STYLE
|
||||
from zipline.utils.log_utils import ZeroMQLogHandler, stdout_only_pipe
|
||||
from zipline.utils import factory
|
||||
|
||||
log = logbook.Logger('Lines')
|
||||
from zipline.test_algorithms import TestAlgorithm
|
||||
|
||||
from zipline.gens.composites import \
|
||||
date_sorted_sources, merged_transforms, sequential_transforms
|
||||
from zipline.gens.transform import Passthrough, StatefulTransform
|
||||
from zipline.gens.tradesimulation import TradeSimulationClient as tsc
|
||||
from logbook import Logger, NestedSetup, Processor
|
||||
|
||||
import zipline.protocol as zp
|
||||
|
||||
log = Logger('Lines')
|
||||
|
||||
|
||||
class CancelSignal(Exception):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class SimulatedTrading(object):
|
||||
"""
|
||||
Zipline with::
|
||||
|
||||
- _no_ data sources.
|
||||
- Trade simulation client, which is available to send callbacks on
|
||||
events and also accept orders to be simulated.
|
||||
- An order data source, which will receive orders from the trade
|
||||
simulation client, and feed them into the event stream to be
|
||||
serialized and order alongside all other data source events.
|
||||
- transaction simulation transformation, which receives the order
|
||||
events and estimates a theoretical execution price and volume.
|
||||
def __init__(self,
|
||||
sources,
|
||||
transforms,
|
||||
algorithm,
|
||||
environment,
|
||||
style,
|
||||
results_socket_uri,
|
||||
context,
|
||||
sim_id):
|
||||
|
||||
All components in this zipline are subject to heartbeat checks and
|
||||
a control monitor, which can kill the entire zipline in the event of
|
||||
exceptions in one of the components or an external request to end the
|
||||
simulation.
|
||||
"""
|
||||
self.date_sorted = date_sorted_sources(*sources)
|
||||
self.transforms = transforms
|
||||
# Formerly merged_transforms.
|
||||
self.with_tnfms = sequential_transforms(self.date_sorted, *self.transforms)
|
||||
self.trading_client = tsc(algorithm, environment, style)
|
||||
self.gen = self.trading_client.simulate(self.with_tnfms)
|
||||
self.results_uri = results_socket_uri
|
||||
self.results_socket = None
|
||||
self.context = context
|
||||
self.sim_id = sim_id
|
||||
|
||||
def __init__(self, **config):
|
||||
# optional process if we fork simulate into an
|
||||
# independent process.
|
||||
self.proc = None
|
||||
self.send_sighup = False
|
||||
self.logger = Logger(sim_id)
|
||||
self.print_logger = Logger('Print')
|
||||
|
||||
# exit status flag
|
||||
self.success = False
|
||||
|
||||
|
||||
def simulate(self, blocking=True, send_sighup=False):
|
||||
|
||||
# for non-blocking,
|
||||
if blocking:
|
||||
self.run_gen()
|
||||
else:
|
||||
self.send_sighup = send_sighup
|
||||
return self.fork_and_sim()
|
||||
|
||||
def fork_and_sim(self):
|
||||
self.proc = multiprocessing.Process(target=self.run_gen)
|
||||
self.proc.start()
|
||||
return self.proc
|
||||
|
||||
def run_gen(self):
|
||||
setproctitle(self.sim_id)
|
||||
self.open()
|
||||
if self.zmq_out:
|
||||
with self.zmq_out.threadbound():
|
||||
self.stream_results()
|
||||
# if no log socket, just run the algo normally
|
||||
else:
|
||||
self.stream_results()
|
||||
|
||||
def stream_results(self):
|
||||
assert self.results_socket, \
|
||||
"Results socket must exist to stream results"
|
||||
try:
|
||||
for event in self.gen:
|
||||
if event.has_key('daily_perf'):
|
||||
msg = zp.PERF_FRAME(event)
|
||||
else:
|
||||
msg = zp.RISK_FRAME(event)
|
||||
self.results_socket.send(msg)
|
||||
|
||||
self.signal_done()
|
||||
self.success = True
|
||||
except Exception as exc:
|
||||
self.handle_exception(exc)
|
||||
finally:
|
||||
# not much to do besides log our exit.
|
||||
self.close()
|
||||
|
||||
def signal_done(self):
|
||||
# notify monitor we're done
|
||||
done_frame = zp.DONE_FRAME('success')
|
||||
self.results_socket.send(done_frame)
|
||||
|
||||
def close(self):
|
||||
log.info("Closing Simulation: {id}".format(id=self.sim_id))
|
||||
if self.proc and self.send_sighup:
|
||||
ppid = os.getppid()
|
||||
if self.success:
|
||||
log.warning("Sending SIGHUP")
|
||||
os.kill(ppid, SIGHUP)
|
||||
else:
|
||||
log.warning("Sending SIGINT")
|
||||
os.kill(ppid, SIGINT)
|
||||
|
||||
def handle_exception(self, exc):
|
||||
if isinstance(exc, CancelSignal):
|
||||
# signal from monitor of an orderly shutdown,
|
||||
# do nothing.
|
||||
pass
|
||||
else:
|
||||
self.signal_exception(exc)
|
||||
|
||||
def signal_exception(self, exc=None):
|
||||
"""
|
||||
:param config: a dict with the following required properties::
|
||||
All exceptions inside any component should boil back to
|
||||
this handler.
|
||||
|
||||
- algorithm: a class that follows the algorithm protocol. See
|
||||
:py:meth:`zipline.finance.trading.TradeSimulationClient.add_algorithm
|
||||
for details.
|
||||
- trading_environment: an instance of
|
||||
:py:class:`zipline.trading.TradingEnvironment`
|
||||
- allocator: an instance of
|
||||
:py:class:`zipline.simulator.AddressAllocator`
|
||||
- simulation_style: optional parameter that configures the
|
||||
:py:class:`zipline.finance.trading.TransactionSimulator`. Expects
|
||||
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
|
||||
Will inform the system that the component has failed and how it
|
||||
has failed.
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
self.algorithm = config['algorithm']
|
||||
self.allocator = config['allocator']
|
||||
self.trading_environment = config['trading_environment']
|
||||
self.sim_style = config.get('simulation_style')
|
||||
self.send_sighup = config.get('send_sighup', False)
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
|
||||
try:
|
||||
log.exception('{id} sending exception to result stream.'\
|
||||
.format(id=self.sim_id))
|
||||
msg = zp.EXCEPTION_FRAME(
|
||||
exc_traceback,
|
||||
exc_type.__name__,
|
||||
exc_value.message
|
||||
)
|
||||
|
||||
self.leased_sockets = []
|
||||
self.sim_context = None
|
||||
self.results_socket.send(msg)
|
||||
|
||||
sockets = self.allocate_sockets(7)
|
||||
addresses = {
|
||||
'sync_address' : sockets[0],
|
||||
'data_address' : sockets[1],
|
||||
'feed_address' : sockets[2],
|
||||
'merge_address' : sockets[3],
|
||||
# TODO: this refers to the results of the merge, a
|
||||
# horribly confusing name for the socket.
|
||||
'results_address' : sockets[4],
|
||||
}
|
||||
except:
|
||||
log.exception("Exception while reporting simulation exception.")
|
||||
|
||||
self.con = Controller(
|
||||
sockets[5],
|
||||
sockets[6],
|
||||
self.send_sighup
|
||||
def open(self):
|
||||
if not self.context:
|
||||
self.context = zmq.Context()
|
||||
if self.results_uri:
|
||||
sock = self.context.socket(zmq.PUSH)
|
||||
sock.connect(self.results_uri)
|
||||
self.results_socket = sock
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
assert self.results_socket
|
||||
# The filter behavior is: matches are logged, mismatches
|
||||
# are bubbled. If bubble is True, matches are also
|
||||
# bubbled. Since we do not want user logs in our system
|
||||
# logs, we set bubble to False.
|
||||
self.zmq_out = ZeroMQLogHandler(
|
||||
socket = self.results_socket,
|
||||
filter = lambda r, h: r.channel in ['Print', 'AlgoLog'],
|
||||
bubble=False
|
||||
)
|
||||
|
||||
self.started = False
|
||||
def join(self):
|
||||
if self.proc:
|
||||
self.proc.join()
|
||||
|
||||
self.sim = ProcessSimulator(addresses)
|
||||
|
||||
self.clients = {}
|
||||
|
||||
self.trading_client = TradeSimulationClient(
|
||||
self.trading_environment,
|
||||
self.sim_style,
|
||||
config['results_socket']
|
||||
)
|
||||
self.add_client(self.trading_client)
|
||||
|
||||
# setup all sources
|
||||
self.sources = {}
|
||||
|
||||
#setup transforms
|
||||
self.transforms = {}
|
||||
|
||||
self.sim.register_controller( self.con )
|
||||
|
||||
self.trading_client.set_algorithm(self.algorithm)
|
||||
def get_pids(self):
|
||||
if self.proc:
|
||||
return [self.proc.pid]
|
||||
else:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def create_test_zipline(**config):
|
||||
@@ -167,7 +251,6 @@ class SimulatedTrading(object):
|
||||
|
||||
- environment - a \
|
||||
:py:class:`zipline.finance.trading.TradingEnvironment`
|
||||
- allocator - a :py:class:`zipline.simulator.AddressAllocator`
|
||||
- sid - an integer, which will be used as the security ID.
|
||||
- order_count - the number of orders the test algo will place,
|
||||
defaults to 100
|
||||
@@ -182,10 +265,10 @@ class SimulatedTrading(object):
|
||||
- simulation_style: optional parameter that configures the
|
||||
:py:class:`zipline.finance.trading.TransactionSimulator`. Expects
|
||||
a SIMULATION_STYLE as defined in :py:mod:`zipline.finance.trading`
|
||||
- transforms: optional parameter that provides a list
|
||||
of StatefulTransform objects.
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
|
||||
allocator = config['allocator']
|
||||
sid = config['sid']
|
||||
|
||||
#--------------------
|
||||
@@ -217,6 +300,10 @@ class SimulatedTrading(object):
|
||||
if not simulation_style:
|
||||
simulation_style = SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
zmq_context = config.get('zmq_context', None)
|
||||
simulation_id = config.get('simulation_id', 'test_simulation')
|
||||
results_socket_uri = config.get('results_socket_uri', None)
|
||||
|
||||
#-------------------
|
||||
# Trade Source
|
||||
#-------------------
|
||||
@@ -230,6 +317,12 @@ class SimulatedTrading(object):
|
||||
trade_count,
|
||||
trading_environment
|
||||
)
|
||||
|
||||
#-------------------
|
||||
# Transforms
|
||||
#-------------------
|
||||
transforms = config.get('transforms', [])
|
||||
|
||||
#-------------------
|
||||
# Create the Algo
|
||||
#-------------------
|
||||
@@ -242,157 +335,19 @@ class SimulatedTrading(object):
|
||||
order_count
|
||||
)
|
||||
|
||||
if config.has_key('results_socket'):
|
||||
results_socket = config['results_socket']
|
||||
else:
|
||||
results_socket = None
|
||||
#-------------------
|
||||
# Simulation
|
||||
#-------------------
|
||||
zipline = SimulatedTrading(**{
|
||||
'algorithm' : test_algo,
|
||||
'trading_environment' : trading_environment,
|
||||
'allocator' : allocator,
|
||||
'simulation_style' : simulation_style,
|
||||
'results_socket' : results_socket,
|
||||
})
|
||||
|
||||
sim = SimulatedTrading(
|
||||
[trade_source],
|
||||
transforms,
|
||||
test_algo,
|
||||
trading_environment,
|
||||
simulation_style,
|
||||
results_socket_uri,
|
||||
zmq_context,
|
||||
simulation_id)
|
||||
#-------------------
|
||||
|
||||
zipline.add_source(trade_source)
|
||||
|
||||
return zipline
|
||||
|
||||
def add_source(self, source):
|
||||
"""
|
||||
Adds the source to the zipline, sets the sid filter of the
|
||||
source to the algorithm's sid filter.
|
||||
"""
|
||||
assert isinstance(source, DataSource)
|
||||
self.check_started()
|
||||
source.set_filter('sid', self.algorithm.get_sid_filter())
|
||||
self.sim.register_components([source])
|
||||
|
||||
# ``id`` is name of source_id, ``get_id`` is the class name
|
||||
self.sources[source.get_id] = source
|
||||
|
||||
def add_transform(self, transform):
|
||||
assert isinstance(transform, BaseTransform)
|
||||
self.check_started()
|
||||
self.sim.register_components([transform])
|
||||
self.transforms[transform.get_id] = transform
|
||||
|
||||
def add_client(self, client):
|
||||
assert isinstance(client, TradeSimulationClient)
|
||||
self.check_started()
|
||||
self.sim.register_components([client])
|
||||
self.clients[client.get_id] = client
|
||||
|
||||
def check_started(self):
|
||||
if self.started:
|
||||
raise ZiplineException("TradeSimulation", "You cannot add \
|
||||
components after the simulation has begun.")
|
||||
|
||||
def get_cumulative_performance(self):
|
||||
return self.trading_client.perf.cumulative_performance.to_dict()
|
||||
|
||||
def allocate_sockets(self, n):
|
||||
"""
|
||||
Allocate sockets local to this line, track them so
|
||||
we can gc after test run.
|
||||
"""
|
||||
|
||||
assert isinstance(n, int)
|
||||
assert n > 0
|
||||
|
||||
leased = self.allocator.lease(n)
|
||||
self.leased_sockets.extend(leased)
|
||||
|
||||
return leased
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""
|
||||
Return the component instances inside of this topology
|
||||
"""
|
||||
|
||||
base = set(self.sim.components.values())
|
||||
transforms = set(self.transforms.values())
|
||||
sources = set(self.sources.values())
|
||||
|
||||
return base | transforms | sources
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
"""
|
||||
Returns the Component names in the topology of the
|
||||
backtest.
|
||||
"""
|
||||
|
||||
# A complete topology is the union of three classes of
|
||||
# components added individually to the simulation client
|
||||
# at various places.
|
||||
#
|
||||
# base : ['FEED', 'MERGE', 'TRADING_CLIENT', 'PASSTHROUGH']
|
||||
# transforms : ['vwap__01', ... ]
|
||||
# sources : ['MongoTradeHistory', ... ]
|
||||
|
||||
base = set(self.sim.components.keys())
|
||||
transforms = set(self.transforms.keys())
|
||||
sources = set(self.sources.keys())
|
||||
|
||||
return base | transforms | sources
|
||||
|
||||
def setup_controller(self):
|
||||
"""
|
||||
Prepare the controller to manage the topology specified
|
||||
by this line.
|
||||
"""
|
||||
self.con.manage(self.topology)
|
||||
|
||||
def simulate(self, blocking=True):
|
||||
self.setup_controller()
|
||||
|
||||
self.started = True
|
||||
self.sim_context = self.sim.simulate()
|
||||
|
||||
# If we're using a threaded simulator block on the pool
|
||||
# of thread since we're only ever in a test and we don't
|
||||
# generally monitor the state of the system as a hold at
|
||||
# the supervisory layer
|
||||
|
||||
# TODO: better way of identifying concurrency substrate
|
||||
if blocking:
|
||||
for process in self.sim.subprocesses:
|
||||
process.join()
|
||||
|
||||
@property
|
||||
def is_success(self):
|
||||
# TODO: other assertions?
|
||||
if self.sim.did_clean_shutdown():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
#--------------------------------
|
||||
# Component property accessors
|
||||
#--------------------------------
|
||||
|
||||
def get_positions(self):
|
||||
"""
|
||||
returns current positions as a dict. draws from the cumulative
|
||||
performance period in the performance tracker.
|
||||
"""
|
||||
perf = self.trading_client.perf.cumulative_performance
|
||||
positions = perf.get_positions()
|
||||
return positions
|
||||
|
||||
class ZiplineException(Exception):
|
||||
def __init__(self, zipline_name, msg):
|
||||
self.name = zipline_name
|
||||
self.message = msg
|
||||
|
||||
def __str__(self):
|
||||
return "Unexpected exception {line}: {msg}".format(
|
||||
line=self.name,
|
||||
msg=self.message
|
||||
)
|
||||
return sim
|
||||
|
||||
+12
-2
@@ -131,7 +131,7 @@ from utils.date_utils import EPOCH, UN_EPOCH, epoch_now
|
||||
# Control Protocol
|
||||
# -----------------------
|
||||
|
||||
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION', 'CANCEL']
|
||||
PRODUCTION_PREFIXES = ['PERF', 'RISK', 'EXCEPTION','CANCEL','DONE', 'LOG']
|
||||
|
||||
INVALID_CONTROL_FRAME = FrameExceptionFactory('CONTROL')
|
||||
|
||||
@@ -527,11 +527,15 @@ def EXCEPTION_FRAME(exception_tb, name, message):
|
||||
rlist = []
|
||||
for stack in stack_list:
|
||||
filename = shorten_filename(stack[0])
|
||||
# default the line to empty string rather than None
|
||||
line = ''
|
||||
if stack[3]:
|
||||
line = stack[3]
|
||||
rstack = {
|
||||
'filename' : filename,
|
||||
'lineno' : stack[1],
|
||||
'method' : stack[2],
|
||||
'line' : stack[3]
|
||||
'line' : line
|
||||
}
|
||||
rlist.append(rstack)
|
||||
result = {
|
||||
@@ -570,6 +574,12 @@ def CANCEL_FRAME(date):
|
||||
|
||||
return BT_UPDATE_FRAME('CANCEL', result)
|
||||
|
||||
def DONE_FRAME(msg):
|
||||
assert isinstance(msg, basestring), \
|
||||
"Done message must be a string."
|
||||
|
||||
return BT_UPDATE_FRAME('DONE', msg)
|
||||
|
||||
|
||||
def BT_UPDATE_FRAME(prefix, payload):
|
||||
"""
|
||||
|
||||
@@ -221,6 +221,32 @@ class DivByZeroAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
class TimeoutAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
self.sid = sid
|
||||
self.incr = 0
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def set_order(self, order_callable):
|
||||
pass
|
||||
|
||||
def set_logger(self, logger):
|
||||
pass
|
||||
|
||||
def set_portfolio(self, portfolio):
|
||||
pass
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.incr > 4:
|
||||
import time
|
||||
time.sleep(100)
|
||||
pass
|
||||
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
class TestPrintAlgorithm():
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ HOLIDAYS = {
|
||||
'july_4th' : datetime(2008 , 7 , 4 ),
|
||||
'labor_day' : datetime(2008 , 9 , 1 ),
|
||||
'tgiving' : datetime(2008 , 11 , 27),
|
||||
'christmas' : datetime(2008 , 5 , 25),
|
||||
'christmas' : datetime(2008 , 12 , 25),
|
||||
}
|
||||
|
||||
# Create a rule to recur every weekday starting today
|
||||
|
||||
+15
-35
@@ -12,7 +12,9 @@ from datetime import datetime, timedelta
|
||||
import zipline.finance.risk as risk
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.finance.sources import SpecificEquityTrades, RandomEquityTrades
|
||||
from zipline.gens.tradegens import RandomEquityTrades
|
||||
from zipline.gens.tradegens import SpecificEquityTrades
|
||||
from zipline.gens.utils import create_trade
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
|
||||
# TODO
|
||||
@@ -69,16 +71,6 @@ def create_trading_environment(year=2006):
|
||||
|
||||
return trading_environment
|
||||
|
||||
def create_trade(sid, price, amount, datetime, source_id = "test_factory"):
|
||||
row = zp.ndict({
|
||||
'source_id' : source_id,
|
||||
'type' : zp.DATASOURCE_TYPE.TRADE,
|
||||
'sid' : sid,
|
||||
'dt' : datetime,
|
||||
'price' : price,
|
||||
'volume' : amount
|
||||
})
|
||||
return row
|
||||
|
||||
def get_next_trading_dt(current, interval, trading_calendar):
|
||||
next = current
|
||||
@@ -89,8 +81,6 @@ def get_next_trading_dt(current, interval, trading_calendar):
|
||||
|
||||
return next
|
||||
|
||||
|
||||
|
||||
def create_trade_history(sid, prices, amounts, interval, trading_calendar):
|
||||
trades = []
|
||||
current = trading_calendar.first_open
|
||||
@@ -222,29 +212,19 @@ def create_minutely_trade_source(sids, trade_count, trading_environment):
|
||||
)
|
||||
|
||||
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment):
|
||||
trade_history = []
|
||||
|
||||
price = [10.1] * trade_count
|
||||
volume = [100] * trade_count
|
||||
args = tuple()
|
||||
kwargs = {
|
||||
'count' : trade_count,
|
||||
'sids' : sids,
|
||||
'start' : trading_environment.first_open,
|
||||
'delta' : trade_time_increment,
|
||||
'filter' : sids
|
||||
}
|
||||
source = SpecificEquityTrades(*args, **kwargs)
|
||||
|
||||
for sid in sids:
|
||||
start_date = trading_environment.first_open
|
||||
# TODO: do we need to set the trading environment's end to same dt as
|
||||
# the last trade in the history?
|
||||
#trading_environment.period_end = trade_history[-1].dt
|
||||
|
||||
generated_trades = create_trade_history(
|
||||
sid,
|
||||
price,
|
||||
volume,
|
||||
trade_time_increment,
|
||||
trading_environment
|
||||
)
|
||||
|
||||
trade_history.extend(generated_trades)
|
||||
|
||||
trade_history = sorted(trade_history, key=attrgetter('dt'))
|
||||
|
||||
#set the trading environment's end to same dt as the last trade in the
|
||||
#history.
|
||||
trading_environment.period_end = trade_history[-1].dt
|
||||
|
||||
source = SpecificEquityTrades(trade_history)
|
||||
return source
|
||||
|
||||
@@ -120,12 +120,6 @@ class ZeroMQLogHandler(Handler):
|
||||
#can't be serialized by JSON, so we need to convert to
|
||||
#unix epoch representation.
|
||||
|
||||
if record.time:
|
||||
assert isinstance(record.time, datetime.datetime)
|
||||
|
||||
time = record.time.replace(tzinfo = pytz.utc)
|
||||
#logbook measures time in utc already, no need to convert.
|
||||
record.time = EPOCH(time)
|
||||
|
||||
#Do the same if algo_dt is a datetime object.
|
||||
if record.extra.has_key('algo_dt'):
|
||||
@@ -151,6 +145,14 @@ class ZeroMQLogHandler(Handler):
|
||||
data[field] = record.extra[field]
|
||||
else:
|
||||
data[field] = None
|
||||
|
||||
if data['time']:
|
||||
assert isinstance(data['time'], datetime.datetime)
|
||||
|
||||
time = data['time'].replace(tzinfo = pytz.utc)
|
||||
#logbook measures time in utc already, no need to convert.
|
||||
data['time'] = EPOCH(time)
|
||||
|
||||
return data
|
||||
|
||||
def emit(self, record):
|
||||
|
||||
+81
-19
@@ -1,11 +1,14 @@
|
||||
import multiprocessing
|
||||
import zmq
|
||||
import time
|
||||
import zipline.protocol as zp
|
||||
from datetime import datetime
|
||||
import blist
|
||||
from bson import ObjectId
|
||||
from zipline.utils.date_utils import EPOCH
|
||||
from itertools import izip
|
||||
from logbook import FileHandler
|
||||
from zipline.core.monitor import Monitor
|
||||
|
||||
def setup_logger(test, path='/var/log/zipline/zipline.log'):
|
||||
test.log_handler = FileHandler(path)
|
||||
@@ -29,6 +32,7 @@ def check_dict(test, a, b, label):
|
||||
# ignore the extra fields used by dictshield
|
||||
if key in ['progress']:
|
||||
continue
|
||||
|
||||
test.assertTrue(a.has_key(key), "missing key at: " + label + "." + key)
|
||||
test.assertTrue(b.has_key(key), "missing key at: " + label + "." + key)
|
||||
a_val = a[key]
|
||||
@@ -58,15 +62,22 @@ def check(test, a, b, label=None):
|
||||
else:
|
||||
test.assertEqual(a, b, "mismatch on path: " + label)
|
||||
|
||||
def check_excluded(test, a, excluded_keys=[]):
|
||||
for key, value in a.iteritems():
|
||||
test.assertTrue(key not in excluded_keys)
|
||||
test.assertFalse(key.endswith('_id'), 'Avoid _id fields!')
|
||||
test.assertFalse(isinstance(value, ObjectId))
|
||||
if isinstance(value, dict):
|
||||
check_excluded(test, value, excluded_keys)
|
||||
|
||||
def drain_zipline(test, zipline):
|
||||
def drain_zipline(test, zipline, p_blocking=False):
|
||||
assert test.ctx, "method expects a valid zmq context"
|
||||
assert test.zipline_test_config, "method expects a valid test config"
|
||||
assert isinstance(test.zipline_test_config, dict)
|
||||
assert test.zipline_test_config['results_socket'], \
|
||||
assert test.zipline_test_config['results_socket_uri'], \
|
||||
"need to specify a socket address for logs/perf/risk"
|
||||
test.receiver = create_receiver(
|
||||
test.zipline_test_config['results_socket'],
|
||||
test.zipline_test_config['results_socket_uri'],
|
||||
test.ctx
|
||||
)
|
||||
# Bind and connect are asynch, so allow time for bind before
|
||||
@@ -74,13 +85,12 @@ def drain_zipline(test, zipline):
|
||||
time.sleep(1)
|
||||
|
||||
# start the simulation
|
||||
zipline.simulate(blocking=False)
|
||||
zipline.simulate(blocking=p_blocking)
|
||||
output, transaction_count = drain_receiver(test.receiver)
|
||||
# some processes will exit after the message stream is
|
||||
# finished. We block here to avoid collisions with subsequent
|
||||
# ziplines.
|
||||
for process in zipline.sim.subprocesses:
|
||||
process.join()
|
||||
zipline.join()
|
||||
|
||||
return output, transaction_count
|
||||
|
||||
@@ -95,16 +105,15 @@ def drain_receiver(receiver):
|
||||
transaction_count = 0
|
||||
while True:
|
||||
msg = receiver.recv()
|
||||
if msg == str(zp.CONTROL_PROTOCOL.DONE):
|
||||
update = zp.BT_UPDATE_UNFRAME(msg)
|
||||
output.append(update)
|
||||
if update['prefix'] == 'PERF':
|
||||
transaction_count += \
|
||||
len(update['payload']['daily_perf']['transactions'])
|
||||
elif update['prefix'] == 'EXCEPTION':
|
||||
break
|
||||
elif update['prefix'] == 'DONE':
|
||||
break
|
||||
else:
|
||||
update = zp.BT_UPDATE_UNFRAME(msg)
|
||||
output.append(update)
|
||||
if update['prefix'] == 'PERF':
|
||||
transaction_count += \
|
||||
len(update['payload']['daily_perf']['transactions'])
|
||||
elif update['prefix'] == 'EXCEPTION':
|
||||
break
|
||||
|
||||
receiver.close()
|
||||
del receiver
|
||||
@@ -115,9 +124,6 @@ def drain_receiver(receiver):
|
||||
def assert_single_position(test, zipline):
|
||||
output, transaction_count = drain_zipline(test, zipline)
|
||||
|
||||
test.assertTrue(zipline.sim.ready())
|
||||
test.assertFalse(zipline.sim.exception)
|
||||
|
||||
test.assertEqual(
|
||||
test.zipline_test_config['order_count'],
|
||||
transaction_count
|
||||
@@ -126,7 +132,8 @@ def assert_single_position(test, zipline):
|
||||
# the final message is the risk report, the second to
|
||||
# last is the final day's results. Positions is a list of
|
||||
# dicts.
|
||||
closing_positions = output[-2]['payload']['daily_perf']['positions']
|
||||
perfs = [x for x in output if x['prefix'] == 'PERF']
|
||||
closing_positions = perfs[-2]['payload']['daily_perf']['positions']
|
||||
|
||||
test.assertEqual(
|
||||
len(closing_positions),
|
||||
@@ -140,3 +147,58 @@ def assert_single_position(test, zipline):
|
||||
sid,
|
||||
"Portfolio should have one position in " + str(sid)
|
||||
)
|
||||
|
||||
|
||||
def launch_component(component):
|
||||
proc = multiprocessing.Process(target=component.run)
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
def launch_monitor(monitor):
|
||||
proc = multiprocessing.Process(target=monitor.run)
|
||||
proc.start()
|
||||
return proc
|
||||
|
||||
|
||||
def create_monitor(allocator):
|
||||
sockets = allocator.lease(3)
|
||||
mon = Monitor(
|
||||
# pub socket
|
||||
sockets[0],
|
||||
# route socket
|
||||
sockets[1],
|
||||
# exception socket to match tradesimclient's result
|
||||
# socket, because we want to relay exceptions to the
|
||||
# same listener
|
||||
sockets[2],
|
||||
# this controller is expected to run in a test, so no
|
||||
# need to signal the parent process on success or error.
|
||||
send_sighup=False
|
||||
)
|
||||
|
||||
return mon
|
||||
|
||||
class ExceptionSource(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hash(self):
|
||||
return "ExceptionSource"
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
5 / 0
|
||||
|
||||
class ExceptionTransform(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hash(self):
|
||||
return "ExceptionTransform"
|
||||
|
||||
def update(self, event):
|
||||
assert False, "An assertion message"
|
||||
|
||||
@@ -0,0 +1,358 @@
|
||||
import pytz
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from dateutil import rrule
|
||||
from zipline.utils.date_utils import utcnow
|
||||
|
||||
def market_opens(start, end, inclusive=False):
|
||||
"""
|
||||
Returns all market opens between the start date and the end date.
|
||||
Must use utc-stamped datetimes.
|
||||
"""
|
||||
return opens.between(start, end, inc=inclusive)
|
||||
|
||||
def market_closes(start, end, inclusive=False):
|
||||
"""
|
||||
Returns all market closes between the start date and the end date.
|
||||
Must use utc-stamped datetimes.
|
||||
"""
|
||||
return closes.between(start, end, inc=inclusive)
|
||||
|
||||
def trading_days_between(start, end):
|
||||
"""
|
||||
Calculate the number of "complete" trading days between two
|
||||
events. We define this as the number of market opens that
|
||||
occurred between start and end, with the caveat that we subtract 1
|
||||
from this total if end falls on the same day as the last market
|
||||
open and end occurs earlier in its own day than start. This
|
||||
reflects the fact that we haven't completed a full day
|
||||
corresponding to the last market open.
|
||||
|
||||
Examples:
|
||||
|
||||
1.)
|
||||
start = Tuesday, Aug 7, 2012, 1:00 pm
|
||||
end = Wednesday, Aug 8, 2012, 1:30 pm
|
||||
|
||||
There is one market open between these dates, on the morning of
|
||||
Wednesday the 8th. This falls on the same calendar day as end,
|
||||
but end is later in the day than start, so we count this as a full
|
||||
day. The correct output is 1.
|
||||
|
||||
2.)
|
||||
start = Tuesday, Aug 7, 2012, 1:30 pm
|
||||
end = Wednesday, Aug 8, 2012, 1:00 pm
|
||||
|
||||
There is one market open between these dayes, on the morning of
|
||||
Wednesday the 8th. This falls on the same calendar day as end,
|
||||
and end is earlier in the day than start, so we do not count this
|
||||
day as completed. The correct output is 0.
|
||||
|
||||
3.)
|
||||
start = Tuesday, Aug 7, 2012, 1:00 pm
|
||||
end = Saturday, Aug 11, 2012, 1:30 pm
|
||||
|
||||
There are 3 market opens between these dates, occurring on
|
||||
Wednesday, Thursday, and Friday. The last open is not on
|
||||
the same day as end, so we simply return 3
|
||||
|
||||
4.)
|
||||
start = Tuesday, Aug 7, 2012, 1:30 pm
|
||||
end = Monday, Aug, 13, 2012, 1:00 pm
|
||||
|
||||
There are 4 market opens between these dates, occurring on
|
||||
Wednesday, Thursday, Friday, and the following Monday. The
|
||||
last open occurs on the same calendar day as end, and end
|
||||
is earlier in the day than start, so we do not count the
|
||||
last market day as completed. The correct output is 3 days.
|
||||
"""
|
||||
# Calculate the number of opens between the events.
|
||||
opens = (market_opens(start, end))
|
||||
days_between = len(opens)
|
||||
if days_between == 0:
|
||||
return days_between
|
||||
|
||||
# If end falls on the same day as an open, subtract 1 from the
|
||||
# total if end is earlier in its respective day than start.
|
||||
last_open = opens[-1]
|
||||
if last_open.date() == end.date() and earlier_in_day(end, start):
|
||||
days_between -=1
|
||||
|
||||
return days_between
|
||||
|
||||
def earlier_in_day(d1, d2):
|
||||
"""
|
||||
Return true if d1 falls earlier in its own day than d2.
|
||||
"""
|
||||
return d1.time() < d2.time()
|
||||
|
||||
WEEKDAYS = [rrule.MO, rrule.TU, rrule.WE, rrule.TH, rrule.FR]
|
||||
|
||||
# Recurrence rule that generates all market opens since Jan 1, 1970.
|
||||
# This does not exclude holidays.
|
||||
market_opens_with_holidays = rrule.rrule(
|
||||
rrule.DAILY,
|
||||
byweekday=WEEKDAYS,
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart=datetime(2000, 1, 1, tzinfo = pytz.utc),
|
||||
until=datetime(2014 , 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rule that generates all market closes since Jan 1, 1970.
|
||||
# This does not exclude holidays.
|
||||
market_closes_with_holidays = rrule.rrule(
|
||||
rrule.DAILY,
|
||||
byweekday=WEEKDAYS,
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart=datetime(2001, 1, 1, tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for excluding the market open/close on new years.
|
||||
new_years_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
byyearday = 1,
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
new_years_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
byyearday = 1,
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for excluding MLK day. It is always the third
|
||||
# monday in January.
|
||||
mlk_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 1,
|
||||
byweekday = (rrule.MO(3)),
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
mlk_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 1,
|
||||
byweekday = (rrule.MO(+3)),
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for generating the market open/close for
|
||||
# presidents' day. Presidents' day always occurs on the third monday
|
||||
# of February.
|
||||
presidents_day_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 2,
|
||||
byweekday = (rrule.MO(3)),
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
presidents_day_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 2,
|
||||
byweekday = (rrule.MO(3)),
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for generating the market open/close for good
|
||||
# friday. Good friday always falls 2 days before easter, which
|
||||
# thankfully is a built-in refernce in this module.
|
||||
good_friday_opens = rrule.rrule(
|
||||
rrule.DAILY,
|
||||
byeaster = -2,
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
good_friday_closes = rrule.rrule(
|
||||
rrule.DAILY,
|
||||
byeaster = -2,
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for generating the market open/close for memorial
|
||||
# day. Memorial day always occurs on the last monday of May.
|
||||
memorial_day_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 5,
|
||||
byweekday = (rrule.MO(-1)),
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
memorial_day_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 5,
|
||||
byweekday = (rrule.MO(-1)),
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rules for generating the market open/close for July 4th.
|
||||
july_4th_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 6,
|
||||
bymonthday = 4,
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
july_4th_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 6,
|
||||
bymonthday = 4,
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rule for generating the market open/close for labor day.
|
||||
# Labor day is always the first monday of September.
|
||||
labor_day_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 9,
|
||||
byweekday = (rrule.MO(1)),
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
labor_day_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 9,
|
||||
byweekday = (rrule.MO(1)),
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence rule for generating the market open/close for
|
||||
# thanksgiving. Thanksgiving always falls on the fourth thursday in
|
||||
# November. (Who decides how these holidays work!?!)
|
||||
thanksgiving_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 11,
|
||||
byweekday = (rrule.TH(-1)),
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
thanksgiving_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 11,
|
||||
byweekday = (rrule.TH(-1)),
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# Recurrence relation for generating the market open/close for
|
||||
# christmas. Christmas always occurs on december 25th.
|
||||
|
||||
christmas_opens = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 12,
|
||||
bymonthday = 25,
|
||||
byhour = 14,
|
||||
byminute = 30,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
christmas_closes = rrule.rrule(
|
||||
rrule.MONTHLY,
|
||||
bymonth = 12,
|
||||
bymonthday = 25,
|
||||
byhour = 21,
|
||||
byminute = 0,
|
||||
cache = True,
|
||||
dtstart = datetime(2000, 1,1,tzinfo = pytz.utc),
|
||||
until=datetime(2014, 1, 1, tzinfo = pytz.utc)
|
||||
)
|
||||
|
||||
# All NYSE observed holidays.
|
||||
holiday_opens = [
|
||||
new_years_opens,
|
||||
mlk_opens,
|
||||
presidents_day_opens,
|
||||
good_friday_opens,
|
||||
memorial_day_opens,
|
||||
july_4th_opens,
|
||||
labor_day_opens,
|
||||
thanksgiving_opens,
|
||||
christmas_opens
|
||||
]
|
||||
holiday_closes = [
|
||||
new_years_closes,
|
||||
mlk_closes,
|
||||
presidents_day_closes,
|
||||
good_friday_closes,
|
||||
memorial_day_closes,
|
||||
july_4th_closes,
|
||||
labor_day_closes,
|
||||
thanksgiving_closes,
|
||||
christmas_closes
|
||||
]
|
||||
|
||||
# Valid market opens are given by all market opens minus holidays.
|
||||
opens = rrule.rruleset(cache=True)
|
||||
opens.rrule(market_opens_with_holidays)
|
||||
for holiday_rule in holiday_opens:
|
||||
opens.exrule(holiday_rule)
|
||||
|
||||
closes = rrule.rruleset(cache=True)
|
||||
closes.rrule(market_closes_with_holidays)
|
||||
for holiday_rule in holiday_closes:
|
||||
closes.exrule(holiday_rule)
|
||||
|
||||
# This runs the calendar to load all data into a cache.
|
||||
open_count = opens.count()
|
||||
close_count = closes.count()
|
||||
|
||||
Reference in New Issue
Block a user